## Init

In [1]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter

## Helper Functions

In [2]:
ALNUM_CHARSET = set('aąbcćdeęfghijklmnńoópqrsśtuvwxyzźżAĄBCĆDEĘFGHIJKLMNŃOÓPQRSŚTUVWXYZŹŻ0123456789')

def convert_to_tokens(indices, tokenizer, extended=False, extra_values_pos=None, strip=True):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else 
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]") 
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    return res


def top_tokens(v, k=100, tokenizer=None, only_alnum=False, only_ascii=True, with_values=False, 
               exclude_brackets=False, extended=True, extra_values=None, only_from_list=None):
    if tokenizer is None:
        tokenizer = my_tokenizer
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not val.strip('Ġ▁').isascii()])
    if only_alnum: 
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not (set(val.strip('Ġ▁[] ')) <= ALNUM_CHARSET)])
    if only_from_list:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if val.strip('Ġ▁ ').lower() not in only_from_list])
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)
        
    ignored_indices = list(set(ignored_indices))
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res

## Extract Weights

In [3]:
model = AutoModel.from_pretrained("prajjwal1/bert-medium")
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-medium")

# Extract embedding weights
emb = model.get_input_embeddings().weight.data.T.detach()

# Get model configuration
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
hidden_dim = model.config.hidden_size
head_size = hidden_dim // num_heads

# Extract key and value projection matrices from the feed-forward layers
K = torch.cat([model.get_parameter(f"encoder.layer.{j}.intermediate.dense.weight").T
               for j in range(num_layers)]).detach()
V = torch.cat([model.get_parameter(f"encoder.layer.{j}.output.dense.weight")
               for j in range(num_layers)]).detach()

# Extract self-attention parameters
W_Q = torch.cat([model.get_parameter(f"encoder.layer.{j}.attention.self.query.weight")
                 for j in range(num_layers)]).detach()
W_K = torch.cat([model.get_parameter(f"encoder.layer.{j}.attention.self.key.weight")
                 for j in range(num_layers)]).detach()
W_V = torch.cat([model.get_parameter(f"encoder.layer.{j}.attention.self.value.weight")
                 for j in range(num_layers)]).detach()
W_O = torch.cat([model.get_parameter(f"encoder.layer.{j}.attention.output.dense.weight")
                 for j in range(num_layers)]).detach()


In [4]:
K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)

In [5]:
emb_inv = emb.T

## Interpretation

#### Alternative I: No Token List

In [6]:
tokens_list = set()

#### Alternative II: Can Load Token List from IMDB

In [7]:
from datasets import load_dataset

In [8]:
imdb = load_dataset(path = 'json', data_files = '../task_1/data/full_text_classification.jsonl')['train']['text']
# imdb = load_dataset('imdb')['train']['text'][:10]

In [9]:
max_tokens_num = None

In [10]:
if max_tokens_num is None:
    tokens_list = set()
    for txt in tqdm(imdb):
        tokens_list = tokens_list.union(set(tokenizer.tokenize(txt)))
else:
    tokens_list = Counter()
    for txt in tqdm(imdb):
        tokens_list.update(set(tokenizer.tokenize(txt)))
    tokens_list = map(lambda x: x[0], tokens_list.most_common(max_tokens_num))
    

100%|██████████| 4441/4441 [00:00<00:00, 6235.98it/s]


In [11]:
tokens_list = set([*map(lambda x: x.strip('Ġ▁').lower(), tokens_list)])

### FF Keys & Values

In [12]:
# i1, i2 = 11, 907
# # i1, i2 = np.random.randint(num_layers), np.random.randint(d_int)

# print(i1, i2)
# print(tabulate([*zip(
#     top_tokens((K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
#     top_tokens((V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
#     # top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
#     # top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
# )], headers=['K', 'V', '-K', '-V']))

### Attention Weights Interpretation

In [13]:
def approx_topk(mat, min_k=500, max_k=250_000, th0=10, max_iters=10, verbose=False):
    _get_actual_k = lambda th, th_max: torch.nonzero((mat > th) & (mat < th_max)).shape[0]
    th_max = np.inf
    left, right = 0, th0 
    while True:
        actual_k = _get_actual_k(right, th_max)
        if verbose:
            print(f"one more iteration. {actual_k}")
        if actual_k <= max_k:
            break
        left, right = right, right * 2
    if min_k <= actual_k <= max_k:
        th = right
    else:
        for _ in range(max_iters):
            mid = (left + right) / 2
            actual_k = _get_actual_k(mid, th_max)
            if verbose:
                print(f"one more iteration. {actual_k}")
            if min_k <= actual_k <= max_k:
                break
            if actual_k > max_k:
                left = mid
            else:
                right = mid
        th = mid
    return torch.nonzero((mat > th) & (mat < th_max)).tolist()

# def get_top_entries(tmp, all_high_pos, only_ascii=False, only_alnum=False, exclude_same=False, exclude_fuzzy=False, tokens_list=None):
#     remaining_pos = all_high_pos
#     if only_ascii:
#         remaining_pos = [*filter(
#             lambda x: (tokenizer.decode(x[0]).strip('Ġ▁').isascii() and tokenizer.decode(x[1]).strip('Ġ▁').isascii()), 
#             remaining_pos)]
#     if only_alnum:
#         remaining_pos = [*filter(
#             lambda x: (tokenizer.decode(x[0]).strip('Ġ▁ ').isalnum() and tokenizer.decode(x[1]).strip('Ġ▁ ').isalnum()), 
#             remaining_pos)]
#     if exclude_same:
#         remaining_pos = [*filter(
#             lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(), 
#             remaining_pos)]
#     if exclude_fuzzy:
#         remaining_pos = [*filter(
#             lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()), 
#             remaining_pos)]
#     if tokens_list:
#         remaining_pos = [*filter(
#             lambda x: ((tokenizer.decode(x[0]).strip('Ġ▁').lower().strip() in tokens_list) and 
#                        (tokenizer.decode(x[1]).strip('Ġ▁').lower().strip() in tokens_list)), 
#             remaining_pos)]

#     pos_val = tmp[[*zip(*remaining_pos)]]
#     good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
#     good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
#     remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:50]]
#     good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]
#     # good_cells[:100]
#     # list(zip(good_tokens[0], good_tokens[1]))
#     return good_cells_best



def get_top_entries(tmp, all_high_pos, only_ascii=False, only_alnum=False, exclude_same=False, exclude_fuzzy=False, tokens_list=None):
    remaining_pos = all_high_pos

    # Filtrowanie tokenów ASCII
    if only_ascii:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode([x[0]]).strip('##').isascii() and tokenizer.decode([x[1]]).strip('##').strip('Ġ▁ ').isascii()), 
            remaining_pos)]
    
    # Filtrowanie tokenów alfanumerycznych
    if only_alnum:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode([x[0]]).strip('##').isalnum() and tokenizer.decode([x[1]]).strip('##').strip('Ġ▁ ').isalnum()), 
            remaining_pos)]
    
    # Wykluczanie identycznych tokenów
    if exclude_same:
        remaining_pos = [*filter(
            lambda x: tokenizer.decode([x[0]]).lower().strip() != tokenizer.decode([x[1]]).lower().strip(), 
            remaining_pos)]
    
    # Wykluczanie tokenów o podobnym brzmieniu
    if exclude_fuzzy:
        remaining_pos = [*filter(
            lambda x: not _fuzzy_eq(tokenizer.decode([x[0]]).lower().strip(), tokenizer.decode([x[1]]).lower().strip()), 
            remaining_pos)]
    
    # Ograniczenie do wybranej listy tokenów
    if tokens_list:
        remaining_pos = [*filter(
            lambda x: ((tokenizer.decode([x[0]]).strip('##').lower().strip() in tokens_list) and 
                       (tokenizer.decode([x[1]]).strip('##').lower().strip() in tokens_list)), 
            remaining_pos)]

    # Uzyskanie wartości z macierzy `tmp`
    pos_val = tmp[[*zip(*remaining_pos)]]
    
    # Dekodowanie najlepszych par tokenów
    good_cells = [*map(lambda x: (tokenizer.decode([x[0]]), tokenizer.decode([x[1]])), remaining_pos)]
    
    # Podsumowanie najczęściej występujących tokenów
    good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
    
    # Wybranie najlepszych pozycji
    remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:50]]
    good_cells_best = [*map(lambda x: (tokenizer.decode([x[0]]), tokenizer.decode([x[1]])), remaining_pos_best)]
    
    return good_cells_best


#### $W_{VO}$ Interpretation

Choose **layer** and **head** here:

In [14]:
# import json
# for i in range(8):
#     for j in range(8):
#         W_V_tmp, W_O_tmp = W_V_heads[i, j, :], W_O_heads[i, j]
#         tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)
#         all_high_pos = approx_topk(tmp, th0=1, verbose=True)
        
#         exclude_same = False
#         reverse_list = False
#         only_ascii = True
#         only_alnum = False
        
#         W_VO = get_top_entries(tmp, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum, exclude_same=exclude_same, 
#                 tokens_list=tokens_list)
#         with open("WVO_BERT/WVO.jsonl", "a") as final:
#             json.dump({'layer': i, 'head': j, 'W_VO': W_VO}, final)
#             final.write('\n')

#### $W_{QK}$ Interpretation

Choose **layer** and **head** here:

In [15]:
import json
for i in range(8):
    for j in range(8):
        W_Q_tmp, W_K_tmp = W_Q_heads[i, j, :], W_K_heads[i, j, :]
        tmp = (emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T)
        all_high_pos = approx_topk(tmp, th0=1, verbose=True)
        
        exclude_same = False
        reverse_list = False
        only_ascii = True
        only_alnum = True
        try:
            W_QK = get_top_entries(tmp, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum, exclude_same=exclude_same, 
                tokens_list=tokens_list)
        except:
            continue
        with open("WVO_BERT/WQK.jsonl", "a") as final:
            json.dump({'layer': i, 'head': j, 'W_QK': W_QK}, final)
            final.write('\n')

one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 101
one more iteration. 525496
one more iteration. 6556
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 910
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 2
one more iteration. 65002
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 96624
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 9
one more iteration. 203402
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 51449
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 510
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more iteration. 36
one more iteration. 329655
one more iteration. 3945
one more iteration. 0
one more iteration. 0
one more iteration. 0
one more itera

In [16]:
print(tmp)

tensor([[ 0.1008,  0.0485,  0.0314,  ...,  0.0912,  0.0708,  0.0260],
        [ 0.0852,  0.0711,  0.0352,  ...,  0.0757,  0.0629,  0.0012],
        [ 0.0564,  0.0203,  0.0391,  ...,  0.0524,  0.0491,  0.0267],
        ...,
        [ 0.1170,  0.0562,  0.0232,  ...,  0.1018,  0.0542,  0.0288],
        [ 0.0741,  0.0257,  0.0464,  ...,  0.0397,  0.0542,  0.0065],
        [ 0.0105, -0.0156,  0.0190,  ...,  0.0107,  0.0262,  0.0832]])


In [17]:
print(all_high_pos)

[[0, 16377], [0, 18341], [0, 27573], [0, 28592], [0, 29060], [0, 29707], [1, 20217], [1, 20350], [1, 20499], [1, 21556], [1, 25731], [1, 29707], [9, 8787], [9, 18163], [9, 18994], [11, 8082], [11, 16571], [11, 19025], [11, 20499], [11, 22499], [16, 29707], [17, 17499], [17, 23633], [17, 29707], [18, 24133], [18, 24991], [21, 23195], [25, 16026], [25, 21272], [25, 23924], [25, 28592], [26, 20554], [26, 24409], [27, 20771], [27, 22049], [27, 26292], [33, 8787], [33, 14204], [33, 15396], [33, 16026], [33, 18060], [33, 19531], [33, 24991], [33, 28592], [33, 29501], [33, 29707], [36, 23195], [36, 29707], [38, 5730], [38, 21556], [38, 27098], [38, 27752], [41, 24607], [42, 29707], [45, 23633], [58, 16026], [58, 20554], [58, 20917], [58, 21272], [58, 24118], [58, 24985], [58, 24991], [58, 27742], [58, 29707], [61, 29707], [65, 20917], [70, 18331], [70, 29501], [70, 29707], [71, 29707], [71, 30086], [77, 24610], [82, 16026], [82, 16236], [82, 25587], [90, 28832], [94, 24607], [96, 29707], [107

## Plots

*We thank Ohad Rubin for the idea of providing plots for better visualizations!*

In [18]:
i1, i2 = 6, 2152

In [19]:
from sklearn.manifold import TSNE
import pandas as pd
import plotly.express as px

In [20]:
def _calc_df(vector, k, coef, normalized, tokenizer):
    mat = emb
    if normalized:
        mat = F.normalize(mat, dim=-1)
    dot = vector @ mat
    sol = torch.topk(dot * coef, k=k).indices # np.argsort(dot * coef)[-k:]
    pattern = mat[:, sol].T
    scores = coef * dot[sol]
    # labels = tokenizer.batch_decode(sol)
    labels = convert_to_tokens(sol, tokenizer=tokenizer)
    X_embedded = TSNE(n_components=3,
                  learning_rate=10,
                   init='pca',
                   perplexity=3).fit_transform(pattern)

    df = pd.DataFrame(dict(x=X_embedded.T[0], y=X_embedded.T[1], z=X_embedded.T[2], label=labels, score=scores))
    return df


def plot_embedding_space(vector, is_3d=False, add_text=False, k=100, coef=1, normalized=False, tokenizer=None):
    df = _calc_df(vector, k=k, coef=coef, normalized=normalized, tokenizer=tokenizer)
    kwargs = {}
    scatter_fn = px.scatter
    if add_text:
        kwargs.update({'text': 'label'})
    if is_3d:
        scatter_fn = px.scatter_3d
        kwargs.update({'z': 'z'})
    fig = scatter_fn(
        data_frame=df, 
        x='x', 
        y='y',
        custom_data=["label", "score"],
        color="score", size_max=1, **kwargs)

    fig.update_traces(
        hovertemplate="<br>".join([
            "ColX: %{x}",
            "ColY: %{y}",
            "label: %{customdata[0]}",
            "score: %{customdata[1]}"
        ])
    )
    
    if add_text:
        fig.update_traces(textposition='middle right')
    fig.show()

In [21]:
# plot_embedding_space(K_heads[i1][i2], tokenizer=tokenizer, normalized=False)