In [1]:
# %%
import os
# os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/"
# %%
from neel.imports import *
from neel_plotly import *

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.set_grad_enabled(False)

model = HookedTransformer.from_pretrained("pythia-2.8b")

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab
# %%
evals.sanity_check(model)
# %%
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch
import numpy as np
# import gradio as gr
import pprint
import json
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from huggingface_hub import HfApi
from IPython.display import HTML
from functools import partial
import tqdm.notebook as tqdm
import plotly.express as px
import pandas as pd



torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.


torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.


torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.



In IPython
In IPython
Set autoreload
Imported everything!
Loaded pretrained model pythia-2.8b into HookedTransformer


In [4]:
example_prompt = \
"""
Cat: C A T
Dog: D O G
Bird: B I R D
Fish: F I"""
example_answer = " S"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', '\n', 'Cat', ':', ' C', ' A', ' T', '\n', 'Dog', ':', ' D', ' O', ' G', '\n', 'B', 'ird', ':', ' B', ' I', ' R', ' D', '\n', 'Fish', ':', ' F', ' I']
Tokenized answer: [' S']


Top 0th token. Logit: 24.59 Prob: 98.55% Token: | S|
Top 1th token. Logit: 18.58 Prob:  0.24% Token: | X|
Top 2th token. Logit: 18.45 Prob:  0.21% Token: | C|
Top 3th token. Logit: 18.31 Prob:  0.19% Token: | L|
Top 4th token. Logit: 18.18 Prob:  0.16% Token: | SH|
Top 5th token. Logit: 17.49 Prob:  0.08% Token: | D|
Top 6th token. Logit: 17.40 Prob:  0.07% Token: | Z|
Top 7th token. Logit: 17.22 Prob:  0.06% Token: | N|
Top 8th token. Logit: 17.06 Prob:  0.05% Token: | T|
Top 9th token. Logit: 17.01 Prob:  0.05% Token: | I|


In [11]:
example_prompt = \
"""
Cat: C A T
Dog: D O G
Bird: B I R D
Fish: F I S H
Apple:"""
example_answer = " A"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', '\n', 'Cat', ':', ' C', ' A', ' T', '\n', 'Dog', ':', ' D', ' O', ' G', '\n', 'B', 'ird', ':', ' B', ' I', ' R', ' D', '\n', 'Fish', ':', ' F', ' I', ' S', ' H', '\n', 'Apple', ':']
Tokenized answer: [' A']


Top 0th token. Logit: 21.78 Prob: 96.88% Token: | A|
Top 1th token. Logit: 16.23 Prob:  0.38% Token: | a|
Top 2th token. Logit: 15.72 Prob:  0.23% Token: |  |
Top 3th token. Logit: 15.55 Prob:  0.19% Token: | I|
Top 4th token. Logit: 15.45 Prob:  0.17% Token: | S|
Top 5th token. Logit: 15.41 Prob:  0.17% Token: | O|
Top 6th token. Logit: 15.24 Prob:  0.14% Token: | P|
Top 7th token. Logit: 15.15 Prob:  0.13% Token: | Apple|
Top 8th token. Logit: 15.03 Prob:  0.11% Token: | H|
Top 9th token. Logit: 14.90 Prob:  0.10% Token: | M|


to do:
- DLA
- attention

# DLA

In [6]:
logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)

In [12]:
a_token = model.to_single_token(" A")
e_token = model.to_single_token(" I")

a_unembed = model.W_U[:, a_token].detach().cpu()
e_unembed = model.W_U[:, e_token].detach().cpu()

dif_a_e = a_unembed - e_unembed
print(dif_a_e.shape)

torch.Size([2560])


In [13]:
decomp, labels = cache.get_full_resid_decomposition(expand_neurons=False, return_labels=True)
decomp = decomp[:,0,-1].detach().cpu()

In [14]:
decomp_df = pd.DataFrame({
    "labels": labels,
    "a_proj": decomp @ a_unembed,
    "e_proj": decomp @ e_unembed,
    "a_e_proj": decomp @ dif_a_e,}
)

# decomp_df.sort_values('a_e_proj').head(20).style.background_gradient(cmap='RdBu', axis=0)
px.line(decomp_df, x="labels", y=["a_proj", "e_proj", "a_e_proj"], title="Projection of Residuals onto A and E Embeddings")

Interesting heads:
- L11H19
- L20H29

In [110]:
from transformer_lens.utils import get_act_name

get_act_name("attn_score", 10)

'blocks.10.hook_attn_score'

In [111]:
from circuitsvis.attention import attention_patterns

cache[get_act_name("attn_scores", 10)].shape


torch.Size([1, 32, 33, 33])

# Attention

In [133]:
fig = px.imshow(cache[get_act_name("pattern", 6)][0].detach().cpu(), animation_frame=0, color_continuous_midpoint=0, color_continuous_scale="RdBu")

# set xticks and yticks
str_tokens = model.to_str_tokens(example_prompt)
fig.update_layout(
    xaxis=dict(
        tickmode='array',
        tickvals=list(range(len(str_tokens))),
        ticktext=str_tokens,
    ),
    yaxis=dict(
        tickmode='array',
        tickvals=list(range(len(str_tokens))),
        ticktext=str_tokens,
    )
)

fig.show()

: 

In [17]:
def plot_attn_pattern(cache, layer, head):
    

    fig = px.imshow(cache[get_act_name("pattern", layer)][0,head].detach().cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu")

    # set xticks and yticks
    str_tokens = model.to_str_tokens(example_prompt)
    fig.update_layout(
        xaxis=dict(
            tickmode='array',
            tickvals=list(range(len(str_tokens))),
            ticktext=str_tokens,
        ),
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(str_tokens))),
            ticktext=str_tokens,
        )
    )

    fig.show()
    
plot_attn_pattern(cache, 11, 19)

In [122]:
plot_attn_pattern(cache, 18, 4)

In [123]:
plot_attn_pattern(cache, 14, 10) # attends to other positions in the output after the first example. 

In [115]:
example_prompt

'\nCat: C A T\nDog: D O G\nBird: B I R D\nFish: F I S H\nRabbit: R'

In [79]:
str_tokens = model.to_str_tokens(example_prompt)
attention_patterns(attention =cache[get_act_name("attn_scores", 10)].detach().cpu()[0,-1], tokens=str_tokens)