In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.exp.cfg import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [18]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240730_213328-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240806_221913-msmarco'
# ranker_subdir = 'ranker-20240815_180317-msmarco'
ranker_subdir = 'ranker-20240830_232515-msmarco-fever'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
# ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
ranker_train_path = DATA_PATH / 'train_mllm_ranker_qrels' / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [4]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [6]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [7]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 -7.307192e-07 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.09903387 -0.00037897477 0.099014
vocab_encoder.layer_norm.bias (256,) -0.09947115 0.0054700524 0.0999892
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825222 3.0783922e-06 0.108249314
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108250104 0.00014931659 0.10824773
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108249344 0.00015622671 0.10824616
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825161 -0.00036529166 0.10825222
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.099452436 -0.0011238996 0.09880356
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09775269 -0.0021109863 0.09953707
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846505 -0.00024030457 0.06846458
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09998661 1.6074628e-06 0.099938095
encoder.layer_stack.0.pos_ffn.w_2.weight

In [8]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
del checkpoint_encdec

In [9]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([2, 100]),
 tensor([False, False, False, False, False, False, False, False, False, False,
         False, False,  True,  True]))

In [11]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> Coleophora squamosellaColeophora squamosella is a moth of the family Coleophoridae. It is found in Europe (from Great Britain to Poland and Hungary and from Fennoscandia to France, Italy and Austria), the Baltic states, the Caucasus, Russia (Baikal and Altai) and Turkey.\n\nThe wingspan is 11–13\xa0mm. Adults are on wing in June and July.\n\nThe larvae feed in a case on Erigeron species, including Erigeron acer.\n\nReferences\n\nExternal links\n Lepiforum.de\n\nsquamosella\nCategory:Moths described in 1856\nCategory:Moths of Asia\nCategory:Moths of Europe\nCategory:Taxa named by Henry Tibbats Stainton <|query_end|>'

In [13]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 363907 <|doc_id_end|> <|doc_offset_begin|> 818 <|doc_offset_end|>  the combined efforts of a strengthened Graystripe and newly trained Millie is enough to fight them off
<|doc_begin|> <|doc_id_begin|> 363907 <|doc_id_end|> <|doc_offset_begin|> 908 <|doc_offset_end|>  series.\n\nCritical Reaction\nThe Lost Warrior was praised by Publishers Weekly, which felt that "Many li
<|doc_begin|> <|doc_id_begin|> 363907 <|doc_id_end|> <|doc_offset_begin|> 998 <|doc_offset_end|> seaux praised both the writing and artwork of the book: "Writer Dan Jolley hits the ground running, sho
<|doc_begin|> <|doc_id_begin|> 1353355 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|> (V, E2) if \nE1 ⊆ E ⊆ E2.\nThe graph sandwich problem for property Π is defined as follows:\n\nGraph Sandw
<|doc_begin|> <|doc_id_begin|> 1353355 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>  ⊆ E ⊆ E2 and G satisfies property Π?\n\nThe recognition problem for a class of graphs (tho

In [14]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [15]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.685041 0.596315 F
0.645687 0.553201 F
0.640477 0.568395 F
0.708415 0.589032 F
0.700093 0.562377 F
0.671513 0.563533 F
0.651670 0.515344 F
0.662804 0.521132 F
0.600119 0.680734 F
0.632610 0.550342 F
0.599716 0.575529 F
0.701156 0.560992 F
0.835970 0.634110 T
0.710499 0.725314 T


In [16]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897278 1.1813366e-06 0.010897279
vocab_encoder.layer_norm.weight (256,) -0.099263765 -0.0056994637 0.09828945
vocab_encoder.layer_norm.bias (256,) -0.09996396 -0.0025367192 0.09967184
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825035 -0.00019270678 0.10825103
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108252995 0.0003052362 0.10825147
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824431 0.00021790384 0.10825293
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108249806 -0.0001041878 0.10825263
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09909047 0.002910266 0.09988958
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09803982 0.00083962997 0.09796279
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846532 -5.16994e-05 0.06846457
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09976538 0.0021414077 0.0999813
encoders.0.layer_stack

In [19]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [20]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False, False,  True,
          True,  True, False, False]))

In [21]:
s_target = tokten_to_text(target_chunks)
s_target

"<|query_begin|>  only a couple of months to help him physically prepare for the role. He first went well over the weight required and created concern over whether he would look right for the part. Bale recognized that his large physique was not appropriate for Batman, who relies on speed and strategy. He lost the excess weight by the time filming began. Bale trained in Wing Chun Kung Fu under Eric Oram in preparation for the movie. Child actor Gus Lewis portrays an 8-year-old Bruce at the beginning of the film.\n\nBale reprised the role of Batman in the sequel The Dark Knight, released on July 18, 2008. He trained in the Keysi Fighting Method, and performed many of his own stunts. He reprised the role again for the sequel The Dark Knight Rises, released on July 20, 2012. Bale became the actor to have portrayed Batman on film for the lengthiest period. Following the shooting at a midnight showing of The Dark Knight Rises, he visited survivors of the movie theater in an Aurora, Colorado

In [22]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 363907 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  James Barry.\n\nPlot summary\nThe Lost Warrior opens with narration from Graystripe, a warrior who was sep
<|doc_begin|> <|doc_id_begin|> 363907 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  attempt to flee but gets lost in Twolegplace and battles with a kittypet named Duke. After being force
<|doc_begin|> <|doc_id_begin|> 363907 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>  Twolegplace. She then shows it to Graystripe and asks him to teach her how to hunt and fight after lea
<|doc_begin|> <|doc_id_begin|> 1353355 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|> (V, E2) if \nE1 ⊆ E ⊆ E2.\nThe graph sandwich problem for property Π is defined as follows:\n\nGraph Sandw
<|doc_begin|> <|doc_id_begin|> 1353355 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>  ⊆ E ⊆ E2 and G satisfies property Π?\n\nThe recognition problem for a class of graphs (tho

In [23]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [24]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.719577 0.795197 0.721252 F
0.601030 0.647313 0.544734 F
0.799284 0.755250 0.678450 F
0.044161 0.152418 0.302960 F
0.076853 0.116556 0.247412 F
0.068192 0.091181 0.179821 F
0.531020 0.374339 0.441042 F
0.572577 0.367769 0.423692 F
0.377396 0.291212 0.371668 F
0.969273 0.852169 0.780423 T
0.839205 0.974415 0.878834 T
0.804989 0.911748 0.950284 T
0.234581 0.394426 0.307070 F
0.100861 0.176553 0.153062 F


In [46]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Rick Anderson'
txt = 'Makangarawe Temeke ward'
# txt = 'graph sandwich'
txt = 'james barry'
txt = 'erigeron'
txt = 'Dillon Gabriel america'
txt = 'graph sandwich'
# txt = 'The graph sandwich problem for property Π is defined as follows:'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.198856 F
0.254959 F
0.152381 F
0.598274 F
0.535477 F
0.416071 F
0.066260 F
0.030352 F
0.191550 F
0.132326 T
0.280754 T
0.288930 T
0.282488 F
0.348249 F


In [47]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank

tensor([[0.0357, 0.1303, 0.0144, 0.9366, 0.8794, 0.5010, 0.0042, 0.0024, 0.0354,
         0.0126, 0.0750, 0.0531, 0.0980, 0.1661]], grad_fn=<SigmoidBackward0>)

In [48]:
n_epochs = 59
batch_size = 15
train_steps = 500

n_epochs * train_steps * batch_size


442500

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3