In [1]:
%load_ext autoreload
%autoreload 2
import torch
from alphatoe import plot, game
from transformer_lens import HookedTransformer, HookedTransformerConfig
import json
import einops
import circuitsvis as cv
from plotly.subplots import make_subplots
import plotly.graph_objects as go

In [2]:
weights = torch.load("../scripts/models/prob all 8 layer control-20230718-185339.pt")
with open("../scripts/models/prob all 8 layer control-20230718-185339.json", "r") as f:
    args= json.load(f)

In [3]:
model_cfg = HookedTransformerConfig(
        n_layers=args["n_layers"],
        n_heads=args["n_heads"],
        d_model=args["d_model"],
        d_head=args["d_head"],
        d_mlp=args["d_mlp"],
        act_fn=args["act_fn"],
        normalization_type=args["normalization_type"],
        d_vocab=11,
        d_vocab_out=10,
        n_ctx=10,
        init_weights=True,
        device=args["device"],
        seed=args["seed"])

In [4]:
model =  HookedTransformer(model_cfg)
model.load_state_dict(weights)

<All keys matched successfully>

In [5]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): Identity()
      (ln2): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)

In [28]:
plot.imshow(model.W_E)

In [32]:
emb = model.embed.W_E
vec_count = emb.shape[0]
vec_dim = emb.shape[1]
print(f"The embedding shape is {emb.shape}, so our vectors of length {emb.shape[1]}")

dot_products = einops.einsum(emb, emb, "v2 embs, v1 emb -> v1 v2")

The embedding shape is torch.Size([11, 128]), so our vectors of length 128


In [34]:
print(dot_products.shape)
plot.imshow_div(dot_products)

torch.Size([11, 11])


softmax(x W_Q @ W_K.T x.T)

In [45]:
tokens = [10,0, 1,2,3,3,5, 6, 9, 9]
# tokens = ([10] * 5) + [1,2,5,8,7]
str_tokens = [str(token) for token in tokens]
logits, cache = model.run_with_cache(torch.tensor(tokens).to('cuda'), remove_batch_dim=True)

print(type(cache))
attention_pattern = cache["pattern", 0, "attn"]
print(attention_pattern.shape)
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([8, 10, 10])


In [50]:
plot.imshow(model.W_U)

In [53]:

emb = model.W_U.T
vec_count = emb.shape[0]
vec_dim = emb.shape[1]
print(f"The embedding shape is {emb.shape}, so our vectors of length {emb.shape[1]}")

dot_products = einops.einsum(emb, emb, "v2 embs, v1 emb -> v1 v2")

The embedding shape is torch.Size([10, 128]), so our vectors of length 128


In [54]:
print(dot_products.shape)
plot.imshow_div(dot_products)

torch.Size([10, 10])


# How does it understand a draw?
Use logit attributions to figure out the contribution from each head, and hopefully track the contributions through the mlp

Direct Logit Attributions

In [21]:
model.cfg.use_attn_result = True

In [22]:
seq = [10,1,2,3,4,5,6,7,8]
logits, cache = model.run_with_cache(torch.tensor(seq), remove_batch_dim=True)

In [85]:
for item in cache:
    print(item)


hook_embed
hook_pos_embed
blocks.0.hook_resid_pre
blocks.0.attn.hook_q
blocks.0.attn.hook_k
blocks.0.attn.hook_v
blocks.0.attn.hook_attn_scores
blocks.0.attn.hook_pattern
blocks.0.attn.hook_z
blocks.0.attn.hook_result
blocks.0.hook_attn_out
blocks.0.hook_resid_mid
blocks.0.mlp.hook_pre
blocks.0.mlp.hook_post
blocks.0.hook_mlp_out
blocks.0.hook_resid_post


In [19]:
cache["pattern",0,"attention"]

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4471, 0.5529, 0.0000, 0.0000],
         [0.2887, 0.3693, 0.3419, 0.0000],
         [0.2493, 0.2845, 0.2453, 0.2210]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3770, 0.6230, 0.0000, 0.0000],
         [0.2704, 0.2652, 0.4644, 0.0000],
         [0.2161, 0.2451, 0.2813, 0.2575]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3797, 0.6203, 0.0000, 0.0000],
         [0.2492, 0.3357, 0.4151, 0.0000],
         [0.1884, 0.3139, 0.2600, 0.2377]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4115, 0.5885, 0.0000, 0.0000],
         [0.2640, 0.3268, 0.4092, 0.0000],
         [0.1972, 0.3124, 0.2590, 0.2315]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4845, 0.5155, 0.0000, 0.0000],
         [0.3164, 0.2945, 0.3891, 0.0000],
         [0.2372, 0.2530, 0.2600, 0.2498]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3987, 0.6013, 0.0000, 0.0000],
         [0.3000, 0.3060, 0.3941, 0.0000],
 

scores are pre softmax
pattern is post softmax

In [50]:
h_out = cache["result", 0]
h_contrib = h_out @ model.W_U
mlp_out = cache["mlp_out", 0]
mlp_contrib = mlp_out @ model.W_U

In [95]:
# Seq x Head count x Dict size
h_contrib.shape

torch.Size([4, 8, 10])

In [46]:
for i in range(h_contrib.shape[1]):
    plot.imshow(h_contrib[:, i, :])



In [64]:
h_list = [h_contrib[0, -1, i, :] for i in range(h_contrib.shape[2])]
mlp_list = [mlp_contrib[0,-1]]
logit_list =  [logits[0, -1]]

In [66]:
full_list = h_list + mlp_list + logit_list

In [68]:
plot.imshow(torch.stack(full_list), show=False)

batch x seq x residual_dimension
[10, 1, 2, 3]

In [63]:
h_contrib.shape

torch.Size([1, 9, 8, 10])

In [65]:
print(h_list[0])
print(mlp_list[0])
print(logit_list[0])

tensor([ 0.1526, -0.1598,  0.0920, -0.0806, -0.0475, -0.0949, -0.0338,  0.0199,
        -0.0413,  0.0444], device='cuda:0', grad_fn=<SliceBackward0>)
tensor([ 0.8358,  1.1299,  0.4231,  2.1205,  2.9491,  3.8625,  0.7948,  0.3575,
        -3.4958, -6.8154], device='cuda:0', grad_fn=<SelectBackward0>)
tensor([ 0.8394,  1.2214,  0.4588,  2.4614,  3.4222,  4.0933,  1.0142,  0.2843,
        -3.5575, -7.6659], device='cuda:0')


In [58]:
mlp_contrib.shape

torch.Size([1, 9, 10])

In [59]:
logits.shape

torch.Size([1, 9, 10])

In [53]:
ls = torch.stack(t)

RuntimeError: stack expects each tensor to be equal size, but got [8, 10] at entry 0 and [1, 9, 10] at entry 9

In [49]:
t.shape

AttributeError: 'list' object has no attribute 'shape'

In [48]:
plot.imshow(t, show=False)

TypeError: squeeze(): argument 'input' (position 1) must be Tensor, not list

In [191]:
sum(t[:-1]) / t[-1]

tensor([0.9918, 1.0166, 0.9905, 0.9961, 1.0052, 1.0038, 0.9874, 1.0080, 0.9875,
        0.9891], device='cuda:0', grad_fn=<DivBackward0>)

In [183]:
mlp_out = cache["mlp_out", 0]

In [186]:
mlp_out = cache["mlp_out", 0]
mlp_contrib = mlp_out @ model.W_U

In [187]:
mlp_contrib.shape

torch.Size([4, 10])

Indirect Logit Attribution

Learning to use hooks

In [13]:
def print_stuff(module, input, output):
    print(f"Inside {module.__class__.__name__} forward")
    try:
        input = input[0].shape
    except:
        input = "None"
    print(f"Input size: {input}")
    print(f"Output size: {output.shape}")
    print()


In [19]:
for handle in handles:
    handle.remove()

In [15]:
handles = []
for name, module in model.named_modules():
    print(name, module.__class__.__name__)
    handle = module.register_forward_hook(print_stuff)
    handles.append(handle)

 HookedTransformer
embed Embed
hook_embed HookPoint
pos_embed PosEmbed
hook_pos_embed HookPoint
blocks ModuleList
blocks.0 TransformerBlock
blocks.0.ln1 Identity
blocks.0.ln2 Identity
blocks.0.attn Attention
blocks.0.attn.hook_k HookPoint
blocks.0.attn.hook_q HookPoint
blocks.0.attn.hook_v HookPoint
blocks.0.attn.hook_z HookPoint
blocks.0.attn.hook_attn_scores HookPoint
blocks.0.attn.hook_pattern HookPoint
blocks.0.attn.hook_result HookPoint
blocks.0.mlp MLP
blocks.0.mlp.hook_pre HookPoint
blocks.0.mlp.hook_post HookPoint
blocks.0.hook_q_input HookPoint
blocks.0.hook_k_input HookPoint
blocks.0.hook_v_input HookPoint
blocks.0.hook_attn_out HookPoint
blocks.0.hook_mlp_out HookPoint
blocks.0.hook_resid_pre HookPoint
blocks.0.hook_resid_mid HookPoint
blocks.0.hook_resid_post HookPoint
unembed Unembed


In [16]:
model(torch.tensor([[1]]))

Inside Embed forward
Input size: torch.Size([1, 1])
Output size: torch.Size([1, 1, 128])

Inside HookPoint forward
Input size: torch.Size([1, 1, 128])
Output size: torch.Size([1, 1, 128])

Inside PosEmbed forward
Input size: torch.Size([1, 1])
Output size: torch.Size([1, 1, 128])

Inside HookPoint forward
Input size: torch.Size([1, 1, 128])
Output size: torch.Size([1, 1, 128])

Inside HookPoint forward
Input size: torch.Size([1, 1, 128])
Output size: torch.Size([1, 1, 128])

Inside Identity forward
Input size: torch.Size([1, 1, 128])
Output size: torch.Size([1, 1, 128])

Inside Identity forward
Input size: torch.Size([1, 1, 128])
Output size: torch.Size([1, 1, 128])

Inside Identity forward
Input size: torch.Size([1, 1, 128])
Output size: torch.Size([1, 1, 128])

Inside HookPoint forward
Input size: torch.Size([1, 1, 8, 16])
Output size: torch.Size([1, 1, 8, 16])

Inside HookPoint forward
Input size: torch.Size([1, 1, 8, 16])
Output size: torch.Size([1, 1, 8, 16])

Inside HookPoint for

tensor([[[  12.9202, -111.7838,   15.6010,   12.1924,   13.7957,    6.0624,
            13.6661,   10.4289,   10.8842,   10.0236]]], device='cuda:0',
       grad_fn=<AddBackward0>)

In [36]:
def ablate_output(module, input, output):
    return torch.zeros_like(output)


In [37]:
handle = model.blocks[0].attn.register_forward_hook(ablate_output)

In [43]:
with torch.no_grad():
    logits, cache = model.run_with_cache(torch.tensor(seq))

In [39]:
logits

tensor([[[  1.2947,   3.1262,   2.9081,   2.2827,   2.2714,   0.4744,   2.3996,
            3.6774,   0.3567, -14.5408],
         [  1.7396,  -6.3523,   0.5021,   2.0853,   0.5907,   1.3986,   1.2605,
            2.4660,   2.7258,  -5.0504],
         [  1.7339,   3.0651,  -6.1480,   1.4941,   1.2052,   2.8589,   1.2549,
            2.7622,   2.3173,  -7.6952],
         [  1.8461,   2.7114,   1.6128,  -5.5627,   1.6523,   3.5241,   1.1422,
            2.6375,   3.3177, -11.2693],
         [  1.2697,   2.7014,   0.6421,   0.6968,  -5.8422,   3.1260,   0.3715,
            2.1019,   2.6573,  -7.6685],
         [  1.5819,   2.8753,   0.6986,   1.9004,   1.0083,  -4.3793,   1.2243,
            2.2668,   2.5124,  -8.9321],
         [  0.9159,   1.6511,   0.9972,   0.9426,   1.2930,   3.4563,  -9.3058,
            1.2813,   1.1888,  -2.9022],
         [  1.2890,   2.2091,   0.9294,   1.6820,   1.5716,   2.8159,   1.4430,
           -6.7558,   2.1559,  -7.6582],
         [  0.8394,   1.2214,   

Manually Computing QK-circuit matrices 

In [73]:
dir(model.blocks[0].attn)
print(model.blocks[0].attn.W_Q.shape)

torch.Size([8, 128, 16])


In [75]:
model.W_E.shape

torch.Size([11, 128])

In [76]:
model.W_pos.shape

torch.Size([10, 128])

In [130]:
pos_emb = model.pos_embed(torch.zeros(1, 10, 128), 0)[0]

In [132]:
QK = einops.einsum(model.blocks[0].attn.W_Q,model.blocks[0].attn.W_K, "h_ind r_dim h_dim, h_ind r_dim2 h_dim -> h_ind r_dim r_dim2" )
QK_circuit = einops.einsum(model.W_E, QK, model.W_E, "dict_size1 r_dim1, h_ind r_dim1 r_dim2, dict_size2 r_dim2 -> h_ind dict_size1 dict_size2")

QK_circuit_pos = torch.tril(einops.einsum(pos_emb , QK, pos_emb, "dict_size1 r_dim1, h_ind r_dim1 r_dim2, dict_size2 r_dim2 -> h_ind dict_size1 dict_size2"))

QK_circuit_pos_emb = einops.einsum(pos_emb, QK, model.W_E, "dict_size1 r_dim1, h_ind r_dim1 r_dim2, dict_size2 r_dim2 -> h_ind dict_size1 dict_size2")
QK_circuit_emb_pos = einops.einsum(model.W_E, QK, pos_emb, "dict_size1 r_dim1, h_ind r_dim1 r_dim2, dict_size2 r_dim2 -> h_ind dict_size1 dict_size2")

In [133]:
zmax = torch.max(QK_circuit_pos).item()
zmin = torch.min(QK_circuit_pos).item()

In [142]:
i =5
plot.imshow_div(QK_circuit[i], show=True, yaxis="Q", xaxis="K", zmax=zmax, zmin=zmin)
plot.imshow_div(QK_circuit_pos[i], show=True,yaxis="Q pos", xaxis="K pos", zmax=zmax, zmin=zmin)
plot.imshow_div(QK_circuit_pos_emb[i], show=True,yaxis="Q pos", xaxis="K", zmax=zmax, zmin=zmin)
plot.imshow_div(QK_circuit_emb_pos[i], show=True,yaxis="Q", xaxis="K pos", zmax=zmax, zmin=zmin)

In [149]:
sm_qk = einops.einsum(QK_circuit, "heads ... -> ...")
sm_qk_pos = einops.einsum(QK_circuit_pos, "heads ... -> ...")
sm_qk_pos_emb = einops.einsum(QK_circuit_pos_emb, "heads ... -> ...")
sm_qk_emb_pos = einops.einsum(QK_circuit_emb_pos, "heads ... -> ...")
zmax = max([torch.max(t).item() for t in [sm_qk,sm_qk_pos,sm_qk_pos_emb,sm_qk_emb_pos]])
zmin= min([torch.min(t).item() for t in [sm_qk,sm_qk_pos,sm_qk_pos_emb,sm_qk_emb_pos]])
plot.imshow_div(sm_qk, show=True, yaxis="Q", xaxis="K", zmax=zmax, zmin=zmin)
plot.imshow_div(sm_qk_pos, show=True,yaxis="Q pos", xaxis="K pos", zmax=zmax, zmin=zmin)
plot.imshow_div(sm_qk_pos_emb, show=True,yaxis="Q pos", xaxis="K", zmax=zmax, zmin=zmin)
plot.imshow_div( sm_qk_emb_pos, show=True,yaxis="Q", xaxis="K pos", zmax=zmax, zmin=zmin)

In [92]:
plot.imshow(QK_circuit_pos[0], show=True)