In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.exp.cfg import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [4]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240730_213328-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240806_221913-msmarco'
# ranker_subdir = 'ranker-20240815_180317-msmarco'
ranker_subdir = 'ranker-20240830_232515-msmarco-fever'
ranker_subdir = 'ranker-20240831_231551-msmarco-fever'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
# ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
ranker_train_path = DATA_PATH / 'train_mllm_ranker_qrels' / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [5]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [6]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [7]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [8]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 -2.0029593e-06 0.010897278
vocab_encoder.layer_norm.weight (256,) -0.09956843 -0.0006340821 0.09903169
vocab_encoder.layer_norm.bias (256,) -0.098254435 0.000537434 0.0998828
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824963 0.00023863433 0.10825174
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825088 -0.00010044141 0.108249225
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824873 0.00037398466 0.10824713
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825314 0.00017549611 0.108251445
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09987368 -0.0007624241 0.09937713
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09996825 0.005671611 0.09821061
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846527 1.2881355e-05 0.06846447
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09981292 -0.0021504965 0.09989592
encoder.layer_stack.0.pos_ffn.w_2.weight (25

In [9]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
del checkpoint_encdec

In [10]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [11]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([2, 100]),
 tensor([ True,  True, False, False, False, False, False, False, False, False,
         False, False, False]))

In [12]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> Tremont station (Massachusetts)Tremont station (also known as West Wareham station) was located on Mill Street in West Wareham, Massachusetts. The station was located just east of the former junction of the Cape Cod Branch Railroad and the Fairhaven Branch Railroad.\n\nSee also\nWareham Village station\nOnset station\n\nReferences\n\nExternal links\n\nCategory:Buildings and structures in Wareham, Massachusetts\nCategory:Former railway stations in Massachusetts\nCategory:Stations along Old Colony Railroad lines <|query_end|>'

In [13]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 596725 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Tremont station (Massachusetts) <|doc_title_end|> <|doc_body_begin|> Tremont station 
<|doc_begin|> <|doc_id_begin|> 596725 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  Massachusetts\nCategory:Former railway stations in Massachusetts\nCategory:Stations along Old Colony Rail
<|doc_begin|> <|doc_id_begin|> 5419405 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Vincenzo Marruocco <|doc_title_end|> <|doc_body_begin|> Vincenzo Marruocco (born 26 
<|doc_begin|> <|doc_id_begin|> 5419405 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> azzetta dello Sport player profile \n Profile at AIC.Football.it \n\nCategory:1979 births\nCategory:Sportsp
<|doc_begin|> <|doc_id_begin|> 5419405 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>.S. Salernitana 1919 players\nCategory:S.S. Chieti Calcio players\nCategory:Calcio Foggia 192

In [14]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [15]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.816061 0.632926 T
0.620468 0.952521 T
0.684363 0.559971 F
0.704369 0.581857 F
0.641972 0.551100 F
0.682536 0.475766 F
0.637498 0.520983 F
0.515539 0.858380 F
0.754775 0.533331 F
0.654432 0.770914 F
0.639969 0.569131 F
0.642570 0.606582 F
0.653763 0.593339 F


In [16]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 -2.0654518e-06 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.09975022 0.00761955 0.09980502
vocab_encoder.layer_norm.bias (256,) -0.098543584 0.0020898443 0.099739425
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825247 -0.0002024658 0.10825217
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824628 0.0005351984 0.10825148
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108252965 0.00048539537 0.10825268
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10824868 -0.00010426881 0.10824799
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09997978 -0.0027381293 0.099536434
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.099808775 -0.0068486356 0.099283114
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846521 1.4299577e-05 0.06846511
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09997245 0.0017086698 0.09983479
encoders.0.layer_s

In [17]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [18]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False, False, False,
          True,  True,  True]))

In [19]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> FIBA Europe Cup Finalist -05\nFIBA Europe Cup Final Four All-Star Team -05\nFIBA EuroCup All-Star Game -06 (3-Points Contest Winner)\nTurkish Cup Semifinals -06\nULEB Cup Semifinals -07\nRussian Cup Finalist -07\nRussian A Superleague Regular Season Runner-Up -07\nRussian A Superleague Finalist -07\nRussian A Superleague All-Newcomers Team -07\nGreek League Semifinals -08, 09\nAll-Greek League Forward of the Year -08\nAll-Greek League 1st Team -08\nGreek A1 League All-Imports Team -08\nGreek Cup Semifinals -09\nAll-Greek A1 League Forward of the Year -09\nAll-Greek A1 League 1st Team -09\nGreek A1 League All-Imports Team -09\nAdriatic League Semifinals -10\nKorean KBL Regular Season Runner-Up -11\nKorean KBL Semifinals -11\nAll-Korean KBL Forward of the Year -11\nAll-Korean KBL 1st Team -11\nKorean KBL All-Domestic Players Team -11\n\nPersonal life\nMoon earned his South Korean citizenship in 2011, alongside his brother <|query_end|>'

In [20]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 596725 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Tremont station (Massachusetts) <|doc_title_end|> <|doc_body_begin|> Tremont station 
<|doc_begin|> <|doc_id_begin|> 596725 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  Massachusetts\nCategory:Former railway stations in Massachusetts\nCategory:Stations along Old Colony Rail
<|doc_begin|> <|doc_id_begin|> 5419405 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Vincenzo Marruocco <|doc_title_end|> <|doc_body_begin|> Vincenzo Marruocco (born 26 
<|doc_begin|> <|doc_id_begin|> 5419405 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> azzetta dello Sport player profile \n Profile at AIC.Football.it \n\nCategory:1979 births\nCategory:Sportsp
<|doc_begin|> <|doc_id_begin|> 5419405 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>.S. Salernitana 1919 players\nCategory:S.S. Chieti Calcio players\nCategory:Calcio Foggia 192

In [21]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [22]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.121347 0.211301 0.151374 F
0.106359 0.134949 0.083154 F
0.542664 0.546598 0.366258 F
0.649322 0.662912 0.389068 F
0.534542 0.549585 0.469288 F
0.623445 0.669031 0.601952 F
0.650921 0.668230 0.596647 F
0.477237 0.485762 0.637490 F
0.189576 0.346300 0.208659 F
0.207065 0.369143 0.313433 F
0.980426 0.889673 0.710889 T
0.903756 0.961016 0.713289 T
0.790911 0.779341 0.923147 T


In [30]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Rick Anderson'
txt = 'Makangarawe Temeke ward'
# txt = 'graph sandwich'
txt = 'james barry'
txt = 'erigeron'
txt = 'Dillon Gabriel america'
txt = 'FIBA Europe'
# txt = 'The graph sandwich problem for property Π is defined as follows:'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.023971 F
0.225677 F
0.394549 F
0.441247 F
0.328603 F
0.478442 F
0.493279 F
0.442178 F
0.226645 F
0.348451 F
0.641930 T
0.626250 T
0.467229 T


In [31]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank

tensor([[7.3745e-04, 5.6199e-02, 1.6434e-01, 3.3226e-01, 1.1555e-01, 2.6521e-01,
         6.3644e-01, 6.2138e-01, 1.1367e-02, 9.2394e-02, 9.6314e-01, 9.2693e-01,
         2.9827e-01]], grad_fn=<SigmoidBackward0>)

In [48]:
n_epochs = 59
batch_size = 15
train_steps = 500

n_epochs * train_steps * batch_size


442500

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3

In [2]:
import torch

In [4]:
device1 = torch.device('cpu')
device2 = torch.device('cuda')
device3 = torch.device('cpu')

In [5]:
print(device1 == device2)
print(device2 == device3)
print(device1 == device3)


False
False
True


In [7]:
device1.type, type(device1.type)

('cpu', str)

In [11]:
import numpy as np
toks = [
    np.ones(3) * 0.5,
    np.ones(3) * 1.7,
    np.array([-1, 7, 33])
]
np.stack(toks, axis=0)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [12]:
np.array(toks)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [13]:
np.allclose(np.array(toks), np.stack(toks, axis=0))

True