In [2]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.dsfixed import DsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.model.config import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [4]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = DsLoader(
    ds_dir_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [6]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [7]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897278 1.5340246e-06 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.099695794 0.003109554 0.099689685
vocab_encoder.layer_norm.bias (256,) -0.09886067 -0.00428751 0.09581536
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108250685 -0.00019431567 0.10824986
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108249344 -0.00041340568 0.10825185
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108251214 0.00022752491 0.108246066
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825214 6.3473555e-05 0.10824952
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09977064 -0.0030699032 0.09987781
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09984564 -0.00022852229 0.099203825
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846523 -6.5795057e-06 0.068464
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09990223 0.0012580123 0.09982888
encoder.layer_stack.0.pos_ffn.w_2.weight

In [8]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'])
model_encdec.eval()
del checkpoint_encdec

In [9]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False, False,  True,
          True,  True, False, False]))

In [11]:
s_target = tokten_to_text(target_chunks)
s_target

"<|query_begin|>  club is best known for its men's and women's football teams.\n\nKigwancha's men presently play in the DPR Korea Premier Football League, and won several championships in the late 1990s. The club finished third in 2006 season. They have taken part in continental competition once, finishing second in its group in the group stage of the 2017 AFC Cup.\n\nRivalries\nSharing the Yanggakdo Stadium with them, Kigwancha have a rivalry with Sobaeksu.\n\nCurrent squad\n\nKnown players (including former players)\n\n Lee Chang-myung\n Pak Chol-ryong\n Pak Kwang-ryong\n\nManagers\n Ku Jong-nam (before 2014)\n Han Won-chol (since 2014)\n\nContinental history\n\nAFC clubs ranking\n\nAchievements\nDPR Korea League: 9\n 1996, 1997, 1998, 1999, 2000, 2016\n 1995, 2006, 2012\n\nHwaebul Cup: 2\n 2015\n4th 2017\n\nMan'gyŏngdae Prize: 5\n 2004, 2005\n 2015, 2016\n 2014\n\nPaektusan Prize: 1\n 2012\n\nPoch'ŏnbo Torch Prize: 3\n 2007\n 2010, 2016\n\nOther Sports <|query_end|>"

In [12]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 4020987 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> South Carolina Highway 381 <|doc_title_end|> <|doc_body_begin|> South Carolina Highw
<|doc_begin|> <|doc_id_begin|> 4020987 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  rural highway, it connects the towns of Clio and McColl.\n\nHistory\nThe highway was established in 1930 
<|doc_begin|> <|doc_id_begin|> 4020987 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  extension was dropped while it extended north again to the North Carolina state line. In 1940, SC 381
<|doc_begin|> <|doc_id_begin|> 3927047 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> 2011 Roma Open <|doc_title_end|> <|doc_body_begin|> The 2011 Roma Open was a profess
<|doc_begin|> <|doc_id_begin|> 3927047 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  Simone Bolelli\n  Nicolás Massú\n  Matteo Trevisan\n  Simone Vagnozzi\n\nThe following players 

In [13]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [14]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.228798 0.331661 0.253925 F
0.248907 0.178724 0.101498 F
0.166845 0.238102 0.259646 F
0.253971 0.188307 0.211456 F
0.242095 0.332769 0.314004 F
0.284843 0.367039 0.321652 F
0.116524 0.161117 0.189159 F
0.081947 0.117586 0.165844 F
0.177944 0.138636 0.161296 F
0.481087 0.391401 0.321746 T
0.216637 0.510886 0.401046 T
0.364164 0.500878 0.547548 T
0.110911 0.140458 0.211321 F
0.171793 0.139304 0.205401 F


In [15]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897281 6.0555874e-07 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.0975135 0.0016215969 0.09963961
vocab_encoder.layer_norm.bias (256,) -0.09937843 -0.0011398041 0.09944623
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108243056 -0.0001204217 0.10825152
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108252995 1.6006292e-05 0.10825282
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108252876 -0.000107773216 0.10825194
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108245894 0.00018707474 0.10824753
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09980541 -0.0032008472 0.09960397
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09979291 -0.00087650074 0.09982568
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846457 -2.4778845e-05 0.068464525
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09989893 -0.0024799684 0.09989349
encoders.0.la

In [16]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [30]:
i = 11
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([15, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False, False, False,
         False, False,  True,  True,  True]))

In [31]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> The Brothers Karamazov (1921 film)The Brothers Karamazov () is a 1921 German silent drama film directed by Carl Froelich and an uncredited Dimitri Buchowetzki and starring Fritz Kortner, Bernhard Goetzke, and Emil Jannings. It is an adaptation of the 1880 novel The Brothers Karamazov by Fyodor Dostoevsky.\n\nCast\n\nReferences\n\nBibliography\n\nExternal links\n\nCategory:1921 films\nCategory:German films\nCategory:Films of the Weimar Republic\nCategory:German silent feature films\nCategory:German historical drama films\nCategory:1920s historical drama films\nCategory:Films directed by Carl Froelich\nCategory:Films directed by Dimitri Buchowetzki\nCategory:Films based on The Brothers Karamazov\nCategory:Films set in Russia\nCategory:Films set in the 19th century\nCategory:UFA films\nCategory:German black-and-white films <|query_end|>'

In [19]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 4020987 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> South Carolina Highway 381 <|doc_title_end|> <|doc_body_begin|> South Carolina Highw
<|doc_begin|> <|doc_id_begin|> 4020987 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  rural highway, it connects the towns of Clio and McColl.\n\nHistory\nThe highway was established in 1930 
<|doc_begin|> <|doc_id_begin|> 4020987 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  extension was dropped while it extended north again to the North Carolina state line. In 1940, SC 381
<|doc_begin|> <|doc_id_begin|> 3927047 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> 2011 Roma Open <|doc_title_end|> <|doc_body_begin|> The 2011 Roma Open was a profess
<|doc_begin|> <|doc_id_begin|> 3927047 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  Simone Bolelli\n  Nicolás Massú\n  Matteo Trevisan\n  Simone Vagnozzi\n\nThe following players 

In [32]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [33]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

-0.024947 -0.161539 -0.289531 F
-0.018159 -0.082144 -0.175686 F
-0.014066 -0.031177 -0.111339 F
0.118753 -0.006749 -0.223778 F
0.053832 -0.056969 -0.280484 F
0.066218 0.087920 0.025193 F
-0.099301 -0.166709 -0.241161 F
-0.053571 -0.133196 -0.170445 F
-0.094052 -0.194352 -0.287038 F
0.120644 0.061044 -0.034860 F
-0.079752 -0.088093 -0.194901 F
0.096795 0.097061 0.144068 F
0.633291 0.509644 0.414697 T
0.638870 0.638089 0.629546 T
0.408674 0.488750 0.693010 T


In [41]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Truus Kerkmeester'
txt = 'The Brothers Karamazov'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
-0.207215 F
-0.109966 F
-0.107605 F
-0.067777 F
-0.095617 F
0.099357 F
-0.137719 F
-0.102855 F
-0.213599 F
-0.240572 F
-0.364570 F
-0.110474 F
0.080862 T
0.127579 T
0.349849 T


In [42]:
tokten_to_text(docs_chunks[-2])
tokten_to_text(txt_tokens)
txt_tokens.shape

torch.Size([1, 100])

In [43]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[[1.5967e-05],
         [1.0371e-05],
         [4.9786e-06],
         [1.5882e-05],
         [5.9931e-06],
         [6.3268e-06],
         [4.7893e-06],
         [5.6740e-06],
         [2.9828e-06],
         [3.9060e-06],
         [3.5584e-06],
         [8.1185e-06],
         [5.9206e-04],
         [2.4064e-05],
         [1.9815e-05]]], grad_fn=<SliceBackward0>)

In [40]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[[6.1772e-05],
         [4.4597e-04],
         [3.3278e-04],
         [3.0358e-04],
         [3.3248e-04],
         [2.7358e-04],
         [1.8146e-04],
         [1.9439e-04],
         [1.9108e-04],
         [2.1113e-03],
         [2.8011e-04],
         [1.3459e-03],
         [1.3094e-02],
         [8.0283e-01],
         [1.5070e-01]]], grad_fn=<SliceBackward0>)

In [6]:
import numpy as np
np.random.randint(1, 2, size=20)

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [31]:
np.random.randint(0, 6)

3