In [25]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
import torch
from torch import nn
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.model.embgen_bert import EncoderEmbDecoderModel
from mllm.data.qna import get_hotpotqa
from mllm.train.embgen_bert import get_sq_batch, get_sq_df, QuesInp



# BERT Generator model inference
## Configs and paths

In [110]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

bert_model_name = 'bert-base-uncased'
random_seed = 111
inp_len = 128
train_eed_bert_path = DATA_PATH / 'train_mllm_eed_bert_qna'
eed_subdir = 'eedbert-20250315_140952-bert_base_uncased-d768-emp_f-qi_enc-chkpt_encdecbert_20250131_223521'
# eed_subdir = 'eedbert-20250315_150043-bert_base_uncased-d768-emp_f-qi_dec'

eed_train_path = train_eed_bert_path / eed_subdir
eed_snapshot_fpath = eed_train_path / 'best.pth'

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [111]:
def get_ques_inp(subdir: str) -> QuesInp:
    parts = subdir.split('-')
    for part in parts:
        if part.startswith('qi_'):
            return QuesInp(part[3:])
    raise Exception(f'Cannot find part `qi_QUESINP` where QUESINP is one of: {[qi.value for qi in QuesInp]}')

ques_inp = get_ques_inp(eed_subdir)
ques_inp

<QuesInp.Enc: 'enc'>

## Load models and dataset
### Model

In [112]:
tkz = BertTokenizer.from_pretrained(bert_model_name)
print(tkz)
enc_model: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(bert_model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
dec_model: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    bert_model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
model = EncoderEmbDecoderModel(encoder=enc_model, decoder=dec_model).to(device)

BertTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}


You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossatte

In [113]:
from transformers import BertGenerationConfig
cfg: BertGenerationConfig = enc_model.config
cfg.name_or_path

'bert-base-uncased'

In [114]:
print(f'Load {eed_snapshot_fpath}')
checkpoint = torch.load(eed_snapshot_fpath, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
model.eval()
None

Load /home/misha/data/train_mllm_eed_bert_qna/eedbert-20250315_140952-bert_base_uncased-d768-emp_f-qi_enc-chkpt_encdecbert_20250131_223521/best.pth


### Squad v2 Qna dataset

In [115]:
np.random.seed(random_seed)
# exclude_empty_answers = False
exclude_empty_answers = True
df_sq = get_sq_df(exclude_empty_answers=True)

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


## Inference

In [116]:
batch_size = 5
inds = np.arange(batch_size)
# inds += batch_size * 2
batch = get_sq_batch(tkz=tkz, df_sq=df_sq, inds=inds, inp_len=inp_len, device=device, ques_inp=ques_inp)
for ctx in batch.contexts:
    print(ctx[:400])

Contexts: [1 1 1 1 1]. (5, 128)
QAs: [30 23 24 20 22]. 119. 2889
Qs: [128 128 128 128 128]. 640. 81920
As: [9 4 2 2 2]. 19. 109
Context1. Armenia presently maintains good relations with almost every country in the world, with two major exceptions being its immediate neighbours, Turkey and Azerbaijan. Tensions were running high between Armenians and Azerbaijanis during the final years of the Soviet Union. The Nagorno-Karabakh War dominated the region's politics throughout the 1990s. The border between the two rival countrie
Context2. Every dollar ($1) that is spent on pesticides for crops yields four dollars ($4) in crops saved. This means based that, on the amount of money spent per year on pesticides, $10 billion, there is an additional $40 billion savings in crop that would be lost due to damage by insects and weeds. In general, farmers benefit from having an increase in crop yield and from being able to grow a va
Context3. Participation in the Premier League by some Scottish or Iris

In [117]:
for q, a in batch.qas:
    print(f'Q: {q}. A: {a}')

Q: Context2. Question: How is the health of the general publis affected by pesticides?. A: control of insect-borne diseases and illnesses
Q: Context5. Question: How much was the US government ordered to pay in damages?. A: $100 million
Q: Context4. Question: What year did the "discoveries are property" concept appear in French law?. A: 1791
Q: Context1. Question: Is the border between Armenia and Azerbaijan open or closed?. A: closed
Q: Context3. Question: In which year did a Premier League team consider relocating to Ireland?. A: 1998


In [118]:
ctxs_toks, other_toks = batch.gen_tensors()
ctxs_mask = (ctxs_toks > 0).to(batch.device)
ctx_enc_out: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=ctxs_toks, attention_mask=ctxs_mask)
ctx_emb = ctx_enc_out.last_hidden_state[:, 0].unsqueeze(0)
ctx_emb.shape

torch.Size([1, 5, 768])

In [119]:
qa_ind = 1

if batch.ques_inp == QuesInp.Enc:
    q_toks_l, a_toks_l, a_att_masks_l, a_tgt_masks_l = other_toks
    n_ans = len(a_toks_l)
    q_toks, a_toks, a_att_mask, a_tgt_mask = q_toks_l[qa_ind], a_toks_l[qa_ind], a_att_masks_l[qa_ind], a_tgt_masks_l[qa_ind]
    q_toks = q_toks.unsqueeze(0)
    q_mask = (q_toks > 0).to(batch.device)
    q_enc_out: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=q_toks, attention_mask=q_mask)
    q_emb = q_enc_out.last_hidden_state[:, 0].unsqueeze(0)
    ctxq_emb = torch.concatenate([ctx_emb, q_emb], dim=1)
    print(f'ctx_emb: {ctx_emb.shape}. q_emb: {q_emb.shape}. ctxq_emb: {ctxq_emb.shape}')
    a_toks = a_toks.repeat(len(a_att_mask), 1)
    # a_toks_inp = a_toks * a_att_mask
    a_toks_inp = a_toks
    a_dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
        input_ids=a_toks_inp, attention_mask=a_att_mask, encoder_hidden_states=ctxq_emb,
    )

elif batch.ques_inp == QuesInp.Dec:
    qa_toks_l, qa_att_masks_l, qa_tgt_masks_l = other_toks
    n_qas = len(qa_toks_l)
    qa_toks, qa_att_mask, qa_tgt_mask = qa_toks_l[qa_ind].unsqueeze(0), qa_att_masks_l[qa_ind], qa_tgt_masks_l[qa_ind]
    qa_toks = qa_toks.repeat(len(qa_att_mask), 1)
    qa_toks_inp = qa_toks * qa_att_mask
    dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
        input_ids=qa_toks_inp, attention_mask=qa_att_mask, encoder_hidden_states=ctx_emb
    )
    n = 0
    for i in range(qa_toks.shape[1]):
        if qa_att_mask[0, i] == 0:
            n = i
            break
    q_toks = qa_toks[0, :n + 1].clone()
    q_toks[-1] = 0

else:
    raise Exception(f'Unsupported Question input type: {ques_inp}')

print(q_toks.shape)

ctx_emb: torch.Size([1, 5, 768]). q_emb: torch.Size([1, 1, 768]). ctxq_emb: torch.Size([1, 6, 768])
torch.Size([1, 128])
torch.Size([1, 128])


In [127]:
def predict(model: EncoderEmbDecoderModel, enc_emb: torch.Tensor, toks: torch.Tensor, max_len: int = 10) -> list[int]:
    i, toks_cur, toks_out = 0, toks.tolist(), []
    inp_ids = toks.unsqueeze(0)
    while i < max_len:
        # att_mask = (inp_ids > 0).to(torch.int32)
        att_mask = torch.ones_like(inp_ids)
        dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
            input_ids=inp_ids, attention_mask=att_mask, encoder_hidden_states=enc_emb, use_cache=False,
        )
        print(dec_out.logits.shape)
        probs_pred = torch.softmax(dec_out.logits[0, -1], dim=-1)
        print(probs_pred.shape)
        tok_out = torch.argmax(probs_pred, dim=-1)
        print(tok_out.item())
        tok = tok_out.item()
        if tok == 102:
            break
        toks_cur[-1] = tok
        toks_cur.append(0)
        toks_out.append(tok)
        inp_ids = torch.tensor(toks_cur, dtype=toks.dtype, device=toks.device).unsqueeze(0)
        i += 1
    return toks_out

# print(tkz.decode(q_toks.flatten().flatten().cpu().tolist()))
if ques_inp == QuesInp.Enc:
    q_toks = tkz('Answer: ').input_ids
    q_toks = torch.tensor([*q_toks, 0], dtype=torch.int64, device=device)
    toks_out = predict(model, ctxq_emb, q_toks)
else:
    toks_out = predict(model, ctx_emb, q_toks)
print(tkz.decode(toks_out))

torch.Size([1, 5, 30522])
torch.Size([30522])
23666
torch.Size([1, 6, 30522])
torch.Size([30522])
23666
torch.Size([1, 7, 30522])
torch.Size([30522])
1580
torch.Size([1, 8, 30522])
torch.Size([30522])
29531
torch.Size([1, 9, 30522])
torch.Size([30522])
29531
torch.Size([1, 10, 30522])
torch.Size([30522])
29531
torch.Size([1, 11, 30522])
torch.Size([30522])
29531
torch.Size([1, 12, 30522])
torch.Size([30522])
29531
torch.Size([1, 13, 30522])
torch.Size([30522])
29531
torch.Size([1, 14, 30522])
torch.Size([30522])
29531
bouts bouts ™pkinspkinspkinspkinspkinspkinspkins


In [121]:
ques_inp

<QuesInp.Enc: 'enc'>

In [80]:
qa_att_mask.shape, qa_att_mask.dtype

(torch.Size([9, 30]), torch.int64)

In [39]:
a_toks_l

[tensor([ 2491,  1997, 14211,  1011, 15356,  7870,  1998, 24757,   102]),
 tensor([1002, 2531, 2454,  102]),
 tensor([14362,   102]),
 tensor([2701,  102]),
 tensor([2687,  102])]

## Tensor ops

In [138]:
np.random.seed(20)

In [144]:
n = 4
t = np.tril(np.ones((n, n), dtype=bool), k=0)
t

array([[ True, False, False, False],
       [ True,  True, False, False],
       [ True,  True,  True, False],
       [ True,  True,  True,  True]])

In [145]:
a = np.random.randint(100, size=n)
a

array([28, 90,  9, 20])

In [150]:
b = np.repeat(a[None], n, axis=0)
b

array([[28, 90,  9, 20],
       [28, 90,  9, 20],
       [28, 90,  9, 20],
       [28, 90,  9, 20]])

In [152]:
mask_token_id = 123
mask = np.eye(n, dtype=bool)
bb = np.tril(b, k=-1)
bb[mask] = mask_token_id
bb

array([[123,   0,   0,   0],
       [ 28, 123,   0,   0],
       [ 28,  90, 123,   0],
       [ 28,  90,   9, 123]])

In [153]:
at = torch.tensor(a)
at

tensor([28, 90,  9, 20])

In [156]:
atn = at.repeat(n, 1)
atn = torch.tril(atn)
atn

tensor([[28,  0,  0,  0],
        [28, 90,  0,  0],
        [28, 90,  9,  0],
        [28, 90,  9, 20]])

In [157]:
maskt = torch.tensor(mask)
maskt

tensor([[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False],
        [False, False, False,  True]])

In [158]:
atn[maskt] = tkz.mask_token_id
atn

tensor([[103,   0,   0,   0],
        [ 28, 103,   0,   0],
        [ 28,  90, 103,   0],
        [ 28,  90,   9, 103]])