In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
import torch
from torch import nn
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.model.embgen_bert import EncoderEmbDecoderModel
from mllm.data.qna import get_hotpotqa
from mllm.train.embgen_bert import get_sq_batch



# BERT Generator model inference
## Configs and paths

In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

bert_model_name = 'bert-base-uncased'
random_seed = 111
inp_len = 128
train_eed_bert_path = DATA_PATH / 'train_mllm_eed_bert_qna'
# eed_subdir = 'eedbert-20250303_220645-bert_base_uncased-d768'
eed_subdir = 'eedbert-20250306_220643-bert_base_uncased-d768'

eed_train_path = train_eed_bert_path / eed_subdir
eed_snapshot_fpath = eed_train_path / 'best.pth'

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


## Load models and dataset
### Model

In [4]:
tkz = BertTokenizer.from_pretrained(bert_model_name)
print(tkz)
enc_model: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(bert_model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
dec_model: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    bert_model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
model = EncoderEmbDecoderModel(encoder=enc_model, decoder=dec_model).to(device)

BertTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}


You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossatte

In [5]:
print(f'Load {eed_snapshot_fpath}')
checkpoint = torch.load(eed_snapshot_fpath)
model.load_state_dict(checkpoint['model'], strict=True)

Load /home/misha/data/train_mllm_eed_bert_qna/eedbert-20250306_220643-bert_base_uncased-d768/best.pth


<All keys matched successfully>

### Squad v2 Qna dataset

In [6]:
np.random.seed(random_seed)
ds_sq = load_dataset('squad_v2')
df_sq = pd.concat([ds_sq['train'].to_pandas(), ds_sq['validation'].to_pandas()], axis=0)
n_total = len(df_sq)
df_sq = df_sq.sample(n_total)
val_ratio = 0.05
n_val = int(n_total * val_ratio)
n_train = n_total - n_val
df_sq_t, df_sq_v = df_sq.iloc[:n_train], df_sq.iloc[n_train:]


Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

## Inference

In [7]:
batch_size = 5
inds1 = np.arange(batch_size)
inds1 += batch_size * 1
batch1 = get_sq_batch(tkz=tkz, df_sq=df_sq, inds=inds1, inp_len=inp_len, device=device)
for ctx in batch1.contexts:
    print(ctx[:200])

Context1. Traditionally a carnival feast was the last opportunity to eat well before the time of food shortage at the end of the winter during which one was limited to the minimum necessary. On what n
Context2. DNA replication is for the most part extremely accurate, however errors (mutations) do occur.:7.6 The error rate in eukaryotic cells can be as low as 10−8 per nucleotide per replication, whe
Context3. Like other American research universities, Northwestern was transformed by World War II. Franklyn B. Snyder led the university from 1939 to 1949, when nearly 50,000 military officers and per
Context4. There were two main techniques in Greco-Roman mosaic: opus vermiculatum used tiny tesserae, typically cubes of 4 millimeters or less, and was produced in workshops in relatively small panels
Context5. As a side effect of the electrochemical processes used by neurons for signaling, brain tissue generates electric fields when it is active. When large numbers of neurons show synchronized

In [8]:
for q, a in batch1.qas:
    print(f'Q: {q}. A: {a}')

Q: Context1. What was one limited to during the winter?. A: the minimum necessary
Q: Context4. What were small panel mosaics known as?. A: emblemata
Q: Context5. MEG of the brain is an abbreviation of what?. A: magnetoencephalography
Q: Context2. What can small mutations be caused by?. A: DNA replication
Q: Context3. Between 1939 and 1949, how many military officers and personnel were trained on the Evanston and Chicago campuses?. A: nearly 50,000


In [9]:
qas_toks, qa_att_masks, qa_tgt_masks, ctxs_toks = batch1.gen_tensors()
ctxs_mask = (ctxs_toks > 0).to(device)
enc_out: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=ctxs_toks, attention_mask=ctxs_mask)
enc_emb = enc_out.last_hidden_state[:, 0].unsqueeze(0)

In [13]:
qa_ind = 0
qa_toks, qa_att_mask, qa_tgt_mask = qas_toks[qa_ind].unsqueeze(0), qa_att_masks[qa_ind], qa_tgt_masks[qa_ind]
qa_toks = qa_toks.repeat(len(qa_att_mask), 1)
qa_toks_inp = qa_toks * qa_att_mask
dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
    input_ids=qa_toks_inp, attention_mask=qa_att_mask, encoder_hidden_states=enc_emb
)

n = 0
for i in range(qa_toks.shape[1]):
    if qa_att_mask[0, i] == 0:
        n = i
        break
q_toks = qa_toks[0, :n + 1].clone()
q_toks[-1] = 0

# q_toks = qa_toks[0, :n].clone()


In [14]:
def predict(model: EncoderEmbDecoderModel, enc_emb: torch.Tensor, toks: torch.Tensor, max_len: int = 10) -> list[int]:
    i, toks_cur, toks_out = 0, toks.tolist(), []
    inp_ids = toks.unsqueeze(0)
    while i < max_len:
        att_mask = inp_ids > 0
        dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
            input_ids=inp_ids, attention_mask=att_mask, encoder_hidden_states=enc_emb,
        )
        print(dec_out.logits.shape)
        probs_pred = torch.softmax(dec_out.logits[0, -1], dim=-1)
        print(probs_pred.shape)
        tok_out = torch.argmax(probs_pred, dim=-1)
        print(tok_out.item())
        tok = tok_out.item()
        if tok == 102:
            break
        toks_cur[-1] = tok
        toks_cur.append(0)
        toks_out.append(tok)
        inp_ids = torch.tensor(toks_cur, dtype=toks.dtype, device=toks.device).unsqueeze(0)
        i =+ 1
    return toks_out

print(tkz.decode(q_toks.tolist()))
toks_out = predict(model, enc_emb, q_toks)
print(tkz.decode(toks_out))

[CLS] context1. what was one limited to during the winter? [SEP] [PAD]
torch.Size([1, 15, 30522])
torch.Size([30522])
1011
torch.Size([1, 16, 30522])
torch.Size([30522])
102
-


In [46]:
type(q_toks.tolist())

list