In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.model.config import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [6]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240730_213328-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240806_221913-msmarco'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [7]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [8]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [11]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [12]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 -1.6511602e-06 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.099946104 -0.0034745648 0.099763446
vocab_encoder.layer_norm.bias (256,) -0.09894885 -0.0039836653 0.09960281
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825194 4.181369e-05 0.10825289
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824983 -0.00027774877 0.10825201
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825258 -0.00053658854 0.108251445
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108252816 0.0003341165 0.10825214
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09852302 -0.000749683 0.098978974
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09951661 0.002116121 0.09881059
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068464994 -2.3957578e-05 0.06846508
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09974899 -0.002452617 0.099912666
encoder.layer_stack.0.pos_ffn.w_2.weigh

In [14]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
del checkpoint_encdec

In [15]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [16]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False,  True,  True,  True, False,
         False, False, False, False]))

In [17]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> t. Subhadramma and her husband Sri R. Venkatarao established Sri Venkateshwara Natya Mandali (Surabhi) in 1937 in Jimidipeta village of Srikakulam District of Andhra Pradesh. The wife and husband were assisted by their children Dasaradhirao and Bhojaraju. The theatre group started growing and is now one of the biggest surviving groups with 55 members. Smt. R. Subhadramma has specialised in doing male roles, particularly characters like Duryodhana in Mahabharata. She was awarded the title of \'Kala Praveena\' by Sangeeta Nataka Akademi of Andhra Pradesh. In addition, both the husband and wife received many honors from various organizations of the state. Since their death, their sons are now managing the theatre.\n\nUnder the guidance of Padma Shri B.V. Karanth, the organization learned three plays: Bhishma (1996), organized by the National School of Drama (New Delhi), Chandi Priya (1997) by Alarippu (New Delhi), and Basthi Devatha Yaadamma ("The Good Women of Setzuan" w

In [18]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 2025578 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>   Connecticut Route 127\n  Florida State Road 127 (former)\n  County Road 127 (Baker County, Florida)\n  G
<|doc_begin|> <|doc_id_begin|> 2025578 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  127\n  Massachusetts Route 127A\n  Minnesota State Highway 127 (former)\n  Missouri Route 127\n  New Hamp
<|doc_begin|> <|doc_id_begin|> 2025578 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>, New York)\n  County Route 127 (Herkimer County, New York)\n  County Route 127 (Monroe County, New York)
<|doc_begin|> <|doc_id_begin|> 1443391 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Alfred Chandler (politician) <|doc_title_end|> <|doc_body_begin|> Alfred Elliott Cha
<|doc_begin|> <|doc_id_begin|> 1443391 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> arried on 27 August 1901 to Marie Intermann, with whom he had five children. He served on 

In [19]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [20]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.576479 0.540214 0.615277 F
0.509272 0.501767 0.612942 F
0.563227 0.503132 0.628640 F
0.543316 0.550497 0.614628 F
0.494153 0.609150 0.630751 F
0.515145 0.584256 0.653714 F
0.736585 0.688681 0.686261 T
0.707597 0.672014 0.650622 T
0.713348 0.683684 0.726898 T
0.657356 0.637471 0.661419 F
0.571665 0.600492 0.631132 F
0.537011 0.562579 0.624428 F
0.630608 0.579992 0.644119 F
0.610991 0.606930 0.554952 F


In [21]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897276 2.9072498e-06 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.0995786 0.005689818 0.098533414
vocab_encoder.layer_norm.bias (256,) -0.099617876 -0.0019110441 0.09960973
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825247 5.3248332e-05 0.10825016
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824718 -0.00014600865 0.10824365
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108252734 0.00036321822 0.10823939
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108252786 -0.00020212395 0.10825228
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09909749 -0.0065478124 0.097683884
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09992903 0.0008402802 0.09760336
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846429 -3.0445288e-05 0.06846511
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09999754 0.0025605082 0.09994461
encoders.0.layer_

In [22]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [35]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False,  True,  True,  True, False, False, False, False,
         False, False, False, False]))

In [36]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> Alfred Chandler (politician)Alfred Elliott Chandler (1 July 1873 – 12 February 1935) was an Australian politician.\n\nHe was born in Malvern to market gardener William Chandler and Kate Timewell. He attended state school and became a horticulturist, running a nursery in Boronia. On 24 May 1897 he married Elizabeth Ann Intermann, with whom he had one daughter; he remarried on 27 August 1901 to Marie Intermann, with whom he had five children. He served on Ferntree Gully Shire Council from 1901 to 1935, with four terms as president (1908–09, 1918–19, 1923–24, 1934–35). In 1919 he was elected to the Victorian Legislative Council as a Nationalist, representing South Eastern Province. He was Minister of Public Works and Mines from 1928 to 1929 and a minister without portfolio from 1932 to 1935. Chandler died in Boronia in 1935 and was succeeded by his son Gilbert.\n\nReferences\n\nCategory:1873 births\nCategory:1935 deaths\nCategory:Nationalist Party of Australia members of 

In [37]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 2025578 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  127\n  Massachusetts Route 127A\n  Minnesota State Highway 127 (former)\n  Missouri Route 127\n  New Hamp
<|doc_begin|> <|doc_id_begin|> 2025578 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>, New York)\n  County Route 127 (Herkimer County, New York)\n  County Route 127 (Monroe County, New York)
<|doc_begin|> <|doc_id_begin|> 2025578 <|doc_id_end|> <|doc_offset_begin|> 364 <|doc_offset_end|> Westchester County, New York)\n  North Carolina Highway 127\n  North Dakota Highway 127\n  Ohio State Rou
<|doc_begin|> <|doc_id_begin|> 1443391 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Alfred Chandler (politician) <|doc_title_end|> <|doc_body_begin|> Alfred Elliott Cha
<|doc_begin|> <|doc_id_begin|> 1443391 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> arried on 27 August 1901 to Marie Intermann, with whom he had five children. He served on 

In [38]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [39]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.545416 0.500873 0.436685 F
0.572497 0.510113 0.436598 F
0.455635 0.528292 0.488499 F
0.933395 0.764793 0.648805 T
0.783655 0.923404 0.797592 T
0.672478 0.888736 0.935071 T
0.689806 0.706245 0.651298 F
0.664606 0.614295 0.537516 F
0.484607 0.459275 0.500243 F
0.526222 0.667321 0.719722 F
0.401807 0.624887 0.762381 F
0.676393 0.679892 0.564639 F
0.545654 0.594635 0.487569 F
0.598215 0.550506 0.405428 F


In [44]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Truus Kerkmeester'
txt = 'The Brothers Karamazov'
txt = 'Roslin House in Haverford, Pennsylvania, United States, was built in 1911 for Horace B. Forman Jr.'
txt = 'Alfred Chandler (politician)'
txt = 'Alfred Elliott Chandler (1 July 1873 – 12 February 1935) was an'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.292753 F
0.293142 F
0.406254 F
0.656668 T
0.711962 T
0.730172 T
0.592304 F
0.493509 F
0.446405 F
0.315714 F
0.406525 F
0.513142 F
0.480471 F
0.330674 F


In [45]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank

tensor([[0.0610, 0.0575, 0.1869, 0.6856, 0.8816, 0.9152, 0.7861, 0.3635, 0.6011,
         0.0871, 0.3947, 0.4667, 0.5001, 0.2087]], grad_fn=<SigmoidBackward0>)

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3