In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
# ! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
from torch import nn

In [None]:
from sae_lens import SAE
from transformer_lens.utils import tokenize_and_concatenate

In [None]:
from transformer_lens import HookedTransformer

In [None]:
from datasets import load_dataset

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from scipy.sparse import csr_matrix

In [None]:
import plotly.express as px

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

In [None]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

## Initialize Path

In [None]:
PATH = Path('data')
checkpoint_dir = PATH / 'saes'
checkpoint_dir.mkdir(exist_ok=True, parents=True)
vectors_dir = checkpoint_dir / 'gpt2_small_8_vecs'
vectors_dir.mkdir(exist_ok=True, parents=True)
matrix_dir = checkpoint_dir / 'gpt2_small_8_mats'
matrix_dir.mkdir(exist_ok=True, parents=True)
vectors_path = checkpoint_dir / 'gpt2_small_8_vecs.joblib'

image_dir = PATH / 'images'
image_path = image_dir / '1024.png'

## Initialize simple dataset

In [None]:
dataset = load_dataset(
    path='NeelNanda/pile-10k',
    split='train',
    streaming=False,
)

## Initialize model

In [None]:
layer = 6

In [None]:
class Text2Latent(object):

    def __init__(self, model: nn.Module, sae: nn.Module):
        self.model = model.eval()
        self.sae = sae.eval()

    def tokenize(self, text):
        return self.model.to_tokens(text)

    def to_string(self, tokens):
        return self.model.to_string(tokens)

    @torch.inference_mode()
    def encode(self, text):
        _, cache = self.model.run_with_cache(text, prepend_bos=True)
        # get the feature activations from our SAE
        z = self.sae.encode(cache[hook_point])
        
        return z

    def decode(self, z):
        return self.sae.decode(z)

    def forward(self, text):
        z = self.encode(text)
        r = self.decode(z)

        return r

In [None]:
model_name = 'gpt2-small'
release = 'gpt2-small-mlp-tm'
sae_id = 'blocks.8.hook_mlp_out'
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=release,  # see other options in sae_lens/pretrained_saes.yaml
    sae_id=sae_id,  # won't always be a hook point
    device=device,
)
hook_point = sae.cfg.hook_name
print(hook_point)

In [None]:
model = HookedTransformer.from_pretrained(model_name, device=device)

In [None]:
net = Text2Latent(model, sae)

## Generate V Lattice

In [None]:
gc.collect()

In [None]:
with tqdm(dataset) as pdata:
    for idx, d in enumerate(pdata):
        t = d['text']
        v = net.encode(t)
        v_sparse = csr_matrix(v.to('cpu').detach().numpy()[0])
        joblib.dump(
            v_sparse, 
            matrix_dir / f'{idx}.joblib'
        )

In [None]:
if vectors_path.exists():
    V = joblib.load(vectors_path)
else:
    v_paths = list(matrix_dir.glob('*.joblib'))
    error_paths = []
    V_dict = {}
    V_list = []
    with tqdm(v_paths) as v_ppaths:
        for v_path in v_ppaths:
            v_sparse = joblib.load(v_path)
            vs = v_sparse.toarray()[1:]
            v = np.maximum.reduce(vs)
            V_dict[int(v_path.stem)] = v
    for k in range(10000):
        V_list.append(V_dict[k])
    V = np.array(V_list)
    joblib.dump(V, vectors_path)    

## Generate Context and Analyze

In [None]:
gc.collect()

In [None]:
fca = FCA(V)

In [None]:
text1 = " The Golden Gate Bridge"
z = net.encode(text1)
zs = z.to('cpu').detach().numpy()[0]
v = np.maximum.reduce(zs)

In [None]:
zs[1:].shape

In [None]:
zs.shape

## Shuttle Feature from Neuronscope

In [None]:
v_shuttle = np.zeros_like(v)
v_shuttle[19962] = 3.2
v_shuttle.shape

In [None]:
concept = fca.G_FG(v_shuttle)
concept

In [None]:
dataset[concept.A[2].item()]

## For Feature from Neuronscope

In [None]:
v_for = np.zeros_like(v)
v_for[8] = 6.2
v_for.shape

In [None]:
concept = fca.G_FG(v_for)
concept

In [None]:
dataset[concept.A[0].item()]

## Time Feature from Neuronscope

In [None]:
v_date = np.zeros_like(v)
v_date[2] = 3
v_date.shape

In [None]:
concept = fca.G_FG(v_date)
concept

In [None]:
dataset[concept.A[0].item()]

## Recipie Feature from Neuronscope

In [None]:
v_recipie = np.zeros_like(v)
v_recipie[7] = 8.4
v_recipie.shape

In [None]:
concept = fca.G_FG(v_recipie)
concept

In [None]:
dataset[concept.A[0].item()]

## Add Examples

In [None]:
text1 = " The Golden Gate Bridge"
z = net.encode(text1)
tokens = net.tokenize(text1)
print(torch.topk(z, 12))
print()
print(z)

In [None]:
r = net.encode('Golden')
torch.topk(r[0][1], 50)

In [None]:
torch.nonzero(r[0][1]).shape

In [None]:
texts = [
    "Golden Gate Bridge",
    "New York City",
    "Silicon Valley",
    "The White House",
    "Apple Inc."
]

In [None]:
vs = net.encode(texts)

In [None]:
vs[:3].shape

In [None]:
v_t = [v[1] for v in vs[:3]]
v_t[0].shape

In [None]:
v_A = v_t[0]
for v in v_t:
    v_A = torch.minimum(v_A, v)

In [None]:
torch.topk(v_A, 30)

In [None]:
tk = topK(V[0], 30)
tk

In [None]:
concept = fca.G_FG(v_A)

In [None]:
concept

In [None]:
concept.A.shape, V.shape, V.shape[0] - concept.A.shape[0]

In [None]:
concept.A[0]

In [None]:
dataset[concept.A[1].item()]

## Analyze Tokens

In [None]:
len(dataset)

In [None]:
checks = []
with tqdm(list(range(len(dataset)))) as pdata:
    for idx in pdata:
        tokens = net.tokenize(dataset[idx]['text'])[0]
        vs = joblib.load(matrix_dir / f'{idx}.joblib')
        checks.append(tokens.shape[0] == vs.shape[0])

In [None]:
np.all(checks)

In [None]:
from IPython.display import clear_output
v_paths = list(matrix_dir.glob('*.joblib'))
T_dict = {}
W_dict = {}
with tqdm(concept.A) as v_ppaths:
    for idx in v_ppaths:
        vs = joblib.load(matrix_dir / f'{idx}.joblib').toarray()
        G_x = find_G_x(vs, v_A)
        if G_x.shape[0] > 0:
            T_dict[idx.item()] = G_x
        clear_output(wait=True)

In [None]:
len(T_dict)

In [None]:
T_dict[5]

In [None]:
tokens = net.tokenize(dataset[5]['text'])

In [None]:
net.to_string(tokens[0][100])

In [None]:
T_dict[5]

In [None]:
net.to_string(tokens[0][0]), net.to_string(tokens[0][T_dict[5]])

In [None]:
W_dict = {}
for k, v in T_dict.items():
    tokens = net.tokenize(dataset[k]['text'])[0]
    W_dict[k] = net.to_string(tokens[v])

In [None]:
W_dict