In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRankerLevel
from mllm.config.model import TokenizerCfg, MllmEncdecCfg, MllmRankerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, ChunkTokenizer, tokenizer_from_config



In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

# TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels'
# encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240903_215749-msmarco-fever'

TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec_0'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels_0'
encdec_subdir = 'encdec-lvl0-20241029_140645-wiki_20200501_en-ch_100_fixed-enc-lrs3-embmatFalse-d256-h8-dec-lrs3-seqlen100-d256-h8-vocdecTrue'
ranker_subdir = 'ranker-lvl0-20241030_230226-msmarco-fever-enc-lrs3-embmatFalse-d256-h8-dec-lrs0-d256-h8'
ranker_subdir = 'ranker-lvl0-20241031_215643-msmarco-fever-enc-lrs3-embmatFalse-d256-h8-dec-lrs3-d256-h8'

encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
ranker_snapshot_fpath = ranker_train_path / 'best.pth'
encdec_tkz_cfg_fpath = encdec_train_path / TOKENIZER_CFG_FNAME
ranker_tkz_cfg_fpath = ranker_train_path / TOKENIZER_CFG_FNAME
encdec_model_cfg_fpath = encdec_train_path / ENCDEC_MODEL_CFG_FNAME
ranker_model_cfg_fpath = ranker_train_path / RANKER_MODEL_CFG_FNAME

In [4]:
docs_batch_size = 5
max_chunks_per_doc = 3
model_level = 0
device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [5]:
encdec_tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_tkz_cfg_fpath)
ranker_tkz_cfg = parse_yaml_file_as(TokenizerCfg, ranker_tkz_cfg_fpath)
assert encdec_tkz_cfg == ranker_tkz_cfg
tokenizer = tokenizer_from_config(encdec_tkz_cfg)
tok_dict = encdec_tkz_cfg.custom_tokens
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind

In [6]:
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [7]:
def tokens_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [8]:
model_encdec_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_model_cfg_fpath)
model_encdec = MllmEncdecLevel(model_encdec_cfg, model_level).to(device)
checkpoint_encdec = torch.load(encdec_snapshot_fpath)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 2.1067253e-06 0.010897173
vocab_encoder.layer_norm.weight (256,) -0.09786662 -0.0013290788 0.099208556
vocab_encoder.layer_norm.bias (256,) -0.0994437 -0.011387913 0.09963775
vocab_decoder.word_prj.weight (50271, 256) -0.010897174 -1.922783e-07 0.010897171
encoder.a_em () 0.0899288 0.0899288 0.0899288
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108246244 5.1258652e-05 0.10824751
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824494 0.00030634573 0.10825305
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824328 6.980845e-05 0.1082527
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825218 1.6250924e-05 0.10824487
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.099198155 0.0035320162 0.09855495
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09984912 0.0031056185 0.09970801
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846509 9.866831e-05 0.068465255
enc

In [9]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([12, 100]),
 torch.Size([2, 100]),
 tensor([False, False, False, False, False,  True,  True, False, False, False,
         False, False]))

In [11]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|> Virginie ClaesVirginie Claes (born 17 December 1982 in Herk-de-Stad, Limburg) is television and radio presenter and a former beauty pageant title-holder.\n\nBiography \nClaes was Miss Limburg 2006 and later that year she was crowned Miss Belgium.\n\nShe later became a television and radio presenter, including for French-language channels RTL-TVI and Bel RTL.\n\nReferences\n\nExternal links \nMiss Limburg - Virginie Claes\n\nCategory:Living people\nCategory:People from Limburg (Belgium)\nCategory:Belgian female models\nCategory:1982 births\nCategory:Miss World 2006 delegates\nCategory:Belgian beauty pageant winners\nCategory:Miss Belgium winners\nCategory:Flemish models <|query_end|>'

In [12]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 2177006 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> 2012 European Karate Championships <|doc_title_end|> <|doc_body_begin|> The 2012 Eur
<|doc_begin|> <|doc_id_begin|> 2177006 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  competitions hosted by Spain\nEuropean Karate Championships\nEuropean championships in 2012\nCategory:Spo
<|doc_begin|> <|doc_id_begin|> 5514810 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Chetogaster <|doc_title_end|> <|doc_body_begin|> Chetogaster is a genus of bristle f
<|doc_begin|> <|doc_id_begin|> 5514810 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> berrae Paramonov, 1954 c g\n Chetogaster oblonga (Macquart, 1847) c g\n Chetogaster pellucida Paramonov, 
<|doc_begin|> <|doc_id_begin|> 5514810 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  b = Bugguide.net\n\nReferences\n\nFurther reading\n\nExternal links \n\n \n \n\nCategory:Tac

In [13]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [14]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.184208 0.079882 F
0.094283 0.323757 F
0.165659 0.239028 F
0.217627 0.339339 F
0.165336 0.404311 F
0.102995 0.225116 T
0.140231 0.469443 T
0.079364 -0.090584 F
0.263610 0.251187 F
0.203657 0.113284 F
0.118158 0.168529 F
0.220931 0.330329 F


In [15]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = parse_yaml_file_as(MllmRankerCfg, ranker_model_cfg_fpath)
model_ranker = MllmRankerLevel(model_ranker_cfg, model_level).to(device)
checkpoint_ranker = torch.load(ranker_snapshot_fpath)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897168 -1.2185656e-06 0.01089717
vocab_encoder.layer_norm.weight (256,) -0.099739194 0.00296096 0.0996237
vocab_encoder.layer_norm.bias (256,) -0.099467896 0.001347108 0.09899004
encoder.a_em () -0.038321506 -0.038321506 -0.038321506
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108252116 0.00020532531 0.10824986
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824924 -1.2689488e-05 0.108252436
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825268 -0.00025201027 0.10825282
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825143 8.3124265e-05 0.10824674
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09940179 -0.0054651625 0.09948772
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09809735 0.0043678153 0.09915371
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846482 -1.8771465e-05 0.068464346
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.099980436 0.0006258688

In [53]:
i = 11
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False,  True,  True,  True,
         False, False, False]))

In [54]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|>  methods and seeks to contain a strong sense of narrative. Her work seeks to highlight the absurdities and inequalities of wealth inequalities among a small percentage of society. \n\nIn the 1980s, her work was described as "strongly distinctive", and were originally monochromatic before including the flat application of bright primary colours. Her methods for oil painting sat in contrast to customary methods, instead using small strokes of the brush while keeping the rest of the canvas largely colorless.\n\nIn 1990, she received the Young Artists Award from the Bangladesh Shilpakala Academy.\n\nHer known works include a series titled "Which Can Be Knotted", which was shown at the Fukuoka Asian Art Museum in 2002, representing both the wish to form strong bonds between people, as well as the suffocation created by strong relationships. She has also experimented with mixed media installations, for example at the 2010 Asian Biennale when she created an installation focus

In [55]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 1879814 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Mazzara Airfield <|doc_title_end|> <|doc_body_begin|> Mazara Airfield is an abandone
<|doc_begin|> <|doc_id_begin|> 1879814 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> elfth Air Force 316th Troop Carrier Group  between 1 September and 18 October 1943  The unit had three 
<|doc_begin|> <|doc_id_begin|> 1879814 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|> \n\nThe airfield was not used during Operation Husky (invasion of Sicily). It appears to have been a sta
<|doc_begin|> <|doc_id_begin|> 4409149 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Palm IIIx <|doc_title_end|> <|doc_body_begin|> The Palm IIIx is a PDA from Palm Comp
<|doc_begin|> <|doc_id_begin|> 4409149 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  along with claims to be more efficient, than the 16 MHz Motorola DragonBall CPU found in all pre

In [56]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [57]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.303889 0.394754 0.218053 F
0.144708 0.106359 -0.003139 F
0.303046 0.219121 0.042181 F
0.292043 0.390917 0.428751 F
0.366955 0.339398 0.279463 F
0.332836 0.310158 0.278422 F
0.062432 0.049842 0.042358 F
0.730377 0.612713 0.457807 T
0.640144 0.715191 0.573683 T
0.623566 0.749315 0.604172 T
0.310417 0.362377 0.229236 F
0.367458 0.450242 0.314680 F
0.282556 0.361620 0.234114 F


In [64]:
txt = '1995 Fed Cup Europe'
# txt = ' Kitchener Centre'
# txt = 'Virginie Claes'
txt = 'Asian Biennale when she created an installation focused on the sound of birds'

# txt = 'The graph sandwich problem for property Π is defined as follows:'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.272010 F
0.071873 F
0.156574 F
-0.150676 F
-0.052874 F
-0.086484 F
0.507175 F
0.345749 T
0.273679 T
0.383693 T
0.181840 F
0.213475 F
0.130052 F


In [65]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank_str = [f'{r:.06f}' for r in rank.flatten()]
rank_str

['0.000225',
 '0.000030',
 '0.000252',
 '0.000000',
 '0.000263',
 '0.000169',
 '0.968853',
 '0.007174',
 '0.009793',
 '0.010131',
 '0.000708',
 '0.001147',
 '0.001255']

In [50]:
docs_chunks.shape

torch.Size([14, 100])

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3

In [2]:
import torch

In [4]:
device1 = torch.device('cpu')
device2 = torch.device('cuda')
device3 = torch.device('cpu')

In [5]:
print(device1 == device2)
print(device2 == device3)
print(device1 == device3)


False
False
True


In [7]:
device1.type, type(device1.type)

('cpu', str)

In [11]:
import numpy as np
toks = [
    np.ones(3) * 0.5,
    np.ones(3) * 1.7,
    np.array([-1, 7, 33])
]
np.stack(toks, axis=0)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [12]:
np.array(toks)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [13]:
np.allclose(np.array(toks), np.stack(toks, axis=0))

True