In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
# ! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch

In [None]:
from sae_lens import SAE
from transformer_lens.utils import tokenize_and_concatenate

In [None]:
from transformer_lens import HookedTransformer

In [None]:
from datasets import load_dataset

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import plotly.express as px

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

In [None]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

## Initialize Path

In [None]:
PATH = Path('data')
checkpoint_dir = PATH / 'saes'
checkpoint_dir.mkdir(exist_ok=True, parents=True)
checkpoint_path1 = checkpoint_dir / 'best-checkpoint-v1.ckpt'
checkpoint_path2 = checkpoint_dir / 'best-checkpoint.ckpt'

image_dir = PATH / 'images'
image_path = image_dir / '1024.png'

## Initialize simple dataset

## Initialize model

In [None]:
layer = 6

In [None]:
# model_name = 'gemma-2b'
# release = 'gemma-2b-res-jb'
# sae_id = f'blocks.{layer}.hook_resid_post'
# # get model
# model = HookedTransformer.from_pretrained(
#     model_name, 
#     device=device
# )

# # get the SAE for this layer
# sae, cfg_dict, _ = SAE.from_pretrained(
#     release=release,
#     sae_id=sae_id,
#     device=device
# )

# # get hook point
# hook_point = sae.cfg.hook_name
# print(hook_point)

In [None]:
model_name = 'gpt2-small'
release = 'gpt2-small-mlp-tm'
sae_id = 'blocks.8.hook_mlp_out'
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=release,  # see other options in sae_lens/pretrained_saes.yaml
    sae_id=sae_id,  # won't always be a hook point
    device=device,
)
hook_point = sae.cfg.hook_name
print(hook_point)

In [None]:
model = HookedTransformer.from_pretrained(model_name, device=device)

In [None]:
sv_prompt = " The Golden Gate Bridge"
sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 12))

In [None]:
sv_prompt = "Golden"
sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 12))

In [None]:
sv_prompt = "gate"
sv_logits, cache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)

# get the feature activations from our SAE
sv_feature_acts = sae.encode(cache[hook_point])

# get sae_out
sae_out = sae.decode(sv_feature_acts)

# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 12))

In [None]:
sae.state_dict()['enc'].detach()

In [None]:
sv_feature_acts.shape

In [None]:
model.to_string(tokens)

In [None]:
tokens[:, :1]

In [None]:
tokens.shape

In [None]:
for i in range(tokens.shape[1]):
    print(f'{tokens[:,i]} - {model.to_string(tokens[:, i])}')

In [None]:
print(torch.topk(sv_feature_acts[0][0], 20))

In [None]:
active_neurons = [torch.nonzero(sv_feature_acts[0][i]) for i in range(tokens.shape[1])]
for i in range(tokens.shape[1]):
    print(active_neurons[i].shape)

In [None]:
active_neurons = [torch.nonzero(sv_feature_acts[0][i]) for i in range(tokens.shape[1])]
for i in range(tokens.shape[1]):
    print(active_neurons[i])

In [None]:
for k, n1 in enumerate(active_neurons):
    for p, n2 in enumerate(active_neurons):
        if k != p:
            print(f'testing {k} {p} {n1.shape} {n2.shape}')
            for i in range(n1.shape[0]):
                for j in range(n2.shape[0]):
                    if n1[i] == n2[j]:
                        print(f'{n1[i]=} {n2[j]=}')