In [25]:
!pip install sae_vis --quiet

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [26]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append("/root/sae")
from huggingface_hub import hf_hub_download
from tqdm import tqdm
import gc
import torch
from sae import Sae

# callum imports 
from IPython import get_ipython # type: ignore
ipython = get_ipython(); assert ipython is not None

# Standard imports
import torch
from datasets import load_dataset
import webbrowser
import os
from transformer_lens import utils, HookedTransformer
from datasets.arrow_dataset import Dataset
from huggingface_hub import hf_hub_download
import time

import pandas as pd
import numpy as np
import plotly.express as px

# Library imports
from sae_vis.utils_fns import get_device
from sae_vis.model_fns import AutoEncoder
from sae_vis.data_storing_fns import SaeVisData
from sae_vis.data_config_classes import SaeVisConfig

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained_no_processing("pythia-160m", device="cuda")
model = model.eval()
data = load_dataset("stas/openwebtext-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=64)
tokenized_data = tokenized_data.shuffle(22)
tokens = tokenized_data["tokens"].to("cuda")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-160m into HookedTransformer


In [5]:
sae_dic = {}
for orthog_coeff in [0.0, 0.1, 1, 5]:
     sae_dic[orthog_coeff] = Sae.load_from_hub("jacobcd52/orthog-sae", 
                                               f"orthcoef={orthog_coeff:.1f}_square",
                                               device="cuda")

/root/.cache/huggingface/hub/models--jacobcd52--orthog-sae/snapshots/a96079ec59e658267c76e89f0d160cd8de92d37f/orthcoef=0.0_square.safetensors
/root/.cache/huggingface/hub/models--jacobcd52--orthog-sae/snapshots/a96079ec59e658267c76e89f0d160cd8de92d37f/orthcoef=0.1_square.safetensors
/root/.cache/huggingface/hub/models--jacobcd52--orthog-sae/snapshots/a96079ec59e658267c76e89f0d160cd8de92d37f/orthcoef=1.0_square.safetensors
/root/.cache/huggingface/hub/models--jacobcd52--orthog-sae/snapshots/a96079ec59e658267c76e89f0d160cd8de92d37f/orthcoef=5.0_square.safetensors


In [15]:
cossims = (sae_dic[0].W_dec @ sae_dic[0].W_dec.T).cpu()
reg_cossims = (sae_dic[0.1].W_dec @ sae_dic[0.1].W_dec.T).cpu()

In [16]:
# num values to sample for histogram
k = 10000  # You can adjust this value

total = cossims.shape[0] * cossims.shape[1]
stride = total // k

# Flatten the tensors and select random samples
cossims_sample = cossims.flatten()[::stride]
reg_cossims_sample = reg_cossims.flatten()[::stride]
length = cossims_sample.shape[0]

# Create a DataFrame for Plotly Express
df = pd.DataFrame({
    'values': torch.cat([cossims_sample, reg_cossims_sample]).numpy(),
    'tensor': ['cossims'] * length + ['reg_cossims'] * length
})

# Create the histogram
fig = px.histogram(df, x='values', color='tensor', barmode='overlay',
                   title='Histogram of Random Samples from cossims and reg_cossims',
                   labels={'values': 'Tensor Values', 'count': 'Frequency'},
                   opacity=0.7)

# Update layout for better readability
fig.update_layout(legend_title_text='Tensor')

# Show the plot
fig.show()

In [19]:
batch_size = 16
num_batches = 100

losses = {key : [] for key in sae_dic.keys()}
baseline_loss_list = []

for i in tqdm(range(num_batches)):
    batch = tokens[i*batch_size : (i+1)*batch_size]

    baseline_loss_list.append( model(batch, return_type="loss").item() )

    for key, sae in sae_dic.items():
        def hook(act, hook):
            return sae(act).sae_out
        loss = model.run_with_hooks(batch, 
                                    return_type="loss", 
                                    fwd_hooks = [('blocks.8.hook_resid_pre', hook)]
                                    ).item()
        # print(loss)
        losses[key].append(loss)



  1%|          | 1/100 [00:00<00:18,  5.22it/s]

100%|██████████| 100/100 [00:18<00:00,  5.32it/s]


In [20]:
print(f"Baseline Loss: {np.mean(baseline_loss_list)}")
for key, loss_list in losses.items():
    print(f"Orthog Coeff {key} Loss: {np.mean(loss_list)}")


Baseline Loss: 4.046435465812683
Orthog Coeff 0.0 Loss: 4.796076235771179
Orthog Coeff 0.1 Loss: 5.360004420280457
Orthog Coeff 1 Loss: 7.527114043235779
Orthog Coeff 5 Loss: 8.1656098985672


In [40]:
from torch import nn
for _,sae in sae_dic.items():
    sae.W_enc = nn.Parameter(sae.encoder.weight.data.T)
    sae.b_enc = sae.encoder.bias


In [42]:
tokens.shape

torch.Size([178047, 64])

In [51]:
# Specify the hook point you're using, and the features you're analyzing
sae_vis_config = SaeVisConfig(
    hook_point = 'blocks.8.hook_resid_pre',
    features = range(64),
    verbose = True,
    minibatch_size_tokens=128
)

# Gather the feature data
sae_vis_data = SaeVisData.create(
    encoder = sae_dic[0.1],
    encoder_B = None,
    model = model,
    tokens = tokens[:10_000 , :], # type: ignore
    cfg = sae_vis_config,
)

# Save as HTML file & open in browser (or not, if in Colab)
filename = "_feature_vis_demo.html"
sae_vis_data.save_feature_centric_vis(filename, feature_idx=8)

Forward passes to cache data for vis:   0%|          | 0/79 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/64 [00:00<?, ?it/s]

Saving feature-centric vis:   0%|          | 0/64 [00:00<?, ?it/s]

: 

In [49]:
import gc
gc.collect()
torch.cuda.empty_cache()