In [4]:
import torch
# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device:", device)
else:
    device = torch.device("cpu")
    print("CUDA not available, using CPU device:", device)

Using CUDA device: cuda


In [1]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter

In [2]:
ALNUM_CHARSET = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')

def convert_to_tokens(indices, tokenizer, extended=False, extra_values_pos=None, strip=True):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]")
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    return res


def top_tokens(v, k=100, tokenizer=None, only_alnum=False, only_ascii=True, with_values=False,
               exclude_brackets=False, extended=True, extra_values=None, only_from_list=None):
    if tokenizer is None:
        tokenizer = my_tokenizer
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not val.strip('Ġ▁').isascii()])
    if only_alnum:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not (set(val.strip('Ġ▁[] ')) <= ALNUM_CHARSET)])
    if only_from_list:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if val.strip('Ġ▁ ').lower() not in only_from_list])
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)

    ignored_indices = list(set(ignored_indices))
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res

In [3]:


def print_top_tokens(i1, i2, K_heads, V_heads, emb, tokens_list=None):

    print(f'Layer {i1}, Neuron {i2}')
    print(
        tabulate(
            [
                *zip(
                    top_tokens(
                        (K_heads[i1, i2]) @ emb,
                        k=30,
                        only_from_list=tokens_list,
                        only_alnum=False,
                    ),
                    top_tokens(
                        (V_heads[i1, i2]) @ emb,
                        k=30,
                        only_from_list=tokens_list,
                        only_alnum=False,
                    ),
                    top_tokens((-K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list),
                    top_tokens((-V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list),
                )
            ],
            headers=["K", "V", "-K", "-V"],
        )
    )

In [5]:
exclude_same = False
reverse_list = False
only_ascii = True
only_alnum = False

def approx_topk(mat, min_k=500, max_k=50_000, th0=10, max_iters=10, verbose=False):
    def _get_actual_k(th, th_max):
        # Split the computation into chunks to reduce memory usage
        chunk_size = 1024  # Adjust this value based on your GPU memory
        num_chunks = (mat.shape[0] + chunk_size - 1) // chunk_size
        actual_k = 0
        for i in range(num_chunks):
            start = i * chunk_size
            end = min(start + chunk_size, mat.shape[0])
            chunk = mat[start:end]
            actual_k += torch.nonzero((chunk > th) & (chunk < th_max)).shape[0]
        return actual_k

    th_max = np.inf
    left, right = 0, th0
    while True:
        actual_k = _get_actual_k(right, th_max)
        if verbose:
            print(f"one more iteration. {actual_k}")
        if actual_k <= max_k:
            break
        left, right = right, right * 2
    if min_k <= actual_k <= max_k:
        th = right
    else:
        for _ in range(max_iters):
            mid = (left + right) / 2
            actual_k = _get_actual_k(mid, th_max)
            if verbose:
                print(f"one more iteration. {actual_k}")
            if min_k <= actual_k <= max_k:
                break
            if actual_k > max_k:
                left = mid
            else:
                right = mid
        th = mid

    # Get the nonzero indices in chunks
    all_indices = []
    chunk_size = 1024  # Adjust this value based on your GPU memory
    num_chunks = (mat.shape[0] + chunk_size - 1) // chunk_size
    for i in range(num_chunks):
        start = i * chunk_size
        end = min(start + chunk_size, mat.shape[0])
        chunk = mat[start:end]
        indices = torch.nonzero((chunk > th) & (chunk < th_max))
        # Adjust indices to reflect the original matrix
        indices[:, 0] += start
        all_indices.extend(indices.tolist())

    return all_indices
def get_top_entries(tmp, all_high_pos, only_ascii=False, only_alnum=False, exclude_same=False, exclude_fuzzy=False, tokens_list=None):
    remaining_pos = all_high_pos
    if only_ascii:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁').isascii() and tokenizer.decode(x[1]).strip('Ġ▁').isascii()),
            remaining_pos)]
    if only_alnum:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁ ').isalnum() and tokenizer.decode(x[1]).strip('Ġ▁ ').isalnum()),
            remaining_pos)]
    if exclude_same:
        remaining_pos = [*filter(
            lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(),
            remaining_pos)]
    if exclude_fuzzy:
        remaining_pos = [*filter(
            lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()),
            remaining_pos)]
    if tokens_list:
        remaining_pos = [*filter(
            lambda x: ((tokenizer.decode(x[0]).strip('Ġ▁').lower().strip() in tokens_list) and
                       (tokenizer.decode(x[1]).strip('Ġ▁').lower().strip() in tokens_list)),
            remaining_pos)]

    pos_val = tmp[[*zip(*remaining_pos)]]
    good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
    good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
    remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val).cpu()[:50]]
    good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]
    # good_cells[:100]
    # list(zip(good_tokens[0], good_tokens[1]))
    return good_cells_best

#GPT2

In [3]:
model = AutoModelForCausalLM.from_pretrained("sdadas/polish-gpt2-medium")
tokenizer = my_tokenizer = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-medium")
emb = model.get_output_embeddings().weight.data.T.detach()
emb = model.get_output_embeddings().weight.data.T.detach()
model.config

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


GPT2Config {
  "_attn_implementation_autoset": true,
  "_name_or_path": "sdadas/polish-gpt2-medium",
  "activation_function": "gelu_fast",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 0,
  "embd_pdrop": 0.1,
  "eos_token_id": 2,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 1024,
  "n_head": 16,
  "n_inner": 4096,
  "n_layer": 24,
  "n_positions": 2048,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "tokenizer_class": "GPT2TokenizerFast",
  "torch_dtype": "float32",
  "transformers_version": "4.46.2",
  "use_cache": true,
  "vocab_size": 51200
}

## Ekstrakcja wag z modelu GPT2

In [4]:

num_layers = model.config.n_layer
num_heads = model.config.n_head
hidden_dim = model.config.n_embd
head_size = hidden_dim // num_heads

K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
                           for j in range(num_layers)]).detach()
V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
                           for j in range(num_layers)]).detach()

W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
                           for j in range(num_layers)]).detach().chunk(3, dim=-1)
W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
                           for j in range(num_layers)]).detach()

K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
emb_inv = emb.T
print(emb_inv.shape)
print(f"Layers in the model: {num_layers}")
print(f"Neurons in the model: {model.config.n_inner}")

torch.Size([51200, 1024])
Layers in the model: 24
Neurons in the model: 4096


In [11]:
print(K_heads.shape)
print(V_heads.shape)
print(W_Q_heads.shape)
print(W_K_heads.shape)
print(W_V_heads.shape)
print(W_O_heads.shape)

torch.Size([24, 4096, 1024])
torch.Size([24, 4096, 1024])
torch.Size([24, 16, 1024, 64])
torch.Size([24, 16, 1024, 64])
torch.Size([24, 16, 1024, 64])
torch.Size([24, 16, 64, 1024])


## Interpretacja wag modelu GPT2 na pustej liście tokenów

In [5]:
tokens_list = set()

In [None]:
# Przykład z notebook'a
ilayer = 23
ineuron = 907
print_top_tokens(ilayer, ineuron, K_heads, V_heads, emb, tokens_list)

Layer 23, Neuron 907
K          V              -K          -V
---------  -------------  ----------  ---------
przycu     kody           dotychczas  #bot
zalog      #gory          rodzi       #lot
wylegi     #ei            do          #dzista
#walifik   #zmy           przebie     #remont
przesp     Apo            gatunku     #lee
pochowany  apokali        przez       #wan
#ppe       #128           #ja         #ette
wep        ludy           rodzin      #up
sfinans    #ords          zrazu       #bul
#iss       Cezary         lokalnie    #puszczam
#CS        archa          porywa      spu
#gny       przy           two         zamyka
#zwol      Homo           jeszcze     odstawi
#-).       litera         #jaw        #mont
lock       #pka           na          #beki
skonfisk   Benedykt       wymaga      #laks
#post      narodem        ty          tap
postoju    polityki       Drze        poby
erek       symbolu        pod         posto
#ionu      #lachet        Nie         #wana
#ppo       

nauka

In [None]:
# ostatni neuron w ostatniej warstwie
ilayer = 23
ineuron = 4034
print_top_tokens(ilayer, ineuron, K_heads, V_heads, emb, tokens_list)

Layer 23, Neuron 4034
K             V             -K           -V
------------  ------------  -----------  ---------
matematyka    #fur          #rost        #znania
#dysk         #ki           #ariu        #rda
#pedia        Timo          #obro        instancji
lektura       #kowe         #sina        #zowie
przysm        #kowych       usta         #ida
#najmniej     #kov          #lat         #racja
lektury       obejmuje      #kal         #rzo
nauka         #gary         #hala        #rzysty
naukowy       #Zwie         wzniesi      #rzono
#resort       Bran          #klat        #jasz
#BN           Brand         #ari         #rab
pamie         #+            linii        #datek
najwybit      #hau          #atem        #tta
najciekaw     #ce           Rami         #rada
doktorat      #karzem       nieruchom    #rado
najciekawsze  #Kom          #atu         #alog
zjada         #Karo         #nio         #pierdol
#owaj         Christian     Go           #ira
naukowa       #com          

In [None]:
i1, i2 = 0, 0
print_top_tokens(i1, i2, K_heads, V_heads, emb, tokens_list)

Layer 0, Neuron 0
K        V            -K             -V
-------  -----------  -------------  ---------
#uel     #pus         Tod            #plek
#ul      #ont         #skom          #strzy
#pul     #onek        podzie         #kli
#gul     #sud         #dzie          #wat
#ann     #onych       #sce           #liwie
bel      #TK          #wie           #rgi
#jal     #STWO        #max           uszy
#cht     #oper        #stycznych     #RI
#oll     #oty         rezerwat       #las
#El      #BW          predy          #gre
#ulf     #ARS         #system        #dle
El       #poty        niespodzianki  #nisz
el       #ones        niespodzianka  #genera
dywiden  postoju      #styczne       #inga
#alu     #sbur        #kaza          rozdz
#atak    konspiracji  #rom           usposobie
#itt     #ono         ciekawostki    #rwi
#chol    #pes         #rick          #rodzi
#anga    ognisk       #dac           #wska
#al      #akami       przygod        #isa
#eman    #ematy       #zna           


Afganistan

In [None]:
ilayer = 5
ineuron = 2031
print_top_tokens(ilayer, ineuron, K_heads, V_heads, emb, tokens_list)

Layer 5, Neuron 2031
K               V          -K           -V
--------------  ---------  -----------  -----------
Afganistanu     Rzesz      #reb         emerytalne
Afga            #parcie    fair         publiczne
Jehowy          #tyz       #gle         #szto
Wojny           #stie      #cie         ui
Obywatelskich   #tz        #bal         #smo
zbiorowych      #kop       #patrz       Sub
etni            #tkami     #arter       #szard
getta           #zdro      #gl          publiczna
rdzenia         #lock      moimi        odprawy
Kinga           #tto       mym          powsze
etnicznych      #TS        #bki         #Ur
Afganistanie    #wcze      moim         autorskie
#zacje          #ieu       #pi          Anne
przegranej      uderzeniu  #de          #mion
emerytalnych    rzesz      #lar         integra
wojennych       #sci       pas          ur
Krajowych       #zej       biurko       publicznych
zbiorowe        #zno       #dgo         celne
uznanych        hamowania  #dalej      

ustroje, rządy

In [None]:
ilayer = 11
ineuron = 2076
print_top_tokens(ilayer, ineuron, K_heads, V_heads, emb, tokens_list)

Layer 11, Neuron 2076
K          V              -K          -V
---------  -------------  ----------  ------
Republi    #rowicza       #p          #zyn
monarch    #HO            #wodni      #Miesz
krajami    #hra           #alnie      #stin
Republika  #bela          #mieni      #ZY
Gut        generalnej     #alni       #spon
monarchii  #kacz          #program    Spi
ostra      Mazow          #lacyjny    #ntem
jurysdy    wilki          przy        skopi
poprzek    #szewskiego    #praw       #pet
manifest   #gha           #niczki     #logi
#rump      #ucha          #pla        #skor
#omato     #anowskiego    #Ph         #sne
zatar      Mazu           Ph          #imi
#litary    bliski         #cika       staran
zaciera    #gh            #ha         #LT
kanclerz   #kiewicza      #szlo       pobra
wete       #dztwa         #gr         pigu
republika  #cha           #Zdro       #fr
#omas      #chta          #mistrza    #jemu
Douglas    #szno          #dzki       #Syl
republi    #rwale       

### WVO Interpretation
 polega na analizie macierzy przejścia, identyfikując pary słów, które są silnie ze sobą powiązane.

prefiksy

In [None]:
i1, i2 = 21, 7

# Move tensors to the selected device
W_V_tmp = W_V_heads[i1, i2, :].to(device)
W_O_tmp = W_O_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

# Perform the computation on the device
tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)
all_high_pos = approx_topk(tmp, th0=1, verbose=True)
get_top_entries(tmp, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum,
                exclude_same=exclude_same, tokens_list=None)

[(' ranem', ' nad'),
 ('ornie', ' przez'),
 ('datek', ' nad'),
 ('arl', ' przed'),
 (' pewno', ' na'),
 ('spodziewanie', ' nad'),
 (' razu', ' od'),
 ('miernie', ' nad'),
 (' okazji', ' przy'),
 (' wsk', ' na'),
 (' czele', ' na'),
 ('orne', ' przez'),
 ('godziny', ' nad'),
 ('sione', ' przed'),
 (' ranem', 'nad'),
 ('natural', ' nad'),
 ('przewo', ' nad'),
 (' Duna', ' nad'),
 ('ktory', ' do'),
 ('miar', ' nad'),
 (' dobra', ' dla'),
 (' Jezi', ' nad'),
 ('spodzie', ' nad'),
 (' niedawna', ' do'),
 ('pisie', ' pod'),
 (' wygody', ' dla'),
 ('arcie', ' przed'),
 (' sumie', ' w'),
 ('miernie', 'nad'),
 ('tek', ' pod'),
 ('wcze', ' przed'),
 ('mier', ' nad'),
 ('hala', ' pod'),
 ('ornie', ' przeze'),
 (' uboczu', ' na'),
 (' podstawie', ' na'),
 (' ranem', 'Nad'),
 ('czesne', ' do'),
 ('granicznych', ' nad'),
 ('wiska', ' przez'),
 (' barkach', ' na'),
 ('ornie', 'przez'),
 (' odmiany', ' dla'),
 ('niego', ' przed'),
 (' plecami', ' za'),
 (' dobi', ' na'),
 ('przewodni', ' nad'),
 (' ko

In [15]:
i1, i2 = 20, 13

# Move tensors to the selected device
W_V_tmp = W_V_heads[i1, i2, :].to(device)
W_O_tmp = W_O_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

# Perform the computation on the device
tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)
all_high_pos = approx_topk(
    tmp, th0=1, verbose=True
)  # torch.nonzero((tmp > th) & (tmp < th_max)).tolist()

get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=only_ascii,
    only_alnum=only_alnum,
    exclude_same=exclude_same,
    tokens_list=None,
)

one more iteration. 358
one more iteration. 11868


[(' jej', ' jej'),
 (' jej', ' ona'),
 (' Ciebie', ' Ciebie'),
 (' jej', ' Ona'),
 (' wjej', ' jej'),
 (' Ciebie', ' Tobie'),
 (' ciebie', ' ciebie'),
 (' Twoich', ' Ciebie'),
 (' ich', ' ich'),
 (' ciebie', ' tobie'),
 (' Ciebie', ' Twojej'),
 (' Twojej', ' Ciebie'),
 (' Ciebie', ' Twoich'),
 (' twoimi', ' ciebie'),
 (' wjej', ' ona'),
 (' Twoim', ' Ciebie'),
 (' Ciebie', ' Twoim'),
 (' twoich', ' ciebie'),
 (' Twoje', ' Ciebie'),
 (' twoje', ' ciebie'),
 (' jej', ' wjej'),
 (' Ciebie', ' Twoje'),
 (' Twoich', ' Tobie'),
 (' wjej', ' Ona'),
 (' Tobie', ' Ciebie'),
 (' Twoich', ' Twoich'),
 (' ciebie', ' twoich'),
 (' Twoich', ' Twojej'),
 (' Ci', ' Ciebie'),
 (' ciebie', ' twoimi'),
 (' twoim', ' ciebie'),
 (' Twoim', ' Twoich'),
 (' ciebie', ' twojej'),
 (' was', ' was'),
 (' ciebie', ' twoim'),
 (' twoimi', ' tobie'),
 (' jej', ' niej'),
 (' jej', ' Jej'),
 (' was', ' wami'),
 (' twoich', ' tobie'),
 (' tobie', ' ciebie'),
 (' Twoich', ' Twoje'),
 (' twoich', ' twoich'),
 (' Twojej'

### WQK Interpretation
Interpretacja Wqk polega na analizie macierzy, identyfikując pary słów, które silnie ze sobą współgrają w kontekście mechanizmu uwagi

imiona

In [8]:
i1, i2 = 20, 7

# Move tensors to the selected device
W_Q_tmp = W_Q_heads[i1, i2, :].to(device)
W_K_tmp = W_K_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

tmp = emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T

all_high_pos = approx_topk(tmp, th0=1, verbose=True)

get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=only_ascii,
    only_alnum=only_alnum,
    exclude_same=exclude_same,
    tokens_list=None,
)

one more iteration. 0
one more iteration. 0
one more iteration. 580


[(' Clark', ' Clark'),
 (' Allen', ' Allen'),
 (' Marie', ' Marie'),
 (' Anderson', ' Anderson'),
 (' Haw', 'sau'),
 (' Taylor', ' Taylor'),
 (' Betty', ' Bell'),
 (' Betty', ' Brown'),
 (' Jones', ' Jones'),
 (' Lucas', ' Lucas'),
 (' Sally', 'yn'),
 (' Adams', ' Adams'),
 (' Thom', ' Thom'),
 (' Nicole', ' Anderson'),
 (' Bent', ' Bent'),
 (' Davi', ' Jones'),
 (' Claire', ' Claire'),
 (' Jone', ' Jones'),
 (' Morgan', ' Morgan'),
 (' Laura', 'nifer'),
 (' Abraham', ' Abraham'),
 (' Lee', ' Lee'),
 (' Hanna', ' Hanna'),
 (' Beck', ' Beck'),
 (' Blake', ' Blake'),
 (' Bec', ' Beck'),
 (' Jane', ' Sim'),
 (' Kowalskiego', ' Kowalski'),
 (' Bell', ' Bell'),
 (' Adams', ' Johnson'),
 (' Kaczmarek', ' Kaczmarek'),
 (' Perry', ' Johnson'),
 (' Betty', ' Miller'),
 (' Davis', ' Johnson'),
 (' Julie', ' Julie'),
 (' Anderson', ' Jones'),
 (' Betty', ' Case'),
 (' Daniel', ' Daniel'),
 (' Davis', ' Anderson'),
 (' Carter', ' Carter'),
 (' Hannah', ' Jackson'),
 (' Clark', ' Johnson'),
 (' Dan

pan, panu

In [8]:
i1, i2 = 23, 15

# Move tensors to the selected device
W_Q_tmp = W_Q_heads[i1, i2, :].to(device)
W_K_tmp = W_K_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

tmp = emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T

all_high_pos = approx_topk(tmp, th0=1, verbose=True)

get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=only_ascii,
    only_alnum=only_alnum,
    exclude_same=exclude_same,
    tokens_list=None,
)

one more iteration. 0
one more iteration. 573


[(' panem', ' pan'),
 (' pana', ' pan'),
 (' pan', ' panu'),
 (' pana', ' panu'),
 (' panem', ' panu'),
 (' pan', ' pan'),
 (' panem', ' pana'),
 (' pana', ' pana'),
 (' pan', ' pana'),
 (' panu', ' panu'),
 (' Pani', ' Pani'),
 ('Pana', ' pan'),
 (' pana', ' Panu'),
 (' pan', ' Panu'),
 (' panu', ' pan'),
 ('panem', ' pan'),
 (' Panem', ' pan'),
 (' pana', ' Pan'),
 (' pani', ' Pani'),
 (' pana', 'pan'),
 (' pana', ' panowie'),
 (' panem', ' Pan'),
 (' pan', 'pan'),
 (' pan', ' panem'),
 (' panu', ' pana'),
 (' pan', ' Pan'),
 ('Pani', ' pani'),
 (' pana', ' Pana'),
 ('Pani', ' Pani'),
 ('panem', ' panu'),
 (' pan', ' Pana'),
 (' panem', ' Panu'),
 (' pana', ' panem'),
 (' pan', ' panom'),
 ('Pana', ' pana'),
 ('Pana', ' panu'),
 (' Pana', ' pan'),
 ('Pan', ' pan'),
 (' panom', ' pan'),
 ('Pan', ' pana'),
 (' Panem', ' panu'),
 (' Panem', ' pana'),
 (' panem', ' Pana'),
 (' pani', ' pani'),
 (' pana', ' panom'),
 (' panem', ' panem'),
 ('panem', ' pana'),
 ('Pan', ' panu'),
 (' panu',

# BERT

In [6]:
!pip install sacremoses



In [7]:
model = AutoModelForCausalLM.from_pretrained("allegro/herbert-base-cased")
tokenizer = my_tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")
emb = model.get_output_embeddings().weight.data.T.detach()
model.config

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


BertConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "allegro/herbert-base-cased",
  "architectures": [
    "BertModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 514,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "tokenizer_class": "HerbertTokenizerFast",
  "transformers_version": "4.46.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 50000
}

## Ekstrakcja wag z modelu BERT


In [8]:
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
hidden_dim = model.config.hidden_size
head_size = hidden_dim // num_heads

In [9]:
print(num_layers)
print(num_heads)
print(hidden_dim)
print(head_size)

12
12
768
64


In [9]:


K = torch.cat(
    [
        model.get_parameter(f"bert.encoder.layer.{j}.attention.self.key.weight").T
        for j in range(num_layers)
    ]
).detach()
V = torch.cat(
    [
        model.get_parameter(f"bert.encoder.layer.{j}.attention.self.value.weight")
        for j in range(num_layers)
    ]
).detach()

W_Q, W_K, W_V = (
    torch.cat(
        [
            model.get_parameter(f"bert.encoder.layer.{j}.attention.self.query.weight")
            for j in range(num_layers)
        ]
    )
    .detach()
    .chunk(3, dim=-1)
)
W_O = torch.cat(
    [
        model.get_parameter(f"bert.encoder.layer.{j}.attention.output.dense.weight")
        for j in range(num_layers)
    ]
).detach()

K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = [model.get_parameter(f"bert.encoder.layer.{j}.attention.self.query.weight").detach() for j in range(num_layers)]
W_Q_heads = [w.reshape(hidden_dim, num_heads, head_size).permute(1, 0, 2) for w in W_Q_heads]  # Reshape and permute each layer's weights separately

W_K_heads = [model.get_parameter(f"bert.encoder.layer.{j}.attention.self.key.weight").detach() for j in range(num_layers)]
W_K_heads = [w.reshape(hidden_dim, num_heads, head_size).permute(1, 0, 2) for w in W_K_heads]

W_V_heads = [model.get_parameter(f"bert.encoder.layer.{j}.attention.self.value.weight").detach() for j in range(num_layers)]
W_V_heads = [w.reshape(hidden_dim, num_heads, head_size).permute(1, 0, 2) for w in W_V_heads]

W_O_heads = [model.get_parameter(f"bert.encoder.layer.{j}.attention.output.dense.weight").detach() for j in range(num_layers)]
W_O_heads = [w.reshape(num_heads, head_size, hidden_dim) for w in W_O_heads]
emb_inv = emb.T
print(emb_inv.shape)
print(f"Layers in the model: {num_layers}")
print(f"Neurons in the model: {model.config.hidden_size}")

torch.Size([50000, 768])
Layers in the model: 12
Neurons in the model: 768


In [10]:
import torch

W_Q_heads = torch.stack([model.get_parameter(f"bert.encoder.layer.{j}.attention.self.query.weight").detach().reshape(hidden_dim, num_heads, head_size).permute(1, 0, 2) for j in range(num_layers)])
W_K_heads = torch.stack([model.get_parameter(f"bert.encoder.layer.{j}.attention.self.key.weight").detach().reshape(hidden_dim, num_heads, head_size).permute(1, 0, 2) for j in range(num_layers)])
W_V_heads = torch.stack([model.get_parameter(f"bert.encoder.layer.{j}.attention.self.value.weight").detach().reshape(hidden_dim, num_heads, head_size).permute(1, 0, 2) for j in range(num_layers)])
W_O_heads = torch.stack([model.get_parameter(f"bert.encoder.layer.{j}.attention.output.dense.weight").detach().reshape(num_heads, head_size, hidden_dim) for j in range(num_layers)])

In [12]:
print(K_heads.shape)
print(V_heads.shape)
print(W_Q_heads.shape)
print(W_K_heads.shape)
print(W_V_heads.shape)
print(W_O_heads.shape)

torch.Size([12, 768, 768])
torch.Size([12, 768, 768])
torch.Size([12, 12, 768, 64])
torch.Size([12, 12, 768, 64])
torch.Size([12, 12, 768, 64])
torch.Size([12, 12, 64, 768])


## Interpretacja wag modelu BERT na pustej liście tokenów

In [11]:
tokens_list = set()

In [16]:
ilayer = 10
ineuron = 55
print_top_tokens(ilayer, ineuron, K_heads, V_heads, emb, tokens_list)

Layer 10, Neuron 55
K                 V               -K                 -V
----------------  --------------  -----------------  ------------------
#targ             #tyn            #vio</w>           #repli
#komplet          #tyn</w>        #Dla</w>           #wyrobi
#targ</w>         #kontyn         #DLA</w>           #serie</w>
#table            #sztyn</w>      #ROZ               #018</w>
#komplet</w>      #cykl           #dobro</w>         #dora
#kura             #komitet</w>    #Ze</w>            #interpreta
#Ori              #zony</w>       #lizuje</w>        #suser</w>
#kin</w>          #iny</w>        #noszone</w>       #empi
#kongre           #Komitet</w>    #dla</w>           #007</w>
#Litera           #ranka</w>      #RIO</w>           #Serie</w>
#~                #tabu</w>       #Ade               #rece
#komp             #fotel</w>      #oddania</w>       #interpretacja</w>
#Kajetan</w>      #dzieckiem</w>  #wykorzystuje</w>  #lera</w>
#instan           #rytm</w>       #Dob

In [52]:
ilayer = 7
ineuron = 515
print_top_tokens(ilayer, ineuron, K_heads, V_heads, emb, tokens_list)

Layer 7, Neuron 515
K                V                  -K                -V
---------------  -----------------  ----------------  ------------
#norma</w>       #uczucie</w>       #rodu</w>         #|
#norma           #powstanie</w>     #nienia</w>       #monet</w>
#NHL</w>         #relacji</w>       #kow</w>          #MIGA</w>
#nota</w>        #swego</w>         #alternatyw       #nn
#oma</w>         #drugiego</w>      #prosimy</w>      #ALU
#smal            #dziewi            #gwiazd           #nal</w>
#dag</w>         #drugiej</w>       #spraw</w>        #tyn
#mist            #pierwszego</w>    #nian</w>         #Marti
#chorymi</w>     #drzwi             #laz              #IO</w>
#Hanna</w>       #tago              #czytaj</w>       #wyd</w>
#trol            #prv</w>           #lania</w>        #Journal</w>
#Tisch           #oznacza</w>       #wienia</w>       #ener
#nota            #modeli</w>        #plaz             #NOW
#Monte           #prawda</w>        #kowej</w>        #zio


### WVO Interpretation

liczby

In [57]:
i1, i2 = 11, 7

# Move tensors to the selected device
W_V_tmp = W_V_heads[i1, i2, :].to(device)
W_O_tmp = W_O_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

# Perform the computation on the device
tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)
all_high_pos = approx_topk(tmp, th0=1, verbose=True)
get_top_entries(tmp, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum,
                exclude_same=exclude_same, tokens_list=None)

one more iteration. 198449126
one more iteration. 1006820
one more iteration. 43
one more iteration. 6722


[('351', '355'),
 ('351', '351'),
 ('351', '349'),
 ('rale', 'rale'),
 ('Dort', 'Dort'),
 ('351', '353'),
 ('arium', 'arium'),
 ('355', '355'),
 ('353', '353'),
 ('389', '389'),
 ('nawi', 'nawi'),
 ('tlu', 'tlu'),
 ('367', '353'),
 ('367', '349'),
 ('DRA', 'DRA'),
 ('1908', '1910'),
 ('351', 'fanta'),
 ('389', '355'),
 ('363', '363'),
 ('367', '355'),
 ('353', '351'),
 ('1905', '1904'),
 ('351', '345'),
 ('367', '332'),
 ('367', '339'),
 ('435', '332'),
 ('367', '880'),
 ('368', '353'),
 ('435', '435'),
 ('334', '332'),
 ('353', '355'),
 ('439', '349'),
 ('363', '353'),
 ('351', '332'),
 ('387', '355'),
 ('435', '349'),
 ('1908', '1904'),
 ('scher', 'scher'),
 ('439', '332'),
 ('332', '332'),
 ('435', '355'),
 ('1899', '1949'),
 ('439', '353'),
 ('1905', '1910'),
 ('368', '332'),
 ('1908', '1925'),
 ('439', '435'),
 ('1890', '1904'),
 ('439', '647'),
 ('389', '349')]

In [18]:
i1, i2 = 4, 5

# Move tensors to the selected device
W_V_tmp = W_V_heads[i1, i2, :].to(device)
W_O_tmp = W_O_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

# Perform the computation on the device
tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)
all_high_pos = approx_topk(tmp, th0=1, verbose=True)
get_top_entries(tmp, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum,
                exclude_same=exclude_same, tokens_list=None)

one more iteration. 1614823723
one more iteration. 531785098
one more iteration. 8901907
one more iteration. 473
one more iteration. 67368
one more iteration. 5742


[('obrotach', 'remis'),
 ('obrotach', 'remis'),
 ('obrotach', 'DIE'),
 ('iera', 'Spo'),
 ('zaka', 'Gad'),
 ('dzo', 'blon'),
 ('iera', 'ATP'),
 ('dzo', 'jury'),
 ('sic', 'blon'),
 ('iera', 'WWW'),
 ('iera', 'Bou'),
 ('dzo', 'spro'),
 ('dych', 'remis'),
 ('iera', 'obie'),
 ('iera', 'ABC'),
 ('sic', 'blon'),
 ('iera', 'jury'),
 ('dych', 'remis'),
 ('obrotach', 'Spo'),
 ('iera', 'PIL'),
 ('trol', 'Sot'),
 ('nowa', 'rej'),
 ('obrotach', 'KAN'),
 ('dych', 'DIE'),
 ('zaka', 'remis'),
 ('dzo', 'kapu'),
 ('obrotach', 'ABC'),
 ('iera', 'Eva'),
 ('iera', 'plot'),
 ('iera', 'Cham'),
 ('dzo', 'rej'),
 ('dzo', 'blon'),
 ('Magiera', 'WF'),
 ('upiera', 'Spo'),
 ('dzo', 'Sot'),
 ('iera', 'EK'),
 ('dzo', 'Pot'),
 ('iera', 'VE'),
 ('iera', 'domu'),
 ('Magiera', 'WWW'),
 ('lew', 'remis'),
 ('iera', 'fax'),
 ('dych', 'rej'),
 ('iera', 'Spy'),
 ('dzo', 'ABC'),
 ('iera', 'niebie'),
 ('iera', 'formule'),
 ('upiera', 'WF'),
 ('lew', 'remis'),
 ('iera', 'bie')]

### WKQ Interpretation

In [15]:
i1, i2 = 2, 2

# Move tensors to the selected device
W_Q_tmp = W_Q_heads[i1, i2, :].to(device)
W_K_tmp = W_K_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

tmp = emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T

all_high_pos = approx_topk(tmp, th0=1, verbose=True)

get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=only_ascii,
    only_alnum=only_alnum,
    exclude_same=exclude_same,
    tokens_list=None,
)

one more iteration. 799290475
one more iteration. 315156187
one more iteration. 21015350
one more iteration. 17873


[('Ferenc', 'napi'),
 ('Lista', 'NEJ'),
 ('Transport', 'Bom'),
 ('owali', 'Canal'),
 ('Hos', 'kow'),
 ('Natura', 'NEJ'),
 ('Ferenc', 'FN'),
 ('Wen', 'Bom'),
 ('Ferenc', 'Kru'),
 ('Lista', 'kow'),
 ('bra', 'NEJ'),
 ('Stel', 'Bom'),
 ('Kluby', 'kow'),
 ('Juli', 'Bom'),
 ('barba', 'KTU'),
 ('publicznego', 'napi'),
 ('Br', 'NEJ'),
 ('pry', 'NEJ'),
 ('Lista', 'Arme'),
 ('NEGO', 'KOV'),
 ('Ferenc', 'Google'),
 ('pobo', 'NEJ'),
 ('Pry', 'NEJ'),
 ('Ele', 'Tol'),
 ('Infra', 'Bom'),
 ('Andrew', 'NEJ'),
 ('Infra', 'Arme'),
 ('Administracyjnego', 'znie'),
 ('Dos', 'znacznej'),
 ('Ferenc', 'Canal'),
 ('Wen', 'NEJ'),
 ('Lesznie', 'znie'),
 ('Ferenc', 'Czech'),
 ('Juli', 'kow'),
 ('Indo', 'KOV'),
 ('Lista', 'Tol'),
 ('Kuba', 'NEJ'),
 ('Mario', 'Tol'),
 ('poda', 'kow'),
 ('Tol', 'doby'),
 ('Kuba', 'kow'),
 ('publicznego', 'kowy'),
 ('Kore', 'NEJ'),
 ('Instru', 'Arme'),
 ('Hos', 'cyjnych'),
 ('Klu', 'FN'),
 ('Fle', 'kow'),
 ('Ferenc', 'Hy'),
 ('Dos', 'kow'),
 ('bra', 'NY')]

In [13]:
i1, i2 = 8, 11

# Move tensors to the selected device
W_Q_tmp = W_Q_heads[i1, i2, :].to(device)
W_K_tmp = W_K_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

tmp = emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T

all_high_pos = approx_topk(tmp, th0=1, verbose=True)

get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=only_ascii,
    only_alnum=only_alnum,
    exclude_same=exclude_same,
    tokens_list=None,
)

one more iteration. 464473585
one more iteration. 99335701
one more iteration. 1815857
one more iteration. 249
one more iteration. 21559


[('analogi', 'mentar'),
 ('Wiele', 'Wiele'),
 ('275', 'Wiele'),
 ('192', 'Red'),
 ('275', '649'),
 ('512', 'Wiele'),
 ('alne', '256'),
 ('192', '384'),
 ('Wiele', 'Tere'),
 ('kod', 'Pur'),
 ('alne', '384'),
 ('192', 'Leszczy'),
 ('275', '384'),
 ('299', 'Wiele'),
 ('384', '649'),
 ('wyraz', 'tenis'),
 ('299', '512'),
 ('512', 'Leszczy'),
 ('512', 'Red'),
 ('512', '384'),
 ('312', 'Wiele'),
 ('192', '512'),
 ('275', '249'),
 ('kod', 'Tere'),
 ('192', 'Lite'),
 ('wyraz', 'najem'),
 ('baj', '512'),
 ('Rzecz', 'Stel'),
 ('mno', '512'),
 ('384', '384'),
 ('Medal', 'statystyk'),
 ('wyraz', 'Interna'),
 ('zej', 'ction'),
 ('baj', '256'),
 ('portal', 'tional'),
 ('297', 'Wiele'),
 ('wyraz', 'cke'),
 ('256', '649'),
 ('249', '649'),
 ('512', 'Fle'),
 ('332', 'Wiele'),
 ('275', 'Fle'),
 ('512', '431'),
 ('254', '649'),
 ('godziny', 'banko'),
 ('299', 'Red'),
 ('najmniejszego', '384'),
 ('242', 'Red'),
 ('192', '192'),
 ('alne', 'Wiele')]