In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
from pathlib import Path
import sys
from typing import Optional, cast
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer, AutoTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import ENCMIX_BERT_MODEL_CFG_FNAME
from mllm.train.utils import EedWikiIterator, QnaQuesInp, get_squadv2_df, get_squadv2_batch
from mllm.model.encmix import EncmixBertGan
from mllm.config.model import EncmixBertCfg, EncmixTrainDsType


# EncmixBert inference
## Config

In [5]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'

TRAIN_ENCMIX_BERT_PATH = DATA_PATH / 'train_mllm_encmix_bert_gan'
# encmix_subdir = 'encmixbert-20250413_220133-bert-base-uncased-d768-inp256-oemb_inp-ds_qna'
# encmix_subdir = 'encmixbert-20250414_221310-bert-base-uncased-d768-inp256-oemb_inp-ds_qna'
encmix_subdir = 'encmixbert-20250420_134303-bert-base-uncased-d768-inp256-oemb_inp-tte_t-ds_qna'

encmix_train_path = TRAIN_ENCMIX_BERT_PATH / encmix_subdir
encmix_snapshot_fpath = encmix_train_path / 'best.pth'
encmix_model_cfg_fpath = encmix_train_path / ENCMIX_BERT_MODEL_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

batch_size = 5

cpu


In [6]:
ds_type = EncmixTrainDsType.Msk
for part in encmix_subdir.split('-'):
    if part.startswith('ds_'):
        ds_type = EncmixTrainDsType(part[3:])
print(ds_type)

EncmixTrainDsType.Qna


## Load model

In [7]:
model_cfg = parse_yaml_file_as(EncmixBertCfg, encmix_model_cfg_fpath)
tkz = AutoTokenizer.from_pretrained(model_cfg.tokenizer_name)
tkz = cast(PreTrainedTokenizer, tkz)
print(model_cfg)
print(tkz)

inp_len=256 d_model=768 pretrained_model_name='bert-base-uncased' tokenizer_name='bert-base-uncased' out_embs_type=<EncmixOutEmbsType.Inp: 'inp'> token_types_for_embs=True
BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_wo

In [8]:
chkpt = torch.load(encmix_snapshot_fpath, map_location=device)
model = EncmixBertGan(cfg=model_cfg, tkz=tkz, device=device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
del chkpt
model.eval()
# print(model)
None

## Wikipedia mask prediction

In [9]:
wiki_ds_name = '20200501.en'
print(f'Loading Wikipedia dataset: {wiki_ds_name}')
wiki_ds_subdir = 'wikipedia'
dss = load_dataset(wiki_ds_subdir, wiki_ds_name, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
ds = cast(Dataset, dss['train'])
n_docs = len(ds)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')

Loading Wikipedia dataset: 20200501.en


Reusing dataset wikipedia (/home/misha/data/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

Wikipedia 20200501.en docs: 6078422


In [10]:
doc_inds = np.arange(len(ds))
ds_it = EedWikiIterator(
    ds=ds, inds=doc_inds, inp_len=model_cfg.inp_len, tkz=tkz, docs_batch_size=batch_size, device=device,
    preserve_edge_tokens=True,
)

In [11]:
i_batch = 0
i1, i2 = i_batch * batch_size, (i_batch + 1) * batch_size
batch_inds = doc_inds[i1:i2]

docs_toks_aug_t, docs_toks_tgt_t = ds_it.get_batch_tokens(batch_inds)

Token indices sequence length is longer than the specified maximum sequence length for this model (730 > 512). Running this sequence through the model will result in indexing errors


In [12]:
for i, doc_ind in enumerate(batch_inds):
    doc = ds[doc_ind.item()]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{i:03d} {text}')

000 Yangliuqing () is a market town in Xiqing District, in the western suburbs of Tianjin, People's Republic of China. Despite its relatively small size, it has been named since 2006 in the "famous historical and cultural market towns in China".\n\nIt is best known in China for creating nianhua or Yangliuqing nianhua. For more than 400 years, Yangliuqing has in effect specialised in the creation of these woodcuts for the New Year.  wood block prints using vivid colourschemes to portray traditional scenes of children's games often interwoven with auspiciouse objects.\n\n, it had 27 residential communities () and 25 villages under its administration.\n\nShi Family Grand Courtyard\n\nShi Family Grand Courtyard (Tiānjīn Shí Jiā Dà Yuàn, 天津石家大院) is situated in Yangliuqing Town of Xiqing District, which is the former residence of wealthy merchant Shi Yuanshi - the 4th son of Shi Wancheng, one of the eight great masters in Tianjin. First built in 1875, it covers over 6,000 square meters, incl

In [13]:
for i, doc_ind in enumerate(batch_inds):
    # doc = ds[doc_ind]
    # title, text = doc['title'], doc['text'].replace('\n', '\\n')
    toks_aug = docs_toks_aug_t[i]
    txt_aug = tkz.decode(toks_aug)
    print(f'{doc_ind:03d} {txt_aug}')
    if (toks_aug == tkz.mask_token_id).sum() > 0:
        txt_tgt = tkz.decode(docs_toks_tgt_t)
        print(f'{doc_ind:03d} {txt_tgt}')

000 [CLS], covers 1, 200 square meters, incorporates the elegance of imperial garden and delicacy of south garden. now the courtyard of shi family covers about 10, 000 square meters, which is called the first mansion in north china. now it serves as the folk custom museum in yangliuqing, which has a large collection of folk custom museum in yanliuqing, which has a large collection of folk art pieces like yanliuqing new year pictures, brick sculpture. shi's ancestor came from dong'e county in shandong province, engaged in water transport of grain. as the wealth gradually accumulated, the shi family moved to yangliuqing and bought large tracts of land and set up their residence. shi yuanshi came from the fourth generation of the family, who was a successful businessman and a good household manager, and the residence was thus enlarged for several times until it acquired the present scale. it is believed to be the first mansion in the west of tianjin. the residence is symmetric based on th

In [None]:
toks_pred = model.predict(chunk_toks=docs_toks_aug_t)
txt_pred = tkz.decode(toks_pred)
print(txt_pred)

## Qna prediction

In [15]:
exclude_empty_answers = True
ques_inp = QnaQuesInp.Enc
df_sq = get_squadv2_df(exclude_empty_answers=exclude_empty_answers)
sq_inds = np.arange(len(df_sq))

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


In [22]:
ind = 1
row = df_sq.iloc[ind]
context, question, answers = row['context'], row['question'], row['answers']['text']

In [23]:
print('C:', context)
print('Q:', question)
for answer in answers:
    print('A:', answer)

C: On August 24, 2006, Apple and Creative announced a broad settlement to end their legal disputes. Apple will pay Creative US$100 million for a paid-up license, to use Creative's awarded patent in all Apple products. As part of the agreement, Apple will recoup part of its payment, if Creative is successful in licensing the patent. Creative then announced its intention to produce iPod accessories by joining the Made for iPod program.
Q: How much did Apple pay to Creative Technologies to settle their 2006 suit?
A: $100 million


In [24]:
out_toks, out_str = model.predict(context, question, max_out_toks=5)
print(out_toks)
print('M:', out_str)

Ctx emb: [CLS]
torch.Size([1, 2, 768])
question : how much did apple pay to creative technologies to settle their 2006 suit? answer : [MASK] [MASK] [MASK] [MASK] [MASK] [SEP]
tensor([101, 101, 101, 101, 101])
M: [CLS] [CLS] [CLS] [CLS] [CLS]
