In [3]:
%load_ext autoreload
%autoreload 2

In [6]:
import os
from pathlib import Path
from pprint import pprint
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer, AutoTokenizer

from mllm.data.dsqrels import QrelsPlainBatch
from mllm.data.utils import load_qrels_datasets
from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import RANKER_BERT_MODEL_CFG_FNAME
from mllm.model.encdec_ranker_hg import RankerBert
from mllm.config.model import TokenizerCfg, RankerBertCfg
from mllm.tokenization.chunk_tokenizer import ChunkTokenizer, gen_all_tokens
from mllm.model.losses import RankerCosEmbLoss



# RankerHg inference
## Config and paths

In [9]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'
DS_MSMARCO_DIR_PATH = DATA_PATH / 'msmarco'
DS_FEVER_DIR_PATH = DATA_PATH / 'fever'
ds_dir_paths = [DS_MSMARCO_DIR_PATH, DS_FEVER_DIR_PATH]

TRAIN_RANKER_BERT_PATH = DATA_PATH / 'train_mllm_ranker_bert_qrels'
ranker_subdir = 'rankerbert-20250203_234515-msmarco-fever-bert-base-uncased-inp128-d768-emb_cls-dmlp_none-tdo_f'

ranker_train_path = TRAIN_RANKER_BERT_PATH / ranker_subdir
ranker_snapshot_fpath = ranker_train_path / 'best.pth'
ranker_model_cfg_fpath = ranker_train_path / RANKER_BERT_MODEL_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


## Load models and datasets

In [14]:
model_cfg: RankerBertCfg = parse_yaml_file_as(RankerBertCfg, ranker_model_cfg_fpath)
inp_len = model_cfg.enc_bert.inp_len
model = RankerBert(model_cfg).to(device)
print(f'Load model from {ranker_snapshot_fpath}')
checkpoint = torch.load(ranker_snapshot_fpath)
model.load_state_dict(checkpoint['model'], strict=True)
model.eval()

Load model from /home/misha/data/train_mllm_ranker_bert_qrels/rankerbert-20250203_234515-msmarco-fever-bert-base-uncased-inp128-d768-emb_cls-dmlp_none-tdo_f/best.pth


RankerBert(
  (enc_bert): EncoderBert(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=Tru

In [12]:
tkz = AutoTokenizer.from_pretrained(model_cfg.enc_bert.pretrained_model_name)
print(tkz)
custom_tokens = gen_all_tokens()
ch_tkz = ChunkTokenizer(custom_tokens, tkz, n_emb_tokens=inp_len, fixed_size=True)
dss = load_qrels_datasets(ds_dir_paths, ch_tkz, inp_len, device, join=False)
for ds in dss:
    print(ds)

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
Msmarco. Queries: 372206. Docs: 3213835. QueryDocRels: 372206
Fever. Queries: 123142. Docs: 5416568. QueryDocRels: 156101


In [13]:
ds1, ds2 = dss[0], dss[1]
ds1

Msmarco. Queries: 372206. Docs: 3213835. QueryDocRels: 372206

## Inference on Qrels batch

In [16]:
loss_fn = RankerCosEmbLoss()
ds = ds1
dsqids = ds.df_qs['dsqid']

In [17]:
batch_size = 10
ib = 0
batch_off = ib * batch_size
dsqids_batch = dsqids.iloc[batch_off:batch_off + batch_size]

batch: QrelsPlainBatch = ds.get_batch_plain_qids(dsqids_batch)

Token indices sequence length is longer than the specified maximum sequence length for this model (5237 > 512). Running this sequence through the model will result in indexing errors


In [18]:
batch.df_qs

Unnamed: 0_level_0,qid,query,dsid,dsqid
dsqid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,1185869,)what was the immediate impact of the success ...,1,0
1,1185868,_________ justice is designed to repair the ha...,1,1
2,1183785,elegxo meaning,1,2
3,645590,what does physical medicine do,1,3
4,186154,feeding rice cereal how many times per day,1,4
5,457407,most dependable affordable cars,1,5
6,441383,lithophile definition,1,6
7,683408,what is a flail chest,1,7
8,484187,put yourself on child support in texas,1,8
9,666321,what happens in a wrist sprain,1,9


In [19]:
batch.df_docs

Unnamed: 0_level_0,did,offset,dsid,dsdid,text,title
dsdid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2376038,59219,16949087845,1,2376038,The Manhattan Project Introduction Importance ...,Introduction
100673,59235,722556427,1,100673,"""From Wikipedia, the free encyclopedianavigati...",Restorative justice
1201976,576811,8572285503,1,1201976,"John, Sermons 2001The Ministry of the Holy Spi...",The Ministry of the Holy Spirit
1232741,576840,8792350091,1,1232741,"""Looking for a Physiatrist? Find a Physiatrist...",What Is a Physiatrist?
403227,114789,2878807930,1,403227,"""FEEDING GUIDELINES AGES 4 TO 6 MONTHSSTARTING...",FEEDING GUIDELINES AGES 4 TO 6 MONTHS
958172,389790,6827227963,1,958172,"""Surmount the snow with 7 of the best used all...",Surmount the snow with 7 of the best used all-...
2470237,576852,17619697389,1,2470237,"""From Wikipedia, the free encyclopedianavigati...",Goldschmidt classification
1380756,576861,9848986903,1,1380756,Flail chest describes a situation in which a p...,Flail Chest
277403,275258,1973432619,1,277403,Welcome!Notice ×THE TEXAS OAG CHILD SUPPORT WE...,Welcome!
2660349,576889,18969142686,1,2660349,Expert Reviewed How to Look After a Sprained W...,How to Look After a Sprained Wrist


In [20]:
batch.df_qrels

Unnamed: 0_level_0,qid,did,dsid,dsqid,dsdid
dsqid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,1185869,59219,1,0,2376038
1,1185868,59235,1,1,100673
2,1183785,576811,1,2,1201976
3,645590,576840,1,3,1232741
4,186154,114789,1,4,403227
5,457407,389790,1,5,958172
6,441383,576852,1,6,2470237
7,683408,576861,1,7,1380756
8,484187,275258,1,8,277403
9,666321,576889,1,9,2660349


In [21]:
for query in batch.df_qs['query']:
    print(query)

)what was the immediate impact of the success of the manhattan project?
_________ justice is designed to repair the harm to victim, the community and the offender caused by the offender criminal act. question 19 options:
elegxo meaning
what does physical medicine do
feeding rice cereal how many times per day
most dependable affordable cars
lithophile definition
what is a flail chest
put yourself on child support in texas
what happens in a wrist sprain


In [22]:
batch.qrels_masks

array([[ True, False, False, False, False, False, False, False, False,
        False],
       [False,  True, False, False, False, False, False, False, False,
        False],
       [False, False,  True, False, False, False, False, False, False,
        False],
       [False, False, False,  True, False, False, False, False, False,
        False],
       [False, False, False, False,  True, False, False, False, False,
        False],
       [False, False, False, False, False,  True, False, False, False,
        False],
       [False, False, False, False, False, False,  True, False, False,
        False],
       [False, False, False, False, False, False, False,  True, False,
        False],
       [False, False, False, False, False, False, False, False,  True,
        False],
       [False, False, False, False, False, False, False, False, False,
         True]])

In [23]:
qs_toks, qs_masks, docs_toks, docs_masks, qrels_masks = batch.gen_tensors()
out_rank = model(docs_toks, qs_toks)
out_rank = out_rank.detach().cpu().numpy()
out_rank = np.round(100 * out_rank).astype(int)
print(out_rank)

[[ 63  -2  -9  66  -3 -11  -6 -22 -12 -11]
 [ 44  94  -2  -2  -4 -10   1  -2  10  -6]
 [ 38   6  10  -3  -1 -32   8 -16 -22 -10]
 [ -8  -7 -13  35 -24 -14   0  49  -9  81]
 [-12  -3 -16  -2  95  -8 -12 -19  -1  -9]
 [ -4 -15  -4 -11  -4  96 -34   0 -10  -2]
 [  2  -5 -15  -3  -8 -19  56 -14 -12 -27]
 [-13   5 -14 -15 -17   1 -22 100  -7  61]
 [-23  30  -5 -12  -2   8 -13  -5  98  -3]
 [ -7   0 -13  -9  -6  -3 -40  63  -3  99]]


In [34]:
loss_fn(out_rank, qrels_masks)

(tensor(0.3631, grad_fn=<DivBackward0>),
 tensor(0.5339, grad_fn=<DivBackward0>),
 tensor(0.1923, grad_fn=<DivBackward0>))

In [25]:
doc_ind = 3
print(batch.df_docs.iloc[doc_ind]['title'])
print(batch.df_docs.iloc[doc_ind]['text'])

What Is a Physiatrist?
"Looking for a Physiatrist? Find a Physiatrist in Spine-health's growing Doctor Directory. A physiatrist practices in the field of physiatry - also called physical medicine and rehabilitation - which is a branch of medicine that specializes in diagnosis, treatment, and management of disease primarily using ""physical"" means, such as physical therapy and medications. Essentially, physiatrists specialize in a wide variety of treatments for the musculoskeletal system - the muscles, bones, and associated nerves, ligaments, tendons, and other structures - and the musculoskeletal disorders that cause pain and/or difficulty with functioning. Physiatrists do not perform surgery. Physiatry for Back Pain Video A physiatrist's treatment focuses on helping the patient become as functional and pain-free as possible in order to participate in and enjoy life as fully as possible. A physiatrist can be either a medical doctor (MD) or a doctor of osteopathic medicine (DO). A phys