In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
from pprint import pprint
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.dsqrels import QrelsPlainBatch
from mllm.data.utils import load_qrels_datasets
from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, RANKER_HG_MODEL_CFG_FNAME
from mllm.model.encdec_ranker_hg import RankerHg
from mllm.config.model import TokenizerCfg, RankerHgCfg
from mllm.tokenization.chunk_tokenizer import tokenizer_from_config, ChunkTokenizer
from s_03_06_train_ranker_hg_qrels import RankerCosEmbLoss





# RankerHg inference
## Config and paths

In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'
DS_MSMARCO_DIR_PATH = DATA_PATH / 'msmarco'
DS_FEVER_DIR_PATH = DATA_PATH / 'fever'

TRAIN_RANKER_HG_PATH = DATA_PATH / 'train_mllm_ranker_hg_qrels'
# ranker_subdir = 'rankerhg-20250112_163410-msmarco-fever-inp128-pos_emb-lrs7x1-rdc_avg-step2-d768-h12-t0-dmlp_1024b_tanh_768b_tanh-tdo_t'
ranker_subdir = 'rankerhg-20250112_232417-msmarco-fever-inp128-pos_emb-lrs7x1-rdc_avg-step2-d512-h8-t0-dmlp_512-tdo_f'

ranker_train_path = TRAIN_RANKER_HG_PATH / ranker_subdir
ranker_snapshot_fpath = ranker_train_path / 'best.pth'
ranker_model_cfg_fpath = ranker_train_path / RANKER_HG_MODEL_CFG_FNAME
ranker_tkz_cfg_fpath = ranker_train_path / TOKENIZER_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


## Load models and datasets

In [4]:
model_cfg: RankerHgCfg = parse_yaml_file_as(RankerHgCfg, ranker_model_cfg_fpath)
pprint(model_cfg.dict())
inp_len = model_cfg.enc_pyr.inp_len

tkz_cfg: TokenizerCfg = parse_yaml_file_as(TokenizerCfg, ranker_tkz_cfg_fpath)
tkz = tokenizer_from_config(tkz_cfg)
ch_tkz = ChunkTokenizer(tkz_cfg.custom_tokens, tkz, n_emb_tokens=inp_len, fixed_size=True)
pad_tok = tkz_cfg.custom_tokens['pad'].ind

{'dec_rank': {'d_model': 512, 'mlp_layers': '512'},
 'enc_pyr': {'d_inner': 2048,
             'd_k': 64,
             'd_model': 512,
             'd_v': 64,
             'dropout_rate': 0.0,
             'inp_len': 128,
             'n_heads': 8,
             'n_layers': 7,
             'n_similar_layers': 1,
             'pad_idx': 50267,
             'reduct_type': <HgReductType.Avg: 'avg'>,
             'step': 2,
             'temperature': 0.0,
             'vocab_encoder': {'d_model': 512,
                               'd_word_vec': 512,
                               'dropout_rate': 0.0,
                               'inp_len': 128,
                               'n_vocab': 50271,
                               'pad_idx': 50267,
                               'pos_enc_type': <PosEncType.Emb: 'emb'>}}}


In [36]:
chkpt = torch.load(ranker_snapshot_fpath, map_location=device)
model = RankerHg(model_cfg).to(device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
model.eval()

RankerHg(
  (enc_pyr): EncoderPyramid(
    (vocab_encoder): VocabEncoder(
      (src_word_emb): Embedding(50271, 512, padding_idx=50267)
      (position_enc): Embedding(128, 512)
      (dropout): Dropout(p=0.0, inplace=False)
      (layer_norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
    )
    (enc_layers): ModuleList(
      (0-6): 7 x EncoderLayer(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=512, out_features=512, bias=False)
          (w_ks): Linear(in_features=512, out_features=512, bias=False)
          (w_vs): Linear(in_features=512, out_features=512, bias=False)
          (fc): Linear(in_features=512, out_features=512, bias=False)
          (attention): ScaledDotProductAttention(
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (layer_norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        )
        (pos_ffn): PositionwiseFeedForward(
         

In [37]:
ds_qrels = load_qrels_datasets([DS_MSMARCO_DIR_PATH, DS_FEVER_DIR_PATH], ch_tkz, inp_len, device)
dsqids = ds_qrels.df_qs['dsqid']

Join datasets:
   Msmarco. Queries: 372206. Docs: 3213835. QueryDocRels: 372206
   Fever. Queries: 123142. Docs: 5416568. QueryDocRels: 156101


## Inference on Qrels batch

In [38]:
loss_fn = RankerCosEmbLoss()

In [39]:
batch_size = 10
ib = 0
batch_off = ib * batch_size
dsqids_batch = dsqids.iloc[batch_off:batch_off + batch_size]

batch: QrelsPlainBatch = ds_qrels.get_batch_plain_qids(dsqids_batch)

Token indices sequence length is longer than the specified maximum sequence length for this model (13645 > 10000). Running this sequence through the model will result in indexing errors


In [40]:
batch.df_qs

Unnamed: 0_level_0,qid,query,dsid,dsqid
dsqid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,1185869,)what was the immediate impact of the success ...,1,0
1,1185868,_________ justice is designed to repair the ha...,1,1
2,1183785,elegxo meaning,1,2
3,645590,what does physical medicine do,1,3
4,186154,feeding rice cereal how many times per day,1,4
5,457407,most dependable affordable cars,1,5
6,441383,lithophile definition,1,6
7,683408,what is a flail chest,1,7
8,484187,put yourself on child support in texas,1,8
9,666321,what happens in a wrist sprain,1,9


In [41]:
batch.df_docs

Unnamed: 0_level_0,did,offset,dsid,dsdid,text,title
dsdid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2376038,59219,16949087845,1,2376038,The Manhattan Project Introduction Importance ...,Introduction
100673,59235,722556427,1,100673,"""From Wikipedia, the free encyclopedianavigati...",Restorative justice
1201976,576811,8572285503,1,1201976,"John, Sermons 2001The Ministry of the Holy Spi...",The Ministry of the Holy Spirit
1232741,576840,8792350091,1,1232741,"""Looking for a Physiatrist? Find a Physiatrist...",What Is a Physiatrist?
403227,114789,2878807930,1,403227,"""FEEDING GUIDELINES AGES 4 TO 6 MONTHSSTARTING...",FEEDING GUIDELINES AGES 4 TO 6 MONTHS
958172,389790,6827227963,1,958172,"""Surmount the snow with 7 of the best used all...",Surmount the snow with 7 of the best used all-...
2470237,576852,17619697389,1,2470237,"""From Wikipedia, the free encyclopedianavigati...",Goldschmidt classification
1380756,576861,9848986903,1,1380756,Flail chest describes a situation in which a p...,Flail Chest
277403,275258,1973432619,1,277403,Welcome!Notice ×THE TEXAS OAG CHILD SUPPORT WE...,Welcome!
2660349,576889,18969142686,1,2660349,Expert Reviewed How to Look After a Sprained W...,How to Look After a Sprained Wrist


In [42]:
batch.df_qrels

Unnamed: 0_level_0,qid,did,dsid,dsqid,dsdid
dsqid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,1185869,59219,1,0,2376038
1,1185868,59235,1,1,100673
2,1183785,576811,1,2,1201976
3,645590,576840,1,3,1232741
4,186154,114789,1,4,403227
5,457407,389790,1,5,958172
6,441383,576852,1,6,2470237
7,683408,576861,1,7,1380756
8,484187,275258,1,8,277403
9,666321,576889,1,9,2660349


In [43]:
for query in batch.df_qs['query']:
    print(query)

)what was the immediate impact of the success of the manhattan project?
_________ justice is designed to repair the harm to victim, the community and the offender caused by the offender criminal act. question 19 options:
elegxo meaning
what does physical medicine do
feeding rice cereal how many times per day
most dependable affordable cars
lithophile definition
what is a flail chest
put yourself on child support in texas
what happens in a wrist sprain


In [44]:
batch.qrels_masks

array([[ True, False, False, False, False, False, False, False, False,
        False],
       [False,  True, False, False, False, False, False, False, False,
        False],
       [False, False,  True, False, False, False, False, False, False,
        False],
       [False, False, False,  True, False, False, False, False, False,
        False],
       [False, False, False, False,  True, False, False, False, False,
        False],
       [False, False, False, False, False,  True, False, False, False,
        False],
       [False, False, False, False, False, False,  True, False, False,
        False],
       [False, False, False, False, False, False, False,  True, False,
        False],
       [False, False, False, False, False, False, False, False,  True,
        False],
       [False, False, False, False, False, False, False, False, False,
         True]])

In [49]:
qs_toks, qs_masks, docs_toks, docs_masks, qrels_masks = batch.gen_tensors()
out_rank = model(docs_toks, qs_toks)
out_rank = out_rank.detach().cpu().numpy()
out_rank = np.round(100 * out_rank).astype(int)
print(out_rank)

[[ 72  61   1  16 -32 -24   9 -24 -12 -23]
 [ 62  86  38  60   0 -55  33  -3  -5   0]
 [ 36  26   7 -11  -8 -22   4  -8 -12  -6]
 [  9  43  38  94 -15 -51  22   0  40   7]
 [ -1   7 -31   4   0 -12 -36  98   4  97]
 [-12 -37 -44 -69 -14  83 -38  -2 -24  -6]
 [-20 -35   5 -23 -16 -19  42 -46 -34 -49]
 [ -6   2 -22   1  -8 -11 -31 100  -3  99]
 [-13  -4 -32  56  -2  25 -48  -8  99  -6]
 [ -6   2 -22   2 -10 -10 -31 100  -1  99]]


In [34]:
loss_fn(out_rank, qrels_masks)

(tensor(0.3631, grad_fn=<DivBackward0>),
 tensor(0.5339, grad_fn=<DivBackward0>),
 tensor(0.1923, grad_fn=<DivBackward0>))

In [83]:
import re
from typing import Union

tkz.mask_token

In [84]:
NEWLINE_PAT = re.compile(r'[\n\r]+', re.M)
STR_DELIM_PAT = re.compile(r'\s+')


def mask_random_words(s: str, mask_tok_str: str, rem_ratio: float = 0.33, rem_prob: float = 0.15,
                      rem_conseq_ratio: float = 0.33, rem_conseq_prob: float = 0.3) -> Optional[str]:
    rv = np.random.rand()
    if rv < 1 - (rem_ratio + rem_conseq_ratio):
        return None
    lines = NEWLINE_PAT.split(s)
    res = []
    n_total = 0
    for line in lines:
        if not line:
            continue
        words = STR_DELIM_PAT.split(line)
        words = filter(None, words)
        words = list(words)
        if not words:
            continue
        res.append(words)
        n_total += len(words)
    
    if n_total < 5:
        return None

    if rv < 1 - rem_conseq_ratio:
        mask = np.random.rand(n_total) <= rem_prob
    else:
        n_rem = int(n_total * rem_conseq_prob)
        n_rem = np.random.randint(2, max(n_rem, 2) + 1)
        i = np.random.randint(n_total - n_rem + 1)
        mask = np.full(n_total, False, dtype=bool)
        mask[i:i + n_rem] = True

    im = 0
    for words in res:
        for iw in range(len(words)):
            if mask[im]:
                words[iw] = mask_tok_str
            im += 1

    return '\n'.join([' '.join(words) for words in res])
    


In [85]:
s = '''
# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore

#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
'''

In [90]:
s1 = mask_random_words(s, '<|mask|>')
print(s1)

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# <|mask|> <|mask|> <|mask|> <|mask|> <|mask|> <|mask|> global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/


In [81]:
toks = tkz(s1)['input_ids']
print(toks)

[2, 9485, 1925, 1670, 198, 50270, 19013, 9414, 1299, 2176, 11055, 318, 9456, 287, 257, 220, 50270, 19013, 9414, 1299, 13, 18300, 46430, 326, 460, 198, 2, 220, 50270, 1043, 379, 3740, 1378, 12567, 13, 785, 14, 12567, 14, 18300, 46430, 14, 2436, 672, 14, 12417, 14, 22289, 14, 42273, 9414, 1299, 13, 18300, 46430, 198, 50270, 290, 460, 307, 2087, 284, 262, 3298, 17606, 46430, 220, 50270, 23791, 656, 428, 2393, 13, 1114, 257, 517, 4523, 198, 2, 3038, 357, 1662, 7151, 8, 345, 460, 8820, 434, 262, 1708, 284, 8856, 220, 50270, 2104, 2126, 9483, 13, 198, 13, 485, 64, 14]


In [78]:
print(tkz_cfg.custom_tokens['mask'])

name='mask' repr='<|mask|>' special=False ind=50270


In [82]:
tkz_cfg.custom_tokens['mask'].repr

'<|mask|>'