In [28]:
from transformer_lens import HookedTransformer
import torch
import einops

In [3]:
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [136]:
model = HookedTransformer.from_pretrained("attn-only-2l", device=device, fold_ln=True)

Loaded pretrained model attn-only-2l into HookedTransformer


In [8]:
text = "Social security is a government program that produces exuberant daisies"
tokens = model.to_tokens(text)
logits, cache = model.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 15, 48262])


In [11]:
predicted_tokens = logits[0].argmax(dim=-1)
predicted_text = model.to_string(predicted_tokens)
print(predicted_text)

. Media is a great- that is a-ance andemonies.


In [15]:
media_token = predicted_tokens[1]
security_token = tokens[0, 2]

print(media_token, model.to_string(media_token))
print(security_token, model.to_string(security_token))

tensor(10967, device='cuda:0')  Media
tensor(3860, device='cuda:0')  security


In [22]:
head_results = cache.stack_head_results(apply_ln=True)
head_results.shape # head batch pos res

torch.Size([16, 1, 15, 512])

In [32]:
W_EU = einops.einsum(model.W_E, model.W_U, "d_vocab_in d_model, d_model d_vocab_out -> d_vocab_in d_vocab_out")
print(W_EU.shape)

torch.Size([48262, 48262])


In [34]:
50000*50000 / 1e9

2.5

In [58]:
bigram_predictions = W_EU[tokens, :]
print(bigram_predictions.shape)

torch.Size([1, 15, 48262])


In [65]:
bigram_topk, bigram_topk_indices = torch.topk(bigram_predictions, k=10, dim=-1)
model.tokenizer.convert_ids_to_tokens(bigram_topk_indices[0, 2])

['oupe',
 'aders',
 'fully',
 'igu',
 'ange',
 'ools',
 'aded',
 'Ġdeposit',
 'Ġforces',
 'iels']

In [61]:
bigram_topk_indices[0]

tensor([[42933, 42706,  6396, 29264,  2335, 18724, 12907, 22803, 15039, 12393],
        [  266,   778, 15732,  7887,  1322,   899, 10975, 38135, 24299,   382],
        [36407, 15760,  2830, 35365,   904, 22944,  7708, 18292,  5457, 38301],
        [ 4546, 37559, 35463,  7595, 47593, 35606,  7283,  5093, 22697, 26466],
        [34311, 23979, 46119, 37725, 22440, 13762, 20373,  3221, 26422, 43532],
        [25990,  4223, 13037, 16736, 35665, 19246, 10609,  6166,  8081,  1174],
        [32625, 32540, 35056,  5214, 15999,  6766,   273, 44044, 17458, 47409],
        [ 7492,  5161, 13210,   349,  2082,  2875,   891,  2697,  4325,  3138],
        [18367, 31026,  2631, 18750, 43757, 10890, 38699,   739,  2151, 45001],
        [28147,  7592, 19694,   738, 43026, 20174, 47058, 35792, 37221, 42023],
        [10818,  1048,   527,   592,  2083, 19249,  1915,  5785,   582,  9265],
        [ 9767, 33845,  4527, 46829,  2329, 17228, 46828,   627,  2959, 13314],
        [14706, 17464, 15770, 23979,  20

In [57]:
topk, topk_indices = torch.topk(model.b_U, k=10)
model.tokenizer.convert_ids_to_tokens(topk_indices)

['Ġand', ',', 'Ġthe', 'Ċ', '.', 'Ġa', 'Ġin', ')', 'Ġto', 'Ġnot']

In [46]:
model.tokenizer.convert_ids_to_tokens(bigram_predictions[0].argmax(dim=-1))


['Ġand',
 'Ġand',
 'Ġand',
 'Ġthe',
 'Ċ',
 'Ġand',
 'Ġand',
 'Ġthe',
 'Ġthe',
 ',',
 'Ġand',
 'Ġand',
 'Ġand',
 'Ġand',
 'Ġand']

In [81]:
text = "Social security is a government program that produces exuberant daisies"
tokens = model.to_tokens(text)
logits, cache = model.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 15, 48262])


In [82]:
def get_log_probs(logits, tokens):
    log_probs = logits.log_softmax(dim=-1)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return log_probs_for_tokens

In [83]:
def ablate_mlp_hook(value, hook):
    return torch.zeros_like(value)

layer_to_ablate = 0
original_losses = []
ablation_losses = []

original_loss = model(tokens, return_type="loss")

ablated_loss = model.run_with_hooks(
        tokens, 
        return_type="loss", 
        fwd_hooks=[(
            f"blocks.0.hook_attn_out", 
            ablate_mlp_hook
            )]
        )

bigram_pred = W_EU[tokens, :]
bigram_loss = - get_log_probs(bigram_pred, tokens).mean()

print(f"Original model loss: {original_loss:.6f}")
print(f"Ablated loss: {ablated_loss:.6f}")
print(f"Bigram loss {bigram_loss}")

Original model loss: 5.388212
Ablated loss: 7.431926
Bigram loss 9.814844131469727


In [85]:
# Compare head contribution to final logits

text = "Social security is a government program that produces exuberant daisies"
tokens = model.to_tokens(text)
logits, cache = model.run_with_cache(tokens)
print(logits.shape)

head_results = cache.stack_head_results(apply_ln=True)
print(head_results.shape) # head batch pos res

torch.Size([1, 15, 48262])
Tried to stack head results when they weren't cached. Computing head results now
torch.Size([16, 1, 15, 512])


In [87]:
model_tokens = logits.argmax(dim=-1)
directions = model.tokens_to_residual_directions(model_tokens)

head_attribution = einops.einsum(head_results, directions, "head batch pos d_model, batch pos d_model -> head")
print(head_attribution)

tensor([ -2.2576,  -4.5832,  -3.0620,   2.2735,   2.3002,   9.1988,  -2.8497,
         -1.3467,  -1.7043,   1.4677,  19.0915,   5.8268,  13.0674,   4.6253,
          8.9530, -35.5903], device='cuda:0')


In [90]:
print(directions.shape)
print(head_results.shape)

torch.Size([1, 15, 512])
torch.Size([16, 1, 15, 512])


In [97]:
cos = torch.nn.CosineSimilarity(dim=-1)
similarities = cos(directions[0], head_results[:, 0])
print(similarities.shape) # head position
similarities[:, 2]


torch.Size([16, 15])


tensor([-0.0223, -0.0876, -0.0189, -0.0100,  0.0426,  0.0146, -0.0363,  0.0008,
        -0.0137,  0.1088, -0.0167,  0.1116,  0.0736, -0.0306,  0.0404, -0.0601],
       device='cuda:0')

In [98]:
torch.max(similarities)

tensor(0.2165, device='cuda:0')

In [100]:
embed = model.W_E[tokens, :]
similarities = cos(directions[0], embed[0])
print(similarities.shape)
torch.max(similarities)

torch.Size([15])


tensor(0.2665, device='cuda:0')

In [144]:
#text = "Harry Potter is great. Harry"
text = "Leonard Potter is great. Leonard"
tokens = model.to_tokens(text)
logits, cache = model.run_with_cache(tokens)

In [145]:
potter_token = model.to_tokens(" Potter", prepend_bos=False)
print(potter_token)
potter_direction = model.tokens_to_residual_directions(potter_token)
print("Potter direction shape", potter_direction.shape)


head_results = cache.stack_head_results(apply_ln=True)
print("Head result shape", head_results.shape) # head batch pos res

cos = torch.nn.CosineSimilarity(dim=-1)
similarities = cos(potter_direction, head_results[:, 0])
print("Cosine similarity shape", similarities.shape) # head position
similarities[:, -1] # Similarity of Potter following the last Harry

tensor([[27754]], device='cuda:0')
Potter direction shape torch.Size([512])
Tried to stack head results when they weren't cached. Computing head results now
Head result shape torch.Size([16, 1, 8, 512])
Cosine similarity shape torch.Size([16, 8])


tensor([ 0.0334, -0.0452,  0.0025, -0.0315, -0.0109, -0.0741,  0.0194, -0.0182,
         0.0218,  0.0739,  0.0436, -0.0171,  0.0950, -0.0613,  0.1445,  0.1164],
       device='cuda:0')

In [148]:
def get_bigram_logits(tokens, model: HookedTransformer, add_bias=False):
    embed = model.W_E[tokens, :]
    unembed = einops.einsum(embed, model.W_U, "batch pos d_model, d_model d_vocab_out -> batch pos d_vocab_out")
    if add_bias:
        unembed += model.b_U
    return unembed

def get_topk_words(logits, pos=-1, k=10):
    # Input with batch dim = 1
    topk, topk_indices = torch.topk(logits[0, pos], k=k)
    tokens = model.tokenizer.convert_ids_to_tokens(topk_indices)
    return tokens

In [149]:
get_topk_words(logits)

['Ġis',
 'ĠCohen',
 'ĠBernstein',
 'Ġhas',
 'ĠPotter',
 'ĠNim',
 'Ġwas',
 'Ġand',
 ',',
 'âĢĻ']

In [150]:
bigram_logits = get_bigram_logits(tokens, model)
print(get_topk_words(bigram_logits, pos=-1))

['town', 'sson', 'ounce', 'stown', 'esque', 'nier', 'ĠNim', 'ĠCohen', 'ounced', 'ette']


In [143]:
from transformer_lens import utils
utils.test_prompt("Leonard Potter is great. Leonard", " Potter", model, prepend_bos=True)

Tokenized prompt: ['<|BOS|>', 'Leon', 'ard', ' Potter', ' is', ' great', '.', ' Leonard']
Tokenized answer: [' Potter']


Top 0th token. Logit: 14.11 Prob:  9.62% Token: | is|
Top 1th token. Logit: 13.53 Prob:  5.38% Token: | Cohen|
Top 2th token. Logit: 13.43 Prob:  4.86% Token: | Bernstein|
Top 3th token. Logit: 13.34 Prob:  4.45% Token: | has|
Top 4th token. Logit: 13.33 Prob:  4.40% Token: | Potter|
Top 5th token. Logit: 12.83 Prob:  2.65% Token: | Nim|
Top 6th token. Logit: 12.69 Prob:  2.31% Token: | was|
Top 7th token. Logit: 12.64 Prob:  2.21% Token: | and|
Top 8th token. Logit: 12.61 Prob:  2.14% Token: |,|
Top 9th token. Logit: 12.58 Prob:  2.07% Token: |’|
