In [1]:
%load_ext autoreload
%autoreload 2

In [112]:
import os
from pathlib import Path
import sys
from typing import Optional, cast
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer, AutoTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import ENCMIX_BERT_MODEL_CFG_FNAME
from mllm.train.utils import EedWikiIterator, QnaQuesInp, get_squadv2_df, get_squadv2_batch
from mllm.model.encmix import EncmixBert
from mllm.config.model import EncmixBertCfg, EncmixTrainDsType


# EncmixBert inference
## Config

In [None]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'

TRAIN_ENCDEC_BERT_PATH = DATA_PATH / 'train_mllm_encmix_bert'
encdec_subdir = 'encmixbert-20250408_115220-bert-base-uncased-d768-inp128-oemb_inp'
encdec_subdir = 'encmixbert-20250408_225609-bert-base-uncased-d768-inp128-oemb_new'
encdec_subdir = 'encmixbert-20250410_155235-bert-base-uncased-d768-inp128-oemb_new-ds_qna'

encmix_train_path = TRAIN_ENCDEC_BERT_PATH / encdec_subdir
encmix_snapshot_fpath = encmix_train_path / 'best.pth'
encmix_model_cfg_fpath = encmix_train_path / ENCMIX_BERT_MODEL_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

batch_size = 5

cpu


In [107]:
ds_type = EncmixTrainDsType.Msk
for part in encdec_subdir.split('-'):
    if part.startswith('ds_'):
        ds_type = EncmixTrainDsType(part[3:])
print(ds_type)

EncmixTrainDsType.Qna


## Load model

In [108]:
model_cfg = parse_yaml_file_as(EncmixBertCfg, encmix_model_cfg_fpath)
tkz = AutoTokenizer.from_pretrained(model_cfg.tokenizer_name)
tkz = cast(PreTrainedTokenizer, tkz)
print(model_cfg)
print(tkz)

inp_len=128 d_model=768 pretrained_model_name='bert-base-uncased' tokenizer_name='bert-base-uncased' out_embs_type=<EncmixOutEmbsType.Inp: 'inp'>
BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False

In [109]:
chkpt = torch.load(encmix_snapshot_fpath, map_location=device)
model = EncmixBert(cfg=model_cfg, tkz=tkz, device=device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
del chkpt
model.eval()
# print(model)
None

## Wikipedia mask prediction

In [81]:
wiki_ds_name = '20200501.en'
print(f'Loading Wikipedia dataset: {wiki_ds_name}')
wiki_ds_subdir = 'wikipedia'
dss = load_dataset(wiki_ds_subdir, wiki_ds_name, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
ds = cast(Dataset, dss['train'])
n_docs = len(ds)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')

Loading Wikipedia dataset: 20200501.en


Reusing dataset wikipedia (/home/misha/data/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

Wikipedia 20200501.en docs: 6078422


In [83]:
doc_inds = np.arange(len(ds))
ds_it = EedWikiIterator(
    ds=ds, inds=doc_inds, inp_len=model_cfg.inp_len, tkz=tkz, docs_batch_size=batch_size, device=device,
    preserve_edge_tokens=True,
)

In [95]:
i_batch = 0
i1, i2 = i_batch * batch_size, (i_batch + 1) * batch_size
batch_inds = doc_inds[i1:i2]

docs_toks_aug_t, docs_toks_tgt_t = ds_it.get_batch_tokens(batch_inds)

In [91]:
for i, doc_ind in enumerate(batch_inds):
    doc = ds[doc_ind.item()]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{i:03d} {text}')

000 Yangliuqing () is a market town in Xiqing District, in the western suburbs of Tianjin, People's Republic of China. Despite its relatively small size, it has been named since 2006 in the "famous historical and cultural market towns in China".\n\nIt is best known in China for creating nianhua or Yangliuqing nianhua. For more than 400 years, Yangliuqing has in effect specialised in the creation of these woodcuts for the New Year.  wood block prints using vivid colourschemes to portray traditional scenes of children's games often interwoven with auspiciouse objects.\n\n, it had 27 residential communities () and 25 villages under its administration.\n\nShi Family Grand Courtyard\n\nShi Family Grand Courtyard (Tiānjīn Shí Jiā Dà Yuàn, 天津石家大院) is situated in Yangliuqing Town of Xiqing District, which is the former residence of wealthy merchant Shi Yuanshi - the 4th son of Shi Wancheng, one of the eight great masters in Tianjin. First built in 1875, it covers over 6,000 square meters, incl

In [96]:
for i, doc_ind in enumerate(batch_inds):
    # doc = ds[doc_ind]
    # title, text = doc['title'], doc['text'].replace('\n', '\\n')
    toks_aug = docs_toks_aug_t[i]
    txt_aug = tkz.decode(toks_aug)
    print(f'{doc_ind:03d} {txt_aug}')
    if (toks_aug == tkz.mask_token_id).sum() > 0:
        txt_tgt = tkz.decode(docs_toks_tgt_t)
        print(f'{doc_ind:03d} {txt_tgt}')

000 [CLS] ( ) and 25 villages under its administration. shi family grand courtyard shi family grand courtyard ( tianjin shi jia da yuan, 天 [UNK] 石 家 大 [UNK] ) is situated in yangliuqing town of xiqing district, which is the former residence of wealthy merchant shi yuanshi - the 4th son of shi wancheng, one of the eight great masters in tianjin. first built in 1875, it covers over 6, 000 square meters, including large and small yards and over 200 folk houses, a theater and over 275 rooms that served as apartments and places of business and worship for this powerful family. shifu garden, which finished its [SEP]
001 [CLS] first disability service organisations to achieve quality accreditation. the services and products they offer are : packaging assembly sewing collating & mailing furniture - retail furniture – manufacture for commercial market worm farming work crews pet & grain – retail in 2018, after 65 years of bettering people ’ s lives, orana identified a community need and expande

In [97]:
out_logits = model.run_chunks_plain_seq(chunk_toks=docs_toks_aug_t, target_toks=docs_toks_tgt_t)
out_logits.shape

torch.Size([18, 30522])

In [98]:
probs_pred = torch.softmax(out_logits, dim=-1)
toks_pred = torch.argmax(probs_pred, dim=-1)
# print(toks_pred.shape)
print(txt_tgt)
txt_pred = tkz.decode(toks_pred)
print(txt_pred)

church and the construction of the new church. only parts of the old medieval church remained [SEP]
church, the church of the church church. the the of the [SEP] [SEP] [SEP] [SEP] [SEP]


## Qna prediction

In [117]:
exclude_empty_answers = True
ques_inp = QnaQuesInp.Enc
df_sq = get_squadv2_df(exclude_empty_answers=exclude_empty_answers)
sq_inds = np.arange(len(df_sq))

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


In [119]:
i_batch = 0
i1, i2 = i_batch * batch_size, (i_batch + 1) * batch_size
batch_inds = sq_inds[i1:i2]
sq_batch = get_squadv2_batch(tkz=tkz, df_sq=df_sq, inds=batch_inds, inp_len=model_cfg.inp_len, device=device, ques_inp=ques_inp)
ctx_toks_t, (q_toks_t, a_toks_t, a_att_masks_t, a_tgt_masks_t) = sq_batch.gen_tensors()
print(ctx_toks_t.shape, len(q_toks_t), len(a_toks_t))


torch.Size([5, 128]) 4 4


In [131]:
for ctx_toks in ctx_toks_t:
    ctx_txt = tkz.decode(ctx_toks)
    print(ctx_txt)

[CLS] context1. there are several technologies aimed to provide better experience to passengers suffering from claustrophobia, anthropophobia or social anxiety. israeli startup digigage uses motion sensors to scroll the pre - rendered images, building and floor - specific content on a screen embedded into the wall as the cab moves up and down. british company lifteye provides a virtual window technology to turn common elevator into panoramic. it creates 3d video panorama using live feed from cameras placed vertically along the facade and synchronizes it with cab movement. the video is projected on a wall - sized screens making it look like the walls are
[CLS] context2. sleep - and - charge usb ports can be used to charge electronic devices even when the computer is switched off. normally, when a computer is powered off the usb ports are powered down, preventing phones and other devices from charging. sleep - and - charge usb ports remain powered even when the computer is off. on laptop

In [136]:
for i in range(len(q_toks_t)):
    q_toks, a_toks = q_toks_t[i], a_toks_t[i]
    q_toks = q_toks[q_toks != tkz.pad_token_id]
    out_logits = model.run_chunks_plain_seq(chunk_toks=ctx_toks_t, plain_toks=q_toks, target_toks=a_toks)
    out_logits.shape
    q_txt, a_txt = tkz.decode(q_toks), tkz.decode(a_toks)
    print(f'{i:02d}. Q: {q_txt}')
    print(f'{i:02d}. A: {a_txt}')
    probs_pred = torch.softmax(out_logits, dim=-1)
    toks_pred = torch.argmax(probs_pred, dim=-1)
    txt_pred = tkz.decode(toks_pred)
    print(f'{i:02d}. M: {txt_pred}')
    print('-' * 50)

00. Q: [CLS] context5. question : what is sacks'career?
00. A: neurologist [SEP]
00. M: [SEP]uter [SEP] [SEP]
--------------------------------------------------
01. Q: [CLS] context1. question : lifteye uses virtual window technology for what?
01. A: to turn common elevator into panoramic [SEP]
01. M: [SEP] a [SEP] [SEP] [SEP] [SEP] - [SEP] [SEP]
--------------------------------------------------
02. Q: [CLS] context4. question : what does'yoh - hu'mean?
02. A: music people [SEP]
02. M: [SEP] [SEP] [SEP]
--------------------------------------------------
03. Q: [CLS] context3. question : what common terminological problem can sometimes lead to confusion surrounding a treaty?
03. A: named something other than a treaty, [SEP]
03. M: treaty [SEP] [SEP] treaty the treaty of and
--------------------------------------------------
