In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
import torch
from torch import nn
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.model.embgen_bert import EncoderEmbDecoderModel, EncEmbExpansionType
from mllm.data.qna import get_hotpotqa
from mllm.train.embgen_bert import get_sq_batch, get_sq_df, QuesInp



# BERT Generator model inference
## Configs and paths

In [18]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

bert_model_name = 'bert-base-uncased'
random_seed = 111
inp_len = 128
train_eed_bert_path = DATA_PATH / 'train_mllm_eed_bert_qna'
eed_subdir = 'eedbert-20250316_195907-bert_base_uncased-d768-emp_f-qi_enc'
eed_subdir = 'eedbert-20250317_100519-bert_base_uncased-d768-emp_f-qi_enc-chkpt_encdecbert_20250131_223521'
eed_subdir = 'eedbert-20250317_223145-bert_base_uncased-d768-emp_f-qi_dec'
eed_subdir = 'eedbert-20250319_221739-bert_base_uncased-d768-emp_f-qi_dec-chkpt_encdecbert_20250131_223521'
# eed_subdir = 'eedbert-20250323_180203-bert_base_uncased-d768-emp_f-qi_enc-exp_mat_b-bt_6-chkpt_encdecbert_20250131_223521'

eed_train_path = train_eed_bert_path / eed_subdir
eed_snapshot_fpath = eed_train_path / 'best.pth'

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [19]:
@dataclass
class EedParams:
    ques_inp: QuesInp
    exp_type: EncEmbExpansionType
    exp_bias: bool
    batch_size: int
    enc_batch_size: int

def get_params(subdir: str) -> EedParams:
    ques_inp, exp_type, exp_bias, enc_batch_size = None, EncEmbExpansionType.Emb, False, 0
    parts = subdir.split('-')
    for part in parts:
        if part.startswith('qi_'):
            ques_inp = QuesInp(part[3:])
        elif part.startswith('exp_'):
            subparts = part.split('_')
            exp_type = EncEmbExpansionType(subparts[1])
            if len(subparts) == 3:
                assert subparts[-1] == 'b', f'"{part}" is expected to end with \'b\' when have '
                exp_bias = True
        elif part.startswith('bt_'):
            enc_batch_size = int(part[3:])
    assert ques_inp is not None, f'Cannot find part `qi_QUESINP` where QUESINP is one of: {[qi.value for qi in QuesInp]}'
    
    batch_size = enc_batch_size if ques_inp == QuesInp.Dec else enc_batch_size - 1
    return EedParams(
        ques_inp=ques_inp, exp_type=exp_type, exp_bias=exp_bias, batch_size=batch_size, enc_batch_size=enc_batch_size,
    )

eed_params = get_params(eed_subdir)
eed_params

EedParams(ques_inp=<QuesInp.Dec: 'dec'>, exp_type=<EncEmbExpansionType.Emb: 'emb'>, exp_bias=False, batch_size=0, enc_batch_size=0)

## Load models and dataset
### Model

In [20]:
tkz = BertTokenizer.from_pretrained(bert_model_name)
print(tkz)
enc_model: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(bert_model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
dec_model: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    bert_model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
model = EncoderEmbDecoderModel(
    encoder=enc_model, decoder=dec_model, enc_emb_exp_type=eed_params.exp_type, enc_emb_exp_bias=eed_params.exp_bias,
    enc_inp_len=inp_len, enc_inp_batch_size=eed_params.enc_batch_size,
).to(device)

BertTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}


You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossatte

In [21]:
from transformers import BertGenerationConfig
cfg: BertGenerationConfig = enc_model.config
cfg.name_or_path

'bert-base-uncased'

In [22]:
print(f'Load {eed_snapshot_fpath}')
checkpoint = torch.load(eed_snapshot_fpath, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
model.eval()
None

Load /home/misha/data/train_mllm_eed_bert_qna/eedbert-20250319_221739-bert_base_uncased-d768-emp_f-qi_dec-chkpt_encdecbert_20250131_223521/best.pth


### Squad v2 Qna dataset

In [23]:
np.random.seed(random_seed)
# exclude_empty_answers = False
exclude_empty_answers = True
df_sq = get_sq_df(exclude_empty_answers=True)

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


## Inference

In [24]:
batch_size = eed_params.batch_size or 5
inds = np.arange(batch_size)
inds += batch_size * 1
batch = get_sq_batch(tkz=tkz, df_sq=df_sq, inds=inds, inp_len=inp_len, device=device, ques_inp=eed_params.ques_inp)
for ctx in batch.contexts:
    print(ctx[:400])

Contexts: [1 1 1 1 1]. (5, 128)
QAs: [24 20 19 33 18]. 114. 2750
Qs: [16 15 15 27 14]. 87. 1631
As: [7 4 3 5 3]. 22. 108
Context1. Traditionally a carnival feast was the last opportunity to eat well before the time of food shortage at the end of the winter during which one was limited to the minimum necessary. On what nowadays is called vastenavond (the days before fasting) all the remaining winter stores of lard, butter and meat which were left would be eaten, for it would soon start to rot and decay. The selected 
Context2. DNA replication is for the most part extremely accurate, however errors (mutations) do occur.:7.6 The error rate in eukaryotic cells can be as low as 10−8 per nucleotide per replication, whereas for some RNA viruses it can be as high as 10−3. This means that each generation, each human genome accumulates 1–2 new mutations. Small mutations can be caused by DNA replication and the aftermat
Context3. Like other American research universities, Northwestern was transfo

In [25]:
for q, a in batch.qas:
    print(f'Q: {q}. A: {a}')

Q: Context5. Question: MEG of the brain is an abbreviation of what?. A: magnetoencephalography
Q: Context1. Question: What was one limited to during the winter?. A: the minimum necessary
Q: Context4. Question: What were small panel mosaics known as?. A: emblemata
Q: Context3. Question: Between 1939 and 1949, how many military officers and personnel were trained on the Evanston and Chicago campuses?. A: nearly 50,000
Q: Context2. Question: What can small mutations be caused by?. A: DNA replication


In [26]:
ctxs_toks, other_toks = batch.gen_tensors()
ctxs_mask = (ctxs_toks > 0).to(batch.device)
ctx_enc_out: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=ctxs_toks, attention_mask=ctxs_mask)
ctx_emb = ctx_enc_out.last_hidden_state[:, 0].unsqueeze(0)
ctx_emb.shape

torch.Size([1, 5, 768])

In [39]:
qa_ind = 3

if batch.ques_inp == QuesInp.Enc:
    q_toks_l, a_toks_l, a_att_masks_l, a_tgt_masks_l = other_toks
    n_ans = len(a_toks_l)
    q_toks, a_toks, a_att_mask, a_tgt_mask = q_toks_l[qa_ind], a_toks_l[qa_ind], a_att_masks_l[qa_ind], a_tgt_masks_l[qa_ind]
    q_toks = q_toks.unsqueeze(0)
    q_mask = (q_toks > 0).to(batch.device)
    q_enc_out: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=q_toks, attention_mask=q_mask)
    q_emb = q_enc_out.last_hidden_state[:, 0].unsqueeze(0)
    ctxq_emb = torch.concatenate([ctx_emb, q_emb], dim=1)
    print(f'ctx_emb: {ctx_emb.shape}. q_emb: {q_emb.shape}. ctxq_emb: {ctxq_emb.shape}')
    # a_toks = a_toks.repeat(len(a_att_mask), 1)
    # # a_toks_inp = a_toks * a_att_mask
    # a_toks_inp = a_toks
    # a_dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
    #     input_ids=a_toks_inp, attention_mask=a_att_mask, encoder_hidden_states=ctxq_emb,
    # )

elif batch.ques_inp == QuesInp.Dec:
    qa_toks_l, qa_att_masks_l, qa_tgt_masks_l = other_toks
    n_qas = len(qa_toks_l)
    qa_toks, qa_att_mask, qa_tgt_mask = qa_toks_l[qa_ind].unsqueeze(0), qa_att_masks_l[qa_ind], qa_tgt_masks_l[qa_ind]
    # qa_toks = qa_toks.repeat(len(qa_att_mask), 1)
    # qa_toks_inp = qa_toks * qa_att_mask
    # dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
    #     input_ids=qa_toks_inp, attention_mask=qa_att_mask, encoder_hidden_states=ctx_emb
    # )
    # n = 0
    # for i in range(qa_toks.shape[1]):
    #     if qa_att_mask[0, i] == 0:
    #         n = i
    #         break
    # q_toks = qa_toks[0, :n + 1].clone()
    # q_toks[-1] = 0

else:
    raise Exception(f'Unsupported Question input type: {ques_inp}')


In [None]:
def predict(model: EncoderEmbDecoderModel, enc_emb: torch.Tensor, toks: torch.Tensor, max_len: int = 10) -> list[int]:
    i, toks_cur, toks_out = 0, toks.tolist(), []
    inp_ids = toks.unsqueeze(0)
    while i < max_len:
        # att_mask = (inp_ids > 0).to(torch.int32)
        att_mask = torch.ones_like(inp_ids)
        dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
            input_ids=inp_ids, attention_mask=att_mask, encoder_hidden_states=enc_emb, use_cache=False,
        )
        # print(dec_out.logits.shape)
        probs_pred = torch.softmax(dec_out.logits[0, -1], dim=-1)
        # print(probs_pred.shape)
        tok_out = torch.argmax(probs_pred, dim=-1)
        # print(tok_out.item())
        tok = tok_out.item()
        if tok == 102:
            break
        toks_cur[-1] = tok
        toks_cur.append(tkz.mask_token_id)
        inp_ids = torch.tensor(toks_cur, dtype=toks.dtype, device=toks.device).unsqueeze(0)
        i += 1
    return toks_cur if toks_cur[-1] != tkz.mask_token_id else toks_cur[:-1]

q, a = batch.qas[qa_ind]
print(f'{q} - {a}')
if eed_params.ques_inp == QuesInp.Enc:
    q_toks = [tkz.mask_token_id]
    q_toks = torch.tensor(q_toks, dtype=torch.int64, device=device)
    toks_out = predict(model, ctxq_emb, q_toks, max_len=20)
else:
    q_toks = qa_toks.squeeze().tolist()
    print(q_toks)
    for i, q_tok in enumerate(q_toks):
        # print(i, q_tok, q_tok == tkz.sep_token_id)
        if q_tok == tkz.sep_token_id:
            q_toks = q_toks[:i + 2]
            q_toks[i + 1] = tkz.mask_token_id
            break
    q_toks = torch.tensor(q_toks, dtype=torch.int64, device=device)
    print(q_toks)
    toks_out = predict(model, ctx_emb, q_toks, max_len=20)
print(tkz.decode(toks_out))

Context3. Question: Between 1939 and 1949, how many military officers and personnel were trained on the Evanston and Chicago campuses? - nearly 50,000
[101, 6123, 2509, 1012, 3160, 1024, 2090, 3912, 1998, 4085, 1010, 2129, 2116, 2510, 3738, 1998, 5073, 2020, 4738, 2006, 1996, 6473, 2669, 1998, 3190, 13696, 1029, 102, 3053, 2753, 1010, 2199, 102]
tensor([  101,  6123,  2509,  1012,  3160,  1024,  2090,  3912,  1998,  4085,
         1010,  2129,  2116,  2510,  3738,  1998,  5073,  2020,  4738,  2006,
         1996,  6473,  2669,  1998,  3190, 13696,  1029,   102,   103])
[CLS] context3. question : between 1939 and 1949, how many military officers and personnel were trained on the evanston and chicago campuses? [SEP] 30, 000


## Tensor ops

In [138]:
np.random.seed(20)

In [144]:
n = 4
t = np.tril(np.ones((n, n), dtype=bool), k=0)
t

array([[ True, False, False, False],
       [ True,  True, False, False],
       [ True,  True,  True, False],
       [ True,  True,  True,  True]])

In [145]:
a = np.random.randint(100, size=n)
a

array([28, 90,  9, 20])

In [150]:
b = np.repeat(a[None], n, axis=0)
b

array([[28, 90,  9, 20],
       [28, 90,  9, 20],
       [28, 90,  9, 20],
       [28, 90,  9, 20]])

In [152]:
mask_token_id = 123
mask = np.eye(n, dtype=bool)
bb = np.tril(b, k=-1)
bb[mask] = mask_token_id
bb

array([[123,   0,   0,   0],
       [ 28, 123,   0,   0],
       [ 28,  90, 123,   0],
       [ 28,  90,   9, 123]])

In [153]:
at = torch.tensor(a)
at

tensor([28, 90,  9, 20])

In [156]:
atn = at.repeat(n, 1)
atn = torch.tril(atn)
atn

tensor([[28,  0,  0,  0],
        [28, 90,  0,  0],
        [28, 90,  9,  0],
        [28, 90,  9, 20]])

In [157]:
maskt = torch.tensor(mask)
maskt

tensor([[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False],
        [False, False, False,  True]])

In [158]:
atn[maskt] = tkz.mask_token_id
atn

tensor([[103,   0,   0,   0],
        [ 28, 103,   0,   0],
        [ 28,  90, 103,   0],
        [ 28,  90,   9, 103]])