In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.exp.cfg import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240730_213328-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240806_221913-msmarco'
ranker_subdir = 'ranker-20240814_212415-msmarco'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [4]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [6]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [7]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897281 -1.8666622e-06 0.010897279
vocab_encoder.layer_norm.weight (256,) -0.09827324 0.005363074 0.09857943
vocab_encoder.layer_norm.bias (256,) -0.099614 -0.00033356063 0.09982528
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10823792 0.0002399791 0.108251676
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824498 3.197349e-05 0.10824982
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824938 -0.00019202907 0.1082476
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108250074 3.539113e-05 0.10824882
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09914669 -0.0065606707 0.09979097
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.099993445 -0.00055581884 0.09999251
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068464994 -2.0417589e-05 0.06846526
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.099808134 -0.0007240103 0.09992351
encoder.layer_stack.0.pos_ffn.w_2.weight (

In [8]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
del checkpoint_encdec

In [9]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([12, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False, False,  True,  True, False, False, False, False,
         False, False]))

In [11]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> John Williams HarrisJohn Williams Harris (1808–4 February 1872) was a New Zealand trader, whaler and farmer. He was born in Cornwall, United Kingdom, in 1808.\n\nReferences\n\nCategory:1808 births\nCategory:1872 deaths\nCategory:New Zealand farmers\nCategory:New Zealand traders\nCategory:New Zealand whalers\nCategory:Cornish emigrants to New Zealand\nCategory:Cornish farmers <|query_end|>'

In [15]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 2510265 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Oswaldo Larriva <|doc_title_end|> <|doc_body_begin|> Óscar Oswaldo Larriva Alvarado 
<|doc_begin|> <|doc_id_begin|> 2510265 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  of Azuay Province in 2019.\n\nOn 6 January 2020, Larriva died from cancer at age 74.\n\nReferences\n\nCatego
<|doc_begin|> <|doc_id_begin|> 4439084 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Daulatpur Thikriya <|doc_title_end|> <|doc_body_begin|> Daulatpur Thikriya is a vill
<|doc_begin|> <|doc_id_begin|> 4439084 <|doc_id_end|> <|doc_offset_begin|> 90 <|doc_offset_end|>  rate of population excluding children aged 6 and below) is 60.33%.\n\nReferences \n\nCategory:Villages in 
<|doc_begin|> <|doc_id_begin|> 4480279 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> John Williams Harris <|doc_title_end|> <|doc_body_begin|> John William

In [16]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [17]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.614007 F
0.673856 F
0.625156 F
0.629489 F
0.763358 T
0.573496 T
0.632314 F
0.644705 F
0.639496 F
0.644607 F
0.652777 F
0.596580 F


In [18]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897281 -6.751595e-07 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.09936471 -0.003953435 0.09939728
vocab_encoder.layer_norm.bias (256,) -0.099776365 0.0010321124 0.09996175
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108252876 -6.340468e-05 0.108252995
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825134 0.00023663577 0.10824148
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825018 0.00010743411 0.10825088
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.1082491 7.4542084e-05 0.10824958
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09864228 -0.00070906454 0.098641984
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09989204 -0.0014073192 0.09974086
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846465 9.7233125e-05 0.06846488
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09999335 0.0012871977 0.09993503
encoders.0.layer_st

In [19]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [20]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([12, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False, False,  True,  True, False, False, False, False,
         False, False]))

In [21]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> John Williams HarrisJohn Williams Harris (1808–4 February 1872) was a New Zealand trader, whaler and farmer. He was born in Cornwall, United Kingdom, in 1808.\n\nReferences\n\nCategory:1808 births\nCategory:1872 deaths\nCategory:New Zealand farmers\nCategory:New Zealand traders\nCategory:New Zealand whalers\nCategory:Cornish emigrants to New Zealand\nCategory:Cornish farmers <|query_end|>'

In [22]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 2510265 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Oswaldo Larriva <|doc_title_end|> <|doc_body_begin|> Óscar Oswaldo Larriva Alvarado 
<|doc_begin|> <|doc_id_begin|> 2510265 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  of Azuay Province in 2019.\n\nOn 6 January 2020, Larriva died from cancer at age 74.\n\nReferences\n\nCatego
<|doc_begin|> <|doc_id_begin|> 4439084 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Daulatpur Thikriya <|doc_title_end|> <|doc_body_begin|> Daulatpur Thikriya is a vill
<|doc_begin|> <|doc_id_begin|> 4439084 <|doc_id_end|> <|doc_offset_begin|> 90 <|doc_offset_end|>  rate of population excluding children aged 6 and below) is 60.33%.\n\nReferences \n\nCategory:Villages in 
<|doc_begin|> <|doc_id_begin|> 4480279 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> John Williams Harris <|doc_title_end|> <|doc_body_begin|> John William

In [23]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [24]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.569388 F
0.464496 F
0.548942 F
0.535388 F
0.899940 T
0.519560 T
0.145180 F
0.463652 F
0.204229 F
0.175676 F
0.272292 F
0.343683 F


In [29]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Rick Anderson'
txt = 'John Williams Harris 1808'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.424732 F
0.314930 F
0.380358 F
0.516397 F
0.366578 T
0.645302 T
0.283477 F
0.387487 F
0.208015 F
0.346727 F
0.279248 F
0.332401 F


In [30]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank

tensor([[0.6312, 0.4497, 0.7059, 0.6515, 0.6643, 0.6516, 0.2519, 0.7166, 0.2107,
         0.3248, 0.3066, 0.4003]], grad_fn=<SigmoidBackward0>)

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3