## Init

In [1]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter

## Helper Functions

In [2]:
def convert_to_tokens(indices, tokenizer, extended, extra_values_pos, strip=True):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else 
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]") 
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    return res


def top_tokens(v, k=100, tokenizer=None, only_english=False, only_ascii=True, with_values=False, 
               exclude_brackets=False, extended=True, extra_values=None):
    if tokenizer is None:
        tokenizer = my_tokenizer
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices = [key for val, key in tokenizer.vocab.items() if not val.strip('Ġ').isascii()]
    if only_english: 
        ignored_indices =[key for val, key in tokenizer.vocab.items() if not (val.strip('Ġ').isascii() and val.strip('Ġ[]').isalnum())]
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res


def top_matrix_tokens(mat, k=100, tokenizer=None, rel_thresh=None, thresh=None, 
                      sample_entries=10000, alphabetical=True, only_english=False,
                      exclude_brackets=False, with_values=True, extended=True):
    if tokenizer is None:
        tokenizer = my_tokenizer
    mat = deepcopy(mat)
    ignored_indices = []
    if only_english:
        ignored_indices = [key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.strip('[]').isalnum())]
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)
    mat[ignored_indices, :] = -np.inf
    mat[:, ignored_indices] = -np.inf
    cond = torch.ones_like(mat).bool()
    if rel_thresh:
        cond &= (mat > torch.max(mat) * rel_thresh)
    if thresh:
        cond &= (mat > thresh)
    entries = torch.nonzero(cond)
    if sample_entries:
        entries = entries[np.random.randint(len(torch.nonzero(cond)), size=sample_entries)]
    res_indices = sorted(entries, 
                         key=lambda x: x[0] if alphabetical else -mat[x[0], x[1]])
    res = [*map(partial(convert_to_tokens, extended=extended, tokenizer=tokenizer), res_indices)]
            
    if with_values:
        res_ = []
        for (x1, x2), (i1, i2) in zip(res, res_indices):
            res_.append((x1, x2, mat[i1][i2].item()))
        res = res_    
    return res

## Extract Weights

In [3]:
gpt = AutoModelForCausalLM.from_pretrained("gpt2-medium")
tokenizer = my_tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
emb = gpt.get_output_embeddings().weight.data.T.detach()

num_layers = gpt.config.n_layer
num_heads = gpt.config.n_head
hidden_dim = gpt.config.n_embd
head_size = hidden_dim // num_heads

K = torch.cat([gpt.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
                           for j in range(num_layers)]).detach()
V = torch.cat([gpt.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
                           for j in range(num_layers)]).detach()

W_Q, W_K, W_V = torch.cat([gpt.get_parameter(f"transformer.h.{j}.attn.c_attn.weight") 
                           for j in range(num_layers)]).detach().chunk(3, dim=-1)
W_O = torch.cat([gpt.get_parameter(f"transformer.h.{j}.attn.c_proj.weight") 
                           for j in range(num_layers)]).detach()

K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)

## Interpretation

### FF Keys & Values

In [4]:
i1, i2 = 21, 7
print(i1, i2)
print(tabulate([*zip(
    top_tokens((K_heads[i1, i2]) @ emb, k=200),
    top_tokens((V_heads[i1, i2]) @ emb, k=200),
)]))

21 7
----------------  ----------------
#problem          issues
Problems          issue
#Problem          problems
problem           problem
problems          Issues
Problem           #issues
woes              Problems
Trouble           Problem
trouble           #Problem
difficulties      trouble
#issues           Issue
dilemma           #problem
deficiencies      #issue
#blems            difficulties
drawback          troubles
shortcomings      #Issue
Issues            dilemma
defic             woes
proble            #Iss
#undrum           concerns
objection         controversy
failings          confusion
weakness          flaw
weaknesses        dispute
disagreements     Trouble
issues            disputes
objections        dile
pitfalls          proble
predicament       deficiencies
troubles          #blems
#errors           deficiency
dissatisfaction   concern
disagreement      shortcomings
worries           conflicts
disadvantages     pain
#Fail             bug
glitches          co

### Attention Weights Interpretation

#### $W_{VO}$ Interpretation

Choose **layer** and **head** here:

In [5]:
# i1, i2 = np.random.randint(num_layers), np.random.randint(num_heads)
i1, i2 = 18, 2

Then run

In [6]:
W_V_tmp, W_O_tmp = W_V_heads[i1, i2, :], W_O_heads[i1, i2]
tmp = (emb.T @ (W_V_tmp @ W_O_tmp) @ emb)

**Compute interpretation**

`th` is the threshold to ease the work of the `torch.topk` operation. This is used to avoid computing the `topk` over the entire matrix.

In [7]:
th = 25
th_max = np.inf

Check how many relevant entries are above `th`. This is important to minimize work done by `topk`. The lower the better (as long as its less that the `k` of the `topk`)

In [8]:
torch.nonzero((tmp > th) & (tmp < th_max)).shape

torch.Size([45, 2])

In [9]:
exclude_same = False
exclude_fuzzy = False

In [10]:
reverse_list = False
only_ascii = True

In [11]:
remaining_pos = torch.nonzero((tmp > th) & (tmp < th_max)).tolist()
if only_ascii:
    remaining_pos = [*filter(
        lambda x: (tokenizer.decode(x[0]).strip('Ġ').isascii() and tokenizer.decode(x[1]).strip('Ġ').isascii()), 
        remaining_pos)]
if exclude_same:
    remaining_pos = [*filter(
        lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(), 
        remaining_pos)]
if exclude_fuzzy:
    remaining_pos = [*filter(
        lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()), 
        remaining_pos)]
    
pos_val = tmp[[*zip(*remaining_pos)]]
good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:50]]
good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]
# good_cells[:100]
# list(zip(good_tokens[0], good_tokens[1]))

In [12]:
good_cells_best

[(' herself', ' her'),
 (' herself', ' herself'),
 (' herself', ' she'),
 (' herself', ' hers'),
 (' herself', 'She'),
 (' herself', 'she'),
 (' herself', 'Her'),
 (' himself', 'his'),
 (' herself', ' She'),
 (' itself', 'Its'),
 (' herself', ' SHE'),
 (' himself', 'His'),
 (' himself', ' his'),
 (' themselves', 'their'),
 (' herself', ' HER'),
 (' herself', 'her'),
 (' herself', ' Her'),
 (' themselves', 'Their'),
 (' itself', ' Its'),
 (' himself', ' he'),
 (' themselves', ' THEIR'),
 (' himself', ' himself'),
 (' himself', 'He'),
 (' themselves', ' Their'),
 (' himself', 'him'),
 (' itself', ' its'),
 (' themselves', 'They'),
 (' himself', ' HIS'),
 (' Himself', 'His'),
 (' Himself', 'his'),
 (' themselves', 'they'),
 (' themselves', ' their'),
 (' yourselves', ' yourselves'),
 ('uel', 'Its'),
 ('his', 'his'),
 (' himself', ' him'),
 (' themselves', ' themselves'),
 (' themselves', ' theirs'),
 (' himself', ' His'),
 (' themselves', ' They'),
 ('MH', 'Its'),
 (' ITS', 'Its'),
 ('his

#### $W_{QK}$ Interpretation

Choose **layer** and **head** here:

In [13]:
# i1, i2 = np.random.randint(num_layers), np.random.randint(num_heads)
i1, i2 = 20, 3

Then run

In [14]:
W_Q_tmp, W_K_tmp = W_Q_heads[i1, i2, :], W_K_heads[i1, i2, :]
tmp2 = ((emb).T @ (W_Q_tmp @ W_K_tmp.T) @ (emb))

**Compute interpretation**

Again, `th2` is the threshold to ease the work of the `torch.topk` operation.

In [17]:
th2 = 2
th_max2 = np.inf

In [18]:
torch.nonzero((tmp2 > th2) & (tmp2 < th_max2)).shape

torch.Size([46087, 2])

In [19]:
exclude_same = False
exclude_fuzzy = False

In [20]:
reverse_list = False
only_ascii = True

In [21]:
remaining_pos = torch.nonzero((tmp2 > th2) & (tmp2 < th_max2)).tolist()
if only_ascii:
    remaining_pos = [*filter(
        lambda x: (tokenizer.decode(x[0]).strip('Ġ').isascii() and tokenizer.decode(x[1]).strip('Ġ').isascii()), 
        remaining_pos)]
if exclude_same:
    remaining_pos = [*filter(
        lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(), 
        remaining_pos)]
if exclude_fuzzy:
    remaining_pos = [*filter(
        lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()), 
        remaining_pos)]
    
pos_val = tmp2[[*zip(*remaining_pos)]]
good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:50]]
good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]

In [22]:
good_cells_best

[('..."', '..."'),
 (' [', ' ['),
 ('...', '..."'),
 ('..."', '..."'),
 ('..."', ' ['),
 (' [', '..."'),
 ('...', ' ['),
 ('..."', ' ['),
 (' [...]', ' ['),
 ('...', ' [...]'),
 ('..."', ' [...]'),
 (' [', '[/'),
 ('...', ')",'),
 (' [', '%"'),
 (' [', ' [...]'),
 (' [', ')."'),
 (' [', "['"),
 ('...', '%"'),
 (' [', " ['"),
 (' [...]', '..."'),
 ('...', '..."'),
 ('."', '..."'),
 ('...', '..."'),
 ('."', ' ['),
 ('."', '..."'),
 (' [', '..."'),
 ('...', '[/'),
 ('..."', '."'),
 (' [', '\'"'),
 (' [', '!"'),
 ('..."', ' [...]'),
 ('!"', ' ['),
 (' [', ')",'),
 (' [...]', ' [...]'),
 ('...', ')",'),
 (' [', '."['),
 (' [', '.",'),
 (' [', '";'),
 ('ind', 'iph'),
 ('...', '\'"'),
 (' [', '":'),
 (',"', ' ['),
 (' [', ',"'),
 (' [', '),"'),
 (' [', '\',"'),
 ('..."', ')."'),
 ('..."', '..."'),
 ('..."', '..."'),
 ('...', '";'),
 ('...', ' </')]