In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
from pathlib import Path
import sys
from typing import Optional, cast
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer, AutoTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import ENCMIX_BERT_MODEL_CFG_FNAME
from mllm.train.utils import EedWikiIterator, QnaQuesInp, get_squadv2_df, get_squadv2_batch
from mllm.model.encmix import EncmixBert
from mllm.config.model import EncmixBertCfg, EncmixTrainDsType




# EncmixBert inference
## Config

In [16]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'

TRAIN_ENCDEC_BERT_PATH = DATA_PATH / 'train_mllm_encmix_bert'
encdec_subdir = 'encmixbert-20250408_115220-bert-base-uncased-d768-inp128-oemb_inp'
encdec_subdir = 'encmixbert-20250408_225609-bert-base-uncased-d768-inp128-oemb_new'
encdec_subdir = 'encmixbert-20250410_161455-bert-base-uncased-d768-inp128-oemb_new-ds_qna'
encdec_subdir = 'encmixbert-20250411_100815-bert-base-uncased-d768-inp256-oemb_new-ds_qna'

encmix_train_path = TRAIN_ENCDEC_BERT_PATH / encdec_subdir
encmix_snapshot_fpath = encmix_train_path / 'best.pth'
encmix_model_cfg_fpath = encmix_train_path / ENCMIX_BERT_MODEL_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

batch_size = 5

cpu


In [17]:
ds_type = EncmixTrainDsType.Msk
for part in encdec_subdir.split('-'):
    if part.startswith('ds_'):
        ds_type = EncmixTrainDsType(part[3:])
print(ds_type)

EncmixTrainDsType.Qna


## Load model

In [18]:
model_cfg = parse_yaml_file_as(EncmixBertCfg, encmix_model_cfg_fpath)
tkz = AutoTokenizer.from_pretrained(model_cfg.tokenizer_name)
tkz = cast(PreTrainedTokenizer, tkz)
print(model_cfg)
print(tkz)

inp_len=256 d_model=768 pretrained_model_name='bert-base-uncased' tokenizer_name='bert-base-uncased' out_embs_type=<EncmixOutEmbsType.New: 'new'>
BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False

In [19]:
chkpt = torch.load(encmix_snapshot_fpath, map_location=device)
model = EncmixBert(cfg=model_cfg, tkz=tkz, device=device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
del chkpt
model.eval()
# print(model)
None

## Wikipedia mask prediction

In [165]:
wiki_ds_name = '20200501.en'
print(f'Loading Wikipedia dataset: {wiki_ds_name}')
wiki_ds_subdir = 'wikipedia'
dss = load_dataset(wiki_ds_subdir, wiki_ds_name, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
ds = cast(Dataset, dss['train'])
n_docs = len(ds)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')

Loading Wikipedia dataset: 20200501.en


Reusing dataset wikipedia (/home/misha/data/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

Wikipedia 20200501.en docs: 6078422


In [192]:
doc_inds = np.arange(len(ds))
ds_it = EedWikiIterator(
    ds=ds, inds=doc_inds, inp_len=model_cfg.inp_len, tkz=tkz, docs_batch_size=batch_size, device=device,
    preserve_edge_tokens=True,
)

In [196]:
i_batch = 1
i1, i2 = i_batch * batch_size, (i_batch + 1) * batch_size
batch_inds = doc_inds[i1:i2]

docs_toks_aug_t, docs_toks_tgt_t = ds_it.get_batch_tokens(batch_inds)

In [194]:
for i, doc_ind in enumerate(batch_inds):
    doc = ds[doc_ind.item()]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{i:03d} {text}')

000 Latin liturgical rites, or Western liturgical rites, are Catholic liturgical rites employed by the Latin Church, the largest particular church sui iuris of the Catholic Church, that originated in Europe where the Latin language once dominated. Its language is now known as Ecclesiastical Latin. The most used rite is the Roman Rite.\n\nThe Latin rites were for many centuries no less numerous than the liturgical rites of the Eastern autonomous particular Churches. Their number is now much reduced. In the aftermath of the Council of Trent, in 1568 and 1570 Pope Pius V suppressed the Breviaries and Missals that could not be shown to have an antiquity of at least two centuries (see Tridentine Mass and Roman Missal). Many local rites that remained legitimate even after this decree were abandoned voluntarily, especially in the 19th century. In the second half of the 20th century, most of the religious orders that had a distinct liturgical rite chose to adopt in its place the Roman Rite as 

In [197]:
for i, doc_ind in enumerate(batch_inds):
    # doc = ds[doc_ind]
    # title, text = doc['title'], doc['text'].replace('\n', '\\n')
    toks_aug = docs_toks_aug_t[i]
    txt_aug = tkz.decode(toks_aug)
    print(f'{doc_ind:03d} {txt_aug}')
    if (toks_aug == tkz.mask_token_id).sum() > 0:
        txt_tgt = tkz.decode(docs_toks_tgt_t)
        print(f'{doc_ind:03d} {txt_tgt}')

005 [CLS] the establishment of the personal ordinariates, parishes in the united states were called " anglican use " and used the book of divine worship, an adaptation of the book of common prayer. the book of divine worship has been replaced with the similar divine worship : the missal for use in the ordinariates worldwide. anglican liturgical rituals, whether those used in the ordinariates of the catholic church or in the various prayer books and missals of the anglican communion and other denominations trace their origin back to the sarum use, which was a variation of the roman rite used in england before introduction during the reign of edward [SEP]
006 [CLS] 18 flowers, the pedicels 4 – 6 mm. long ; bracts ovate, long ; calyx lobes ovate, acute or obtuse, 2 – 3 mm. long ; corolla white within, greenish outside. references leon, j., h. goldbach & j. engels, 1979 : die genetischen ressourcen der kulturpflanzen zentralamerikas., int. genbank catie / gtz in turrialba, costa rica, san 

In [198]:
toks_pred = model.predict(chunk_toks=docs_toks_aug_t)
txt_pred = tkz.decode(toks_pred)
print(txt_pred)

, a roman catholic church in the city of st. john, [SEP] [MASK]


## Qna prediction

In [20]:
exclude_empty_answers = True
ques_inp = QnaQuesInp.Enc
df_sq = get_squadv2_df(exclude_empty_answers=exclude_empty_answers)
sq_inds = np.arange(len(df_sq))

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


In [44]:
batch_size = 1
i_batch = 7
i1, i2 = i_batch * batch_size, (i_batch + 1) * batch_size
batch_inds = sq_inds[i1:i2]
sq_batch = get_squadv2_batch(tkz=tkz, df_sq=df_sq, inds=batch_inds, inp_len=model_cfg.inp_len, device=device, ques_inp=ques_inp)
ctx_toks_t, (q_toks_t, a_toks_t, a_att_masks_t, a_tgt_masks_t) = sq_batch.gen_tensors()
print(ctx_toks_t.shape, len(q_toks_t), len(a_toks_t))


torch.Size([1, 256]) 1 1


In [45]:
for ctx_toks in ctx_toks_t:
    ctx_txt = tkz.decode(ctx_toks)
    print(ctx_txt)

[CLS] context1. in addition to a spoken standard and a closely related written standard, czech has several regional dialects primarily used in rural areas by speakers less proficient in other dialects or standard czech. during the second half of the twentieth century, czech dialect use began to weaken. by the early 1990s dialect use was stigmatized, associated with the shrinking lower class and used in literature or other media for comedic effect. increased travel and media availability to dialect - speaking populations has encouraged them to shift to ( or add to their own dialect ) standard czech. although czech has received considerable scholarly interest for a slavic language, this interest has focused primarily on modern standard czech and ancient texts rather than dialects. standard czech is still the norm for politicians, businesspeople and other czechs in formal situations, but common czech is gaining ground in journalism and the mass media. [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

In [46]:
for i in range(len(q_toks_t)):
    q_toks, a_toks = q_toks_t[i], a_toks_t[i]
    q_toks = q_toks[q_toks != tkz.pad_token_id]
    toks_pred = model.predict(chunk_toks=ctx_toks_t, plain_toks=q_toks)
    q_txt, a_txt = tkz.decode(q_toks), tkz.decode(a_toks)
    print(f'{i:02d}. Q: {q_txt}')
    print(f'{i:02d}. A: {a_txt}')
    txt_pred = tkz.decode(toks_pred)
    print(f'{i:02d}. M: {txt_pred}')
    print('-' * 50)

00. Q: [CLS] question : where are dialects of czech commonly found?. answer :
00. A: rural areas [SEP]
00. M: the northern european region [SEP]
--------------------------------------------------
