In [16]:
import transformer_lens
from transformer_lens import HookedTransformer, HookedTransformerConfig
from transformer_lens import utils
import torch
import gdown
from einops import einsum
import circuitsvis as cv

In [2]:
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
"""
$ 2-layer model
$ attention only
$ no layer-normalization and biases
$ positional embeddings are added to each query and key vectors in the attention layer(not token embeddings).
  so no position data in the value matrix hence also not in the residual-stream[line: shortformer]
"""

cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b", 
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)

In [4]:
weights_dir = "/home/happyuser/main/3m/jupyter/transformerlens-practice/learn-mech-interp/essentials"

In [5]:
model = HookedTransformer(cfg)
pretrained_weights = torch.load(weights_dir+"/essentialsewk26ptl.part", map_location="cuda")
model.load_state_dict(pretrained_weights)

<All keys matched successfully>

In [6]:
logits, cache = model.run_with_cache("")

In [7]:
cache.cache_dict.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.attn.hook_result', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_post'])

In [15]:
for item in logits[0, -1].argsort(dim=-1, descending=True)[:10]:    # 0 is \n
    print(f"'{model.to_string(item)}'", end="   ")

'
'   ','   '.'   '-'   ' of'   ' and'   ' the'   ' a'   ' to'   ' ('   

In [21]:
value_store = torch.zeros((model.cfg.n_layers, model.cfg.d_head), device="cuda")

In [32]:
def value_hook(value_tensor, hook):
    value_store[hook.layer(), :] = value_tensor[0, 0, 0, :]
    

In [33]:
model.run_with_hooks(
    model.to_tokens(""),
    return_type=None,
    fwd_hooks=[
        ("blocks.0.attn.hook_v", value_hook),
        ("blocks.1.attn.hook_v", value_hook),
    ],
)

In [66]:
torch.testing.assert_close(value_store[0], (model.embed([0]) @ model.blocks[0].attn.W_V[0])[0], atol=1e-3, rtol=0)
print("value0 pass!")
torch.testing.assert_close(value_store[1], cache["blocks.0.hook_resid_post"][0, 0] @ model.blocks[1].attn.W_V[0], atol=1e-3, rtol=0)
print("value1 pass!")

value0 pass!
value1 pass!


In [None]:
def value_contributions(value_store, W_U):
    