## Init

In [28]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter

## Helper Functions

In [2]:
def convert_to_tokens(indices, tokenizer, extended, extra_values_pos, strip=True):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else 
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]") 
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    return res


def top_tokens(v, k=100, tokenizer=None, only_english=False, only_ascii=False, with_values=False, 
               exclude_brackets=False, extended=True, extra_values=None):
    if tokenizer is None:
        tokenizer = my_tokenizer
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices = [key for val, key in tokenizer.vocab.items() if not val.strip('Ġ').isascii()]
    if only_english: 
        ignored_indices =[key for val, key in tokenizer.vocab.items() if not (val.strip('Ġ').isascii() and val.strip('Ġ[]').isalnum())]
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res


def top_matrix_tokens(mat, k=100, tokenizer=None, rel_thresh=None, thresh=None, 
                      sample_entries=10000, alphabetical=False, only_english=False,
                      exclude_brackets=False, with_values=True, extended=True):
    if tokenizer is None:
        tokenizer = my_tokenizer
    mat = deepcopy(mat)
    ignored_indices = []
    if only_english:
        ignored_indices = [key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.strip('[]').isalnum())]
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)
    mat[ignored_indices, :] = -np.inf
    mat[:, ignored_indices] = -np.inf
    cond = torch.ones_like(mat).bool()
    if rel_thresh:
        cond &= (mat > torch.max(mat) * rel_thresh)
    if thresh:
        cond &= (mat > thresh)
    entries = torch.nonzero(cond)
    if sample_entries:
        entries = entries[np.random.randint(len(torch.nonzero(cond)), size=sample_entries)]
    res_indices = sorted(entries, 
                         key=lambda x: x[0] if alphabetical else -mat[x[0], x[1]])
    res = [*map(partial(convert_to_tokens, extended=extended, tokenizer=tokenizer), res_indices)]
            
    if with_values:
        res_ = []
        for (x1, x2), (i1, i2) in zip(res, res_indices):
            res_.append((x1, x2, mat[i1][i2].item()))
        res = res_    
    return res

## Extract Weights

In [5]:
gpt = AutoModelForCausalLM.from_pretrained("gpt2-medium")
tokenizer = my_tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
emb = gpt.get_output_embeddings().weight.data.T.detach()

num_layers = gpt.config.n_layer
num_heads = gpt.config.n_head
hidden_dim = gpt.config.n_ctx
head_size = hidden_dim // num_heads

K = torch.cat([gpt.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
                           for j in range(num_layers)]).detach()
V = torch.cat([gpt.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
                           for j in range(num_layers)]).detach()

W_Q, W_K, W_V = torch.cat([gpt.get_parameter(f"transformer.h.{j}.attn.c_attn.weight") 
                           for j in range(num_layers)]).detach().chunk(3, dim=-1)
W_O = torch.cat([gpt.get_parameter(f"transformer.h.{j}.attn.c_proj.weight") 
                           for j in range(num_layers)]).detach()

K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)

In [10]:
# to normalize the embeddings
K_freq = (F.relu((K_heads @ emb).reshape(-1, len(tokenizer))).mean(0) + 1e-9)
V_freq = (F.relu((V_heads @ emb).reshape(-1, len(tokenizer))).mean(0) + 1e-9)
W_O_freq = F.relu(abs(W_O_heads @ emb).mean([0, 1, 2]) + 1e-9)
W_V_freq = F.relu(abs(W_V_heads.transpose(-1, -2) @ emb).mean([0, 1, 2]) + 1e-9)
W_Q_freq = F.relu(abs(W_Q_heads.transpose(-1, -2) @ emb).mean([0, 1, 2]) + 1e-9)
W_K_freq = F.relu(abs(W_K_heads.transpose(-1, -2) @ emb).mean([0, 1, 2]) + 1e-9)

## Interpretation

### FF Keys & Values

In [17]:
i1, i2 = 19, 170
print(i1, i2)
print(tabulate([*zip(
    top_tokens((K_heads[i1, i2]) @ emb.float() / K_freq, k=200),
    top_tokens((V_heads[i1, i2]) @ emb.float() / V_freq, k=200),
)]))

19 170
---------------  --------------
anywhere         outdoors
everywhere       near
guidance         north
locally          HERE
bombs            south
wherever         supplies
plac             northeast
Belfast          here
instruction      southeast
refuge           east
shelter          nearby
jihad            southwest
elsewhere        Here
#jet             somewhere
bombed           indoors
roaming          northwest
in               #near
bomb             undecided
#sel             herein
bombing          #850
#opl             ultrasound
sitting          west
Bomb             anywhere
#cl              among
#acle            outside
forming          statewide
banners          elsewhere
arriving         northeastern
MB               stationary
Factory          #lifting
pra              #Near
#cil             locally
travelling       THERE
inspiration      upstream
assistance       instructors
occupying        Interstate
Select           concealed
camouflage       Somewhere
pea

### Attention Weights Interpretation

**Choose the layer and head here**:

In [42]:
i1, i2 = np.random.randint(num_layers), np.random.randint(num_heads)
i1, i2 = 23, 9

Then run

In [43]:
W_V_tmp, W_O_tmp = W_V_heads[i1, i2, :], W_O_heads[i1, i2]
tmp = ((emb / W_V_freq).T @ (W_V_tmp @ W_O_tmp) @ (emb / W_O_freq))

W_Q_tmp, W_K_tmp = W_Q_heads[i1, i2, :], W_K_heads[i1, i2, :]
tmp2 = ((emb / W_Q_freq).T @ (W_Q_tmp @ W_K_tmp.T) @ (emb / W_K_freq))

In [44]:
exclude_same = False
exclude_fuzzy = False
reverse_list = True

#### $W_{VO}$ Interpretation

`th` is the threshold to ease the work of the `torch.topk` operation. This is used to avoid computing the `topk` over the entire matrix.

In [47]:
th = 200
th_max = np.inf

Check how many relevant entries are above `th`. This is important to minimize work done by `topk`. The lower the better (as long as its less that the `k` of the `topk`)

In [48]:
torch.nonzero((tmp > th) & (tmp < th_max)).shape

torch.Size([136, 2])

In [49]:
remaining_pos = torch.nonzero((tmp > th) & (tmp < th_max)).tolist()
if exclude_same:
    remaining_pos = [*filter(
        lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(), 
        remaining_pos)]
if exclude_fuzzy:
    remaining_pos = [*filter(
        lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()), 
        remaining_pos)]
    
pos_val = tmp2[[*zip(*remaining_pos)]]
good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:100]]
good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]
# good_cells[:100]
# list(zip(good_tokens[0], good_tokens[1]))

In [50]:
good_cells_best

[('ortunately', ' inf'),
 ('hyde', ' De'),
 ('rored', ' Mir'),
 ('utsch', 'De'),
 ('onomous', ' aut'),
 ('KER', ' Ac'),
 ('hesion', 'Ad'),
 ('perors', 'Em'),
 (' facto', 'De'),
 ('izabeth', ' Sel'),
 ('elaide', ' Ad'),
 ('hesion', ' Ad'),
 ('obos', ' inf'),
 ('urnal', ' di'),
 ('forcer', ' en'),
 ('ograph', ' aut'),
 ('urnal', ' Di'),
 ('isbury', ' Sal'),
 ('cend', ' Des'),
 ('amous', ' inf'),
 ('ortunately', ' unf'),
 ('rollment', ' en'),
 (' Portug', ' Em'),
 ('oustic', ' Ac'),
 ('umn', ' aut'),
 ('isbury', 'Sal'),
 ('legates', ' De'),
 ('bris', ' Sem'),
 ('lishes', ' Ab'),
 ('izontal', ' Hor'),
 ('azar', ' Sal'),
 ('hematically', ' Dal'),
 ('opted', ' Ad'),
 ('escal', 'De'),
 (' instincts', ' Ab'),
 ('phasis', ' Em'),
 ('opsis', ' Syn'),
 ('leted', 'De'),
 ('duction', ' Ab'),
 ('requent', ' inf'),
 ('anoia', ' Par'),
 ('idelity', ' inf'),
 ('ideo', ' Ir'),
 ('legate', ' De'),
 ('aturated', ' Ter'),
 (' Ad', 'Ad'),
 ('mercial', ' Sy'),
 ('ificent', ' Mun'),
 ('Ad', ' Ad'),
 (' facto'

#### $W_{QK}$ Interpretation

Again, `th2` is the threshold to ease the work of the `torch.topk` operation.

In [51]:
th2 = 20
th_max2 = np.inf

Here again, we count the number of entries above `th2`

In [52]:
torch.nonzero((tmp2 > th2) & (tmp2 < th_max2)).shape

torch.Size([29769, 2])

In [53]:
remaining_pos = torch.nonzero((tmp2 > th2) & (tmp2 < th_max2)).tolist()
if exclude_same:
    remaining_pos = [*filter(
        lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(), 
        remaining_pos)]
if exclude_fuzzy:
    remaining_pos = [*filter(
        lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()), 
        remaining_pos)]
    
pos_val = tmp2[[*zip(*remaining_pos)]]
good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:100]]
good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]
# good_cells[:100]
# list(zip(good_tokens[0], good_tokens[1]))

In [54]:
good_cells_best

[('-', ' bicycle'),
 ('-', ' platinum'),
 ('-', ' Velocity'),
 ('-', 'estine'),
 ('-', 'acket'),
 ('-', ' grotesque'),
 ('-', ' expenditures'),
 ('-', ' Nah'),
 ('-', ' festivities'),
 ('-', ' scrutin'),
 ('-', ' =>'),
 ('-', '�'),
 ('-', ' penalties'),
 ('-', '789'),
 ('-', 'ozy'),
 ('-', ' Size'),
 ('-', 'Socket'),
 ('-', ' opening'),
 ('-', ' withdraw'),
 ('-', 'Commun'),
 ('-', 'Although'),
 ('-', ' lobster'),
 ('-', 'hp'),
 ('-', ' rhy'),
 ('-', 'Utah'),
 ('-', ' Jun'),
 ('\xad', 'transform'),
 ('-', 'Beh'),
 ('-', ' Hots'),
 ('-', ' dr'),
 ('-', 'Monitor'),
 ('-', '274'),
 ('-', ' morning'),
 ('-', ' Caller'),
 ('-', ' hunting'),
 ('-', ' config'),
 ('-', ' reunited'),
 ('-', 'neck'),
 (' and', ' interstate'),
 ('-', ' jeopardy'),
 ('-', ' 200'),
 ('-', '517'),
 ('-', ' justifies'),
 ('-', ' singles'),
 ('-', '672'),
 ('-', ' gunshot'),
 ('-', 'innacle'),
 ('-', 'On'),
 ('-', ' saves'),
 ('-', '129'),
 ('-', ' Cook'),
 ('-', ' overpowered'),
 ('-', ' Ten'),
 ('-', 'Ry'),
 ('-', '