In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.config.model import MllmRankerCfg, MllmEncdecCfg, TokenizerCfg
from mllm.data.dsqrels_embs import DsQrelsEmbs, QrelsEmbsBatch
from mllm.data.utils import load_qrels_datasets
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import RankProbLoss, MllmRankerLevel
from mllm.tokenization.chunk_tokenizer import gen_all_tokens, ChunkTokenizer, tokenizer_from_config
from mllm.train.utils import find_create_train_path, calc_print_batches



In [14]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_MSMARCO_DIR_PATH = DATA_PATH / 'msmarco'
DS_FEVER_DIR_PATH = DATA_PATH / 'fever'
TRAIN_RANKER_EMBS_PATH = DATA_PATH / 'train_mllm_ranker_qrels_1'
DS_WIKI_DIR_PATH = DATA_PATH / 'wiki_20200501_en/ch_100_fixed'
DS_EMBS_DIR_PATH = DATA_PATH / 'ranker_embs_msmarco_fever'
CFG_DIR_PATH = Path(os.path.abspath('.')).parent / 'mllm/config/cfg'

tokenizer_cfg_fpath = CFG_DIR_PATH / 'tokenizer_cfg_01.yaml'

ranker_subdir = 'ranker-l1-20241012_102220-encdec-l1-20241005_175446-msmarco-fever'
ranker_train_path = TRAIN_RANKER_EMBS_PATH / ranker_subdir
ranker_snapshot_path = ranker_train_path / 'best.pth'

ranker_model_cfg_fpath = CFG_DIR_PATH / 'ranker_model_cfg_02.yaml'
print(f'Ranker cfg fpath: {ranker_model_cfg_fpath}. Exists: {ranker_model_cfg_fpath.exists()}')

Ranker cfg fpath: /home/misha/prog/mllm/mllm/config/cfg/ranker_model_cfg_02.yaml. Exists: True


In [15]:
emb_chunk_size = 100
embs_chunk_size = 100
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)
model_level = 1
ranker_model_cfg: MllmRankerCfg = parse_yaml_file_as(MllmRankerCfg, ranker_model_cfg_fpath)
enc_cfg = ranker_model_cfg.encoders[model_level]
max_docs_embs = 0


cpu


In [16]:
tkz_cfg = parse_yaml_file_as(TokenizerCfg, tokenizer_cfg_fpath)
ch_tkz = tokenizer_from_config(tkz_cfg)

In [17]:
ds_qrels = load_qrels_datasets([DS_MSMARCO_DIR_PATH, DS_FEVER_DIR_PATH], ch_tkz, emb_chunk_size, device)

Join datasets:
   Msmarco. Queries: 372206. Docs: 3213835. QueryDocRels: 372206
   Fever. Queries: 123142. Docs: 5416568. QueryDocRels: 156101


In [18]:
ds_embs = DsQrelsEmbs(
    ds_dir_path=DS_EMBS_DIR_PATH, chunk_size=embs_chunk_size, emb_size=enc_cfg.d_model, emb_dtype=np.float32,
    doc_id_driven=True, max_docs_embs=max_docs_embs, device=device,
)

In [19]:
model_ranker = MllmRankerLevel(ranker_model_cfg, model_level).to(device)
print(f'Loading model weights from {ranker_snapshot_path}')
checkpoint = torch.load(ranker_snapshot_path, map_location=device)
model_ranker.load_state_dict(checkpoint['model'])

encoder.a_em () -0.04032854 -0.04032854 -0.04032854
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824747 -0.00036556713 0.10825267
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108251445 -0.00040734917 0.108250014
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825122 -0.00033577072 0.10825282
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108250886 -0.00018787659 0.10824925
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.099928476 -0.0032927117 0.098024465
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09862503 0.0036262968 0.09957059
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846341 2.1765818e-05 0.068464585
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09999468 -0.00026246576 0.09977813
encoder.layer_stack.0.pos_ffn.w_2.weight (256, 1024) -0.06846482 5.7863457e-05 0.06846513
encoder.layer_stack.0.pos_ffn.w_2.bias (256,) -0.09902563 0.00016542478 0.09985527
encoder.layer_stack.0.pos_ffn.layer_norm.weig

<All keys matched successfully>