## Init

In [23]:
import numpy as np
from tabulate import tabulate
import torch
from torch import nn
from copy import deepcopy
from transformers import (AutoModelForCausalLM, AutoTokenizer)
from utils import top_tokens, top_matrix_tokens

## Extract Parameters

In [8]:
gpt = AutoModelForCausalLM.from_pretrained("gpt2-medium")
gpt_tokenizer = tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
emb = gpt.get_output_embeddings().weight.data.T.detach()

num_layers = gpt.config.n_layer
num_heads = gpt.config.n_head
hidden_dim = gpt.config.n_ctx
head_size = hidden_dim // num_heads

K = torch.cat([gpt.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
                           for j in range(num_layers)]).detach()
V = torch.cat([gpt.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
                           for j in range(num_layers)]).detach()
W_Q, W_K, W_V = torch.cat([gpt.get_parameter(f"transformer.h.{j}.attn.c_attn.weight") 
                           for j in range(num_layers)]).detach().chunk(3, dim=-1)
W_O = torch.cat([gpt.get_parameter(f"transformer.h.{j}.attn.c_proj.weight") 
                           for j in range(num_layers)]).detach()

K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)

## Interpret Parameters

### FF

In [120]:
E1 = emb
E2 = E1.T #(E1 + 3 * torch.eye(*E1.shape)).pinverse()

In [164]:
i1, i2 = np.random.randint(24), np.random.randint(4096)
print(i1, i2)
f = lambda x: x # - x.mean()
print(tabulate([*zip(
    top_tokens(f(K_heads[i1, i2]) @ E2.T, k=200, tokenizer=tokenizer),
    top_tokens(f(V_heads[i1, i2]) @ E1, k=200, tokenizer=tokenizer),
    top_tokens(-f(K_heads[i1, i2]) @ E2.T, k=200, tokenizer=tokenizer),
    top_tokens(-f(V_heads[i1, i2]) @ E1, k=200, tokenizer=tokenizer),
)], headers=['K', 'V', '-K', '-V']))

0 835
K               V                   -K                  -V
--------------  ------------------  ------------------  --------------
Views           #qus                Dum                 #odied
Hobby           #JM                 wiret               Â®
#agement        Pyr                 #milo               #False
#Own            Jav                 mustard             #ãĥķãĤ©
views           #agall              extrad              redistributed
OWN             #hra                theorem             #locking
#own            Nab                 #aundering          ventured
#Community      markers             laundering          #Economic
#Activity       #beck               Rib                 #rosis
hobby           #FUL                hypot               scarce
#Self           #azaki              swe                 poppy
Owner           #tale               #htar               #bane
#ession         drag                tuberculosis        #argo
Owners          rally               d

### Attention Parameters