##### Imports

In [82]:
import os
import warnings

import numpy as np
import pandas as pd
import plotly.express as px
import torch
from datasets import load_dataset
from IPython.display import IFrame
from torch import Tensor
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import test_prompt, tokenize_and_concatenate

from e2e_sae import SAETransformer
from e2e_sae.data import create_data_loader

In [83]:
torch.set_grad_enabled(False)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cpu


In [84]:
class DotDict(dict):
    """A dictionary that supports dot notation."""

    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(f"'DotDict' object has no attribute '{key}'")

    def __setattr__(self, key, value):
        self[key] = value

    def __delattr__(self, key):
        try:
            del self[key]
        except KeyError:
            raise AttributeError(f"'DotDict' object has no attribute '{key}'")

##### Load Model

In [85]:
with warnings.catch_warnings(action="ignore"):
    model = SAETransformer.from_wandb("sparsify/gpt2/tvj2owza")

transformer = model.tlens_model
saes_dict = model.saes

sae_pos = model.raw_sae_positions[0]
sae = saes_dict["blocks-6-hook_resid_pre"]
d_sae = sae.encoder[0].out_features

Loaded pretrained model gpt2-small into HookedTransformer


##### Neuronpedia Dashboard

In [86]:
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"


def get_dashboard_html(sae_release="gpt2-small", sae_id="7-res-jb", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

##### Basic Config

In [87]:
dataset_config = DotDict(
    {
        "dataset_name": "NeelNanda/pile-10k",
        "is_tokenized": False,
        "tokenizer_name": "gpt2",
        "streaming": True,
        "split": "train",
        "n_ctx": 1024,
        "seed": 0,
    }
)

In [88]:
dataloader, _ = create_data_loader(dataset_config=dataset_config, batch_size=8)


`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884



##### Basic Test Prompt

In [89]:
prompt = "The next person in line is singer Johnny"
answer = "Cash"

# Show that the model can confidently predict the next token.
test_prompt(prompt, answer, transformer)

Tokenized prompt: ['<|endoftext|>', 'The', ' next', ' person', ' in', ' line', ' is', ' singer', ' Johnny']
Tokenized answer: [' Cash']


Top 0th token. Logit: 17.33 Prob: 38.65% Token: | Cash|
Top 1th token. Logit: 15.72 Prob:  7.80% Token: | De|
Top 2th token. Logit: 14.47 Prob:  2.21% Token: | B|
Top 3th token. Logit: 14.34 Prob:  1.95% Token: | Mar|
Top 4th token. Logit: 13.67 Prob:  1.00% Token: | "|
Top 5th token. Logit: 13.66 Prob:  0.99% Token: | H|
Top 6th token. Logit: 13.62 Prob:  0.95% Token: | G|
Top 7th token. Logit: 13.43 Prob:  0.79% Token: | R|
Top 8th token. Logit: 13.43 Prob:  0.78% Token: | Carson|
Top 9th token. Logit: 13.40 Prob:  0.77% Token: | Mercer|


##### Max Activating Feature

In [90]:
out, cache = model.forward(prompt, sae_positions=model.raw_sae_positions, cache_positions=None)

In [91]:
px.line(
    cache[sae_pos].output[0, -1, :].cpu().numpy(),
    title="Feature activations at the final token position",
    labels={"index": "Feature", "value": "Activation"},
).show()

# let's print the top 5 features and how much they fired
vals, inds = torch.topk(cache[sae_pos].output[0, -1, :], 5)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} fired {val:.2f}")

Feature 447 fired 17.32
Feature 326 fired 9.92
Feature 266 fired 7.84
Feature 288 fired 7.00
Feature 481 fired 6.27


##### Max Activating Token

In [92]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]


# A very handy function Neel wrote to get context around a feature activation
def make_token_df(tokens=Tensor, model=HookedTransformer, len_prefix=5, len_suffix=3):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p - len_prefix) : p])
            if p == tokens.shape[1] - 1:
                suffix = ""
            else:
                suffix = "".join(
                    str_tokens[b][p + 1 : min(tokens.shape[1] - 1, p + 1 + len_suffix)]
                )
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(
        dict(
            str_tokens=list_flatten(str_tokens),
            unique_token=list_flatten(unique_token),
            context=context,
            prompt=prompt,
            pos=pos,
            label=label,
        )
    )


In [101]:
feature_list = torch.randint(0, d_sae, (100,))
examples_found = 0
all_fired_tokens = []
all_feature_acts = []
all_reconstructions = []
all_token_dfs = []

total_batches = 10
dl_iter = iter(dataloader)
pbar = tqdm(range(total_batches))
for i in pbar:
    tokens = next(dl_iter)["input_ids"]
    tokens_df = make_token_df(tokens=tokens, model=transformer)
    tokens_df["batch"] = i

    flat_tokens = tokens.flatten()

    logits, activations = model.forward(tokens=tokens, sae_positions=model.raw_sae_positions)
    feature_acts = activations[sae_pos].c
    feature_acts = feature_acts.flatten(0, 1)
    fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
    fired_tokens = transformer.to_str_tokens(flat_tokens[fired_mask])
    reconstruction = (
        feature_acts[fired_mask][:, feature_list] @ sae.decoder.weight[:, feature_list].T
    )

    token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]
    all_token_dfs.append(token_df)
    all_feature_acts.append(feature_acts[fired_mask][:, feature_list])
    all_fired_tokens.append(fired_tokens)
    all_reconstructions.append(reconstruction)

    examples_found += len(fired_tokens)
    # print(f"Examples found: {examples_found}")
    # update description
    pbar.set_description(f"Examples found: {examples_found}")

# flatten the list of lists
all_token_dfs = pd.concat(all_token_dfs)
all_fired_tokens = list_flatten(all_fired_tokens)
all_reconstructions = torch.cat(all_reconstructions)
all_feature_acts = torch.cat(all_feature_acts)

Examples found: 3931: 100%|██████████| 10/10 [02:24<00:00, 14.46s/it]


In [102]:
feature_acts_df = pd.DataFrame(
    all_feature_acts.detach().cpu().numpy(), columns=[f"feature_{i}" for i in feature_list]
)
feature_acts_df.shape

(3931, 100)

In [110]:
feature_idx = 1
# get non-zero activations

all_positive_acts = all_feature_acts[all_feature_acts[:, feature_idx] > 0][:, feature_idx].detach()
prop_positive_activations = 100 * len(all_positive_acts) / (total_batches * 8 * 1024)

px.histogram(
    all_positive_acts.cpu(),
    nbins=50,
    title=f"Histogram of positive activations - {prop_positive_activations:.3f}% of activations were positive",
    labels={"value": "Activation"},
    width=800,
)

In [111]:
print(f"Feature: {feature_list[feature_idx]}")
top_10_activations = feature_acts_df.sort_values(
    f"feature_{feature_list[feature_idx]}", ascending=False
).head(10)
all_token_dfs.iloc[top_10_activations.index]  # TODO: double check this is working correctly

Feature: 20127


Unnamed: 0,str_tokens,unique_token,context,prompt,pos,label,batch
1113,a,a/89,the others were arrested during| a| nude danc...,1,89,1/89,1
846,in,in/846,for the nude ban passed| in| City Hall.,0,846,0/846,1
5137,for,for/17,read that he was terminated| for|\r\nun,5,17,5/17,1
6877,of,of/733,that Moore was terminated because| of| speakin...,6,733,6/733,1
1408,near,near/384,"at the Standing Rock camp| near| Cannon Ball,",1,384,1/384,1
6841,for,for/697,leading industry-sponsored portal| for| the P...,6,697,6/697,6
869,a,a/869,"January, a judge dismissed| a| lawsuit filed by",0,869,0/869,1
5085,for,for/989,stated that Moore was fired| for| sexual hara...,4,989,4/989,1
848,Hall,Hall/848,nude ban passed in City| Hall|.\n\n,0,848,0/848,1
121,a,a/121,er or third party in| a| work-related,0,121,0/121,2


In [112]:
print(f"Shape of the decoder weights {sae.decoder.weight.shape})")
print(f"Shape of the model unembed {transformer.W_U.shape}")
projection_matrix = sae.decoder.weight.T @ transformer.W_U
print(f"Shape of the projection matrix {projection_matrix.shape}")

# then we take the top_k tokens per feature and decode them
top_k = 10
# let's do this for 100 random features
_, top_k_tokens = torch.topk(projection_matrix[feature_list], top_k, dim=1)


feature_df = pd.DataFrame(
    top_k_tokens.cpu().numpy(), index=[f"feature_{i}" for i in feature_list]
).T
feature_df.index = [f"token_{i}" for i in range(top_k)]
feature_df.applymap(lambda x: transformer.tokenizer.decode(x))

Shape of the decoder weights torch.Size([768, 46080]))
Shape of the model unembed torch.Size([768, 50257])
Shape of the projection matrix torch.Size([46080, 50257])



DataFrame.applymap has been deprecated. Use DataFrame.map instead.



Unnamed: 0,feature_34980,feature_20127,feature_23779,feature_2858,feature_24234,feature_14867,feature_46016,feature_36611,feature_27392,feature_7104,...,feature_24603,feature_34176,feature_33568,feature_29158,feature_45506,feature_39286,feature_6077,feature_45360,feature_14211,feature_8945
token_0,龍�,crackdown,belonged,amiya,berman,mail,EStream,amped,liga,doms,...,pillar,slot,externalActionCode,requisite,Cola,batches,*/(,storms,net,guiActiveUn
token_1,esan,­,belongs,eon,issance,catentry,MU,attered,segreg,Ways,...,yielding,utherford,Zeal,largest,Era,eters,DragonMagazine,words,hazards,srfAttach
token_2,RAFT,counterterrorism,reserved,Lumpur,ente,peak,rss,ogged,zik,Shards,...,iors,eligible,*=-,aforementioned,anyl,oom,""":[{""",laureate,ntil,作
token_3,soDeliveryDate,surveillance,Lap,Yuk,ggles,reply,rg,inkle,elsen,Privacy,...,worldly,league,atis,same,uilt,demolition,tn,isms,valve,ointment
token_4,andra,government,belong,mentioned,uberty,reason,TAG,acks,putable,rh,...,ixie,rius,Mah,latest,OTE,camps,WATCHED,poetic,safety,Mous
token_5,Izan,protest,Cherokee,camouflage,cess,=-=-=-=-=-=-=-=-,Attach,amps,haus,quest,...,venants,aband,naire,ses,odi,increments,""":[",descriptions,Reviewer,Courts
token_6,keyes,redevelopment,Ming,tein,kus,akings,--------------------,icity,igned,atari,...,craw,uild,sov,strongest,reon,graffiti,ь,istically,ailability,��
token_7,Compat,renovations,Dynasty,recomp,Pearce,manufact,olen,atter,bart,views,...,DRAGON,itaire,whel,following,oppable,eding,guiName,ラン,councill,Casino
token_8,egal,federal,holder,enta,zie,crow,\',irled,bris,HIP,...,Spider,perty,lord,latter,ategory,udes,px,writer,goggles,ilver
token_9,esian,migrant,belonging,wikipedia,allery,Crate,Proxy,acked,iders,�,...,Abyss,gain,holy,utmost,axis,ic,soDeliveryDate,opener,diving,masks
