In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional, cast
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer, AutoTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import ENCMIX_BERT_MODEL_CFG_FNAME
from mllm.train.utils import EedWikiIterator, QnaQuesInp, get_squadv2_df, get_squadv2_batch
from mllm.model.encmix import EncmixBert
from mllm.config.model import EncmixBertCfg, EncmixTrainDsType




# EncmixBert inference
## Config

In [None]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'

TRAIN_ENCDEC_BERT_PATH = DATA_PATH / 'train_mllm_encmix_bert'
# encdec_subdir = 'encmixbert-20250416_225131-bert-base-uncased-d768-inp256-oemb_inp-tte_t-ds_msk'
encdec_subdir = 'encmixbert-20250417_221441-bert-base-uncased-d768-inp256-oemb_inp-tte_t-ds_qna'

encmix_train_path = TRAIN_ENCDEC_BERT_PATH / encdec_subdir
encmix_snapshot_fpath = encmix_train_path / 'best.pth'
encmix_model_cfg_fpath = encmix_train_path / ENCMIX_BERT_MODEL_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

batch_size = 1

cpu


In [12]:
ds_type = EncmixTrainDsType.Msk
for part in encdec_subdir.split('-'):
    if part.startswith('ds_'):
        ds_type = EncmixTrainDsType(part[3:])
print(ds_type)

EncmixTrainDsType.Qna


## Load model

In [13]:
model_cfg = parse_yaml_file_as(EncmixBertCfg, encmix_model_cfg_fpath)
tkz = AutoTokenizer.from_pretrained(model_cfg.tokenizer_name)
tkz = cast(PreTrainedTokenizer, tkz)
print(model_cfg)
print(tkz)

inp_len=256 d_model=768 pretrained_model_name='bert-base-uncased' tokenizer_name='bert-base-uncased' out_embs_type=<EncmixOutEmbsType.Inp: 'inp'> token_types_for_embs=True
BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_wo

In [14]:
chkpt = torch.load(encmix_snapshot_fpath, map_location=device)
model = EncmixBert(cfg=model_cfg, tkz=tkz, device=device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
del chkpt
model.eval()
# print(model)
None

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.4",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}



## Wikipedia mask prediction

In [48]:
wiki_ds_name = '20200501.en'
print(f'Loading Wikipedia dataset: {wiki_ds_name}')
wiki_ds_subdir = 'wikipedia'
dss = load_dataset(wiki_ds_subdir, wiki_ds_name, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
ds = cast(Dataset, dss['train'])
n_docs = len(ds)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')

Loading Wikipedia dataset: 20200501.en


Reusing dataset wikipedia (/home/misha/data/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

Wikipedia 20200501.en docs: 6078422


In [49]:
doc_inds = np.arange(len(ds))
ds_it = EedWikiIterator(
    ds=ds, inds=doc_inds, inp_len=model_cfg.inp_len, tkz=tkz, docs_batch_size=batch_size, device=device,
    preserve_edge_tokens=True,
)

In [59]:
i_batch = 5
i1, i2 = i_batch * batch_size, (i_batch + 1) * batch_size
batch_inds = doc_inds[i1:i2]

docs_toks_aug_t, docs_toks_tgt_t = ds_it.get_batch_tokens(batch_inds)

Token indices sequence length is longer than the specified maximum sequence length for this model (2931 > 512). Running this sequence through the model will result in indexing errors


In [60]:
for i, doc_ind in enumerate(batch_inds):
    doc = ds[doc_ind.item()]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{i:03d} {text}')

000 Latin liturgical rites, or Western liturgical rites, are Catholic liturgical rites employed by the Latin Church, the largest particular church sui iuris of the Catholic Church, that originated in Europe where the Latin language once dominated. Its language is now known as Ecclesiastical Latin. The most used rite is the Roman Rite.\n\nThe Latin rites were for many centuries no less numerous than the liturgical rites of the Eastern autonomous particular Churches. Their number is now much reduced. In the aftermath of the Council of Trent, in 1568 and 1570 Pope Pius V suppressed the Breviaries and Missals that could not be shown to have an antiquity of at least two centuries (see Tridentine Mass and Roman Missal). Many local rites that remained legitimate even after this decree were abandoned voluntarily, especially in the 19th century. In the second half of the 20th century, most of the religious orders that had a distinct liturgical rite chose to adopt in its place the Roman Rite as 

In [61]:
for i, doc_ind in enumerate(batch_inds):
    # doc = ds[doc_ind]
    # title, text = doc['title'], doc['text'].replace('\n', '\\n')
    toks_aug = docs_toks_aug_t[i]
    txt_aug = tkz.decode(toks_aug)
    print(f'{doc_ind:03d} {txt_aug}')
    if (toks_aug == tkz.mask_token_id).sum() > 0:
        txt_tgt = tkz.decode(docs_toks_tgt_t)
        print(f'{doc_ind:03d} {txt_tgt}')

005 [CLS] common prayer, following the break from the roman church under the previous monarch henry viii. in the united states, under a pastoral provision in 1980, personal parishes were established that introduced adapted anglican traditions to [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] as an exception and on a case by case basis, the ordination of married former episcopal ministers as catholic priests. as personal parishes, these parishes were formerly part of the local roman catholic diocese, but accepted as members any former anglican who wished to make use of the provision. on 9 november 2009, pope benedict xvi established a worldwide provision for anglicans who joined the church. this process set up personal ordinariates for former anglicans and other persons entering the full communion of the catholic church. these ordinariates would be similar to dioceses, but encompassing entire regions or nations. parishes belongin

In [62]:
# toks_pred = model.predict(chunk_toks=docs_toks_aug_t)
toks_pred = model.predict_beam(chunk_toks=docs_toks_aug_t)
txt_pred = tkz.decode(toks_pred)
print(txt_pred)

catholic church [SEP]
church of england [SEP]
the church of england [SEP]
catholic church of england [SEP]
church of england and the church of england [SEP]
catholic church [SEP]


## Qna prediction

In [15]:
exclude_empty_answers = True
ques_inp = QnaQuesInp.Enc
df_sq = get_squadv2_df(exclude_empty_answers=exclude_empty_answers)
sq_inds = np.arange(len(df_sq))

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


In [40]:
batch_size = 1
i_batch = 3
i1, i2 = i_batch * batch_size, (i_batch + 1) * batch_size
batch_inds = sq_inds[i1:i2]
sq_batch = get_squadv2_batch(tkz=tkz, df_sq=df_sq, inds=batch_inds, inp_len=model_cfg.inp_len, device=device, ques_inp=ques_inp)
ctx_toks_t, (q_toks_t, a_toks_t, a_att_masks_t, a_tgt_masks_t) = sq_batch.gen_tensors()
print(ctx_toks_t.shape, len(q_toks_t), len(a_toks_t))


torch.Size([1, 256]) 1 1


In [43]:
for ctx_toks in ctx_toks_t:
    ctx_txt = tkz.decode(ctx_toks)
    print(ctx_txt)

[CLS] context1. villa and carranza had different political goals causing villa to become an enemy of carranza. after carranza took control in 1914, villa and other revolutionaries who opposed him met at what was called the convention of aguascalientes. the convention deposed carranza in favor of eulalio gutierrez. in the winter of 1914 villa's and zapata's troops entered and occupied mexico city. villa was forced from the city in early 1915 and attacked the forces of gen. obregon at the battle of celaya and was badly defeated in the bloodiest battle of the revolution, with thousands dead. with the defeat of villa, carranza seized power. a short time later the united states recognized carranza as president of mexico. even though villa's forces were badly depleted by his loss at celaya, he continued his fight against the carranza government. finally, in 1920, obregon — who had defeated him at celaya — finally reached an agreement with villa end his rebellion. [PAD] [PAD] [PAD] [PAD] [PAD

In [47]:
for i in range(len(q_toks_t)):
    q_toks, a_toks = q_toks_t[i], a_toks_t[i]
    q_toks = q_toks[q_toks != tkz.pad_token_id]
    # toks_pred = model.predict(chunk_toks=ctx_toks_t, plain_toks=q_toks)
    toks_pred = model.predict_beam(chunk_toks=ctx_toks_t, plain_toks=q_toks, temperature=1)
    q_txt, a_txt = tkz.decode(q_toks), tkz.decode(a_toks)
    print(f'{i:02d}. Q: {q_txt}')
    print(f'{i:02d}. A: {a_txt}')
    txt_pred = tkz.decode(toks_pred)
    print(f'{i:02d}. M: {txt_pred}')
    print('-' * 50)

Beams out: 5
napoleon iii [SEP]
napoleon ii [SEP]
napoleon i [SEP]
napoleon v [SEP]
napoleon iii of france [SEP]
00. Q: [CLS] question : villa became an enemy of whom?. answer :
00. A: carranza [SEP]
00. M: napoleon iii [SEP]
--------------------------------------------------
