In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.config.model import MllmRankerCfg, MllmEncdecCfg, TokenizerCfg
from mllm.data.dsqrels_embs import DsQrelsEmbs, QrelsEmbsBatch
from mllm.data.utils import load_qrels_datasets
from mllm.exp.args import ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import RankProbLoss, MllmRanker, MllmRankerLevel
from mllm.tokenization.chunk_tokenizer import gen_all_tokens, ChunkTokenizer, tokenizer_from_config
from mllm.train.utils import find_create_train_path, calc_print_batches



# Ranker level 1 inference
## Config and paths

In [12]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_MSMARCO_DIR_PATH = DATA_PATH / 'msmarco'
DS_FEVER_DIR_PATH = DATA_PATH / 'fever'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qrels_0'
TRAIN_RANKER_EMBS_PATH = DATA_PATH / 'train_mllm_ranker_qrels_1'
DS_WIKI_DIR_PATH = DATA_PATH / 'wiki_20200501_en/ch_100_fixed'
DS_EMBS_DIR_PATH = DATA_PATH / 'ranker_embs_msmarco_fever'
CFG_DIR_PATH = Path(os.path.abspath('.')).parent / 'mllm/config/cfg'

tokenizer_cfg_fpath = CFG_DIR_PATH / 'tokenizer_cfg_02.yaml'

ranker0_subdir = 'ranker-20241021_062053-msmarco-fever'
ranker0_train_path = TRAIN_RANKER_PATH / ranker0_subdir
ranker0_snapshot_fpath = ranker0_train_path / 'best.pth'

ranker1_subdir = 'ranker-lvl1-20241023_220614-enc-lrs2-embmatTrue-d256-h8-dec-lrs2-d256-h8-encdec-20241022_224217'
ranker1_train_path = TRAIN_RANKER_EMBS_PATH / ranker1_subdir
ranker1_snapshot_fpath = ranker1_train_path / 'best.pth'

ranker0_model_cfg_fpath = ranker0_train_path / RANKER_MODEL_CFG_FNAME
print(f'Ranker cfg fpath: {ranker0_model_cfg_fpath}. Exists: {ranker0_model_cfg_fpath.exists()}')
ranker1_model_cfg_fpath = ranker1_train_path / RANKER_MODEL_CFG_FNAME
print(f'Ranker cfg fpath: {ranker1_model_cfg_fpath}. Exists: {ranker1_model_cfg_fpath.exists()}')

Ranker cfg fpath: /home/misha/data/train_mllm_ranker_qrels_0/ranker-20241021_062053-msmarco-fever/ranker_model_cfg.yaml. Exists: True
Ranker cfg fpath: /home/misha/data/train_mllm_ranker_qrels_1/ranker-lvl1-20241023_220614-enc-lrs2-embmatTrue-d256-h8-dec-lrs2-d256-h8-encdec-20241022_224217/ranker_model_cfg.yaml. Exists: True


In [13]:
emb_chunk_size = 100
embs_chunk_size = 100
docs_batch_size = 10
chunk_size = 100
max_docs_embs = 10
docs_per_chunk = chunk_size // max_docs_embs

device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)
ranker0_model_cfg: MllmRankerCfg = parse_yaml_file_as(MllmRankerCfg, ranker0_model_cfg_fpath)
ranker1_model_cfg: MllmRankerCfg = parse_yaml_file_as(MllmRankerCfg, ranker1_model_cfg_fpath)
enc_cfg_1 = ranker1_model_cfg.encoders[1]


cpu


## Load models and datasets

In [14]:
tkz_cfg = parse_yaml_file_as(TokenizerCfg, tokenizer_cfg_fpath)
ch_tkz = tokenizer_from_config(tkz_cfg)

In [15]:
ds_qrels = load_qrels_datasets([DS_MSMARCO_DIR_PATH, DS_FEVER_DIR_PATH], ch_tkz, emb_chunk_size, device)

Join datasets:
   Msmarco. Queries: 372206. Docs: 3213835. QueryDocRels: 372206
   Fever. Queries: 123142. Docs: 5416568. QueryDocRels: 156101


In [16]:
ds_embs = DsQrelsEmbs(
    ds_dir_path=DS_EMBS_DIR_PATH, chunk_size=embs_chunk_size, emb_size=enc_cfg_1.d_model, emb_dtype=np.float32,
    doc_id_driven=True, max_docs_embs=max_docs_embs, device=device,
)

In [21]:
model_ranker_0 = MllmRankerLevel(ranker0_model_cfg, 0).to(device)
print(f'Loading model weights from {ranker0_snapshot_fpath}')
checkpoint = torch.load(ranker0_snapshot_fpath, map_location=device)
model_ranker_0.load_state_dict(checkpoint['model'])
model_ranker_0.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 3.2431778e-06 0.010897168
vocab_encoder.layer_norm.weight (256,) -0.09924302 -0.0019709559 0.099181
vocab_encoder.layer_norm.bias (256,) -0.09945156 -0.003854694 0.09921739
encoder.a_em () 0.027928842 0.027928842 0.027928842
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824769 -3.121439e-05 0.10824582
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10824748 -6.0666946e-05 0.10825218
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.1082516 -2.6242898e-05 0.10824809
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825094 0.0003963641 0.10825312
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.099667266 0.0017616183 0.09951814
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.098205104 -0.00443795 0.09944582
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846493 3.4025892e-05 0.06846483
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.0999995 -0.002172865 0.0999667

In [20]:
model_ranker_1 = MllmRankerLevel(ranker1_model_cfg, level=1).to(device)
print(f'Loading model weights from {ranker1_snapshot_fpath}')
checkpoint = torch.load(ranker1_snapshot_fpath, map_location=device)
model_ranker_1.load_state_dict(checkpoint['model'])
model_ranker_1.eval()
None

encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825261 -5.5802517e-05 0.10825173
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825261 0.00031761464 0.108252
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.108244516 0.00034677965 0.108246915
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108252995 0.00018068234 0.108252764
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.099787556 0.0006585007 0.09933175
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09794102 0.0020022742 0.09789349
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846473 -0.00012958856 0.06846424
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.099623226 0.0017477174 0.09999601
encoder.layer_stack.0.pos_ffn.w_2.weight (256, 1024) -0.06846489 -1.03440125e-05 0.068465084
encoder.layer_stack.0.pos_ffn.w_2.bias (256,) -0.099925615 0.0061488203 0.09897119
encoder.layer_stack.0.pos_ffn.layer_norm.weight (256,) -0.099364296 0.0023797723 0.09973752
encoder.l

### Using dataset queries

In [22]:
embs_view = ds_embs.get_embs_view(batch_size=docs_batch_size * docs_per_chunk)

In [23]:
i_batch = 0
embs_batch_it = embs_view.get_batch_iterator(with_queries=True)
for _ in range(i_batch):
    embs_batch = next(embs_batch_it)
embs_batch = next(embs_batch_it)

In [24]:
assert embs_batch.qs_embs is not None and embs_batch.qs_ind_len is not None
print(f'Batch {i_batch}. Docs embs: {embs_batch.docs_embs.shape}. Queries embs: {embs_batch.qs_embs.shape} {len(embs_batch.qs_ind_len)}')


Batch 0. Docs embs: (10, 100, 256). Queries embs: (149, 256) 149


In [25]:
len(embs_batch.df_docs_ids), len(embs_batch.df_qrels), len(embs_batch.df_qs_ids)

(769, 149, 149)

In [26]:
docs_embs_t = embs_batch.get_docs_embs_tensor()
qs_embs_t, qs_masks_t = embs_batch.get_qs_tensors()
out_rank = model_ranker_1.run_qs_embs(docs_embs_t, qs_embs_t, embs_batch.qs_ind_len)
print(f'out_rank. min, mean, max: {out_rank.min():0.4f}, {out_rank.mean():0.4f}, {out_rank.max():0.4f}')

out_rank. min, mean, max: 0.0889, 0.1000, 0.1130


In [27]:
print(out_rank.shape)

torch.Size([149, 10])


In [50]:
query_ind = 3
dsqid = embs_batch.qs_ind_len[query_ind][0]
print(f'Query {query_ind} {dsqid}: {ds_qrels.df_qs.loc[dsqid].query}')


Query 3 11791: +where does the pacific walrus ive


In [51]:
out_rank[query_ind]

tensor([0.1035, 0.0996, 0.1007, 0.0982, 0.0997, 0.0989, 0.1043, 0.1006, 0.1004,
        0.0940], grad_fn=<SelectBackward0>)

In [52]:
qs_masks_t[query_ind]

tensor([False, False, False, False, False, False, False,  True, False, False])

In [53]:
dsdids = embs_batch.ids.reshape((docs_batch_size, docs_per_chunk))
i_doc = 7
dsdids[i_doc]

array([1369469, 1901654,  632008, 2724055, 2305197, 1019421, 2345255,
       2805343, 3182944,  655573])

In [54]:
for i, dsdid in enumerate(dsdids[i_doc]):
    title, text = ds_qrels.get_doc(dsdid)
    print(f'{i + 1:02d}. Doc {dsdid}: {title[:100]}. {text[:200]}')

01. Doc 1369469: Morse Code & the Telegraph. "Early Forms of Long-Distance Communication Before the development of the electric telegraph in the 19th century revolutionized how information was transmitted across long distances, ancient civilizat
02. Doc 1901654: When Did Mardi Gras Start?. "Listen H urray! It's Tuesday! How often do you hear that? If you've had a bad Monday, then maybe Tuesday rolling around might brighten your spirits. Otherwise, Tuesday usually doesn't get too much at
03. Doc 632008: Walrus. "From Wikipedia, the free encyclopedianavigation search For other uses, see Walrus (disambiguation). Walrus Temporal range: Pleistocene to Recent Male Female with young Conservation status Vulnerable 
04. Doc 2724055: Where Is Bactria?. "Humanities ›History & Culture Where Is Bactria? Share Flipboard Email Printvia Wikipediaby Kallie Szczepanski Updated December 27, 2017Bactria is an ancient region of Central Asia, between the Hindu 
05. Doc 2305197: Find the Routing Number on a 

In [36]:
print(ds_qrels.df_qs.loc[dsqid])
qrel = ds_qrels.df_qrels.loc[dsqid]
print(qrel)

qid                           265
query    +what is fascia or facia
dsid                            1
dsqid                        9174
Name: 9174, dtype: object
qid         265
did       97881
dsid          1
dsqid      9174
dsdid    508390
Name: 9174, dtype: int64


In [38]:
title, text = ds_qrels.get_doc(qrel.dsdid)
print(f'{title[:100]}. {text[:400]}')

fascia. "fascia Also found in: Thesaurus, Medical, Legal, Encyclopedia, Wikipedia. Related to fascia: Colles fasciafas·cia (făsh′ə, fä′shə)n. pl. fas·ci·ae (făsh′ē-ē′, fä′shē-ē′)1. Anatomya. A sheet or band of fibrous connective tissue enveloping, separating, or bindingtogether muscles, organs, and other soft structures of the body.b. The tissue of which such a sheet or band is composed.2. Biology A broad


In [39]:
qrel.dsdid in embs_batch.ids

True

In [35]:
set(ds_qrels.df_qrels.index) == set(ds_qrels.df_qs.index)

True

In [40]:
embs_batch.ids

array([ 895028, 3085932, 1616445, 2396834,  465048,  679478, 1282538,
       2558828, 1918243, 1184395, 1931964,  991350, 1141254, 2247720,
       2726870, 1461296, 2434362,  127982, 1479578, 1000903,  509488,
       1161217, 3026873,  109872, 1446407, 1152228,  598469, 2971199,
       2951585, 1365772, 2635634, 2588840, 2495189,   38604, 1537726,
        699367, 2268633,  517965,  796713,  416362,  299777, 1742059,
        544317, 2520629, 1436264, 2858298,  412377,  456861,  688057,
        826758,  294285, 3081347,  650865, 2410836, 1107353, 3195243,
       1505331, 1627906, 2145371,  739799, 1234207,  508390,  198804,
       1895725,  239631, 1388082, 2056134,   51362,  760521,   38219,
       1369469, 1901654,  632008, 2724055, 2305197, 1019421, 2345255,
       2805343, 3182944,  655573,  308315,  580376, 1664242, 3115493,
       2165184, 2810761,   58275,  525091,  919745, 1647851,  406509,
       1163151, 2521349,  207467, 1206529,  803368,  597726,  874445,
        314351, 1573

In [41]:
for i, did in enumerate(embs_batch.ids):
    if did == 508390:
        print(i)
        break

61


In [42]:
df_off = ds_qrels.df_off
df_off.iloc[:5]

Unnamed: 0_level_0,did,offset,dsid,dsdid
dsdid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,1555982,0,1,0
1,301595,1852,1,1
2,1359209,7973,1,2
3,2147834,23656,1,3
4,1568809,31104,1,4


In [43]:
df_off.loc[[1,0,3,2,4]]

Unnamed: 0_level_0,did,offset,dsid,dsdid
dsdid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,301595,1852,1,1
0,1555982,0,1,0
3,2147834,23656,1,3
2,1359209,7973,1,2
4,1568809,31104,1,4


In [44]:
df = df_off.loc[[1,0,3,2,4]]
df.index, df.did

(Index([1, 0, 3, 2, 4], dtype='int64', name='dsdid'),
 dsdid
 1     301595
 0    1555982
 3    2147834
 2    1359209
 4    1568809
 Name: did, dtype: int64)

In [None]:
pad_tok = tkz_cfg.custom_tokens['pad']
inp_len = ranker1_model_cfg.vocab_encoder.inp_len

def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = ch_tkz.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = ch_tkz(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)