In [1]:
import torch
import json
from utils import ToyTransformer, ModelConfig

# we want to load in the config & weights 
model_name = "attention_only_1L"

# load in the config
with open(f'{model_name}_config.json', 'r') as f:
    config = json.load(f)
# turn config into a dot dict
config = ModelConfig(**config)
# load in the weights
model = ToyTransformer(config)
state_dict = torch.load(f'toy_transformer_{model_name}.pt')
model.load_state_dict(state_dict)

# we want to load in the data

<All keys matched successfully>

In [2]:
model

ToyTransformer(
  (embed): Embedding(10000, 512)
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0): QuadraticAttention(
      (rotary): Rotary()
      (norm): RMSNorm((64,), eps=None, elementwise_affine=True)
      (mask): Mask()
      (q): Linear(in_features=512, out_features=512, bias=True)
      (k): Linear(in_features=512, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=512, bias=True)
      (o): Linear(in_features=512, out_features=512, bias=False)
    )
  )
  (head): Linear(in_features=512, out_features=10000, bias=False)
)

In [None]:
E = model.embed.weight.data
Q = model.layers[0].q.weight.data
K = model.layers[0].k.weight.data
V = model.layers[0].v.weight.data
O = model.layers[0].o.weight.data
U = model.head.weight.data

print(f"Shape of E: {E.shape}")
print(f"Shape of Q: {Q.shape}")
print(f"Shape of K: {K.shape}")
print(f"Shape of V: {V.shape}")
print(f"Shape of O: {O.shape}")

attn_head_idx = 0
d_head = model.layers[0].d_head

q = Q[attn_head_idx * d_head:(attn_head_idx + 1) * d_head]
k = K[attn_head_idx * d_head:(attn_head_idx + 1) * d_head]
v = V[attn_head_idx * d_head:(attn_head_idx + 1) * d_head]
o = O[attn_head_idx * d_head:(attn_head_idx + 1) * d_head]

qk = q.T @ k
ov = v.T @ o

qk_circuit = E @ qk @ E.T # Doesn't work in bilinear
ov_circuit = U @ ov @ E.T

Shape of E: torch.Size([10000, 512])
Shape of Q: torch.Size([512, 512])
Shape of K: torch.Size([512, 512])
Shape of V: torch.Size([512, 512])
Shape of O: torch.Size([512, 512])


In [51]:
ov_circuit = U @ ov.T @ E.T

In [52]:
# Find the rank of each token w/ itself
import numpy as np
rank_row = []
rank_col = []
row_is_higher = []
for i in range(1000):
    self_value = ov_circuit[i, i]
    row = ov_circuit[i]
    col = ov_circuit.T[i]
    rank_in_row = (row >= self_value).sum()
    rank_in_col = (col >= self_value).sum()
    rank_row.append(rank_in_row)
    rank_col.append(rank_in_col)
    row_is_higher.append(rank_in_row < rank_in_col)

print(rank_row)
print(rank_col)

print(f"Mean rank in row: {np.median(rank_row)}")
print(f"median rank in col: {np.median(rank_col)}")

print(f"Row is higher: {np.mean(row_is_higher)}")

[tensor(943), tensor(191), tensor(9), tensor(159), tensor(139), tensor(1548), tensor(9783), tensor(8607), tensor(9746), tensor(8972), tensor(9233), tensor(9774), tensor(7577), tensor(1914), tensor(208), tensor(8124), tensor(6896), tensor(34), tensor(3), tensor(2), tensor(47), tensor(2690), tensor(2088), tensor(9068), tensor(9948), tensor(5866), tensor(9762), tensor(2113), tensor(8422), tensor(1611), tensor(471), tensor(5), tensor(303), tensor(7150), tensor(4525), tensor(3417), tensor(3540), tensor(47), tensor(9673), tensor(9953), tensor(204), tensor(2901), tensor(8529), tensor(9930), tensor(7972), tensor(9861), tensor(9253), tensor(1295), tensor(9981), tensor(8434), tensor(163), tensor(1438), tensor(1186), tensor(7081), tensor(9959), tensor(8), tensor(9954), tensor(9735), tensor(1146), tensor(2599), tensor(6896), tensor(56), tensor(1704), tensor(9926), tensor(9955), tensor(9997), tensor(9718), tensor(9779), tensor(2001), tensor(130), tensor(8476), tensor(9992), tensor(9958), tensor(897

In [41]:
from tokenization.tokenization import  tokenizer
destination_token = 10
val, ind = ov_circuit[destination_token].topk(10)
print("Destination token: ", tokenizer.decode(destination_token))
for v, i in zip(val, ind):
    print(f"{v:.4f} | {tokenizer.decode(i.item())}")

val, ind = ov_circuit.T[destination_token].topk(10)
print("Destination token: ", tokenizer.decode(destination_token))
for v, i in zip(val, ind):
    print(f"{v:.4f} | {tokenizer.decode(i.item())}")


Destination token:   The
9.6782 |  agreed
9.3309 |  their
8.9546 |  voice
8.2828 |  bark
8.2175 |  do
8.1109 |  home
7.8079 |  idea
7.4755 |  smell
7.4372 |  little
7.2488 |  talk
Destination token:   The
8.5323 |  were
8.0428 |  looked
7.4233 |  look
6.9625 |  build
6.8243 |  threw
6.6537 |  throw
6.5878 |  looking
6.4921 |  all
6.0935 |  up
6.0557 |  bit


In [42]:
from tokenization.tokenization import  tokenizer
destination_token = 2
val, ind = ov_circuit[destination_token].topk(10)
print("Destination token: ", tokenizer.decode(destination_token))
for v, i in zip(val, ind):
    print(f"{v:.4f} | {tokenizer.decode(i.item())}")

val, ind = ov_circuit.T[destination_token].topk(10)
print("Destination token: ", tokenizer.decode(destination_token))
for v, i in zip(val, ind):
    print(f"{v:.4f} | {tokenizer.decode(i.item())}")

Destination token:   the
10.6877 | 's
10.4061 |  its
9.9053 |  his
9.8758 |  her
9.6531 |  agreed
9.3208 |  twin
8.9878 |  bark
8.9015 |  tried
8.8043 |  pur
8.2855 |  their
Destination token:   the
10.1610 | 's
10.0965 |  his
9.6766 |  your
9.2617 |  her
8.9085 |  His
8.5088 |  up
8.2135 |  Her
7.6320 |  my
7.2988 |  the
6.9803 |  their


In [38]:
from tokenization.tokenization import  tokenizer
destination_token = 5
bil_qk_circuit = qk_circuit.square()
val, ind = bil_qk_circuit.T[destination_token].topk(10)
# val, ind = bil_qk_circuit[destination_token].topk(10)
print("Destination token: ", tokenizer.decode(destination_token))
for v, i in zip(val, ind):
    o_val, o_ind = ov_circuit[i].topk(10)
    # o_val, o_ind = ov_circuit.T[i].topk(10)
    text_print = f"{v:.4f} | {tokenizer.decode(i.item())}"
    for o_v, o_i in zip(o_val, o_ind):
        text_print += f" {o_v:.4f} | {tokenizer.decode(o_i.item())}"
    print(text_print)

Destination token:   to
69.0686 | 's 14.4123 |  look 11.5252 |  mouth 10.4861 |  low 10.4163 |  back 10.3064 |  by 10.2232 |  smell 10.1610 |  the 9.9262 |  wing 9.8162 |  stuck 9.5438 |  engine
38.3394 |  other 10.5961 |  Every 9.3460 |  called 8.9469 |  enough 8.6660 |  asked 8.6202 | When 8.5336 | s 8.3339 |  allowed 7.1777 | As 7.1692 |  allow 6.9283 |  named
35.6107 |  all 10.1174 |  because 10.0460 | ed 9.7023 |  let 8.8875 |  never 8.8856 |  enjoyed 8.7785 |  spent 8.7045 |  all 8.4376 |  scared 8.3462 |  spend 8.2874 |  spring
34.4136 |  dad 6.8226 |  frightened 6.5416 |  world 6.0228 |  scared 6.0078 |  from 5.7346 |  nervous 5.5589 | um 5.4191 |  town 5.3892 |  but 5.1584 |  meeting 5.1267 |  kitchen
33.3804 |  back 6.6792 |  world 6.4928 | His 6.2387 |  sunglasses 5.9841 |  alone 5.7788 |  wrong 5.7467 |  storm 5.7402 |  from 5.7359 | OK 5.5455 | Mr 5.4005 |  Mr
33.1870 |  little 9.4671 | [BEGIN] 8.3228 |  first 8.1457 |  favorite 8.1279 |  best 7.2677 |  They 7.0401 |  borr

In [None]:
token = [0]
text = tokenizer.decode(token)
text

'.'

In [23]:
from tinymodel import TinyModel, tokenizer

lm = TinyModel()

# for inference
tok_ids, padding_mask = tokenizer(['Once upon a time', 'In the forest'])
logprobs = lm(tok_ids)

# Get SAE/transcoder acts
# See 'SAEs/Transcoders' section for more information.
feature_acts = lm['M1N123'](tok_ids)
all_feat_acts = lm['M2'](tok_ids)

# Generation
lm.generate('Once upon a time, Ada was happily walking through a magical forest with')

# To decode tok_ids you can use
tokenizer.decode(tok_ids)

ImportError: cannot import name 'TinyModel' from 'tinymodel' (/venv/main/lib/python3.10/site-packages/tinymodel/__init__.py)

In [None]:
from tokenization.tokenization import tokenizer

10000