In [None]:
from tqdm import tqdm
import pandas as pd 
import torch 
from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE
import numpy as np
import plotly_express as px 

In [None]:
torch.set_grad_enabled(False)
device = "cuda"

print(f"Device: {device}")

model = HookedTransformer.from_pretrained("gemma-2b", device = device, dtype = torch.bfloat16)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.0.hook_resid_post", # won't always be a hook point
    device = device
)
sae.fold_W_dec_norm()

In [None]:
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab

In [None]:
vocab_df = pd.DataFrame(
    {
        "token": np.arange(d_vocab),
        "string": model.to_str_tokens(np.arange(d_vocab)),
    }
)
vocab_df["is_alpha"] = vocab_df.string.str.match(r"^( ?)[a-z]+$")
vocab_df["is_word"] = vocab_df.string.str.match(r"^ [a-z]+$")
vocab_df["is_fragment"] = vocab_df.string.str.match(r"^[a-z]+$")
vocab_df["has_space"] = vocab_df.string.str.match(r"^ [A-Za-z]+$")
vocab_df["num_chars"] = vocab_df.string.apply(lambda n: len(n.strip()))
vocab_df

In [None]:
letters = [[] for _ in range(20)]
alphabet = "abcdefghijklmnopqrstuvwxyz"
for i, row in tqdm(enumerate(vocab_df.iterrows())):
    row = row[1]
    string = row.string.strip()
    for i in range(20):
        if not row.is_alpha or i >= len(string):
            letters[i].append(-1)
        else:
            letters[i].append(alphabet.index(string[i]))
# %%
letters_array = np.array(letters, dtype=np.int32)
(letters_array != -1).sum(-1)

# %%
vocab_df["let0"] = letters_array[0]
vocab_df["let1"] = letters_array[1]
vocab_df["let2"] = letters_array[2]
vocab_df["let3"] = letters_array[3]
vocab_df["let4"] = letters_array[4]
vocab_df["let5"] = letters_array[5]
vocab_df

In [None]:
sub_vocab_df = vocab_df.query("is_alpha & num_chars>=4")
sub_vocab_df["let0_string"] = sub_vocab_df.let0.apply(lambda n: alphabet[n] if n != -1 else "")
print(sub_vocab_df.shape)
sub_vocab_df.sample(10)

In [None]:
mask = ((vocab_df.is_alpha) & (vocab_df.num_chars >= 4)).to_numpy()
embed_masked = model.W_E[mask]
eff_embed_masked = embed_masked + model.blocks[0].mlp(model.blocks[0].ln2(embed_masked[None]))

In [None]:
eff_embed_masked.shape

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

char_index = 0
col_label = f"let{char_index}"
X = eff_embed_masked.squeeze().float().cpu().numpy()

In [None]:
y = sub_vocab_df[col_label].values
y.shape

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
probe = LogisticRegression(max_iter=100)
probe.fit(X_train, y_train)
probe.score(X_test, y_test)

In [None]:
lp_test = probe.predict_log_proba(X_test)
clp_test = lp_test[np.arange(X_test.shape[0]), y_test]

In [None]:
px.line(
    x=np.arange(clp_test.size),
    y=clp_test,
    labels={"x": "sample", "y": "log prob"},
)

In [None]:
lp = probe.predict_log_proba(X)
clp = lp[np.arange(X.shape[0]), y]

In [None]:
sub_vocab_df["correct_class_log_prob"] = clp

In [None]:
sub_vocab_df.head()

In [None]:
px.strip(
    sub_vocab_df.sample(2000).reset_index(),
    x="token",
    y="correct_class_log_prob",
    hover_data=["string"],
)

In [None]:
first_letter_probe_pars = torch.tensor(probe.coef_)
first_letter_probe_pars.shape

In [None]:
probe_feature_virtual_weights = sae.W_dec @ first_letter_probe_pars.T.to("cuda").float()
px.line(
    probe_feature_virtual_weights.T[1].cpu().numpy(),
)

In [None]:
probe_feature_virtual_weights = sae.W_enc.T @ first_letter_probe_pars.T.to("cuda").float().detach()
px.line(
    probe_feature_virtual_weights.T[1].cpu().detach().numpy(),
)

In [None]:
vals, inds = torch.topk(probe_feature_virtual_weights.T[1], 5)
inds

In [None]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

get_neuronpedia_quick_list(
    features=inds.tolist(),
    model="gemma-2b",
    dataset="res-jb",
    layer=0,
)

# spelling task + attribution

In [None]:
from functools import partial
import circuitsvis as cv

def reconstr_hook(activations, hook, sae_out):
    return sae_out

def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)

TEMPLATE = prompt = """ string: S T R I N G
 heaven: H E A V E N
 {}:{}"""

def get_random_word():
    return sub_vocab_df.string.sample().values[0].strip()

get_random_word()

def spell_word(word):
    return " " + " ".join(word.upper())

spell_word(get_random_word())

def get_filled_template(word):
    return TEMPLATE.format(word.lower(), spell_word(word))

def get_unfilled_template(word):
    return TEMPLATE.format(word.lower(), "")

word = get_random_word()
print(get_filled_template(word))


print(get_unfilled_template(word))

In [None]:
from functools import partial
import circuitsvis as cv


prompt = get_filled_template(get_random_word())
logits, cache = model.run_with_cache(prompt)
display(cv.logits.token_log_probs(model.to_tokens(prompt), model(prompt)[0].log_softmax(dim=-1), model.to_string))
sae_out = sae(cache[sae.cfg.hook_name])


def reconstr_hook(activations, hook, sae_out):
    return sae_out

def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)


print("positive control")
display(cv.logits.token_log_probs(model.to_tokens(prompt), model(prompt)[0].log_softmax(dim=-1), model.to_string))


print("test group")
with model.hooks(
    fwd_hooks=[
        (
            sae.cfg.hook_name,
            partial(reconstr_hook, sae_out=sae_out),
        )
    ]
):
    display(cv.logits.token_log_probs(model.to_tokens(prompt), model(prompt)[0].log_softmax(dim=-1), model.to_string))


print("negative control")
with model.hooks(
    fwd_hooks=[
        (
            sae.cfg.hook_name,
            partial(zero_abl_hook),
        )
    ]
):
    display(cv.logits.token_log_probs(model.to_tokens(prompt), model(prompt)[0].log_softmax(dim=-1), model.to_string))


In [None]:
from functools import partial
import circuitsvis as cv


prompt = get_filled_template("bacon")
logits, cache = model.run_with_cache(prompt)
display(cv.logits.token_log_probs(model.to_tokens(prompt), model(prompt)[0].log_softmax(dim=-1), model.to_string))

# Gradient Based Attribution

In [None]:
from typing import List, Union, Optional, Callable
from transformer_lens import ActivationCache

# Metric = Callable[[torch.Tensor, float]]


filter_resid_only = lambda name: "resid" in name

def get_cache_fwd_and_bwd(model, tokens, metric, filter = filter_resid_only):
    model.reset_hooks()
    cache = {}
    def forward_cache_hook(act, hook):
        cache[hook.name] = act.detach()
    model.add_hook(filter, forward_cache_hook, "fwd")

    grad_cache = {}
    def backward_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()
    model.add_hook(filter, backward_cache_hook, "bwd")

    logits = model(tokens)
    value = metric(logits)
    value.backward()
    model.reset_hooks()
    return logits, value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)

def get_logit_diff_metric(pos_token: str, neg_token: str, model: HookedTransformer):
    
    def logit_diff_metric(logits: torch.Tensor) -> float:
        positive_token_id = model.to_single_token(pos_token)
        negative_token_id = model.to_single_token(neg_token)
        pos_neg_logit_diff = logits[0,-1,positive_token_id] - logits[0,-1,negative_token_id]
        return pos_neg_logit_diff
    
    return logit_diff_metric

torch.set_grad_enabled(True)


clean_tokens = model.to_tokens(prompt)
pos_token = " B"
neg_token = " A"
logit_diff_metric = get_logit_diff_metric(pos_token, neg_token, model)


filter = lambda name: (("resid" in name) or ("attn" in name) or ("mlp" in name)) and ("result" not in name) and ("_in" not in name)


logits, clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, logit_diff_metric, filter)
print("Clean Value:", clean_value)
print("Clean Activations Cached:", len(clean_cache))
print("Clean Gradients Cached:", len(clean_grad_cache))
# clean_cache.get_full_resid_decomposition(expand_neurons=False, return_labels=True)

In [None]:
def get_logit_diff_metric(pos_token: str, neg_token: str, model: HookedTransformer):
    
    def logit_diff_metric(logits: torch.Tensor) -> float:
        positive_token_id = model.to_single_token(pos_token)
        negative_token_id = model.to_single_token(neg_token)
        pos_neg_logit_diff = logits[0,-1,positive_token_id] - logits[0,-1,negative_token_id]
        return pos_neg_logit_diff
    
    return logit_diff_metric


def get_sae_out_all_layers(cache, sae_dict):

    sae_outs = {}
    feature_actss = {}
    for hook_point, sae in sae_dict.items():
        feature_acts = sae.encode(cache[hook_point])
        sae_out = sae.decode(feature_acts)
        sae_outs[hook_point] = sae_out.float()
        feature_actss[hook_point] = feature_acts.float()
        
    return sae_outs, feature_actss

saes = {sae.cfg.hook_name: sae}

In [None]:
from transformer_lens import utils
import re 

def gradient_based_attributation_all_layers(
    model: HookedTransformer,
    sparse_autoencoders: dict[str, SAE],
    prompt: str  = "John and Mary went to the store and then John said to",
    metric: Callable[[torch.Tensor], float] = None,
    position: int = 2,
    test_prompt = False,
    ):

    if test_prompt:
        utils.test_prompt(prompt, pos_token, model, prepend_bos=True)
    
    logit_diff_metric = get_logit_diff_metric(pos_token, neg_token, model)
    logits, clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, prompt, logit_diff_metric)
    sae_outs, feature_actss = get_sae_out_all_layers(cache, sparse_autoencoders)

    attribution_dfs = []
    for hook_point, sparse_autoencoder in sparse_autoencoders.items():
        feature_acts = feature_actss[hook_point]
        fired = (feature_acts[0,position,:] > 0).nonzero().squeeze()
        activations = feature_acts[0,position,:][fired]
        fired_directions = sparse_autoencoder.W_dec[fired]
        contributions = activations[:, None] * fired_directions
        logit_diff_grad = clean_grad_cache[hook_point][0,position].float()
        # attribution_scores = contributions @ pos_neg_logit_diff_direction
        attribution_scores = contributions @ logit_diff_grad
        
        attribution_df = pd.DataFrame(
            {"feature": fired.detach().cpu().numpy(),
            "activation": activations.detach().cpu().numpy(),
            "attribution": attribution_scores.detach().cpu().numpy()})
        attribution_df["layer"] = sparse_autoencoder.cfg.hook_name
        attribution_df["layer_idx"] = int(re.search(r"blocks.(\d+).hook_.*", sparse_autoencoder.cfg.hook_name).group(1)) + 1*("post" in sparse_autoencoder.cfg.hook_name)
        attribution_df["position"] = position
        
        attribution_dfs.append(attribution_df)
        
    attribution_df = pd.concat(attribution_dfs)
    attribution_df["feature"] = attribution_df.feature.astype(str)
    attribution_df["layer"] = attribution_df.layer.astype("category")
    
    tokens = model.to_str_tokens(prompt)
    unique_tokens = [f"{i}/{tokens[i]}" for i in range(len(tokens))]
    attribution_df["unique_token"]= attribution_df["position"].apply(lambda x: unique_tokens[x])

    return attribution_df



prompt = get_unfilled_template("bacon")
# logits, cache = model.run_with_cache(prompt)
# display(cv.logits.token_log_probs(model.to_tokens(prompt), model(prompt)[0].log_softmax(dim=-1), model.to_string))
# model.to_string(model.to_tokens(prompt, prepend_bos=True))

In [None]:

logit_diff_metric = get_logit_diff_metric(pos_token, neg_token, model)

attribution_dfs = []
n_tokens = len(model.to_str_tokens(prompt))
for position in tqdm(range(0,n_tokens)):

    attribution_df = gradient_based_attributation_all_layers(
        model, saes,
        prompt, metric=logit_diff_metric,
        position=position,
        test_prompt=False)
    attribution_dfs.append(attribution_df)

attribution_df = pd.concat(attribution_dfs)
attribution_df

In [None]:
attribution_df[attribution_df.position == 19].sort_values("attribution", ascending=False)

In [None]:
top_features= attribution_df[attribution_df.position == 19].sort_values("attribution", ascending=False)["feature"].astype(int).tolist()[:10]
top_features

In [None]:
get_neuronpedia_quick_list(
    features=[1826, 15367, 2256, 776, 13189, 1357, 5300, 8445, 8918, 8333],
    model="gemma-2b",
    dataset="res-jb",
    layer=0,
)