In [32]:
%load_ext autoreload
%autoreload 2

In [63]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.dsfixed import DsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.model.config import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [34]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [35]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [36]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = DsLoader(
    ds_dir_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [64]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [38]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 1.2500475e-06 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.099734835 -3.7707156e-05 0.09949229
vocab_encoder.layer_norm.bias (256,) -0.09868218 0.0024448738 0.099951915
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824596 -8.811242e-05 0.10825054
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108249344 -0.00025192086 0.1082518
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825072 -0.00028878736 0.108250424
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10824778 -8.361798e-05 0.108233996
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09921829 -0.0034752465 0.09923847
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09971982 -0.0018919911 0.09789826
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846406 -9.791839e-06 0.0684653
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09992825 0.001125851 0.09986128
encoder.layer_stack.0.pos_ffn.w_2.weight 

In [39]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'])
model_encdec.eval()
del checkpoint_encdec

In [40]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [43]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([12, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False, False, False, False,  True, False, False, False,
         False, False]))

In [44]:
s_target = tokten_to_text(target_chunks)
s_target

"<|query_begin|> Track and field at the 2015 Military World Games – Men's discus throwThe men's discus throw event at the 2015 Military World Games was held on 8 October at the KAFAC Sports Complex.\n\nRecords\nPrior to this competition, the existing world and CISM record were as follows:\n\nSchedule\n\nMedalists\n\nResults\n\nFinal\n\nReferences\n\ndiscus throw <|query_end|>"

In [45]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 2723747 <|doc_id_end|> <|doc_offset_begin|> 818 <|doc_offset_end|> oo, Michigan, October 30–November 25, 2008\n\n2007\nTwo Years, Whitney Museum of American Art, New York, 
<|doc_begin|> <|doc_id_begin|> 2723747 <|doc_id_end|> <|doc_offset_begin|> 908 <|doc_offset_end|> \n\n2005\nThe Painted World, MoMA PS1, New York\n\n2001\nAs Painting: Division and Displacement, curated by 
<|doc_begin|> <|doc_id_begin|> 2723747 <|doc_id_end|> <|doc_offset_begin|> 998 <|doc_offset_end|>  York, curated by Robert Nickas\n\n1999\nThe Stroke: An Overview of Contemporary Painting, curated by Ros
<|doc_begin|> <|doc_id_begin|> 756349 <|doc_id_end|> <|doc_offset_begin|> 92 <|doc_offset_end|>  native speakers as of 2001.\n\nKannauji shares many structural and functional differences from other dial
<|doc_begin|> <|doc_id_begin|> 756349 <|doc_id_end|> <|doc_offset_begin|> 184 <|doc_offset_end|>  distribution\nKannauji is not a standard dialect of Hindi and can be assumed to be t

In [46]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [47]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.157678 F
0.200629 F
0.147184 F
0.143557 F
0.017989 F
-0.021412 F
0.397230 T
0.011577 F
0.085528 F
0.188297 F
0.340801 F
0.156483 F


In [48]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 -1.5456724e-06 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.09994151 -0.003787417 0.09852063
vocab_encoder.layer_norm.bias (256,) -0.09938567 0.002720798 0.09817394
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825065 0.0003224965 0.10825294
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825259 0.00016920731 0.1082493
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824857 -0.00027747374 0.10825134
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10824688 5.7866688e-05 0.108252995
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.099243365 -0.0010800387 0.09971624
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09992782 0.00231219 0.09992212
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846407 1.4100352e-05 0.06846438
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09947839 -3.29467e-06 0.0999241
encoders.0.layer_stack.0.p

In [49]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [51]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([12, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False, False,  True,
          True,  True]))

In [52]:
s_target = tokten_to_text(target_chunks)
s_target

"<|query_begin|>  1st prize in drawing at the Royal Academy of Painting and Sculpture.\nFrom 1668 to 1671 he lived in Italy at the Medici Villa.\nHe won the Prix de Rome for drawing in 1668 for his work Première conquête de la Franche-Comté,\nand again in 1671 for Le Roi donnant la paix à l'Europe.\nHe was admitted to the Academy in 1678.\nBy 1684 he was a Professor.\nHe exhibited at the Salon of 1704.\nHe died in Paris in 1730.\n\nŒuvres \nMany of his drawings and paintings are at the Musée du Louvre.  They include:\n Saint Paulin de Nole, drawing in red with white chalk highlights, 36.5\xa0cm  x 28.5\xa0cm\nLe triomphe de la religion, oil on canvas, 53x43cm, Paris, Musée du Louvre.\nLa Chute des anges rebelles, oil on canvas, 163x134cm, Paris, Musée du Louvre.\nA number of his drawings on religious subjects were made into prints by the engraver Nicolas-Henri Tardieu and are now held in the Museum of Fine Arts of Nancy.\n\nGallery\n\n <|query_end|>"

In [53]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 2723747 <|doc_id_end|> <|doc_offset_begin|> 1089 <|doc_offset_end|> : Aspects of Abstract Painting since 1970, curated by Lily Wei; Visual Arts Gallery, New York\n\n1996\nS
<|doc_begin|> <|doc_id_begin|> 2723747 <|doc_id_end|> <|doc_offset_begin|> 1179 <|doc_offset_end|>  by Saul Ostrow; Usdan Gallery, Bennington College, Bennington, Vermont\nNatural Process, Center Galle
<|doc_begin|> <|doc_id_begin|> 2723747 <|doc_id_end|> <|doc_offset_begin|> 1269 <|doc_offset_end|> 1994\nNew York Abstract Painting, Salvatore Ala Gallery, New York\n\n1993\nItalia – America: L’astrazione
<|doc_begin|> <|doc_id_begin|> 756349 <|doc_id_end|> <|doc_offset_begin|> 2195 <|doc_offset_end|>  after word tasla so it is an example of eco formation process. Some other examples are\n\n haldi-waldi\n
<|doc_begin|> <|doc_id_begin|> 756349 <|doc_id_end|> <|doc_offset_begin|> 2286 <|doc_offset_end|> \ndama:d  → sarka:r ko dama:d\n\ndama:d is a person who is preferred very much in his/h

In [54]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [55]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.229338 0.503317 0.487222 F
0.098186 0.371476 0.392254 F
0.184128 0.496609 0.498269 F
-0.049690 0.064302 0.018029 F
0.061651 0.016034 -0.064635 F
-0.014851 0.081973 0.004711 F
-0.261946 -0.156148 -0.126469 F
0.365198 0.179673 0.196634 F
-0.120499 -0.023255 0.088105 F
0.744169 0.534230 0.453429 T
0.650417 0.657011 0.579204 T
0.528779 0.635975 0.575344 T


In [65]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = '1st prize in drawing 1668 Italy'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


0.113991 F
0.003470 F
0.094656 F
-0.167752 F
-0.219987 F
-0.117217 F
0.085706 F
0.089210 F
0.511148 F
0.122731 T
0.121838 T
0.089691 T


In [66]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[[1.0095e-05],
         [1.2888e-05],
         [1.0771e-05],
         [2.1788e-06],
         [3.8103e-06],
         [3.2414e-06],
         [6.5914e-05],
         [2.2676e-03],
         [1.7381e-04],
         [9.8576e-05],
         [8.9405e-06],
         [5.7049e-06]]], grad_fn=<SliceBackward0>)

In [67]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[[0.0127],
         [0.0119],
         [0.0085],
         [0.0038],
         [0.0075],
         [0.0041],
         [0.0006],
         [0.0013],
         [0.0007],
         [0.1171],
         [0.5373],
         [0.2596]]], grad_fn=<SliceBackward0>)

In [6]:
import numpy as np
np.random.randint(1, 2, size=20)

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [31]:
np.random.randint(0, 6)

3