In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRankerLevel
from mllm.config.model import TokenizerCfg, MllmEncdecCfg, MllmRankerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, ChunkTokenizer, tokenizer_from_config

In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

TRAIN_ENCDEC_0_PATH = DATA_PATH / 'train_mllm_encdec_0'
TRAIN_ENCDEC_1_PATH = DATA_PATH / 'train_mllm_encdec_1'
encdec_0_subdir = 'encdec-lvl0-20241026_120743-wiki_20200501_en-ch_100_fixed-enc-lrs2-embmatFalse-d256-h8-dec-lrs2-seqlen100-d256-h8'
encdec_1_subdir = 'encdec-lvl1-20241022_224217-msmarco-fever-enc-lrs2-embmatTrue-d256-h8-dec-lrs2-seqlen100-d256-h8'

encdec_0_train_path = TRAIN_ENCDEC_0_PATH / encdec_0_subdir
encdec_1_train_path = TRAIN_ENCDEC_1_PATH / encdec_1_subdir
encdec_0_snapshot_fpath = encdec_0_train_path / 'best.pth'
encdec_1_snapshot_fpath = encdec_1_train_path / 'best.pth'
encdec_0_tkz_cfg_fpath = encdec_0_train_path / TOKENIZER_CFG_FNAME
encdec_0_model_cfg_fpath = encdec_0_train_path / ENCDEC_MODEL_CFG_FNAME
encdec_1_model_cfg_fpath = encdec_1_train_path / ENCDEC_MODEL_CFG_FNAME

In [4]:
encdec_tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_0_tkz_cfg_fpath)
tokenizer = tokenizer_from_config(encdec_tkz_cfg)
tok_dict = encdec_tkz_cfg.custom_tokens
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind

In [5]:
docs_batch_size = 5
max_chunks_per_doc = 3
model_level = 0
device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [6]:
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [7]:
def tokens_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [8]:
model_encdec_0_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_0_model_cfg_fpath)
model_encdec_0 = MllmEncdecLevel(model_encdec_0_cfg, 0).to(device)
checkpoint_encdec_0 = torch.load(encdec_0_snapshot_fpath)
model_encdec_0.load_state_dict(checkpoint_encdec_0['model'], strict=False)
model_encdec_0.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 -3.6519843e-06 0.010897173
vocab_encoder.layer_norm.weight (256,) -0.0998892 -0.0028097667 0.09690992
vocab_encoder.layer_norm.bias (256,) -0.099947885 -0.0012904839 0.099928394
vocab_decoder.word_prj.weight (50271, 256) -0.010897174 -1.8544521e-06 0.010897173
encoder.a_em () 0.0035125257 0.0035125257 0.0035125257
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824708 -0.00033453054 0.108253054
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824891 -0.00026277496 0.108251795
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.1082525 -0.00027601025 0.108245134
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108252764 0.00030779542 0.108250566
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.099997304 6.118789e-05 0.09898521
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09913106 -0.0061583123 0.09971646
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068464845 4.254

In [9]:
model_encdec_1_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_1_model_cfg_fpath)
model_encdec_1 = MllmEncdecLevel(model_encdec_1_cfg, 1).to(device)
checkpoint_encdec_1 = torch.load(encdec_1_snapshot_fpath)
model_encdec_1.load_state_dict(checkpoint_encdec_1['model'], strict=False)
model_encdec_1.eval()
None

encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824737 -0.00018348327 0.10825293
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825174 0.0004385211 0.10823614
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824423 -0.00033476908 0.10825088
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10824534 0.00011559089 0.10825205
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09864409 0.0010340018 0.0991658
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09685143 0.0035557617 0.09987175
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846469 7.936111e-05 0.06846529
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09958398 6.612661e-05 0.099899314
encoder.layer_stack.0.pos_ffn.w_2.weight (256, 1024) -0.06846531 -3.3801393e-06 0.06846486
encoder.layer_stack.0.pos_ffn.w_2.bias (256,) -0.09832839 -0.0058244728 0.09848913
encoder.layer_stack.0.pos_ffn.layer_norm.weight (256,) -0.09979723 -0.004286126 0.09886774
encoder.layer_stac

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([1, 100]),
 tensor([ True,  True, False, False, False, False, False, False, False, False,
         False, False, False]))

In [15]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|> KirenskyKirensky (masculine), Kirenskaya (feminine), or Kirenskoye (neuter) may refer to:\nKirensky District, a district of Irkutsk Oblast, Russia\nKirenskoye Urban Settlement, a municipal formation which the town of Kirensk and nine rural localities in Kirensky District of Irkutsk Oblast, Russia are incorporated as <|query_end|>'

In [21]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:600].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 3435884 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Kirensky <|doc_title_end|> <|doc_body_begin|> Kirensky (masculine), Kirenskaya (feminine), or Kirenskoye (neuter) may refer to:\nKirensky District, a district of Irkutsk Oblast, Russia\nKirenskoye Urban Settlement, a municipal formation which the town of Kirensk and nine rural localities in Kirensky District of Irkutsk Oblast, Russia are
<|doc_begin|> <|doc_id_begin|> 3435884 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  incorporated as <|doc_body_end|> <|doc_end|>
<|doc_begin|> <|doc_id_begin|> 3568924 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Aleksandr Averin (cyclist) <|doc_title_end|> <|doc_body_begin|> Aleksandr Dmitriyevich Averin (; born 11 April 1954) is a retired Soviet cyclist. He competed at the 1976 Summer Olympics in the road race and finished in 17th place. He won the multistage Peace Race individually in 1978 

In [17]:
docs_chunks_pred = model_encdec_0(docs_chunks)
# docs_chunks_pred = torch.sigmoid(docs_chunks_pred)
docs_chunks_pred = torch.softmax(docs_chunks_pred, dim=-1)
docs_chunks_pred.shape, docs_chunks_pred.dtype

(torch.Size([13, 100, 50271]), torch.float32)

In [18]:
dc_toks_pred = torch.argmax(docs_chunks_pred, dim=-1)
dc_toks_pred.shape

torch.Size([13, 100])

In [19]:
for toks in dc_toks_pred:
    s = tokens_to_text(toks)
    s = s.replace('\n', '\\n')
    print(len(toks), len(s), s)


100 407 <|doc_begin|> <|doc_id_begin|> 736460 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Konky <|doc_title_end|> <|doc_body_begin|> Koyky (oratan, Kekaya (feminine), as Koyoyoye (ruguter) may refer to:\nKhangkyky, a locality of Selkkk Oblast, Russia\nKkkoye Rural Settlement, the local station in the border of Kokky, a rural railwayities in Karkky, in Khkodk Oblast, Russia covers
100 135 <|doc_begin|> <|doc_id_begin|> 706432 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  served to <|doc_body_end|> <|doc_end|>
100 471 <|doc_begin|> <|doc_id_begin|> 733834 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Aleksandr Kfst (cyclist) <|doc_title_end|> <|doc_body_begin|> Aleksandrrovrjrevev ( (born; born 25 January 1974) is a a swim cyclist. He competed at the 1992 Summer Olympics in the silver team and finished in 18th place. He won the bronzeistist Team Prix, in 1988 and with the Olympic title in 1976–08. \n\nHe started

In [41]:
txts = [
    '<|doc_begin|> 20 <|doc_id_begin|> 733860 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>Hello, my name is Mikhail<|doc_end|>',
    'Malaga is a city in Spain',
    'LLM stands for Large <|mask|> Model',
    'You\'d better learn new modeling approaches first, Mikhail from Malaga, Spain!',
]

In [42]:
chunks = []
for txt in txts:
    toks = text_to_tokens(txt)
    print(toks.shape)
    chunks.append(toks)

chunks = torch.concat(chunks)
for toks in chunks:
    s = tokens_to_text(toks)
    print(s)


torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 100])
<|doc_begin|>  20  <|doc_id_begin|>  733860  <|doc_id_end|>   <|doc_offset_begin|>  91  <|doc_offset_end|> Hello, my name is Mikhail <|doc_end|>
Malaga is a city in Spain
LLM stands for Large  <|mask|>  Model
You'd better learn new modeling approaches first, Mikhail from Malaga, Spain!


In [43]:
chunks_pred = model_encdec_0(chunks)
# chunks_pred = torch.sigmoid(chunks_pred)
chunks_pred = torch.softmax(chunks_pred, dim=-1)
print(chunks_pred.shape, chunks_pred.dtype)
toks_pred = torch.argmax(chunks_pred, dim=-1)
print(toks_pred.shape)

torch.Size([4, 100, 50271]) torch.float32
torch.Size([4, 100])


In [44]:
for toks in toks_pred:
    s = tokens_to_text(toks)
    s = s.replace('\n', '\\n')
    print(s)


<|doc_begin|> <|doc_id_begin|> 7:: <|doc_offset_begin|> 9291\n    \nst, a we is:\n
aaaa in the
--'s
 a a of a the the the, Bosnia by from Marisi, the <|query_end|> <|query_end|>


In [28]:
probs = chunks_pred[0][0]
inds = torch.arange(len(probs))
prob_thres = 0.95
prob_thres = 0.1
probs_mask = probs >= prob_thres
print(probs[probs_mask])
toks = inds[probs_mask]
strs = []
for tok in toks:
    toks_ = torch.Tensor([tok])
    s = tokens_to_text(toks_)
    strs.append(s)
print(strs)


tensor([], grad_fn=<IndexBackward0>)
[]


In [4]:
1 / 2.71**0.01

0.9900800442452646