In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from utils.load_probes import load_probe

probes = load_probe(
    "gender"
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from utils.probes import make_probes_for_each_layer

weights, biases = probes

probes_for_each_layer = make_probes_for_each_layer(weights, biases)
weights.shape, biases.shape

(torch.Size([42, 2, 3584]), torch.Size([42, 2]))

In [4]:
from utils.probes import load_dataset

texts, labels = load_dataset("gender")
texts[0], labels[0]


("### Human: Hello, I just signed up for this chat service. How are you today?\n\n### Assistant: Hello! I'm an AI assistant, so I don't have emotions, but I'm here to help you. How can I assist you today?\n\n### Human: That's great. I have a question about fashion. Can you recommend any trendy outfits for a casual girls' night out?\n\n### Assistant: Of course! For a casual girls' night out, you can never go wrong with a cute pair of high-waisted jeans, a cropped top, and some stylish sneakers. You can also add a statement piece like a chunky necklace or a fashionable handbag to complete the look. Just remember to choose colors and patterns that reflect your personal style!\n\n### Human: That sounds like a great idea! I love the idea of adding a statement piece. What other accessories do you think would pair well with the outfit?\n\n### Assistant: Adding a few more accessories can definitely elevate your outfit. How about some hoop earrings or a stack of delicate bracelets? You can also

In [5]:
import transformer_lens as tl
import torch

torch.set_grad_enabled(False)
model_name = f"google/gemma-2-9b-it"
model = tl.HookedTransformer.from_pretrained(model_name, center_unembed=False, dtype="bfloat16")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  6.35it/s]


Loaded pretrained model google/gemma-2-9b-it into HookedTransformer


In [6]:
import utils.neel_utils as nutils
logits, fwd_cache = model.run_with_cache(texts[0])
logits = logits[:, -1]
nutils.show_df(nutils.create_vocab_df(logits[0], make_probs=True).head(15))

Unnamed: 0,token,logit,log_prob,prob
8914,·female,19.0,-0.212891,0.808594
5476,·likely,16.0,-3.21875,0.040283
15815,·assumed,15.375,-3.84375,0.021484
30870,·Female,15.25,-3.96875,0.019043
5231,·**,14.9375,-4.28125,0.013916
780,·not,14.75,-4.46875,0.011536
139,··,13.9375,-5.28125,0.005127
49332,·irrelevant,13.8125,-5.40625,0.004517
53012,·feminine,13.75,-5.46875,0.004242
2845,·important,13.625,-5.59375,0.003738


In [7]:
torch.cuda.empty_cache()

In [8]:
from utils.probes import LABEL_MAPS
# reverse the label map
label_to_token = {v: k for k, v in LABEL_MAPS["gender"].items()}

label_to_token[labels[0]]

'female'

In [60]:
from utils.probes import LinearProbes

def get_accuracy_for_cache(cache, labels, probes: list[LinearProbes]):
    accs = []
    labels = torch.tensor(labels).to(probes[0].probe.weight.device)
    for layer_idx, probe in enumerate(probes):
        resid_cache = cache[f"blocks.{layer_idx}.hook_resid_post"]
        logits = probe(resid_cache.to(probe.probe.weight.device))
        preds = logits.argmax(dim=-1)
        # labels shape = (batch_size,) and preds shape = (batch_size, num_tokens) so we need to compare each token with the label
        acc = (preds.T == labels).float()
        accs.append(acc)
    return torch.stack(accs).view(preds.shape[0], len(probes), -1)


In [10]:
accs = get_accuracy_for_cache(fwd_cache, labels[0], probes_for_each_layer)
accs.shape

torch.Size([42, 487])

In [20]:
from utils.cache import batched_fwd_cache

all_fwd_cache = batched_fwd_cache(model, texts[:10], batch_size=1)
all_fwd_cache.keys()


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [01:50<00:00, 11.05s/it]


dict_keys(['blocks.0.hook_resid_pre', 'blocks.0.hook_resid_mid', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_resid_mid', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_pre', 'blocks.2.hook_resid_mid', 'blocks.2.hook_resid_post', 'blocks.3.hook_resid_pre', 'blocks.3.hook_resid_mid', 'blocks.3.hook_resid_post', 'blocks.4.hook_resid_pre', 'blocks.4.hook_resid_mid', 'blocks.4.hook_resid_post', 'blocks.5.hook_resid_pre', 'blocks.5.hook_resid_mid', 'blocks.5.hook_resid_post', 'blocks.6.hook_resid_pre', 'blocks.6.hook_resid_mid', 'blocks.6.hook_resid_post', 'blocks.7.hook_resid_pre', 'blocks.7.hook_resid_mid', 'blocks.7.hook_resid_post', 'blocks.8.hook_resid_pre', 'blocks.8.hook_resid_mid', 'blocks.8.hook_resid_post', 'blocks.9.hook_resid_pre', 'blocks.9.hook_resid_mid', 'blocks.9.hook_resid_post', 'blocks.10.hook_resid_pre', 'blocks.10.hook_resid_mid', 'blocks.10.hook_resid_post', 'blocks.11.hook_resid_pre', 'blocks.11.hook_resid_mid', 'blocks.11.hook_resid_post',

In [61]:
all_fwd_cache_accs = get_accuracy_for_cache(all_fwd_cache, labels[:10], probes_for_each_layer)
all_fwd_cache_accs.shape

torch.Size([10, 42, 561])

In [70]:
from utils.neel_utils import line

line(all_fwd_cache_accs.mean(dim=0)[:, -1], title="Average accuracy for each layer")
