In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional, cast
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer, AutoTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import ENCMIX_BERT_MODEL_CFG_FNAME
from mllm.train.utils import EedWikiIterator, QnaQuesInp, get_squadv2_df, get_squadv2_batch
from mllm.model.encmix import EncmixBert
from mllm.config.model import EncmixBertCfg, EncmixTrainDsType




# EncmixBert inference
## Config

In [17]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'

TRAIN_ENCDEC_BERT_PATH = DATA_PATH / 'train_mllm_encmix_bert'
encdec_subdir = 'encmixbert-20250416_225131-bert-base-uncased-d768-inp256-oemb_inp-tte_t-ds_msk'
# encdec_subdir = 'encmixbert-20250410_161455-bert-base-uncased-d768-inp128-oemb_new-ds_qna'
# encdec_subdir = 'encmixbert-20250411_100815-bert-base-uncased-d768-inp256-oemb_new-ds_qna'
# encdec_subdir = 'encmixbert-20250411_231241-bert-large-uncased-d1024-inp256-oemb_new-ds_qna'

encmix_train_path = TRAIN_ENCDEC_BERT_PATH / encdec_subdir
encmix_snapshot_fpath = encmix_train_path / 'best.pth'
encmix_model_cfg_fpath = encmix_train_path / ENCMIX_BERT_MODEL_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

batch_size = 1

cpu


In [18]:
ds_type = EncmixTrainDsType.Msk
for part in encdec_subdir.split('-'):
    if part.startswith('ds_'):
        ds_type = EncmixTrainDsType(part[3:])
print(ds_type)

EncmixTrainDsType.Msk


## Load model

In [19]:
model_cfg = parse_yaml_file_as(EncmixBertCfg, encmix_model_cfg_fpath)
tkz = AutoTokenizer.from_pretrained(model_cfg.tokenizer_name)
tkz = cast(PreTrainedTokenizer, tkz)
print(model_cfg)
print(tkz)

inp_len=256 d_model=768 pretrained_model_name='bert-base-uncased' tokenizer_name='bert-base-uncased' out_embs_type=<EncmixOutEmbsType.Inp: 'inp'> token_types_for_embs=True
BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_wo

In [20]:
chkpt = torch.load(encmix_snapshot_fpath, map_location=device)
model = EncmixBert(cfg=model_cfg, tkz=tkz, device=device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
del chkpt
model.eval()
# print(model)
None

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.4",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}



## Wikipedia mask prediction

In [21]:
wiki_ds_name = '20200501.en'
print(f'Loading Wikipedia dataset: {wiki_ds_name}')
wiki_ds_subdir = 'wikipedia'
dss = load_dataset(wiki_ds_subdir, wiki_ds_name, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
ds = cast(Dataset, dss['train'])
n_docs = len(ds)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')

Loading Wikipedia dataset: 20200501.en


Reusing dataset wikipedia (/home/misha/data/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

Wikipedia 20200501.en docs: 6078422


In [22]:
doc_inds = np.arange(len(ds))
ds_it = EedWikiIterator(
    ds=ds, inds=doc_inds, inp_len=model_cfg.inp_len, tkz=tkz, docs_batch_size=batch_size, device=device,
    preserve_edge_tokens=True,
)

In [37]:
i_batch = 2
i1, i2 = i_batch * batch_size, (i_batch + 1) * batch_size
batch_inds = doc_inds[i1:i2]

docs_toks_aug_t, docs_toks_tgt_t = ds_it.get_batch_tokens(batch_inds)

In [34]:
for i, doc_ind in enumerate(batch_inds):
    doc = ds[doc_ind.item()]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{i:03d} {text}')

000 The St. Mary's Church is a church owned by the Church of Denmark in Sønderborg, Denmark and the church of the parish with the same name. Thanks to its location on a hill, the church building is very iconic for the city.\n\nHistory \nIn the Middle Ages there was a leper colony on a hill just outside the city. It was named after Saint George and around 1300 the chapel of this leper colony stood in the place of the present St. Mary's Church. After the old parish church of the city, the St. Nicholas Church, was demolished around 1530, the Saint-George chapel became the new main church. Towards the end of the 16th century, John II, Duke of Schleswig-Holstein-Sonderburg commissioned the enlargement of the building in order to make it suitable for the function of the parish church of his city.\n\nThe current St. Mary's Church \nIn 1595 a start was made on the partial demolition of the old church and the construction of the new church. Only parts of the old medieval church remained. From t

In [38]:
for i, doc_ind in enumerate(batch_inds):
    # doc = ds[doc_ind]
    # title, text = doc['title'], doc['text'].replace('\n', '\\n')
    toks_aug = docs_toks_aug_t[i]
    txt_aug = tkz.decode(toks_aug)
    print(f'{doc_ind:03d} {txt_aug}')
    if (toks_aug == tkz.mask_token_id).sum() > 0:
        txt_tgt = tkz.decode(docs_toks_tgt_t)
        print(f'{doc_ind:03d} {txt_tgt}')

002 [CLS], the church building is very iconic for the city. history in the middle ages there was a leper colony on a hill just outside the city. it was named after saint george and around 1300 the chapel of this leper colony stood in the place of the present st. mary'[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]. nicholas church, was demolished around 1530, the saint - george chapel became the new main church. towards the end of the 16th century, john ii, duke of schleswig - holstein - sonderburg commissioned the enlargement of the building in order to make it suitable for the function of the parish church of his city. the current st. mary's church in 1595 a start was made on the partial demolition of the old church and the construction of the new church. only parts of the old medieval church remained. from the medieval church, a medieval wooden wall cupboard dating from about 1400 remained. the solemn inauguration of the new parish c

In [39]:
toks_pred = model.predict(chunk_toks=docs_toks_aug_t)
txt_pred = tkz.decode(toks_pred)
print(txt_pred)

s church in the town of st. john, [SEP]


## Qna prediction

In [53]:
exclude_empty_answers = True
ques_inp = QnaQuesInp.Enc
df_sq = get_squadv2_df(exclude_empty_answers=exclude_empty_answers)
sq_inds = np.arange(len(df_sq))

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


In [67]:
batch_size = 1
i_batch = 3
i1, i2 = i_batch * batch_size, (i_batch + 1) * batch_size
batch_inds = sq_inds[i1:i2]
sq_batch = get_squadv2_batch(tkz=tkz, df_sq=df_sq, inds=batch_inds, inp_len=model_cfg.inp_len, device=device, ques_inp=ques_inp)
ctx_toks_t, (q_toks_t, a_toks_t, a_att_masks_t, a_tgt_masks_t) = sq_batch.gen_tensors()
print(ctx_toks_t.shape, len(q_toks_t), len(a_toks_t))


torch.Size([1, 256]) 1 1


In [68]:
for ctx_toks in ctx_toks_t:
    ctx_txt = tkz.decode(ctx_toks)
    print(ctx_txt)

[CLS] context1. further studies, e. g. jerome ravetz 1971 scientific knowledge and its social problems referred to the role of the scientific community, as a social construct, in accepting or rejecting ( objective ) scientific knowledge. the science wars of the 1990 were about the influence of especially french philosophers, which denied the objectivity of science in general or seemed to do so. they described as well differences between the idealized model of a pure science and the actual scientific practice ; while scientism, a revival of the positivism approach, saw in precise measurement and rigorous calculation the basis for finally settling enduring metaphysical and moral controversies. however, more recently some of the leading critical theorists have recognized that their postmodern deconstructions have at times been counter - productive, and are providing intellectual ammunition for reactionary interests. bruno latour noted that " dangerous extremists are using the very same ar

In [69]:
for i in range(len(q_toks_t)):
    q_toks, a_toks = q_toks_t[i], a_toks_t[i]
    q_toks = q_toks[q_toks != tkz.pad_token_id]
    toks_pred = model.predict(chunk_toks=ctx_toks_t, plain_toks=q_toks)
    q_txt, a_txt = tkz.decode(q_toks), tkz.decode(a_toks)
    print(f'{i:02d}. Q: {q_txt}')
    print(f'{i:02d}. A: {a_txt}')
    txt_pred = tkz.decode(toks_pred)
    print(f'{i:02d}. M: {txt_pred}')
    print('-' * 50)

00. Q: [CLS] question : what did scientific knowledge and its social problems describe the scientific community as?. answer :
00. A: a social construct [SEP]
00. M: social and social [SEP]
--------------------------------------------------
