In [1]:
%load_ext autoreload
%autoreload 2

In [16]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRankerLevel
from mllm.config.model import TokenizerCfg, MllmEncdecCfg, MllmRankerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, ChunkTokenizer, tokenizer_from_config



In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

# TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels'
# encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240903_215749-msmarco-fever'

TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec_0'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels_0'
encdec_subdir = 'encdec-20241018_092135-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20241018_230250-msmarco-fever'

encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
ranker_snapshot_fpath = ranker_train_path / 'best.pth'
encdec_tkz_cfg_fpath = encdec_train_path / TOKENIZER_CFG_FNAME
ranker_tkz_cfg_fpath = ranker_train_path / TOKENIZER_CFG_FNAME
encdec_model_cfg_fpath = encdec_train_path / ENCDEC_MODEL_CFG_FNAME
ranker_model_cfg_fpath = ranker_train_path / RANKER_MODEL_CFG_FNAME

In [4]:
docs_batch_size = 5
max_chunks_per_doc = 3
model_level = 0
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [5]:
encdec_tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_tkz_cfg_fpath)
ranker_tkz_cfg = parse_yaml_file_as(TokenizerCfg, ranker_tkz_cfg_fpath)
assert encdec_tkz_cfg == ranker_tkz_cfg
tokenizer = tokenizer_from_config(encdec_tkz_cfg)
tok_dict = encdec_tkz_cfg.custom_tokens
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind

In [6]:
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [7]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [9]:
model_encdec_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_model_cfg_fpath)
model_encdec = MllmEncdecLevel(model_encdec_cfg, model_level).to(device)
checkpoint_encdec = torch.load(encdec_snapshot_fpath)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897168 2.833078e-06 0.010897173
vocab_encoder.layer_norm.weight (256,) -0.09988137 -0.0042997296 0.09879484
vocab_encoder.layer_norm.bias (256,) -0.098634355 0.004373301 0.09782844
vocab_decoder.word_prj.weight (50271, 256) -0.010897171 -3.3309432e-06 0.010897171
encoder.a_em () 0.045746114 0.045746114 0.045746114
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825307 0.00037836193 0.10825286
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108251706 8.872396e-05 0.10824464
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108242646 0.0004454744 0.1082503
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825124 -0.00040385238 0.108252384
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09982707 0.0031327584 0.09918846
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09960126 0.0014409341 0.099286936
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068465166 5.2775606e-05 0.068

In [10]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [11]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([9, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False,  True,  True, False, False, False, False]))

In [12]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> Janów, Krotoszyn CountyJanów  () is a village in the administrative district of Gmina Krotoszyn, within Krotoszyn County, Greater Poland Voivodeship, in west-central Poland. It lies approximately  east of Krotoszyn and  south-east of the regional capital Poznań.\n\nReferences\n\nCategory:Villages in Krotoszyn County <|query_end|>'

In [13]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 4018364 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Wide Sargasso Sea (disambiguation) <|doc_title_end|> <|doc_body_begin|> Wide Sargass
<|doc_begin|> <|doc_id_begin|> 4018364 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  song on the Stevie Nicks album In Your Dreams\n\nSee also\n Sargasso Sea (disambiguation) <|doc_body_end|
<|doc_begin|> <|doc_id_begin|> 1500907 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Anarita <|doc_title_end|> <|doc_body_begin|> Anarita () is a village in the Paphos D
<|doc_begin|> <|doc_id_begin|> 3392514 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Janów, Krotoszyn County <|doc_title_end|> <|doc_body_begin|> Janów  () is a village 
<|doc_begin|> <|doc_id_begin|> 3392514 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  in Krotoszyn County <|doc_body_end|> <|doc_end|>
<|doc_begin|> <|doc_id_begin|> 5624235 <|doc_i

In [14]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [15]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.100320 F
0.161749 F
0.295927 F
0.241414 T
0.087471 T
0.023370 F
0.035144 F
0.209107 F
0.225067 F


In [23]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = parse_yaml_file_as(MllmRankerCfg, ranker_model_cfg_fpath)
model_ranker = MllmRankerLevel(model_ranker_cfg, model_level).to(device)
checkpoint_ranker = torch.load(ranker_snapshot_fpath)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 -1.5446053e-06 0.010897171
vocab_encoder.layer_norm.weight (256,) -0.0996469 0.0028574266 0.09866418
vocab_encoder.layer_norm.bias (256,) -0.09998738 -0.006327389 0.099993505
encoder.a_em () 0.035666432 0.035666432 0.035666432
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825018 0.00018848208 0.10824925
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825097 -9.785341e-05 0.10824983
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108253 -0.00036103325 0.10825155
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108247496 0.00013491296 0.10825311
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09811924 0.001974782 0.098531544
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09807801 0.0072465986 0.09969182
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846328 1.9215368e-05 0.068465255
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09963577 0.00142563 0.099675

In [39]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([9, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False, False, False, False, False, False,  True]))

In [40]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> Alanbaşı, VezirköprüAlanbaşı is a village in the District of Vezirköprü, Samsun Province, Turkey.\n\nReferences\n\nCategory:Populated places in Samsun Province\nCategory:Villages in Turkey\nCategory:Vezirköprü <|query_end|>'

In [41]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 4018364 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Wide Sargasso Sea (disambiguation) <|doc_title_end|> <|doc_body_begin|> Wide Sargass
<|doc_begin|> <|doc_id_begin|> 4018364 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  song on the Stevie Nicks album In Your Dreams\n\nSee also\n Sargasso Sea (disambiguation) <|doc_body_end|
<|doc_begin|> <|doc_id_begin|> 1500907 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Anarita <|doc_title_end|> <|doc_body_begin|> Anarita () is a village in the Paphos D
<|doc_begin|> <|doc_id_begin|> 3392514 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Janów, Krotoszyn County <|doc_title_end|> <|doc_body_begin|> Janów  () is a village 
<|doc_begin|> <|doc_id_begin|> 3392514 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  in Krotoszyn County <|doc_body_end|> <|doc_end|>
<|doc_begin|> <|doc_id_begin|> 5624235 <|doc_i

In [42]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [43]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.390766 F
0.271421 F
0.589987 F
0.563101 F
0.399155 F
0.518974 F
0.411905 F
0.450434 F
0.733288 T


In [66]:
txt = 'Hong Kong 1987'
txt = 'Wide Sargasso Sea'
txt = 'Alanbaşı, Vezirköprü'
txt = 'Lauren Reynolds'
txt = 'Lauren Raynolds'

# txt = 'The graph sandwich problem for property Π is defined as follows:'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.276059 F
0.437344 F
0.310923 F
0.430023 F
0.474101 F
0.658779 F
0.497670 F
0.450341 F
0.394595 T


In [67]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank_str = [f'{r:.06f}' for r in rank.flatten()]
rank_str

['0.000222',
 '0.005809',
 '0.000487',
 '0.002463',
 '0.005683',
 '0.957277',
 '0.011900',
 '0.014258',
 '0.001901']

442500

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3

In [2]:
import torch

In [4]:
device1 = torch.device('cpu')
device2 = torch.device('cuda')
device3 = torch.device('cpu')

In [5]:
print(device1 == device2)
print(device2 == device3)
print(device1 == device3)


False
False
True


In [7]:
device1.type, type(device1.type)

('cpu', str)

In [11]:
import numpy as np
toks = [
    np.ones(3) * 0.5,
    np.ones(3) * 1.7,
    np.array([-1, 7, 33])
]
np.stack(toks, axis=0)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [12]:
np.array(toks)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [13]:
np.allclose(np.array(toks), np.stack(toks, axis=0))

True