In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRankerLevel
from mllm.config.model import TokenizerCfg, MllmEncdecCfg, MllmRankerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, ChunkTokenizer, tokenizer_from_config

In [4]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

TRAIN_ENCDEC_0_PATH = DATA_PATH / 'train_mllm_encdec_0'
TRAIN_ENCDEC_1_PATH = DATA_PATH / 'train_mllm_encdec_1'
encdec_0_subdir = 'encdec-20241018_092135-wiki_20200501_en-ch_100_fixed'
encdec_1_subdir = 'encdec-lvl1-20241022_224217-msmarco-fever-enc-lrs2-embmatTrue-d256-h8-dec-lrs2-seqlen100-d256-h8'

encdec_0_train_path = TRAIN_ENCDEC_0_PATH / encdec_0_subdir
encdec_1_train_path = TRAIN_ENCDEC_1_PATH / encdec_1_subdir
encdec_0_snapshot_fpath = encdec_0_train_path / 'best.pth'
encdec_1_snapshot_fpath = encdec_1_train_path / 'best.pth'
encdec_0_tkz_cfg_fpath = encdec_0_train_path / TOKENIZER_CFG_FNAME
encdec_0_model_cfg_fpath = encdec_0_train_path / ENCDEC_MODEL_CFG_FNAME
encdec_1_model_cfg_fpath = encdec_1_train_path / ENCDEC_MODEL_CFG_FNAME

In [5]:
encdec_tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_0_tkz_cfg_fpath)
tokenizer = tokenizer_from_config(encdec_tkz_cfg)
tok_dict = encdec_tkz_cfg.custom_tokens
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind

In [7]:
docs_batch_size = 5
max_chunks_per_doc = 3
model_level = 0
device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [8]:
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [20]:
def tokens_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [10]:
model_encdec_0_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_0_model_cfg_fpath)
model_encdec_0 = MllmEncdecLevel(model_encdec_0_cfg, 0).to(device)
checkpoint_encdec_0 = torch.load(encdec_0_snapshot_fpath)
model_encdec_0.load_state_dict(checkpoint_encdec_0['model'], strict=False)
model_encdec_0.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 1.0290058e-06 0.010897167
vocab_encoder.layer_norm.weight (256,) -0.09767299 0.0045641935 0.098901965
vocab_encoder.layer_norm.bias (256,) -0.099590324 0.0021672586 0.099857084
vocab_decoder.word_prj.weight (50271, 256) -0.010897174 9.409382e-07 0.010897167
encoder.a_em () -0.08215385 -0.08215385 -0.08215385
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825258 -5.4940843e-05 0.10824862
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825103 0.00015384433 0.108227864
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825148 -5.2423693e-05 0.10825284
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825139 -3.4555997e-05 0.10825178
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09948003 0.0055500614 0.099358074
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09889706 -0.005347998 0.09974853
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846524 -8.4352e-06 0.06

In [11]:
model_encdec_1_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_1_model_cfg_fpath)
model_encdec_1 = MllmEncdecLevel(model_encdec_1_cfg, 1).to(device)
checkpoint_encdec_1 = torch.load(encdec_1_snapshot_fpath)
model_encdec_1.load_state_dict(checkpoint_encdec_1['model'], strict=False)
model_encdec_1.eval()
None

encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108239725 1.862934e-05 0.10824919
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825277 0.00013587641 0.10824961
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.1082463 -0.00013216451 0.1082527
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108250424 -2.480892e-05 0.10824795
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09999209 -0.002995777 0.09966638
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09989907 -0.004461446 0.099620536
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846472 -4.0456533e-05 0.06846482
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.099666536 -0.0029983278 0.09997972
encoder.layer_stack.0.pos_ffn.w_2.weight (256, 1024) -0.06846484 9.350548e-05 0.06846313
encoder.layer_stack.0.pos_ffn.w_2.bias (256,) -0.09855949 -0.0028014882 0.09998498
encoder.layer_stack.0.pos_ffn.layer_norm.weight (256,) -0.09982752 0.004085768 0.09990708
encoder.layer_sta

In [13]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False,  True,  True,
          True, False, False, False]))

In [14]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|> North Caucasian Legion / Mountain-Caucasian LegionThe North Caucasian Legion (Germ.Legion Nordkaukasien) and the Mountain-Caucasian Legion (Germ.Bergkaukasien Legion) legions were created in accordance with the order of 19 February 1942. Initially, its soldiers, recruited from the camps of prisoners of war, deserters, and partly from representatives of emigration were included in the Caucasian-Mohammedan Legion (Germ. Kaukasisch-Mohammedanische Legion). On 2 August 1942, in accordance with the order of 19 February 1942, all fighters of Muslim North Caucasian and Mountain Caucasian (both Muslims and Christians) origin were separated from the Caucasian-Mohamedan Legion into separate North Caucasian / Mountain-Caucasian legions. These Legions consisted of Abkhazians, Circassians, Kabardians, Balkars, Karachais, Chechens, Ingushes, and the peoples of Daghestan. The Kurds, Talyshs and North Ossetians appeared later. According to the researcher Traho R. "The total number of 

In [15]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 1920586 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Shrewsbury Flower Show <|doc_title_end|> <|doc_body_begin|> The Shrewsbury Flower Sh
<|doc_begin|> <|doc_id_begin|> 1920586 <|doc_id_end|> <|doc_offset_begin|> 92 <|doc_offset_end|>  in the United Kingdom. It is also one of the longest-running shows in the country and featured in the 
<|doc_begin|> <|doc_id_begin|> 1920586 <|doc_id_end|> <|doc_offset_begin|> 184 <|doc_offset_end|>  show jumping, various forms of music and entertainment, which includes a large firework display on bo
<|doc_begin|> <|doc_id_begin|> 137392 <|doc_id_end|> <|doc_offset_begin|> 92 <|doc_offset_end|>  there are a large number of hiking destinations nearby, including Mount Abel, the West Bowl, and East B
<|doc_begin|> <|doc_id_begin|> 137392 <|doc_id_end|> <|doc_offset_begin|> 184 <|doc_offset_end|>  directors changes every year. Directors resign and are voted on each year at the Annual General Me

In [39]:
docs_chunks_pred = model_encdec_0(docs_chunks)
docs_chunks_pred = torch.sigmoid(docs_chunks_pred)
docs_chunks_pred.shape, docs_chunks_pred.dtype

(torch.Size([14, 100, 50271]), torch.float32)

In [40]:
dc_toks_pred = torch.argmax(docs_chunks_pred, dim=-1)
dc_toks_pred.shape

torch.Size([14, 100])

In [41]:
for toks in dc_toks_pred:
    s = tokens_to_text(toks)
    s = s.replace('\n', '\\n')
    print(len(toks), len(s), s)


100 478 <|doc_id_begin|> <|doc_id_begin|> 2727 <|doc_id_end|> <|doc_id_begin|> <|doc_id_end|> <|doc_id_end|> <|doc_title_begin|> <|doc_body_begin|> ton ( " Good <|doc_title_end|> <|doc_title_end|> <|doc_title_end|> <|doc_body_begin|> <|doc_body_begin|>  (,, <|doc_title_end|> <|doc_title_end|>  is is is a a,\n\n\n. the the the,, the the the the,,, the the the,\n,,., the the the the the\n- St,,, in the the the, Lon.\n\n\n the the a the the,\n St L---,,\n\n the is a a features the the
100 392 <|doc_id_begin|> <|doc_id_begin|> 2727 <|doc_id_end|> <|doc_id_end|> <|doc_id_end|> <|doc_id_end|> \n\n\n\n\n\n the is the the the the theised, is the the the the the the the the (,,\n theThe the the the a a the the the the the the the thebody, the\n\n the is is the the the the the <|doc_begin|>,,. the the the the " "a., is the is the a a a is a a and of a is the the a rock,, the the music
100 339 <|doc_begin|> <|doc_id_begin|> 2727 <|doc_id_end|> <|doc_id_end|> <|doc_id_end|> <|doc_id_end|>  a a, th

In [25]:
txts = [
    'Hello, my name is Mikhail',
    'Malaga is a city in Spain',
    'LLM stands for Large Language Model',
    'You\'d better learn new modeling approaches first, Mikhail from Malaga, Spain!',
]

In [27]:
chunks = []
for txt in txts:
    toks = text_to_tokens(txt)
    print(toks.shape)
    chunks.append(toks)

chunks = torch.concat(chunks)
for toks in chunks:
    s = tokens_to_text(toks)
    print(s)


torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 100])
Hello, my name is Mikhail
Malaga is a city in Spain
LLM stands for Large Language Model
You'd better learn new modeling approaches first, Mikhail from Malaga, Spain!


In [44]:
chunks_pred = model_encdec_0(chunks)
# chunks_pred = torch.sigmoid(chunks_pred)
chunks_pred = torch.softmax(chunks_pred, dim=-1)
print(chunks_pred.shape, chunks_pred.dtype)
toks_pred = torch.argmax(chunks_pred, dim=-1)
print(toks_pred.shape)

torch.Size([4, 100, 50271]) torch.float32
torch.Size([4, 100])


In [45]:
for toks in toks_pred:
    s = tokens_to_text(toks)
    s = s.replace('\n', '\\n')
    print(s)


 of <|query_end|> <|query_end|> <|query_end|> CategoryM <|doc_body_end|> <|query_end|> <|doc_body_end|>
igi <|doc_body_end|> <|doc_body_end|> <|doc_body_end|> <|query_end|> <|doc_body_end|> <|query_end|> <|doc_body_end|> <|doc_body_end|>
\n\n <|doc_body_end|> -\n <|doc_body_end|> Category <|query_end|> Category
 of the of of the the the of\n\nhia <|query_end|> <|query_end|> <|query_end|>


In [51]:
probs = chunks_pred[0][0]
inds = torch.arange(len(probs))
probs_mask = probs >= 0.95
print(probs[probs_mask])
toks = inds[probs_mask]
strs = []
for tok in toks:
    toks_ = torch.Tensor([tok])
    s = tokens_to_text(toks_)
    strs.append(s)
print(strs)


tensor([0.0121, 0.0132, 0.0157, 0.0234, 0.0613, 0.0251, 0.0228, 0.0181],
       grad_fn=<IndexBackward0>)
[',', '-', '\n', ' the', ' of', ' in', ' and', ' (']
