In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.model.config import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [4]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240730_213328-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20240806_221913-msmarco'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [5]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [6]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [7]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [8]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897281 -5.371368e-07 0.010897278
vocab_encoder.layer_norm.weight (256,) -0.09767083 0.0018497845 0.098900214
vocab_encoder.layer_norm.bias (256,) -0.099482946 -0.007212233 0.0995147
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.1082476 0.00014912081 0.108250156
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825111 -0.0003242747 0.10823374
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108251184 -0.0002698056 0.10825296
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10824782 -8.96823e-05 0.10825316
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09930195 -0.0023568799 0.099972546
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09888889 -0.0020617023 0.0983296
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846405 -2.1045988e-05 0.06846517
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09992778 0.0017288679 0.09995016
encoder.layer_stack.0.pos_ffn.w_2.weight (256

In [9]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
del checkpoint_encdec

In [10]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [11]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([15, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False, False,  True,
          True,  True, False, False, False]))

In [12]:
s_target = tokten_to_text(target_chunks)
s_target

"<|query_begin|> Doune HillclimbDoune Hillclimb, Carse of Cambus, near Doune in the district of Stirling, Scotland, is the home of the only round of the British Hill Climb Championship currently to be held in Scotland, (Bo'ness, Fintray and the Rest And Be Thankful have featured in the past). The course is 1,476 yards (1,350 m) in length (although when it was first constructed in 1968 it was around 33yd / 30 m longer) and meetings have been staged by the Lothian Car Club since 1968.\n\nPrior to 1968, Lothian Car Club ran rounds of the British Hill Climb Championship at the Bo'ness Hillclimb from 1948 until 1967, when a house estate was built over part of the Bo'ness track. In 1967 the hillclimb track at Doune was designed by Ray Fielding and built with the first event taking place in April 1968. \n\nThe current outright record holder is Scott Moran, who set a time of 34.76 seconds in June 2014 in his Gould GR61X. Video of a 35.05 second run by Jos Goodyear in his GWR Raptor can be seen

In [13]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 3290073 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Echo Party <|doc_title_end|> <|doc_body_begin|> Echo Party is a mixtape by American 
<|doc_begin|> <|doc_id_begin|> 3290073 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>. It was recorded using tape echo, guitar, and kazoo, among other instruments.\n\nRelease\nEcho Party was r
<|doc_begin|> <|doc_id_begin|> 3290073 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  an average score of 71, based on 10 reviews, indicating "generally favorable reviews".\n\nRick Anderson
<|doc_begin|> <|doc_id_begin|> 2574801 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  passed the exam as banker in 1969. From 1972 to 1974 she continued her education in evening classes an
<|doc_begin|> <|doc_id_begin|> 2574801 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  she received her PhD in 1987. APS She was then promoted to Docent. From 1992 to 97 she was D

In [14]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [15]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.660717 0.578585 0.606018 F
0.647171 0.610953 0.636123 F
0.630963 0.629010 0.721536 F
0.627274 0.657909 0.664354 F
0.728437 0.709906 0.706816 F
0.636042 0.686461 0.691334 F
0.673499 0.701636 0.627083 F
0.685973 0.685685 0.619705 F
0.608890 0.641603 0.542379 F
0.712978 0.683268 0.619175 T
0.673893 0.675812 0.678965 T
0.637482 0.633176 0.700422 T
0.652289 0.665160 0.629680 F
0.603251 0.673676 0.642585 F
0.643143 0.639594 0.577397 F


In [16]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897279 -1.6627607e-06 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.0986701 0.0019635216 0.09986668
vocab_encoder.layer_norm.bias (256,) -0.09950633 0.001656612 0.09917348
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108237505 -0.00041527097 0.10825055
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825316 -0.0002554225 0.10825235
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108251885 -0.00017302623 0.10824738
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825226 -0.00021478075 0.10824859
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09978091 -0.0037236954 0.09753255
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09951987 0.001332393 0.0995569
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846455 2.6755595e-05 0.068465225
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09987401 -0.0004405306 0.09974468
encoders.0.layer_sta

In [17]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [18]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([15, 100]),
 torch.Size([3, 100]),
 tensor([ True,  True,  True, False, False, False, False, False, False, False,
         False, False, False, False, False]))

In [19]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|>  an average score of 71, based on 10 reviews, indicating "generally favorable reviews".\n\nRick Anderson of AllMusic gave the album 4 stars out of 5, saying, "you\'ll hear vintage hip-hop basslines, 808 beats, and exuberant \'80s-style rapping interspersed with weirdness like chopped-up Latin rhythms and shout-outs to New York boroughs and zodiac signs." Adam Kennedy of BBC called it "a production album over mere mixtape, one for the breakdancers, as well as appreciators of both forward-thinking and back-in-the-day craft." Nate Patrin of Pitchfork gave the album a 6.8 out of 10, saying, "if you\'ve ever wanted to hear classic cuts from the dawn of hip hop turned into hallucinogenic setpieces that knock and clang like glitched-up King Tubby, Echo Party should justify whatever the hell it is Edan\'s been doing with his time over the past four years."\n\nDave Segal of The Stranger included it on the "2009\'s Top Overlooked Releases" list.\n\nTrack listing\n\nPersonnel\nCr

In [2]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

NameError: name 'docs_chunks' is not defined

In [21]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [22]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.955599 0.844631 0.682803 T
0.860412 0.956955 0.735367 T
0.736258 0.745605 0.942338 T
0.198733 0.191638 0.264690 F
0.250881 0.231401 0.250573 F
0.020881 -0.025045 0.148455 F
0.258112 0.155107 0.196491 F
0.128318 0.113541 0.201714 F
0.377852 0.314580 0.287486 F
0.316507 0.256909 0.265288 F
0.250763 0.175146 0.239105 F
0.415557 0.330336 0.426571 F
-0.099260 -0.177954 -0.001805 F
0.030993 -0.092427 0.130320 F
0.148946 0.036831 0.214487 F


In [1]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Rick Anderson'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


NameError: name 'text_to_tokens' is not defined

In [45]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank

tensor([[0.0610, 0.0575, 0.1869, 0.6856, 0.8816, 0.9152, 0.7861, 0.3635, 0.6011,
         0.0871, 0.3947, 0.4667, 0.5001, 0.2087]], grad_fn=<SigmoidBackward0>)

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3