In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch

In [None]:
from sae_lens import SAE
from transformer_lens.utils import tokenize_and_concatenate

In [None]:
from transformer_lens import HookedTransformer

In [None]:
from datasets import load_dataset

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import plotly.express as px

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

In [None]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

## Initialize Path

In [None]:
PATH = Path('data')
checkpoint_dir = PATH / 'saes'
checkpoint_dir.mkdir(exist_ok=True, parents=True)
checkpoint_path1 = checkpoint_dir / 'best-checkpoint-v1.ckpt'
checkpoint_path2 = checkpoint_dir / 'best-checkpoint.ckpt'

image_dir = PATH / 'images'
image_path = image_dir / '1024.png'

## Initialize simple dataset

In [None]:
# Updated MNIST data loaders with normalization and validation set
def prepare_data(batch_size=128):
    # Normalize to [0, 1] for MNIST
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),  # Mean and std from MNIST
        transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
    ])

    # Training set
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Validation set
    val_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_dataset, train_loader, val_dataset, val_loader


## Initialize model

In [None]:
release = 'gpt2-small-res-jb'
sae_id = 'blocks.8.hook_resid_pre'
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=release,  # see other options in sae_lens/pretrained_saes.yaml
    sae_id=sae_id,  # won't always be a hook point
    device=device,
)

In [None]:
release = "gpt2-small-mlp-tm"
sae_id = "blocks.8.hook_mlp_out"
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=release,  # see other options in sae_lens/pretrained_saes.yaml
    sae_id=sae_id,  # won't always be a hook point
    device=device,
)

In [None]:
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

In [None]:
dataset = load_dataset(
    path='NeelNanda/pile-10k',
    split='train',
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

In [None]:
?? tokenize_and_concatenate

In [None]:
dataset, token_dataset

In [None]:
dataset[0], token_dataset[0]['tokens'].shape

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = token_dataset[:32]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = token_dataset[:32]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

In [None]:
?? sae

In [None]:
sae_out.shape

In [None]:
token_dataset[10], feature_acts[10]

In [None]:
token_dataset.shape, feature_acts.shape

In [None]:
token_dataset[:32]['tokens'].shape, feature_acts.shape, sae_out.shape

In [None]:
token_dataset[:32]['tokens'].shape, feature_acts[0][0].shape, sae_out.shape

In [None]:
torch.max(feature_acts)

In [None]:
torch.topk(torch.flatten(feature_acts), k=10)

In [None]:
top_values, top_indices = torch.topk(feature_acts[0], k=2, dim=1, largest=True, sorted=True)

print("Top values:", top_values)
print("Top indices:", top_indices)

In [None]:
torch.topk(torch.flatten(feature_acts[0][4]), k=10)

In [None]:
torch.topk(torch.flatten(feature_acts[0][28]), k=10)