In [2]:
%load_ext autoreload
%autoreload 2

In [21]:
import os
from pathlib import Path
import sys

if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.dsfixed import DsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.model.config import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [4]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [5]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [6]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = DsLoader(
    ds_dir_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [7]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 3.297105e-06 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.09975823 -0.0025114394 0.09987755
vocab_encoder.layer_norm.bias (256,) -0.09988537 -0.007381388 0.099962406
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825245 -0.00040648147 0.10825277
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824651 -0.0002117044 0.108251125
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824577 -0.00016405826 0.108251214
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825165 0.00010650148 0.1082506
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09909498 0.0031277358 0.098389916
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09852611 0.00021516567 0.09963594
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068464994 -1.0003656e-05 0.06846488
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.099880375 0.0015438542 0.099917404
encoder.layer_stack.0.pos_ffn.w_2.weigh

In [8]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'])
model_encdec.eval()
del checkpoint_encdec

In [39]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res


In [40]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([9, 100]),
 torch.Size([2, 100]),
 tensor([ True,  True, False, False, False, False, False, False, False]))

In [11]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> 1987 in Hong KongEvents in the year 1987 in British Hong Kong.\n\nIncumbents\n Monarch – Elizabeth II\n Governor – Sir David Akers-Jones (until 9 April), Sir David Wilson (starting 9 April)\n\nEvents\n\nJanuary\n\nFebruary\n\nMarch\n\nApril\n\nMay\n\nJune\n\nJuly\n\nAugust\n\nSeptember\n\nOctober\n\nNovember\n\nDecember\n\nReferences\n\nCategory:1987 in Hong Kong\nCategory:Years of the 20th century in Hong Kong\nHong Kong\nHong Kong\nCategory:1987 in British Overseas Territories <|query_end|>'

In [12]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 1991677 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> 1987 in Hong Kong <|doc_title_end|> <|doc_body_begin|> Events in the year 1987 in Br
<|doc_begin|> <|doc_id_begin|> 1991677 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> \n\nCategory:1987 in Hong Kong\nCategory:Years of the 20th century in Hong Kong\nHong Kong\nHong Kong\nCatego
<|doc_begin|> <|doc_id_begin|> 1219385 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Charles Boles House <|doc_title_end|> <|doc_body_begin|> The Charles Boles House, lo
<|doc_begin|> <|doc_id_begin|> 1219385 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  designed more than 80 residences, churches, and commercial buildings in Kalispell.\n\nReferences\n\nCatego
<|doc_begin|> <|doc_id_begin|> 3140247 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Stebnik <|doc_title_end|> <|doc_body_begin|> Stebnik may refer to:\n\n

In [13]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [14]:
for i, docs_emb in enumerate(docs_embs.detach().numpy()):
    for target_emb in target_embs.detach().numpy():
        dist = distance(target_emb, docs_emb, True)
        print(f'{dist:0.6f} ', end='')
    sfx = 'T' if target_mask[i] else 'F'
    print(sfx)

0.806555 0.197400 T
0.357577 0.731144 T
0.306005 0.101844 F
0.201823 0.351509 F
0.121644 0.385238 F
0.244798 0.153832 F
0.309911 0.141900 F
0.223536 0.158956 F
0.134927 0.299271 F


In [15]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 -1.7279738e-06 0.010897279
vocab_encoder.layer_norm.weight (256,) -0.098713435 -0.00068409566 0.09903578
vocab_encoder.layer_norm.bias (256,) -0.099790215 -0.0012416485 0.09938752
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825065 0.0003846487 0.108250104
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825002 -4.3549368e-05 0.10824833
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108249344 -0.00020081543 0.108252294
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108252816 -4.5218905e-05 0.1082463
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09881319 0.00011143202 0.09914114
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09990596 0.0003516497 0.09964675
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846519 -3.5115994e-05 0.0684649
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09875158 -0.0010093136 0.09944109
encoders.0.l

In [16]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [17]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [48]:
def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)

cosine = False
cosine = True
print_dist(target_embs, docs_embs, cosine)

0.480150 0.234837 T
0.233881 0.560774 T
0.056342 -0.095653 F
0.016592 0.186500 F
-0.069437 0.286522 F
-0.053729 -0.298380 F
-0.054713 -0.335745 F
-0.122766 -0.349096 F
-0.281377 0.023078 F


In [67]:
txt = 'Hong Kong 1987'
# txt = 'Events in the year 1987'
# txt = 'Events in the year 1987 Hong Kong'
cosine = True
txt_tokens = text_to_tokens(txt)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, cosine=cosine)


0.049998 T
0.411652 T
-0.035347 F
0.134199 F
0.358041 F
-0.158853 F
-0.138117 F
-0.066518 F
0.190424 F


In [68]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[[3.0222e-05],
         [5.7394e-05],
         [2.1583e-05],
         [1.7689e-05],
         [7.0402e-01],
         [3.1666e-06],
         [3.6296e-06],
         [4.0126e-06],
         [6.7034e-05]]], grad_fn=<SliceBackward0>)