In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRanker
from mllm.config.model import MllmEncdecCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens
from mllm.utils.utils import read_tsv, write_tsv



In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'


In [19]:
embs_dpath = DATA_PATH / 'ranker_embs_msmarco_fever'
docs_ids_fpath = embs_dpath / 'docs_ids.tsv'
qs_ids_fpath = embs_dpath / 'qs_ids.tsv'
df_docs_ids = read_tsv(docs_ids_fpath)
df_qs_ids = read_tsv(qs_ids_fpath)
df_docs_ids

Unnamed: 0,ds_ids,ds_doc_ids
0,1,0
1,1,0
2,1,0
3,1,0
4,1,0
...,...,...
69767920,2,8630399
69767921,2,8630400
69767922,2,8630401
69767923,2,8630402


In [21]:
df_docs_ids.rename(columns={'ds_ids': 'ds_id', 'ds_doc_ids': 'ds_doc_id'}, inplace=True)
df_qs_ids.rename(columns={'ds_ids': 'ds_id', 'ds_query_ids': 'ds_query_id'}, inplace=True)
df_qs_ids

Unnamed: 0,ds_id,ds_query_id
0,1,0
1,1,1
2,1,2
3,1,3
4,1,4
...,...,...
495368,2,495343
495369,2,495344
495370,2,495345
495371,2,495346


In [22]:
df_docs_ids['doc_emb_id'] = np.arange(len(df_docs_ids))
df_qs_ids['query_emb_id'] = np.arange(len(df_qs_ids))
df_docs_ids

Unnamed: 0,ds_id,ds_doc_id,doc_emb_id
0,1,0,0
1,1,0,1
2,1,0,2
3,1,0,3
4,1,0,4
...,...,...,...
69767920,2,8630399,69767920
69767921,2,8630400,69767921
69767922,2,8630401,69767922
69767923,2,8630402,69767923


In [26]:
print(f'Save {len(df_docs_ids)} docs ids in {docs_ids_fpath}')
write_tsv(df_docs_ids, docs_ids_fpath)
print(f'Save {len(df_qs_ids)} qs ids in {qs_ids_fpath}')
write_tsv(df_qs_ids, qs_ids_fpath)

Save 69767925 docs ids in /home/misha/data/ranker_embs_msmarco_fever/docs_ids.tsv
Save 495373 qs ids in /home/misha/data/ranker_embs_msmarco_fever/qs_ids.tsv


In [17]:
model_level = 1
model_cfg_fpath = Path('../mllm/config/cfg/encdec_model_cfg_02.yaml')
model_cfg = parse_yaml_file_as(MllmEncdecCfg, model_cfg_fpath)
model_cfg.encoders[model_level].with_emb_mat = False

# train_subdir = 'encdec-l1-20240918_063547-msmarco-fever'
train_subdir = 'encdec-l1-20241005_175446-msmarco-fever'
train_dir_path = DATA_PATH / 'train_mllm_encdec_1' / train_subdir
checkpoint_path = train_dir_path / 'best.pth'
checkpoint = torch.load(checkpoint_path)

In [18]:

model = MllmEncdecLevel(model_cfg, model_level)
model.load_state_dict(checkpoint['model'])

encoder.a_em () 0.09967066 0.09967066 0.09967066
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825093 -0.00042184087 0.108251125
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108249776 7.992712e-05 0.10824896
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824746 0.00020865718 0.108252384
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108241595 -0.00026560333 0.10825174
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09946153 -0.0003807206 0.09928878
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.098331094 0.0021152585 0.099065565
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846515 2.9833169e-05 0.06846509
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09996591 0.0009740598 0.09976818
encoder.layer_stack.0.pos_ffn.w_2.weight (256, 1024) -0.06846526 3.0548663e-06 0.06846491
encoder.layer_stack.0.pos_ffn.w_2.bias (256,) -0.09884465 0.0037791403 0.099840425
encoder.layer_stack.0.pos_ffn.layer_norm.weight (256,

<All keys matched successfully>

In [None]:
w, b = model.encoder.w_em.weight, model.encoder.w_em.bias
print(b)

In [11]:
print(w.shape)

torch.Size([1, 100])


In [3]:
from transformers import pipeline

captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
img_url = 'https://huggingface.co/datasets/Narsil/image_dummy/resolve/main/parrots.png'
img_url = 'https://media.gettyimages.com/id/1245486932/photo/lusail-city-qatar-lionel-messi-of-argentina-has-a-shot-at-goal-from-the-free-kick-during-the.jpg?s=612x612&w=gi&k=20&c=_VOAohhrtnB__my2cBZza_ohApNdmO9vhesozECG5X0='
img_url = 'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSkCqcoHGJvI-tl_P5xabXjeCjRWKkH4fJqbQ&s'
captioner(img_url)


Device set to use cuda:0


[{'generated_text': 'two humpbacks swimming in the ocean'}]

In [11]:
prec = 8
x = 44 / 27
k = 'key'
s = '{k}: {x:.%df}' % prec
s

'{k}: {x:.8f}'

In [12]:
s.format(k=k, x=x)

'key: 1.62962963'