In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.exp.cfg import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240730_213328-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240806_221913-msmarco'
# ranker_subdir = 'ranker-20240815_180317-msmarco'
ranker_subdir = 'ranker-20240903_215749-msmarco-fever'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
# ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
ranker_train_path = DATA_PATH / 'train_mllm_ranker_qrels' / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [4]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [6]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [7]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 6.494323e-07 0.010897274
vocab_encoder.layer_norm.weight (256,) -0.09865731 -0.001520379 0.099232934
vocab_encoder.layer_norm.bias (256,) -0.0992744 -0.00053189404 0.09931554
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.1082497 9.433067e-05 0.10825228
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108243585 0.0002877548 0.10825116
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825165 -8.560575e-05 0.10825129
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.1082527 0.00017292064 0.10824915
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.0987946 0.0008891132 0.09902761
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09978663 -0.009286782 0.098430514
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846424 -4.3240772e-05 0.068464406
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.099947274 0.0028740857 0.09994185
encoder.layer_stack.0.pos_ffn.w_2.weight (256, 1

In [8]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
del checkpoint_encdec

In [9]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False,  True,  True,  True, False, False, False, False,
         False, False, False, False]))

In [11]:
s_target = tokten_to_text(target_chunks)
s_target

"<|query_begin|> California's 16th State Assembly districtCalifornia's 16th State Assembly district is one of 80 California State Assembly districts. It is currently represented by Democrat Rebecca Bauer-Kahan of Orinda.\n\nDistrict profile \nThe district is located in the East Bay. It consists of the primarily affluent suburbs east of the Berkeley Hills, including Lamorinda and the Tri-Valley. During Catharine Baker's time in office, it was the most Democratic seat held by a Republican in the Assembly.\n\nAlameda County – 13.3% of Alameda County population\n Dublin\n Livermore\n Pleasanton\n\nContra Costa County – 25.3% of Contra Costa County population\n Alamo\n Blackhawk\n Danville\n Diablo\n Lafayette\n Moraga\n Orinda\n Saranap\n San Ramon\n Walnut Creek – 82.5% of Walnut Creek population included\n\nElection results from statewide races\n\nElection results\n\n2020\n\n2018\n\n2016\n\n2014\n\n2012\n\nSee also \n California State Assembly\n California State Assembly districts\n Dist

In [12]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 3357836 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  The show also follows Willie's attempts to win back his former fiancé from his nemesis. Many of Bushw
<|doc_begin|> <|doc_id_begin|> 3357836 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>  predominantly Latino, though a much larger audience will relate to the themes in the series. The show
<|doc_begin|> <|doc_id_begin|> 3357836 <|doc_id_end|> <|doc_offset_begin|> 364 <|doc_offset_end|>  Award for "Best Web Series: Comedy." East Willy B was also highlighted in IndieWire as one of the thi
<|doc_begin|> <|doc_id_begin|> 427145 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> California's 16th State Assembly district <|doc_title_end|> <|doc_body_begin|> Califo
<|doc_begin|> <|doc_id_begin|> 427145 <|doc_id_end|> <|doc_offset_begin|> 92 <|doc_offset_end|>  most Democratic seat held by a Republican in the Assembly.\n\nAlameda County – 13.3% of Alameda Cou

In [13]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [14]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.659873 0.625968 0.588829 F
0.649063 0.515043 0.551877 F
0.638165 0.640675 0.606706 F
0.761945 0.688867 0.711992 T
0.708158 0.748468 0.746569 T
0.672655 0.724277 0.768461 T
0.675588 0.682079 0.598821 F
0.630862 0.613829 0.549941 F
0.631237 0.662260 0.596649 F
0.603324 0.614365 0.643579 F
0.643235 0.683757 0.673601 F
0.633847 0.679635 0.626069 F
0.638340 0.628420 0.610957 F
0.573513 0.547358 0.578394 F


In [15]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897281 -2.6339873e-07 0.010897276
vocab_encoder.layer_norm.weight (256,) -0.09994521 -0.0027807415 0.09999938
vocab_encoder.layer_norm.bias (256,) -0.09996406 0.003614617 0.09834387
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.1082469 -0.00036873628 0.10824658
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825051 -0.0001931882 0.10825238
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824983 -0.00018752666 0.108235635
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825132 -0.0007193911 0.1082494
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09955348 0.0025702321 0.09993328
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09987169 -0.00085548125 0.09968227
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846416 -7.183825e-05 0.06846529
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09973278 0.0007764242 0.09967058
encoders.0.layer_stac

In [16]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [17]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([3, 100]),
 tensor([ True,  True,  True, False, False, False, False, False, False, False,
         False, False, False, False]))

In [18]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> East Willy BEast Willy B is a comedic Web series that profiles Latinos in the gentrifying neighborhood of Bushwick in Brooklyn, New York City. Described as a Puerto Rican Cheers, its episodes generally run six to nine minutes online and tackle themes such as love, race, gentrification, entrepreneurship, and more. The series aims to promote Latino voices in media outlets.\n\nThe series\nThe series, created by Julia Ahumada Grob and Yamin Segal, follows the life of Willie Jr. and his bartender friend, Ceci Rivera. The show documents the process where a Latino neighborhood starts to change, with young hipsters moving to Bushwick in search of the next hot Brooklyn neighborhood. As rents increase in Bushwick, Brooklyn (or "East Williamsburg") due to gentrification, Ceci and Willie organize their neighborhood to save Willie\'s bar. The show also follows Willie\'s attempts to win back his former fiancé from his nemesis. Many of Bushwick\'s "quirky characters" are included in 

In [19]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 3357836 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> East Willy B <|doc_title_end|> <|doc_body_begin|> East Willy B is a comedic Web seri
<|doc_begin|> <|doc_id_begin|> 3357836 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  Julia Ahumada Grob and Yamin Segal, follows the life of Willie Jr. and his bartender friend, Ceci Rive
<|doc_begin|> <|doc_id_begin|> 3357836 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  The show also follows Willie's attempts to win back his former fiancé from his nemesis. Many of Bushw
<|doc_begin|> <|doc_id_begin|> 427145 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> California's 16th State Assembly district <|doc_title_end|> <|doc_body_begin|> Califo
<|doc_begin|> <|doc_id_begin|> 427145 <|doc_id_end|> <|doc_offset_begin|> 92 <|doc_offset_end|>  most Democratic seat held by a Republican in the Assembly.\n\nAlameda County – 13.3% of Alameda Cou

In [20]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [21]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.896406 0.764625 0.658545 T
0.804114 0.923876 0.671374 T
0.765415 0.808461 0.869088 T
0.715746 0.657727 0.467861 F
0.601611 0.647713 0.413840 F
0.516053 0.485636 0.232942 F
0.399742 0.406179 0.354653 F
0.463552 0.525545 0.435180 F
0.569390 0.681343 0.548763 F
0.549657 0.498152 0.543873 F
0.717221 0.709256 0.709675 F
0.445960 0.375591 0.521954 F
0.273253 0.298420 0.504144 F
0.076357 0.108437 0.267569 F


In [44]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Rick Anderson'
txt = 'Makangarawe Temeke ward'
# txt = 'graph sandwich'
txt = 'james barry'
txt = 'erigeron'
txt = 'Dillon Gabriel america'
txt = 'East Willy B'
txt = 'Prataprao Gujar'
# txt = 'The graph sandwich problem for property Π is defined as follows:'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.009376 T
0.085239 T
0.211096 T
0.133755 F
0.143394 F
0.122204 F
0.227482 F
0.067510 F
0.046468 F
0.230950 F
0.220598 F
0.303955 F
0.670864 F
0.703972 F


In [45]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank

tensor([[1.8170e-05, 8.9663e-05, 1.4852e-03, 9.0462e-04, 4.0314e-04, 2.9787e-04,
         7.4527e-03, 1.6140e-04, 1.0724e-04, 2.7201e-03, 2.8551e-04, 1.0182e-02,
         9.8398e-01, 9.8504e-01]], grad_fn=<SigmoidBackward0>)

In [48]:
n_epochs = 59
batch_size = 15
train_steps = 500

n_epochs * train_steps * batch_size


442500

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3

In [2]:
import torch

In [4]:
device1 = torch.device('cpu')
device2 = torch.device('cuda')
device3 = torch.device('cpu')

In [5]:
print(device1 == device2)
print(device2 == device3)
print(device1 == device3)


False
False
True


In [7]:
device1.type, type(device1.type)

('cpu', str)

In [11]:
import numpy as np
toks = [
    np.ones(3) * 0.5,
    np.ones(3) * 1.7,
    np.array([-1, 7, 33])
]
np.stack(toks, axis=0)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [12]:
np.array(toks)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [13]:
np.allclose(np.array(toks), np.stack(toks, axis=0))

True