In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.exp.cfg import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240730_213328-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240806_221913-msmarco'
ranker_subdir = 'ranker-20240815_180317-msmarco'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [4]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [6]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [7]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 -2.312324e-06 0.010897279
vocab_encoder.layer_norm.weight (256,) -0.09936088 0.0025157412 0.09930128
vocab_encoder.layer_norm.bias (256,) -0.099341646 -0.0038342725 0.09996939
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824653 0.00019045657 0.10824982
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108253084 0.00037009074 0.1082508
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824686 -0.00019107715 0.10825228
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825217 1.51973e-05 0.10824655
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.099464975 -0.0007957523 0.09971613
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09999391 0.00015633227 0.09630709
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068465225 -6.698758e-05 0.06846504
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09969604 0.0031469457 0.0999882
encoder.layer_stack.0.pos_ffn.w_2.weight (256

In [8]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
del checkpoint_encdec

In [9]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False,  True, False, False, False, False, False, False,
         False, False, False]))

In [11]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> MakangaraweMakangarawe is an administrative ward in the Temeke district of the Dar es Salaam Region of Tanzania. According to the 2002 census, the ward has a total population of 42,332.\n\nReferences\n\nCategory:Temeke District\nCategory:Wards of Tanzania\nCategory:Populated places in Dar es Salaam Region <|query_end|>'

In [12]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 1720712 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Simon Sjödin <|doc_title_end|> <|doc_body_begin|> Simon Sjödin (born 4 October 1986)
<|doc_begin|> <|doc_id_begin|> 1720712 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  the 200 meter butterfly.\n\nIn 2007, he changed swim club from Södertörns SS to SK Neptun.\n\nIn 2008, Sim
<|doc_begin|> <|doc_id_begin|> 1720712 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  butterfly and 200 m individual medley. Simon reached the semi-final in 200 m butterfly with a time of
<|doc_begin|> <|doc_id_begin|> 3821945 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Makangarawe <|doc_title_end|> <|doc_body_begin|> Makangarawe is an administrative wa
<|doc_begin|> <|doc_id_begin|> 4317597 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> California State Route 28 <|doc_title_end|> <|doc_body_begin|> State Route 2

In [15]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [16]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.635558 F
0.663355 F
0.620507 F
0.770461 T
0.674028 F
0.673462 F
0.689574 F
0.629326 F
0.631240 F
0.648918 F
0.635148 F
0.602765 F
0.631762 F


In [17]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897281 -6.14697e-07 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.097218536 0.0008541888 0.09964408
vocab_encoder.layer_norm.bias (256,) -0.09858362 0.0009617999 0.09993384
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.1082485 -0.00027505943 0.10824111
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108251765 -0.00022737503 0.108252436
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824233 0.00028861655 0.10825155
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108250685 -0.0003508202 0.10824441
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09998933 -0.0011108869 0.09964647
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09951138 0.0024151234 0.099080265
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846395 -5.2154814e-05 0.068464555
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.099506676 -0.0013511884 0.09999602
encoders.0.laye

In [18]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [19]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False,  True, False, False, False, False, False, False,
         False, False, False]))

In [20]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> MakangaraweMakangarawe is an administrative ward in the Temeke district of the Dar es Salaam Region of Tanzania. According to the 2002 census, the ward has a total population of 42,332.\n\nReferences\n\nCategory:Temeke District\nCategory:Wards of Tanzania\nCategory:Populated places in Dar es Salaam Region <|query_end|>'

In [21]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 1720712 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Simon Sjödin <|doc_title_end|> <|doc_body_begin|> Simon Sjödin (born 4 October 1986)
<|doc_begin|> <|doc_id_begin|> 1720712 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  the 200 meter butterfly.\n\nIn 2007, he changed swim club from Södertörns SS to SK Neptun.\n\nIn 2008, Sim
<|doc_begin|> <|doc_id_begin|> 1720712 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  butterfly and 200 m individual medley. Simon reached the semi-final in 200 m butterfly with a time of
<|doc_begin|> <|doc_id_begin|> 3821945 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Makangarawe <|doc_title_end|> <|doc_body_begin|> Makangarawe is an administrative wa
<|doc_begin|> <|doc_id_begin|> 4317597 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> California State Route 28 <|doc_title_end|> <|doc_body_begin|> State Route 2

In [22]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [23]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.616676 F
0.549515 F
0.533021 F
0.962482 T
0.689291 F
0.588012 F
0.586342 F
0.421391 F
0.427308 F
0.470581 F
0.349605 F
0.233505 F
0.138538 F


In [29]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Rick Anderson'
txt = 'Makangarawe Temeke ward'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.652693 F
0.655255 F
0.652386 F
0.845256 T
0.639311 F
0.544554 F
0.556776 F
0.431144 F
0.456301 F
0.458844 F
0.432921 F
0.280780 F
0.343342 F


In [30]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank

tensor([[0.9422, 0.8513, 0.7140, 0.9978, 0.8926, 0.6926, 0.7297, 0.1455, 0.2862,
         0.2185, 0.1052, 0.0027, 0.0214]], grad_fn=<SigmoidBackward0>)

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3