In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.dsfixed import DsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.model.config import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [19]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240730_213328-wiki_20200501_en-ch_100_fixed'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [5]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = DsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [8]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [9]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897279 -9.7442e-07 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.099923074 -0.001466922 0.09995844
vocab_encoder.layer_norm.bias (256,) -0.09973998 -0.0070997463 0.09975698
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824979 -0.00017184466 0.10825187
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825093 -0.0006119554 0.108244695
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825137 -0.00023912507 0.10825263
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10824232 -0.0001345489 0.108251765
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.097872056 -0.0016283693 0.099423945
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09898015 0.002898208 0.099626236
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846511 -9.5868134e-05 0.068464756
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09986945 0.00046189176 0.099934734
encoder.layer_stack.0.pos_ffn.w_2.weig

In [10]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'])
model_encdec.eval()
del checkpoint_encdec

In [11]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [12]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False, False,  True,
          True,  True, False]))

In [13]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|>  Extensive property destruction was caused.\n A magnitude 6.3 earthquake struck off the east coast of Taiwan on April 17 at a depth of 41.6\xa0km.\n A magnitude 6.0 earthquake struck Tonga on April 21 at a depth of 85.0\xa0km.\n A magnitude 7.0 earthquake struck off the east coast of Taiwan on April 24 at a depth of 22.2\xa0km. The shock had a maximum intensity of VII (Very strong). 4 people were killed and 11 were injured. Some damage was reported.\n A magnitude 7.5 earthquake struck Mindoro, Philippines on April 25 at a depth of 25.0\xa0km. The shock had a maximum intensity of VIII (Severe). Some damage was reported.\n A magnitude 6.0 aftershock struck off the north coast of Mindoro, Philippines on April 27 at a depth of 25.0\xa0km.\n A magnitude 7.2 earthquake struck west of Bougainville Island, Papua New Guinea on April 28 at a depth of 409.9\xa0km.\n A magnitude 6.0 aftershock struck off the north coast of Mindoro, Philippines on April 30 at a depth of 25.0\xa0km.

In [14]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 2690797 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|> atsby. The hotel is part of the Hilton Hotels & Resorts chain.\n\nHistory\n\n1869–1924\nLouis Seelbach and 
<|doc_begin|> <|doc_id_begin|> 2690797 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>  but after turning 22 in 1874, he realized that he had greater ambitions. He opened the Seelbach Bar &
<|doc_begin|> <|doc_id_begin|> 2690797 <|doc_id_end|> <|doc_offset_begin|> 364 <|doc_offset_end|>  intent on building Louisville's first grand hotel: a hotel reflecting the opulence of European hotels
<|doc_begin|> <|doc_id_begin|> 848926 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Hoàng Thúc Hào <|doc_title_end|> <|doc_body_begin|> Hoang Thuc Hao (born 1971 in Hano
<|doc_begin|> <|doc_id_begin|> 848926 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  Hanoi University for Civil Engineering. He is also executive member of the Association of Viet

In [15]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [16]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.248495 0.181886 0.157517 F
0.286773 0.272802 0.233572 F
0.247082 0.229371 0.248524 F
0.169080 0.180084 0.225005 F
0.208775 0.255168 0.226181 F
0.209460 0.226597 0.229089 F
0.398413 0.479122 0.362724 F
0.389103 0.477254 0.432144 F
0.473069 0.408184 0.372819 F
0.541386 0.480275 0.534072 T
0.512896 0.454261 0.482845 T
0.497160 0.549742 0.588810 T
0.090437 0.068219 0.241406 F


In [20]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 -5.154602e-07 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.099380784 -0.0002469452 0.09851217
vocab_encoder.layer_norm.bias (256,) -0.0997442 -0.0034882382 0.09990348
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108252466 -0.00020913366 0.10825272
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825106 -0.000111831825 0.10825004
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825122 -6.3578846e-05 0.10824707
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825268 0.00020638958 0.10824436
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09917112 0.0011425647 0.096491575
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09920192 -0.0027629337 0.09965416
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068464875 -8.016823e-06 0.06846483
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.0999907 -0.0017177719 0.09998728
encoders.0.laye

In [21]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [22]:
i = 11
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False,  True,  True,
          True, False, False, False]))

In [23]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> Roslin HouseRoslin House in Haverford, Pennsylvania, United States, was built in 1911 for Horace B. Forman Jr. and his wife, Elizabeth Chandlee Forman. The Philadelphia Quaker architect William L. Price designed and built the house from sketches provided by the Formans. Dr. H. Chandlee Forman, son of Horace and Elizabeth, donated the house to Haverford College in 1948.\n\nSince then, it has served as Haverford College\'s La Casa Hispánica, "a Special Interest House which supports the endeavors of students actively engaged in organizing programs concerned with the cultures and civilizations of the Spanish-speaking world." \n\nBuilt of cut stone, the three-story Gothic Revival house is designed after Roslin Castle and Chapel in Scotland. Its font (east) facade features a pair of extended gabled pavilions. A two-story turret containing the main staircase, and a hooded entry porch project from the southern extended gable. Extending south perpendicularly from this gable is 

In [24]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 2666740 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Yeh Raat Phir Na Aayegi (1966 film) <|doc_title_end|> <|doc_body_begin|> Yeh Raat Ph
<|doc_begin|> <|doc_id_begin|> 2666740 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> ir Na Aayegi is a rivetting film about a woman's two thousand year old skeleton mysteriously coming to 
<|doc_begin|> <|doc_id_begin|> 2666740 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  is presently engaged to Reeta as her fiancé.\n\nCast\n\nBiswajit Chatterjee as Suraj\nSharmila Tagore as K
<|doc_begin|> <|doc_id_begin|> 1400412 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Joke de Korte <|doc_title_end|> <|doc_body_begin|> Johanna Catharina "Joke" de Korte
<|doc_begin|> <|doc_id_begin|> 1400412 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> Category:Female backstroke swimmers\nCategory:Female freestyle swimmers\nCategory:Olympic swim

In [25]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [26]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.138374 0.130108 -0.014807 F
0.071871 0.398005 0.294738 F
-0.003740 0.095608 0.033381 F
0.167234 -0.002258 0.111514 F
0.066184 -0.018406 0.229574 F
-0.016917 -0.159324 0.008091 F
0.147080 -0.058094 0.110920 F
0.107667 -0.092308 0.062023 F
0.991219 0.575976 0.610025 T
0.607465 0.957985 0.603230 T
0.620492 0.755222 0.952388 T
0.258288 0.393751 0.197679 F
0.359407 0.244877 0.236895 F
0.354192 0.477126 0.392710 F


In [43]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Truus Kerkmeester'
txt = 'The Brothers Karamazov'
txt = 'Roslin House in Haverford, Pennsylvania, United States, was built in 1911 for Horace B. Forman Jr.'
txt = 'Roslin House Haverford Pennsylvania'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.070620 F
0.040052 F
-0.014940 F
0.231204 F
0.192953 F
0.008095 F
0.015416 F
0.134448 F
0.527549 T
0.347985 T
0.420489 T
0.145632 F
0.069274 F
0.087963 F


In [44]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[8.1955e-15, 6.4947e-16, 7.8026e-19, 1.9788e-13, 3.8804e-10, 2.0001e-21,
         8.6649e-20, 5.8725e-16, 4.4542e-02, 2.0519e-08, 1.7612e-05, 2.6792e-09,
         6.1470e-14, 4.2783e-17]], grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3