In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRankerLevel
from mllm.config.model import TokenizerCfg, MllmEncdecCfg, MllmRankerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, ChunkTokenizer, tokenizer_from_config



In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

# TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels'
# encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240903_215749-msmarco-fever'

TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec_0'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels_0'
encdec_subdir = 'encdec-lvl0-20241029_140645-wiki_20200501_en-ch_100_fixed-enc-lrs3-embmatFalse-d256-h8-dec-lrs3-seqlen100-d256-h8-vocdecTrue'
ranker_subdir = 'ranker-lvl0-20241030_230226-msmarco-fever-enc-lrs3-embmatFalse-d256-h8-dec-lrs0-d256-h8'
ranker_subdir = 'ranker-lvl0-20241031_215643-msmarco-fever-enc-lrs3-embmatFalse-d256-h8-dec-lrs3-d256-h8'
ranker_subdir = 'ranker-lvl0-20241103_130114-msmarco-fever-enc-lrs3-embmatFalse-d256-h8-dec-lrs3-d256-h8'

encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
ranker_snapshot_fpath = ranker_train_path / 'best.pth'
encdec_tkz_cfg_fpath = encdec_train_path / TOKENIZER_CFG_FNAME
ranker_tkz_cfg_fpath = ranker_train_path / TOKENIZER_CFG_FNAME
encdec_model_cfg_fpath = encdec_train_path / ENCDEC_MODEL_CFG_FNAME
ranker_model_cfg_fpath = ranker_train_path / RANKER_MODEL_CFG_FNAME

In [4]:
docs_batch_size = 5
max_chunks_per_doc = 3
model_level = 0
device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [5]:
encdec_tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_tkz_cfg_fpath)
ranker_tkz_cfg = parse_yaml_file_as(TokenizerCfg, ranker_tkz_cfg_fpath)
assert encdec_tkz_cfg == ranker_tkz_cfg
tokenizer = tokenizer_from_config(encdec_tkz_cfg)
tok_dict = encdec_tkz_cfg.custom_tokens
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind

In [6]:
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [7]:
def tokens_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [8]:
model_encdec_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_model_cfg_fpath)
model_encdec = MllmEncdecLevel(model_encdec_cfg, model_level).to(device)
checkpoint_encdec = torch.load(encdec_snapshot_fpath)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 1.3329893e-06 0.010897173
vocab_encoder.layer_norm.weight (256,) -0.098980986 0.0022841117 0.09967321
vocab_encoder.layer_norm.bias (256,) -0.099942826 -0.0019740022 0.09916983
vocab_decoder.word_prj.weight (50271, 256) -0.010897174 -4.0607728e-07 0.010897173
encoder.a_em () 0.020802975 0.020802975 0.020802975
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825261 0.00022600047 0.10825103
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824928 0.00038556458 0.10824711
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108251125 0.0004757878 0.10825291
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.1082521 0.00055551186 0.108248554
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09971508 0.002642855 0.099901095
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09900289 0.0011596174 0.099956356
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846522 -7.1032286e-05 0.0

In [9]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([2, 100]),
 tensor([False, False, False,  True,  True, False, False, False, False, False,
         False, False, False, False]))

In [11]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|> 1945 Paris–RoubaixThe 1945 Paris–Roubaix was the 43rd\xa0edition of the Paris–Roubaix, a classic one-day cycle race in France. The single day event was held on 9 April 1945 and stretched  from Paris to the finish at Roubaix Velodrome. The winner was Paul Maye from France.\n\nResults\n\nReferences\n\nCategory:Paris–Roubaix\nCategory:1945 in road cycling\nCategory:1945 in French sport <|query_end|>'

In [12]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 4838950 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Maine State Route 149 <|doc_title_end|> <|doc_body_begin|> State Route 149 (SR 149) 
<|doc_begin|> <|doc_id_begin|> 4838950 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  begins at an intersection in the Fairbanks neighborhood of Farmington at SR 4. The road, named South S
<|doc_begin|> <|doc_id_begin|> 4838950 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|> on Hill Road). SR 149 and SR 234 form a concurrency and travel into downtown Strong along Norton Hill 
<|doc_begin|> <|doc_id_begin|> 2020592 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> 1945 Paris–Roubaix <|doc_title_end|> <|doc_body_begin|> The 1945 Paris–Roubaix was t
<|doc_begin|> <|doc_id_begin|> 2020592 <|doc_id_end|> <|doc_offset_begin|> 92 <|doc_offset_end|> baix\nCategory:1945 in road cycling\nCategory:1945 in French sport <|doc_body_end|> <|doc_end|>
<|d

In [13]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [14]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.050723 0.140819 F
0.177963 0.191794 F
0.007825 0.115876 F
0.217006 0.237375 T
0.230547 0.664481 T
0.028282 0.243856 F
0.138972 0.150037 F
0.256995 0.220254 F
0.127542 0.185235 F
0.113792 0.236069 F
0.121711 0.241804 F
0.225511 0.197613 F
0.193322 0.239168 F
0.136008 0.277026 F


In [59]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = parse_yaml_file_as(MllmRankerCfg, ranker_model_cfg_fpath)
model_ranker = MllmRankerLevel(model_ranker_cfg, model_level).to(device)
checkpoint_ranker = torch.load(ranker_snapshot_fpath)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 2.8711395e-07 0.010897168
vocab_encoder.layer_norm.weight (256,) -0.09975773 -0.0026611888 0.099174134
vocab_encoder.layer_norm.bias (256,) -0.099611916 -0.003419092 0.0988902
encoder.a_em () -0.06363992 -0.06363992 -0.06363992
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108250596 -6.590552e-05 0.10825305
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.1082521 6.1234125e-05 0.108250014
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825094 0.00028304648 0.10825305
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108244695 -0.00019191732 0.10824863
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09965586 0.0043836436 0.09980502
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09824039 0.007093214 0.0999532
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846482 -5.6530007e-05 0.06846518
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.099701345 -0.0015187723 0.

In [60]:
i = 12
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([14, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False, False, False, False,  True,  True, False, False,
         False, False, False, False]))

In [61]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|> Austrocidaris spinulosaAustrocidaris spinulosa is a species of sea urchins of the family Cidaridae. Their armour is covered with spines. Austrocidaris spinulosa was first scientifically described in 1910 by Ole Mortensen.\n\nReferences\n\nCategory:Sea Urchins described in 1910\nCategory:Cidaridae\nCategory:Taxa named by Ole Theodor Jensen Mortensen <|query_end|>'

In [62]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 4795098 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  around the world.  The destinations depicted are usually the result of product placement. In 2015, the
<|doc_begin|> <|doc_id_begin|> 4795098 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  based on Richard Gordon's novel "The Captain's Table".\n\nPremise \nThe premise of the series resembled 
<|doc_begin|> <|doc_id_begin|> 4795098 <|doc_id_end|> <|doc_offset_begin|> 273 <|doc_offset_end|>  happy-end in each episode.\n\nActors \nCaptain Heinz Hansen was played by Heinz Weiss who in younger yea
<|doc_begin|> <|doc_id_begin|> 3463886 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> 1988 Western Soccer Alliance <|doc_title_end|> <|doc_body_begin|> Final league stand
<|doc_begin|> <|doc_id_begin|> 3463886 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>, Arturo Velazco, Jeff Stock\nMidfielders: Peter Hattrup, Billy Thompson, Tomás Boy\nForwards:

In [63]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [64]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.201634 F
0.051609 F
0.047279 F
0.121715 F
0.023739 F
0.243091 F
0.921977 T
0.025204 T
0.752494 F
0.503484 F
0.547126 F
0.451889 F
0.235818 F
0.269741 F


In [67]:
txt = 'The Captain\'s Table'
# txt = '1988 Western Soccer Alliance'
txt = 'Austrocidaris spinulosa'
txt = 'Gambierdiscus ruetzleri'
txt = 'Urusbiy'

# txt = 'The graph sandwich problem for property Π is defined as follows:'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.382868 F
0.062052 F
0.169623 F
0.087439 F
0.249196 F
0.237066 F
0.325658 T
0.359878 T
0.590362 F
0.504442 F
0.494654 F
0.664267 F
0.238684 F
0.430113 F


In [68]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank_str = [f'{r:.06f}' for r in rank.flatten()]
rank_str

['0.050682',
 '0.006449',
 '0.044166',
 '0.000333',
 '0.041805',
 '0.001856',
 '0.040835',
 '0.084292',
 '0.241429',
 '0.164106',
 '0.266632',
 '0.973829',
 '0.019062',
 '0.037045']

In [50]:
docs_chunks.shape

torch.Size([14, 100])

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3

In [2]:
import torch

In [4]:
device1 = torch.device('cpu')
device2 = torch.device('cuda')
device3 = torch.device('cpu')

In [5]:
print(device1 == device2)
print(device2 == device3)
print(device1 == device3)


False
False
True


In [7]:
device1.type, type(device1.type)

('cpu', str)

In [11]:
import numpy as np
toks = [
    np.ones(3) * 0.5,
    np.ones(3) * 1.7,
    np.array([-1, 7, 33])
]
np.stack(toks, axis=0)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [12]:
np.array(toks)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [13]:
np.allclose(np.array(toks), np.stack(toks, axis=0))

True