In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
from transformers import (
    AutoModelForCausalLM, 
    BitsAndBytesConfig, 
    AutoTokenizer, 
    GPT2Tokenizer, 
    GPT2Model
)
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch

In [None]:
from sae_lens import SAE
from transformer_lens.utils import tokenize_and_concatenate
from transformer_lens import HookedTransformer

In [None]:
from datasets import load_dataset

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import plotly.express as px

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

In [None]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

## Initialize Path

In [None]:
PATH = Path('data')
checkpoint_dir = PATH / 'saes'
checkpoint_dir.mkdir(exist_ok=True, parents=True)
checkpoint_path1 = checkpoint_dir / 'best-checkpoint-v1.ckpt'
checkpoint_path2 = checkpoint_dir / 'best-checkpoint.ckpt'

image_dir = PATH / 'images'
image_path = image_dir / '1024.png'

## Initialize simple dataset

In [None]:
# Example texts
texts = [
    "The encoded string is U29mdHdhcmUgRW5naW5lZXJpbmc=",  # Base64 encoded text
    "Recent advancements in deep learning have revolutionized artificial intelligence.",  # Academic language
    "Implementing machine learning algorithms to decode base64 strings enhances data processing efficiency."  # Combination of topics
]

## Initialize model

In [None]:
# model_name = 'gpt2-small'
model_name = 'gpt2'
# Load GPT-2 model and tokenizer
gpt2_model = GPT2Model.from_pretrained(model_name).to(device)
# gpt2_model = HookedTransformer.from_pretrained('gpt2-small', device=device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Assign the EOS token as the padding token
tokenizer.pad_token = tokenizer.eos_token

# Load the pre-trained sparse autoencoder for GPT-2 small, layer 8
sae, _, _ = SAE.from_pretrained(
    release="gpt2-small-res-jb",  # Pre-trained SAE release
    sae_id="blocks.8.hook_resid_pre",  # Target layer in GPT-2
    device=device
)

In [None]:
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to(device)

# Get hidden states from GPT-2
with torch.no_grad():
    outputs = gpt2_model(**inputs, output_hidden_states=True)
    hidden_states = outputs.hidden_states[8]

# Pass hidden states through the sparse autoencoder
with torch.no_grad():
    encoded_features = sae.encode(hidden_states)

In [None]:
inputs

In [None]:
# inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to(device)

# sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
# with torch.no_grad():
#     # activation store can give us tokens.
#     _, cache = gpt2_model.run_with_cache(inputs, prepend_bos=True)

#     # Use the SAE
#     feature_acts = sae.encode(cache[sae.cfg.hook_name])

#     # save some room
#     del cache

In [None]:
inputs['input_ids'].shape

In [None]:
encoded_features.shape

In [None]:
encoded_features.shape

In [None]:
torch.topk(torch.flatten(encoded_features[0]), k=20)

In [None]:
torch.topk(torch.flatten(encoded_features[1]), k=20)

In [None]:
torch.topk(torch.flatten(encoded_features[2]), k=20)

In [None]:
# Function to identify active neurons
def get_active_neurons(encoded_tensor, threshold=0.1):
    return (encoded_tensor > threshold).nonzero(as_tuple=True)[1].tolist()

# Analyze activations for each text
for i, text in enumerate(texts):
    active_neurons = get_active_neurons(encoded_features[i])
    print(f"Text: {text}")
    print(f"Active Neurons: {active_neurons}\n")
    print()