In [1]:
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from tabulate import tabulate
from tqdm import tqdm, trange
from copy import deepcopy
import numpy as np
from collections import Counter

In [2]:
ALNUM_CHARSET = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')

def convert_to_tokens(indices, tokenizer, extended=False, extra_values_pos=None, strip=True):
    if extended:
        res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else
               (f"[pos{idx-len(tokenizer)}]" if idx < extra_values_pos else f"[val{idx-extra_values_pos}]")
               for idx in indices]
    else:
        res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    return res


def top_tokens(v, k=100, tokenizer=None, only_alnum=False, only_ascii=True, with_values=False,
               exclude_brackets=False, extended=True, extra_values=None, only_from_list=None):
    if tokenizer is None:
        tokenizer = my_tokenizer
    v = deepcopy(v)
    ignored_indices = []
    if only_ascii:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not val.strip('Ġ▁').isascii()])
    if only_alnum:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not (set(val.strip('Ġ▁[] ')) <= ALNUM_CHARSET)])
    if only_from_list:
        ignored_indices.extend([key for val, key in tokenizer.vocab.items() if val.strip('Ġ▁ ').lower() not in only_from_list])
    if exclude_brackets:
        ignored_indices = set(ignored_indices).intersection(
            {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
        ignored_indices = list(ignored_indices)

    ignored_indices = list(set(ignored_indices))
    v[ignored_indices] = -np.inf
    extra_values_pos = len(v)
    if extra_values is not None:
        v = torch.cat([v, extra_values])
    values, indices = torch.topk(v, k=k)
    res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)
    if with_values:
        res = list(zip(res, values.cpu().numpy()))
    return res

In [3]:
model = AutoModelForCausalLM.from_pretrained("sdadas/polish-gpt2-medium")
tokenizer = my_tokenizer = AutoTokenizer.from_pretrained("sdadas/polish-gpt2-medium")
emb = model.get_output_embeddings().weight.data.T.detach()
emb = model.get_output_embeddings().weight.data.T.detach()
model.config

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/837 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.53G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.34M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

GPT2Config {
  "_attn_implementation_autoset": true,
  "_name_or_path": "sdadas/polish-gpt2-medium",
  "activation_function": "gelu_fast",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 0,
  "embd_pdrop": 0.1,
  "eos_token_id": 2,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 1024,
  "n_head": 16,
  "n_inner": 4096,
  "n_layer": 24,
  "n_positions": 2048,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "tokenizer_class": "GPT2TokenizerFast",
  "torch_dtype": "float32",
  "transformers_version": "4.46.2",
  "use_cache": true,
  "vocab_size": 51200
}

In [4]:

num_layers = model.config.n_layer
num_heads = model.config.n_head
hidden_dim = model.config.n_embd
head_size = hidden_dim // num_heads

K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
                           for j in range(num_layers)]).detach()
V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
                           for j in range(num_layers)]).detach()

W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
                           for j in range(num_layers)]).detach().chunk(3, dim=-1)
W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
                           for j in range(num_layers)]).detach()

In [5]:
K_heads = K.reshape(num_layers, -1, hidden_dim)
V_heads = V.reshape(num_layers, -1, hidden_dim)
d_int = K_heads.shape[1]

W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)

In [6]:
emb_inv = emb.T

In [7]:
i1, i2 = 23, 907

In [8]:
print(i1, i2)

23 907


In [9]:
tokens_list = set()

In [10]:
pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [11]:
from datasets import load_dataset
dataset = load_dataset("clarin-knext/wsd_polish_datasets")

README.md:   0%|          | 0.00/11.0k [00:00<?, ?B/s]

wsd_polish_datasets.py:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

The repository for clarin-knext/wsd_polish_datasets contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/clarin-knext/wsd_polish_datasets.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


sherlock_text.jsonl:   0%|          | 0.00/2.25M [00:00<?, ?B/s]

skladnica_text.jsonl:   0%|          | 0.00/29.0M [00:00<?, ?B/s]

wikiglex_text.jsonl:   0%|          | 0.00/12.0M [00:00<?, ?B/s]

emoglex_text.jsonl:   0%|          | 0.00/23.1M [00:00<?, ?B/s]

walenty_text.jsonl:   0%|          | 0.00/50.6M [00:00<?, ?B/s]

kpwr_text.jsonl:   0%|          | 0.00/57.4M [00:00<?, ?B/s]

kpwr-100_text.jsonl:   0%|          | 0.00/8.02M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [12]:
data = dataset['train']['text']

In [None]:
max_tokens_num = None

In [None]:
if max_tokens_num is None:
    tokens_list = set()
    for txt in tqdm(data):
        tokens_list = tokens_list.union(set(tokenizer.tokenize(txt)))
else:
    tokens_list = Counter()
    for txt in tqdm(data):
        tokens_list.update(set(tokenizer.tokenize(txt)))
    tokens_list = map(lambda x: x[0], tokens_list.most_common(max_tokens_num))

100%|██████████| 7848/7848 [00:39<00:00, 196.74it/s]


In [13]:
tokens_list = set([*map(lambda x: x.strip('Ġ▁').lower(), tokens_list)])

In [13]:
 i1, i2 = 23, 907
# i1, i2 = np.random.randint(num_layers), np.random.randint(d_int)

print(i1, i2)
print(tabulate([*zip(
    top_tokens((K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),
    top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
    top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),
)], headers=['K', 'V', '-K', '-V']))

23 907
K          V              -K          -V
---------  -------------  ----------  ---------
przycu     kody           dotychczas  #bot
zalog      #gory          rodzi       #lot
wylegi     #ei            do          #ush
#walifik   #zmy           przebie     #dzista
przesp     Apo            gatunku     #remont
pochowany  apokali        przez       #lee
#cket      #128           #ja         #wan
#ppe       ludy           rodzin      #ette
wep        #ords          zrazu       #up
sfinans    Cezary         lokalnie    #bul
#iss       archa          porywa      #puszczam
#CS        przy           two         spu
#gny       akcy           jeszcze     zamyka
#zwol      Homo           #jaw        odstawi
#-).       litera         na          #mont
lock       #pka           wymaga      #beki
skonfisk   Cezar          ty          #laks
#post      Benedykt       Drze        tap
postoju    narodem        pod         poby
erek       polityki       Nie         posto
#ionu      symbolu        

In [14]:

def approx_topk(mat, min_k=500, max_k=50_000, th0=10, max_iters=10, verbose=False):
    def _get_actual_k(th, th_max):
        # Split the computation into chunks to reduce memory usage
        chunk_size = 1024  # Adjust this value based on your GPU memory
        num_chunks = (mat.shape[0] + chunk_size - 1) // chunk_size
        actual_k = 0
        for i in range(num_chunks):
            start = i * chunk_size
            end = min(start + chunk_size, mat.shape[0])
            chunk = mat[start:end]
            actual_k += torch.nonzero((chunk > th) & (chunk < th_max)).shape[0]
        return actual_k

    th_max = np.inf
    left, right = 0, th0
    while True:
        actual_k = _get_actual_k(right, th_max)
        if verbose:
            print(f"one more iteration. {actual_k}")
        if actual_k <= max_k:
            break
        left, right = right, right * 2
    if min_k <= actual_k <= max_k:
        th = right
    else:
        for _ in range(max_iters):
            mid = (left + right) / 2
            actual_k = _get_actual_k(mid, th_max)
            if verbose:
                print(f"one more iteration. {actual_k}")
            if min_k <= actual_k <= max_k:
                break
            if actual_k > max_k:
                left = mid
            else:
                right = mid
        th = mid

    # Get the nonzero indices in chunks
    all_indices = []
    chunk_size = 1024  # Adjust this value based on your GPU memory
    num_chunks = (mat.shape[0] + chunk_size - 1) // chunk_size
    for i in range(num_chunks):
        start = i * chunk_size
        end = min(start + chunk_size, mat.shape[0])
        chunk = mat[start:end]
        indices = torch.nonzero((chunk > th) & (chunk < th_max))
        # Adjust indices to reflect the original matrix
        indices[:, 0] += start
        all_indices.extend(indices.tolist())

    return all_indices
def get_top_entries(tmp, all_high_pos, only_ascii=False, only_alnum=False, exclude_same=False, exclude_fuzzy=False, tokens_list=None):
    remaining_pos = all_high_pos
    if only_ascii:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁').isascii() and tokenizer.decode(x[1]).strip('Ġ▁').isascii()),
            remaining_pos)]
    if only_alnum:
        remaining_pos = [*filter(
            lambda x: (tokenizer.decode(x[0]).strip('Ġ▁ ').isalnum() and tokenizer.decode(x[1]).strip('Ġ▁ ').isalnum()),
            remaining_pos)]
    if exclude_same:
        remaining_pos = [*filter(
            lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(),
            remaining_pos)]
    if exclude_fuzzy:
        remaining_pos = [*filter(
            lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()),
            remaining_pos)]
    if tokens_list:
        remaining_pos = [*filter(
            lambda x: ((tokenizer.decode(x[0]).strip('Ġ▁').lower().strip() in tokens_list) and
                       (tokenizer.decode(x[1]).strip('Ġ▁').lower().strip() in tokens_list)),
            remaining_pos)]

    pos_val = tmp[[*zip(*remaining_pos)]]
    good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]
    good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))
    remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val).cpu()[:50]]
    good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]
    # good_cells[:100]
    # list(zip(good_tokens[0], good_tokens[1]))
    return good_cells_best

In [16]:
i1, i2 = np.random.randint(num_layers), np.random.randint(num_heads)
i1, i2 = 21, 7
i1, i2

(21, 7)

In [17]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device:", device)
else:
    device = torch.device("cpu")
    print("CUDA not available, using CPU device:", device)

# Move tensors to the selected device
W_V_tmp = W_V_heads[i1, i2, :].to(device)
W_O_tmp = W_O_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

# Perform the computation on the device
tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)

Using CUDA device: cuda


In [20]:
all_high_pos = approx_topk(tmp, th0=1, verbose=True)

one more iteration. 0
one more iteration. 0
one more iteration. 10
one more iteration. 5182


In [21]:
exclude_same = False
reverse_list = False
only_ascii = True
only_alnum = False

In [30]:
get_top_entries(tmp, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum,
                exclude_same=exclude_same, tokens_list=None)

[(' ranem', ' nad'),
 ('ornie', ' przez'),
 ('datek', ' nad'),
 ('arl', ' przed'),
 (' pewno', ' na'),
 ('spodziewanie', ' nad'),
 (' razu', ' od'),
 ('miernie', ' nad'),
 (' okazji', ' przy'),
 (' wsk', ' na'),
 (' czele', ' na'),
 ('orne', ' przez'),
 ('godziny', ' nad'),
 ('sione', ' przed'),
 (' ranem', 'nad'),
 ('natural', ' nad'),
 ('przewo', ' nad'),
 (' Duna', ' nad'),
 ('ktory', ' do'),
 ('miar', ' nad'),
 (' dobra', ' dla'),
 (' Jezi', ' nad'),
 ('spodzie', ' nad'),
 (' niedawna', ' do'),
 ('pisie', ' pod'),
 (' wygody', ' dla'),
 ('arcie', ' przed'),
 (' sumie', ' w'),
 ('miernie', 'nad'),
 ('tek', ' pod'),
 ('wcze', ' przed'),
 ('mier', ' nad'),
 ('hala', ' pod'),
 ('ornie', ' przeze'),
 (' uboczu', ' na'),
 (' podstawie', ' na'),
 (' ranem', 'Nad'),
 ('czesne', ' do'),
 ('granicznych', ' nad'),
 ('wiska', ' przez'),
 (' barkach', ' na'),
 ('ornie', 'przez'),
 (' odmiany', ' dla'),
 ('niego', ' przed'),
 (' plecami', ' za'),
 (' dobi', ' na'),
 ('przewodni', ' nad'),
 (' ko

In [15]:
i1, i2 = 18, 2
i1, i2

# Check if CUDA is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device:", device)
else:
    device = torch.device("cpu")
    print("CUDA not available, using CPU device:", device)

# Move tensors to the selected device
W_V_tmp = W_V_heads[i1, i2, :].to(device)
W_O_tmp = W_O_heads[i1, i2].to(device)
emb_inv = emb_inv.to(device)
emb = emb.to(device)

# Perform the computation on the device
tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)
all_high_pos = approx_topk(
    tmp, th0=1, verbose=True
)  # torch.nonzero((tmp > th) & (tmp < th_max)).tolist()
exclude_same = False
reverse_list = False
only_ascii = True
only_alnum = False
get_top_entries(
    tmp,
    all_high_pos,
    only_ascii=only_ascii,
    only_alnum=only_alnum,
    exclude_same=exclude_same,
    tokens_list=None,
)

Using CUDA device: cuda
one more iteration. 0
one more iteration. 0
one more iteration. 1738


[('was', '-'),
 ('was', '--'),
 ('lla', '-'),
 ('heim', '-'),
 ('lle', '-'),
 ('gonie', '-'),
 ('lla', '--'),
 ('ye', '-'),
 ('osobowej', '-'),
 ('go', '-'),
 ('has', '-'),
 ('zym', '-'),
 ('tino', '-'),
 ('kowiec', '-'),
 (' Stanu', '-'),
 ('zji', '-'),
 ('lowie', '-'),
 ('procent', '-'),
 ('lle', '-.'),
 ('stanu', '-'),
 ('lle', '--'),
 ('123', '-'),
 ('gos', '-'),
 ('head', '-'),
 ('fil', '--'),
 ('lla', '-.'),
 ('kowca', '-'),
 ('lah', '-'),
 ('lah', '--'),
 ('was', '"-'),
 ('fil', '-'),
 ('krotnie', '-'),
 ('dno', '-'),
 ('top', '-'),
 ('lla', '->'),
 ('sbur', '-'),
 ('czycy', '-'),
 ('dzkiego', '-'),
 ('pii', '-'),
 ('bak', '-'),
 ('lowie', '--'),
 ('lski', '-'),
 ('lit', '-'),
 ('has', '--'),
 ('ben', '-'),
 ('lla', '-)'),
 ('ls', '-'),
 ('tino', '--'),
 ('lli', '-'),
 ('zji', '--')]