In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRankerLevel
from mllm.config.model import TokenizerCfg, MllmEncdecCfg, MllmRankerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, ChunkTokenizer, tokenizer_from_config



In [14]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

# TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels'
# encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240903_215749-msmarco-fever'

TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec_0'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels_0'
encdec_subdir = 'encdec-lvl0-20241029_140645-wiki_20200501_en-ch_100_fixed-enc-lrs3-embmatFalse-d256-h8-dec-lrs3-seqlen100-d256-h8-vocdecTrue'
ranker_subdir = 'ranker-lvl0-20241030_230226-msmarco-fever-enc-lrs3-embmatFalse-d256-h8-dec-lrs0-d256-h8'

encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
ranker_snapshot_fpath = ranker_train_path / 'best.pth'
encdec_tkz_cfg_fpath = encdec_train_path / TOKENIZER_CFG_FNAME
ranker_tkz_cfg_fpath = ranker_train_path / TOKENIZER_CFG_FNAME
encdec_model_cfg_fpath = encdec_train_path / ENCDEC_MODEL_CFG_FNAME
ranker_model_cfg_fpath = ranker_train_path / RANKER_MODEL_CFG_FNAME

In [15]:
docs_batch_size = 5
max_chunks_per_doc = 3
model_level = 0
device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [16]:
encdec_tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_tkz_cfg_fpath)
ranker_tkz_cfg = parse_yaml_file_as(TokenizerCfg, ranker_tkz_cfg_fpath)
assert encdec_tkz_cfg == ranker_tkz_cfg
tokenizer = tokenizer_from_config(encdec_tkz_cfg)
tok_dict = encdec_tkz_cfg.custom_tokens
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind

In [17]:
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [18]:
def tokens_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [19]:
model_encdec_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_model_cfg_fpath)
model_encdec = MllmEncdecLevel(model_encdec_cfg, model_level).to(device)
checkpoint_encdec = torch.load(encdec_snapshot_fpath)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 1.1518982e-09 0.010897171
vocab_encoder.layer_norm.weight (256,) -0.09979006 -0.004591978 0.09982232
vocab_encoder.layer_norm.bias (256,) -0.0992771 0.0019012922 0.099865995
vocab_decoder.word_prj.weight (50271, 256) -0.010897173 -5.5364016e-08 0.010897173
encoder.a_em () 0.07837709 0.07837709 0.07837709
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.1082476 0.00034785352 0.10825298
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825254 -0.00014614759 0.10825209
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825053 -0.00016267816 0.108252786
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.1082332 0.00021137326 0.10824482
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09671755 0.003842628 0.09890232
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09996575 -0.008670668 0.09937097
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068464905 -6.7586116e-06 0.0684650

In [20]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [21]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([12, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False, False, False, False, False, False,  True, False,
         False, False]))

In [22]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|> OsterbergerOsterberger is a surname. Notable people with the surname include:\n\nAndré Osterberger (1920–2009), French hammer thrower\nKenneth Osterberger (1930–2016), American politician <|query_end|>'

In [23]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 3827117 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Clark Township, Atchison County, Missouri <|doc_title_end|> <|doc_body_begin|> Clark
<|doc_begin|> <|doc_id_begin|> 3827117 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  and Pleasant Ridge.\n\nThe streams of Cow Branch, Old Channel Nishnabotna River and Rock Creek run throu
<|doc_begin|> <|doc_id_begin|> 3827117 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  in Missouri <|doc_body_end|> <|doc_end|>
<|doc_begin|> <|doc_id_begin|> 3207145 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  Tokyo. The club rapidly grew, climbing the Japanese football pyramid in five years.\n\nThe 2011 Tohoku 
<|doc_begin|> <|doc_id_begin|> 3207145 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>  of using their own stadium, which was used as a shelter, Cobaltore players - especially the Japanese 
<|doc_begin|> <|doc_id_begin|> 3207145 <|doc_id_end|

In [24]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [25]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.228719 F
0.216974 F
0.450167 F
0.309689 F
0.013322 F
0.088775 F
0.185869 F
0.514710 F
0.360943 T
0.245026 F
0.129973 F
0.220127 F


In [26]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = parse_yaml_file_as(MllmRankerCfg, ranker_model_cfg_fpath)
model_ranker = MllmRankerLevel(model_ranker_cfg, model_level).to(device)
checkpoint_ranker = torch.load(ranker_snapshot_fpath)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 1.0176591e-06 0.010897173
vocab_encoder.layer_norm.weight (256,) -0.099841975 -0.0012195851 0.09988719
vocab_encoder.layer_norm.bias (256,) -0.099575795 0.0024385962 0.09852242
encoder.a_em () 0.031294096 0.031294096 0.031294096
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825192 -0.0002583892 0.10825303
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824997 0.00010487764 0.10825076
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825076 -0.0001962418 0.10825241
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108247586 5.2482348e-05 0.10825072
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09981437 -0.0020973412 0.09952064
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09966959 -0.006852148 0.09864878
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068464585 1.5866732e-05 0.068464905
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09997828 -0.00064508966

In [48]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([12, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False, False,  True,
          True,  True]))

In [49]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|> Ryan McCurdyRyan McCurdy (born 23 May 1991) is a professional footballer who plays as a centre-back.\n\nClub career\n\nKingston FC\nIn 2012, McCurdy signed with Kingston FC in the Canadian Soccer League, and in 2013 featured in the CSL Championship final.\n\nKingston Clippers\nIn 2015 and 2016, McCurdy played for League1 Ontario side Kingston Clippers. In 2016, he made eight appearances for Kingston in league play.\n\nVictoria Higlanders\nIn 2017, McCurdy captained PDL side Victoria Highlanders, making fourteen appearances.\n\nPacific FC\nOn 17 April 2019, McCurdy signed with Canadian Premier League side Pacific FC. On 29 April 2019, McCurdy made his professional debut for Pacific as a substitute in a 1–0 win over HFX Wanderers. He made a total of ten league appearances for Pacific that season. On 4 November 2019, the club announced it would not be offering McCurdy a new contract for the following season.\n\nReferences\n\nExternal links\n\nCategory:1991 births\nCategor

In [57]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 3827117 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Clark Township, Atchison County, Missouri <|doc_title_end|> <|doc_body_begin|> Clark
<|doc_begin|> <|doc_id_begin|> 3827117 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  and Pleasant Ridge.\n\nThe streams of Cow Branch, Old Channel Nishnabotna River and Rock Creek run throu
<|doc_begin|> <|doc_id_begin|> 3827117 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  in Missouri <|doc_body_end|> <|doc_end|>
<|doc_begin|> <|doc_id_begin|> 3207145 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  Tokyo. The club rapidly grew, climbing the Japanese football pyramid in five years.\n\nThe 2011 Tohoku 
<|doc_begin|> <|doc_id_begin|> 3207145 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>  of using their own stadium, which was used as a shelter, Cobaltore players - especially the Japanese 
<|doc_begin|> <|doc_id_begin|> 3207145 <|doc_id_end|

In [58]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [59]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.259929 0.198473 0.211599 F
0.214228 0.204655 0.272802 F
0.147970 0.253623 0.364448 F
0.325724 0.514725 0.540035 F
0.350869 0.493751 0.501141 F
0.576240 0.687341 0.609541 F
0.253276 0.273880 0.550044 F
0.249219 0.347751 0.539483 F
0.490651 0.303308 0.554951 F
0.807451 0.773970 0.613117 T
0.656386 0.777317 0.622798 T
0.631666 0.702860 0.766746 T


In [66]:
txt = 'El Charco del Cura Reservoir'
# txt = 'Théodore Eugène César Ruyssen'
# txt = 'Théodore Ruyssen'
txt = 'orders in certain situations, and both required trade-offs'
txt = 'Tritordeum hybrid crop'
txt = 'etienne Lecroart ddd'
txt = 'Ryan McCurdy'
# txt = 'Osterberger'
txt = 'Inge Glashörster'

# txt = 'The graph sandwich problem for property Π is defined as follows:'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.126053 F
0.104639 F
0.176890 F
0.075464 F
0.085869 F
0.099848 F
0.568349 F
0.253633 F
0.432618 F
0.131461 T
0.046272 T
0.021291 T


In [67]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank_str = [f'{r:.06f}' for r in rank.flatten()]
rank_str

['0.000310',
 '0.000183',
 '0.002241',
 '0.000254',
 '0.000322',
 '0.000682',
 '0.693973',
 '0.005934',
 '0.294509',
 '0.001169',
 '0.000261',
 '0.000161']

In [50]:
docs_chunks.shape

torch.Size([14, 100])

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3

In [2]:
import torch

In [4]:
device1 = torch.device('cpu')
device2 = torch.device('cuda')
device3 = torch.device('cpu')

In [5]:
print(device1 == device2)
print(device2 == device3)
print(device1 == device3)


False
False
True


In [7]:
device1.type, type(device1.type)

('cpu', str)

In [11]:
import numpy as np
toks = [
    np.ones(3) * 0.5,
    np.ones(3) * 1.7,
    np.array([-1, 7, 33])
]
np.stack(toks, axis=0)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [12]:
np.array(toks)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [13]:
np.allclose(np.array(toks), np.stack(toks, axis=0))

True