In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# config
attribute = "gender"
probe_layer = 28
layers = list(range(0, probe_layer+1))

In [3]:
from utils.load_probes import load_probe
from utils.probes import make_probes_for_each_layer

probes = load_probe(
    attribute
)

weights, biases = probes
probes_for_each_layer = make_probes_for_each_layer(weights, biases)



collected_gender_probe_weights.pt:   0%|          | 0.00/1.21M [00:00<?, ?B/s]

collected_gender_probe_biases.pt:   0%|          | 0.00/1.71k [00:00<?, ?B/s]

  weights = torch.load(weights_file, map_location=device)
  biases = torch.load(bias_file, map_location=device)


In [4]:
from utils.probes import load_dataset

texts, labels = load_dataset(attribute)

In [5]:
import transformer_lens as tl
import torch

torch.set_grad_enabled(False)
model_name = f"google/gemma-2-9b"
model = tl.HookedTransformer.from_pretrained(model_name, center_unembed=True, dtype="bfloat16")



Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b into HookedTransformer


In [6]:
from utils.probes import LinearProbes
from utils.index import Ix

def probe_attribution_metric(
    cache: tl.ActivationCache | dict,  
    probe: LinearProbes,  
    hook_point: str, 
    correct_label: int | list[int],
    pos_slice: slice | None = Ix[:, -1].as_index
):
    if pos_slice is None:
        resid_cache = cache[hook_point]
    else:
        resid_cache = cache[hook_point][pos_slice]
    probe = probe.to(resid_cache.device).to(dtype=resid_cache.dtype)
    probe_logits = probe.probe(resid_cache)
    if len(probe_logits.shape) == 2:
        if isinstance(correct_label, list):
            assert len(correct_label) == probe_logits.shape[0]
            return probe_logits[torch.arange(probe_logits.shape[0]), correct_label]
        return probe_logits[:, correct_label]
    else:
        return probe_logits[correct_label]

In [8]:
from functools import partial
from utils.cache import get_cache_fwd_and_bwd

probe = probes_for_each_layer[probe_layer]

### SAE Attribution

In [9]:
from utils.sae_loader import load_gemma_saes
saes = load_gemma_saes("9b", layers=layers[:-1])
saes.keys()

Loading SAEs:   0%|          | 0/29 [00:00<?, ?it/s]

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28])

In [10]:
layers = list(range(0, probe_layer+1))
hook_points = [f"blocks.{l_no}.hook_resid_post" for l_no in layers]

In [23]:
from utils.attribution import compute_sae_activations_and_attributions
from tqdm import tqdm

all_sae_attrs = []
all_sae_acts = []
for prompt_idx in tqdm(range(len(texts[:50]))):
    prompt = texts[prompt_idx]
    correct_label = labels[prompt_idx]
    metric = partial(probe_attribution_metric, probe=probe, hook_point=f"blocks.{probe_layer}.hook_resid_post", correct_label=correct_label)
    loss, fwd_cache, bwd_cache = get_cache_fwd_and_bwd(model, prompt, metric, hook_points=hook_points, metric_needs_cache=True)

    sae_acts, sae_attrs = compute_sae_activations_and_attributions(
        saes, fwd_cache, bwd_cache, hook_points[:-1]
    )
    
    all_sae_attrs.append(sae_attrs)
    all_sae_acts.append(sae_acts)

100%|██████████| 50/50 [02:17<00:00,  2.76s/it]


In [24]:
torch.cuda.empty_cache()

In [None]:
loss, fwd_cache, bwd_cache = get_cache_fwd_and_bwd(model, prompt, metric, hook_points=hook_points, metric_needs_cache=True)

In [17]:
fwd_cache.keys(), "\n", bwd_cache.keys()


(dict_keys(['blocks.0.hook_resid_post', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_post', 'blocks.3.hook_resid_post', 'blocks.4.hook_resid_post', 'blocks.5.hook_resid_post', 'blocks.6.hook_resid_post', 'blocks.7.hook_resid_post', 'blocks.8.hook_resid_post', 'blocks.9.hook_resid_post', 'blocks.10.hook_resid_post', 'blocks.11.hook_resid_post', 'blocks.12.hook_resid_post', 'blocks.13.hook_resid_post', 'blocks.14.hook_resid_post', 'blocks.15.hook_resid_post', 'blocks.16.hook_resid_post', 'blocks.17.hook_resid_post', 'blocks.18.hook_resid_post', 'blocks.19.hook_resid_post', 'blocks.20.hook_resid_post', 'blocks.21.hook_resid_post', 'blocks.22.hook_resid_post', 'blocks.23.hook_resid_post', 'blocks.24.hook_resid_post', 'blocks.25.hook_resid_post', 'blocks.26.hook_resid_post', 'blocks.27.hook_resid_post', 'blocks.28.hook_resid_post']),
 '\n',
 dict_keys(['blocks.27.hook_resid_post', 'blocks.26.hook_resid_post', 'blocks.25.hook_resid_post', 'blocks.24.hook_resid_post', 'blocks.23.hook_resi

In [28]:
from utils.attribution import get_top_k_contributions
import numpy as np

top_k_dfs = []
for attr in all_sae_attrs:
    per_pos_contribution = attr[:, 1:, :].sum(-1) # layer x positions
    top_k_contributions = get_top_k_contributions(per_pos_contribution, k=5)["latent_idx"].tolist()
    tokens_we_care_about, occurrences = np.unique(top_k_contributions, return_counts=True)
    per_latent_contribution = attr[:, tokens_we_care_about].sum(1) # layer x latents
    df = get_top_k_contributions(per_latent_contribution)
    top_k_dfs.append(df)


In [29]:
top_k_dfs[0]

Unnamed: 0,latent_idx,layer,contribution,abs_contribution
0,1891,0,1.249013,1.249013
1,1587,0,0.947196,0.947196
2,15624,0,0.786435,0.786435
3,15335,0,0.652231,0.652231
4,6497,0,-0.529762,0.529762
...,...,...,...,...
135,14008,27,0.740169,0.740169
136,6600,27,-0.565594,0.565594
137,13238,27,-0.423949,0.423949
138,12486,27,-0.421051,0.421051


In [30]:
import pickle
with open("notebooks/results/top_k_dfs.pkl", "wb") as f:
    pickle.dump(top_k_dfs, f)

### Nodewise Analysis