In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRankerLevel
from mllm.config.model import TokenizerCfg, MllmEncdecCfg, MllmRankerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, ChunkTokenizer, tokenizer_from_config



In [4]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

# TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels'
# encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240903_215749-msmarco-fever'

TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec_0'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels_0'
encdec_subdir = 'encdec-20241018_092135-wiki_20200501_en-ch_100_fixed'
ranker_subdir = 'ranker-20241021_062053-msmarco-fever'

encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
ranker_snapshot_fpath = ranker_train_path / 'best.pth'
encdec_tkz_cfg_fpath = encdec_train_path / TOKENIZER_CFG_FNAME
ranker_tkz_cfg_fpath = ranker_train_path / TOKENIZER_CFG_FNAME
encdec_model_cfg_fpath = encdec_train_path / ENCDEC_MODEL_CFG_FNAME
ranker_model_cfg_fpath = ranker_train_path / RANKER_MODEL_CFG_FNAME

In [68]:
docs_batch_size = 5
max_chunks_per_doc = 3
model_level = 0
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [8]:
encdec_tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_tkz_cfg_fpath)
ranker_tkz_cfg = parse_yaml_file_as(TokenizerCfg, ranker_tkz_cfg_fpath)
assert encdec_tkz_cfg == ranker_tkz_cfg
tokenizer = tokenizer_from_config(encdec_tkz_cfg)
tok_dict = encdec_tkz_cfg.custom_tokens
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind

In [70]:
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [71]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [72]:
model_encdec_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_model_cfg_fpath)
model_encdec = MllmEncdecLevel(model_encdec_cfg, model_level).to(device)
checkpoint_encdec = torch.load(encdec_snapshot_fpath)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897173 -3.2294865e-06 0.010897173
vocab_encoder.layer_norm.weight (256,) -0.09595125 0.00065567554 0.09957045
vocab_encoder.layer_norm.bias (256,) -0.09988814 0.0075067556 0.09993924
vocab_decoder.word_prj.weight (50271, 256) -0.010897174 3.4891085e-07 0.010897171
encoder.a_em () 0.04009992 0.04009992 0.04009992
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824596 -0.00040247082 0.108248524
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824942 0.00023723408 0.10825191
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108234644 0.00016568921 0.10825205
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108250834 -0.00017125583 0.1082469
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.098759234 0.0039213006 0.09965485
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09977838 0.00050682866 0.09982135
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.068464346 -5.541658e-05 0.0

In [73]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [74]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([18, 100]),
 torch.Size([2, 100]),
 tensor([ True,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False]))

In [75]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> Global CosmopolitansGlobal Cosmopolitans refers to "a talented population of highly educated multilingual people that have lived, worked and studied for extensive periods in different cultures. While their international identities have diverse starting points and experiences, their views of the world and themselves are profoundly affected by both the realities of living in different cultures and their manner of coping with the challenges that emerge.".\n\nThe term was developed by Linda Brimm, Professor of Organizational Behavior at INSEAD and further explored in her book Global Cosmopolitans: The Creative Edge of Difference.\n\nSee also \n Cosmopolitanism\n Global citizenship\n World Citizen\n\nReferences\n\nCategory:Global citizenship\nCategory:Globalization\nCategory:Cosmopolitanism <|query_end|>'

In [76]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 3926504 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Global Cosmopolitans <|doc_title_end|> <|doc_body_begin|> Global Cosmopolitans refer
<|doc_begin|> <|doc_id_begin|> 3926504 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>, Professor of Organizational Behavior at INSEAD and further explored in her book Global Cosmopolitans: 
<|doc_begin|> <|doc_id_begin|> 5225710 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Thoracic spinal nerve 3 <|doc_title_end|> <|doc_body_begin|> The thoracic spinal ner
<|doc_begin|> <|doc_id_begin|> 3966441 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Fred Miller (American football, born 1973) <|doc_title_end|> <|doc_body_begin|> Fred
<|doc_begin|> <|doc_id_begin|> 3966441 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  Bowl XXXIV, having given up several sacks to Kearse in the regular season match up between the tea

In [77]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [78]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.210349 -0.023194 T
-0.038159 0.102338 T
0.246970 0.046396 F
0.010592 -0.048031 F
0.081924 -0.014397 F
-0.038597 -0.080679 F
0.043957 0.119103 F
-0.153411 -0.039547 F
-0.146363 -0.027516 F
-0.047171 0.202399 F
0.004024 0.006097 F
0.037217 0.085870 F
-0.057347 0.073293 F
0.241241 -0.033839 F
0.032375 0.033223 F
0.041402 -0.134142 F
-0.037370 0.080175 F
-0.112170 0.257837 F


In [79]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = parse_yaml_file_as(MllmRankerCfg, ranker_model_cfg_fpath)
model_ranker = MllmRankerLevel(model_ranker_cfg, model_level).to(device)
checkpoint_ranker = torch.load(ranker_snapshot_fpath)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 -8.0230694e-07 0.010897171
vocab_encoder.layer_norm.weight (256,) -0.09996011 0.002452353 0.0998353
vocab_encoder.layer_norm.bias (256,) -0.09894606 -0.0014543043 0.098527074
encoder.a_em () 0.09004872 0.09004872 0.09004872
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825315 0.00013853457 0.108247146
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108252436 0.00011188915 0.108251445
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824739 0.00022007307 0.10824797
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10824961 0.00014514917 0.10825284
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09991062 0.006778831 0.09973943
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09960178 -0.0046932297 0.09947876
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.0684649 5.3948297e-05 0.068465285
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09996933 0.00040751696 0.0999

In [80]:
i = 12
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([25, 100]),
 torch.Size([3, 100]),
 tensor([ True,  True,  True, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False]))

In [81]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> TritordeumTritordeum is a hybrid crop, obtained by crossing durum wheat with the wild barley Hordeum chilense. It has less gliadin (gluten) than wheat, but still performs well in breads, both in terms of dough rising and texture qualities, and in taste-testing, where it substantially outperformed gluten-free breads. It has ten times more lutein, more oleic acid, and more fiber than wheat, giving products made from it a yellower hue and a pleasant flavor profile.\n\nUnder development by the Spanish National Research Council since 1977, it was launched onto the market in April 2013 by the start-up Agrasys company created under the auspices of the University of Barcelona to commercialize the cereal. It is planted on about 1300 ha in Portugal, Spain, France, Italy and Turkey. It does better in hotter and drier growing conditions than wheat, using less water. Because of this water-saving feature, it won first prize for a Sustainable Ingredient in the 2018 Sustainable Food A

In [82]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 3382600 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Tritordeum <|doc_title_end|> <|doc_body_begin|> Tritordeum is a hybrid crop, obtaine
<|doc_begin|> <|doc_id_begin|> 3382600 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> leic acid, and more fiber than wheat, giving products made from it a yellower hue and a pleasant flavor
<|doc_begin|> <|doc_id_begin|> 3382600 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  It does better in hotter and drier growing conditions than wheat, using less water. Because of this w
<|doc_begin|> <|doc_id_begin|> 5770866 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Birami <|doc_title_end|> <|doc_body_begin|> Birami is a panchayat village in Rajasth
<|doc_begin|> <|doc_id_begin|> 5770866 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> \nGeography\nThe village of Birami is 32 km by road southeast of the city of Jodhpur, located betwe

In [97]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [98]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.851829 0.768325 0.717993 T
0.710824 0.864115 0.791755 T
0.751394 0.784505 0.901214 T
0.601972 0.765515 0.664485 F
0.506602 0.649703 0.602839 F
0.574463 0.730713 0.660725 F
0.385619 0.497914 0.482378 F
0.592718 0.615067 0.464479 F
0.415337 0.564398 0.493007 F
0.590724 0.656549 0.577092 F
0.516982 0.525026 0.389098 F
0.691276 0.658016 0.652711 F
0.466441 0.544165 0.483684 F
0.521704 0.572658 0.464098 F
0.583230 0.539854 0.519573 F
0.464891 0.593930 0.624339 F
0.446136 0.458801 0.597052 F
0.523654 0.535276 0.620607 F
0.393278 0.383729 0.532924 F
0.473186 0.469718 0.547572 F
0.421835 0.542141 0.402089 F
0.443597 0.613942 0.547976 F
0.482703 0.653324 0.531004 F
0.495839 0.644287 0.531666 F
0.453870 0.573715 0.463971 F


In [105]:
txt = 'El Charco del Cura Reservoir'
# txt = 'Théodore Eugène César Ruyssen'
# txt = 'Théodore Ruyssen'
txt = 'orders in certain situations, and both required trade-offs'
txt = 'Tritordeum hybrid crop'
txt = 'etienne Lecroart ddd'
txt = 'Ну приветики ну погоди'

# txt = 'The graph sandwich problem for property Π is defined as follows:'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.281365 T
0.378958 T
0.301095 T
0.381230 F
0.428091 F
0.460663 F
0.501278 F
0.475151 F
0.510579 F
0.420649 F
0.636904 F
0.370556 F
0.539754 F
0.610080 F
0.480029 F
0.353345 F
0.226356 F
0.358335 F
0.394531 F
0.479233 F
0.359256 F
0.500798 F
0.417691 F
0.460942 F
0.405442 F


In [106]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank_str = [f'{r:.06f}' for r in rank.flatten()]
rank_str

['0.001259',
 '0.003427',
 '0.001251',
 '0.003243',
 '0.007756',
 '0.014418',
 '0.033158',
 '0.038146',
 '0.018492',
 '0.008177',
 '0.388008',
 '0.003869',
 '0.064414',
 '0.174521',
 '0.030543',
 '0.006355',
 '0.000525',
 '0.003603',
 '0.003947',
 '0.080367',
 '0.003919',
 '0.048465',
 '0.010669',
 '0.039282',
 '0.012188']

In [50]:
docs_chunks.shape

torch.Size([14, 100])

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3

In [2]:
import torch

In [4]:
device1 = torch.device('cpu')
device2 = torch.device('cuda')
device3 = torch.device('cpu')

In [5]:
print(device1 == device2)
print(device2 == device3)
print(device1 == device3)


False
False
True


In [7]:
device1.type, type(device1.type)

('cpu', str)

In [11]:
import numpy as np
toks = [
    np.ones(3) * 0.5,
    np.ones(3) * 1.7,
    np.array([-1, 7, 33])
]
np.stack(toks, axis=0)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [12]:
np.array(toks)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [13]:
np.allclose(np.array(toks), np.stack(toks, axis=0))

True