In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.config.model import GenmixBertCfg
from mllm.exp.args import GENMIX_BERT_MODEL_CFG_FNAME
from mllm.model.genmix import GenmixBert
from mllm.train.utils import get_squadv2_df, get_squadv2_batch


# BERT Generator model inference
## Configs and paths

In [5]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

bert_model_name = 'bert-base-uncased'
random_seed = 111
inp_len = 128
train_genmix_bert_path = DATA_PATH / 'train_mllm_genmix_bert'
genmix_subdir = 'genmixbert-20250508_221933-bert-base-uncased-d768-inp128'

genmix_train_path = train_genmix_bert_path / genmix_subdir
genmix_snapshot_fpath = genmix_train_path / 'best.pth'

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [10]:
model_cfg = parse_yaml_file_as(GenmixBertCfg, genmix_train_path / GENMIX_BERT_MODEL_CFG_FNAME)
model_cfg


GenmixBertCfg(inp_len=128, d_model=768, pretrained_model_name='bert-base-uncased', tokenizer_name='bert-base-uncased')

## Load models and dataset
### Model

In [None]:
model = GenmixBert(model_cfg, device=device)


tkz = BertTokenizer.from_pretrained(bert_model_name)
print(tkz)
genmix_model: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(bert_model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
dec_model: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    bert_model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
model = EncoderEmbDecoderModel(
    encoder=enc_model, decoder=dec_model, enc_emb_exp_type=genmix_params.exp_type, enc_emb_exp_bias=genmix_params.exp_bias,
    enc_inp_len=inp_len, enc_inp_batch_size=genmix_params.enc_batch_size,
).to(device)

BertTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}


You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossatte

In [10]:
from transformers import BertGenerationConfig
cfg: BertGenerationConfig = enc_model.config
cfg.name_or_path

'bert-base-uncased'

In [None]:
print(f'Load {genmix_snapshot_fpath}')
checkpoint = torch.load(genmix_snapshot_fpath, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
del checkpoint
model.eval()
None

Load /home/misha/data/train_mllm_eed_bert_qna/eedbert-20250503_172002-bert_base_uncased-d768-emp_f-qi_enc-exp_emb_b-bt_11-chkpt_none/best.pth


### Squad v2 Qna dataset

In [12]:
np.random.seed(random_seed)
# exclude_empty_answers = False
exclude_empty_answers = True
df_sq = get_squadv2_df(exclude_empty_answers=True)

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


## Inference

In [34]:
batch_size = eed_params.batch_size or 5
inds = np.arange(batch_size)
inds += batch_size * 1
batch = get_squadv2_batch(tkz=tkz, df_sq=df_sq, inds=inds, inp_len=inp_len, device=device, ques_inp=eed_params.ques_inp)
for ctx in batch.contexts:
    print(ctx[:600])

Context1. Older than The Game by 23 years, the Harvard-Yale Regatta was the original source of the athletic rivalry between the two schools. It is held annually in June on the Thames River in eastern Connecticut. The Harvard crew is typically considered to be one of the top teams in the country in rowing. Today, Harvard fields top teams in several other sports, such as the Harvard Crimson men's ice hockey team (with a strong rivalry against Cornell), squash, and even recently won NCAA titles in Men's and Women's Fencing. Harvard also won the Intercollegiate Sailing Association National Champio
Context2. When aspirated consonants are doubled or geminated, the stop is held longer and then has an aspirated release. An aspirated affricate consists of a stop, fricative, and aspirated release. A doubled aspirated affricate has a longer hold in the stop portion and then has a release consisting of the fricative and aspiration.
Context3. After being lit at the birthplace of the Olympic Games i

In [14]:
for q, a in batch.qas:
    print(f'Q: {q}. A: {a}')

Q: What as the Paris Region's economy shifted towards?. A: high-value-added service industries
Q: Who is the primary rival of the Harvard Crimson hockey team?. A: Cornell
Q: What happens when an aspirated consonant is doubled or geminated?. A: the stop is held longer and then has an aspirated release.
Q: What political party is strongest in Melbourne's working class suburbs?. A: Australian Labor Party
Q: What si the teacher-student ratio for Tuvalu schools?. A: 1:18
Q: How many Khitan Tumens were there?. A: three
Q: How many Khitan Tumens were there?. A: 3
Q: What are Poes?. A: enemy ghosts
Q: Where did the Olympics originate?. A: Olympia, Greece
Q: Who is the primary rival of the Harvard Crimson hockey team?. A: strong rivalry against Cornell


In [24]:
ctxs_toks, other_toks = batch.gen_tensors()
ctxs_mask = (ctxs_toks > 0).to(batch.device)
ctx_enc_out: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=ctxs_toks, attention_mask=ctxs_mask)
ctx_lhs = ctx_enc_out.last_hidden_state


In [25]:
ctx_lhs = torch.rand_like(ctx_lhs)

In [52]:
qa_ind = 5

if batch.ques_inp == QnaQuesInp.Enc:
    q_toks_l, a_toks_l, a_att_masks_l, a_tgt_masks_l = other_toks
    n_ans = len(a_toks_l)
    q_toks, a_toks, a_att_mask, a_tgt_mask = q_toks_l[qa_ind], a_toks_l[qa_ind], a_att_masks_l[qa_ind], a_tgt_masks_l[qa_ind]
    q_toks = q_toks.unsqueeze(0)
    q_mask = (q_toks > 0).to(batch.device)
    q_enc_out: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=q_toks, attention_mask=q_mask)
    ctxq_lhs = torch.concatenate([ctx_lhs, q_enc_out.last_hidden_state], dim=0)
    ctxq_emb = model.run_expansion(ctxq_lhs)
    # ctxq_emb = model.run_expansion(ctx_lhs)
    print(f'ctx_lhs: {ctx_lhs.shape}. q_lhs: {q_enc_out.last_hidden_state.shape}. ctxq_emb: {ctxq_emb.shape}')
    # a_toks = a_toks.repeat(len(a_att_mask), 1)
    # # a_toks_inp = a_toks * a_att_mask
    # a_toks_inp = a_toks
    # a_dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
    #     input_ids=a_toks_inp, attention_mask=a_att_mask, encoder_hidden_states=ctxq_emb,
    # )

elif batch.ques_inp == QnaQuesInp.Dec:
    qa_toks_l, qa_att_masks_l, qa_tgt_masks_l = other_toks
    n_qas = len(qa_toks_l)
    qa_toks, qa_att_mask, qa_tgt_mask = qa_toks_l[qa_ind].unsqueeze(0), qa_att_masks_l[qa_ind], qa_tgt_masks_l[qa_ind]
    # qa_toks = qa_toks.repeat(len(qa_att_mask), 1)
    # qa_toks_inp = qa_toks * qa_att_mask
    # dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
    #     input_ids=qa_toks_inp, attention_mask=qa_att_mask, encoder_hidden_states=ctx_emb
    # )
    # n = 0
    # for i in range(qa_toks.shape[1]):
    #     if qa_att_mask[0, i] == 0:
    #         n = i
    #         break
    # q_toks = qa_toks[0, :n + 1].clone()
    # q_toks[-1] = 0

else:
    raise Exception(f'Unsupported Question input type: {batch.ques_inp}')


ctx_lhs: torch.Size([10, 128, 768]). q_lhs: torch.Size([1, 128, 768]). ctxq_emb: torch.Size([1, 11, 768])


In [None]:
def predict(model: EncoderEmbDecoderModel, enc_emb: torch.Tensor, toks: torch.Tensor, max_len: int = 10) -> list[int]:
    i, toks_cur, toks_out = 0, toks.tolist(), []
    inp_ids = toks.unsqueeze(0)
    while i < max_len:
        print(tkz.decode(toks_cur))
        # att_mask = (inp_ids > 0).to(torch.int32)
        dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
            input_ids=inp_ids, encoder_hidden_states=enc_emb, use_cache=False,
        )
        # print(dec_out.logits.shape)
        probs_pred = torch.softmax(dec_out.logits[0, -1], dim=-1)
        # print(probs_pred.shape)
        tok_out = torch.argmax(probs_pred, dim=-1)
        # print(tok_out.item())
        tok = tok_out.item()
        if tok == 102:
            break
        toks_cur.append(tok)
        inp_ids = torch.tensor(toks_cur, dtype=toks.dtype, device=toks.device).unsqueeze(0)

        i += 1
    return toks_cur if toks_cur[-1] != tkz.cls_token_id else toks_cur[:-1]

def predict_beam(model: EncoderEmbDecoderModel, enc_emb: torch.Tensor, toks: torch.Tensor, num_beams: int = 5, max_len: int = 10,
                 temperature: float = 1) -> list[int]:
    beam_search = BeamSearch(
        num_beams=num_beams, max_len=max_len, temperature=temperature, next_token_id=tkz.cls_token_id,
        last_token_id=tkz.sep_token_id, device=device, append_next_token_id=False,
    )
    # toks_inp: [n_active_beams, beam_seq_len] -> [n_active_beams, vocab_size]
    def run_inference(beam_seq_batch: torch.Tensor) -> torch.Tensor:
        n_active_beams = beam_seq_batch.shape[0]
        dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
            input_ids=beam_seq_batch, encoder_hidden_states=enc_emb,
        )
        return dec_out.logits[:, -1, :]

    beams = beam_search.run(run_inference)
    for beam in beams:
        print(tkz.decode(beam.tokens_cur))
    return beams[0].tokens_cur

q, a = batch.qas[qa_ind]
print(f'{q} - {a}')
if eed_params.ques_inp == QnaQuesInp.Enc:
    q_toks = [tkz.cls_token_id]
    q_toks = torch.tensor(q_toks, dtype=torch.int64, device=device)
    # toks_out = predict(model, ctxq_emb, q_toks, max_len=20)
    toks_out = predict_beam(model, ctxq_emb, q_toks, max_len=20)
else:
    raise Exception(f'{eed_params.ques_inp} qna question input type is not supported.')
    q_toks = qa_toks.squeeze().tolist()
    print(q_toks)
    for i, q_tok in enumerate(q_toks):
        # print(i, q_tok, q_tok == tkz.sep_token_id)
        if q_tok == tkz.sep_token_id:
            q_toks = q_toks[:i + 2]
            q_toks[i + 1] = tkz.cls_token_id
            break
    q_toks = torch.tensor(q_toks, dtype=torch.int64, device=device)
    print(q_toks)
    ctx_emb = model.run_expansion(ctx_lhs)
    toks_out = predict(model, ctx_emb, q_toks, max_len=20)
print(tkz.decode(toks_out))

What political party is strongest in Melbourne's working class suburbs? - Labor
[CLS] [SEP]
[CLS] one [SEP]
[CLS] five [SEP]
[CLS] third [SEP]
[CLS] - [SEP]
[CLS] [SEP]


## Tensor ops

In [5]:
inds1 = np.random.randint(0, 10, 5)
inds2 = np.random.randint(0, 10, 5)
inds1, inds2

(array([4, 0, 4, 5, 6]), array([1, 9, 6, 8, 3]))

In [6]:
np.concatenate([inds1, inds2])

array([4, 0, 4, 5, 6, 1, 9, 6, 8, 3])

In [23]:
np.random.randint(10)

3

In [144]:
n = 4
t = np.tril(np.ones((n, n), dtype=bool), k=0)
t

array([[ True, False, False, False],
       [ True,  True, False, False],
       [ True,  True,  True, False],
       [ True,  True,  True,  True]])

In [145]:
a = np.random.randint(100, size=n)
a

array([28, 90,  9, 20])

In [150]:
b = np.repeat(a[None], n, axis=0)
b

array([[28, 90,  9, 20],
       [28, 90,  9, 20],
       [28, 90,  9, 20],
       [28, 90,  9, 20]])

In [152]:
mask_token_id = 123
mask = np.eye(n, dtype=bool)
bb = np.tril(b, k=-1)
bb[mask] = mask_token_id
bb

array([[123,   0,   0,   0],
       [ 28, 123,   0,   0],
       [ 28,  90, 123,   0],
       [ 28,  90,   9, 123]])

In [30]:
at = torch.tensor([1, 7, 22, -1])
at

[autoreload of mllm.train.embgen_bert failed: Traceback (most recent call last):
  File "/home/misha/miniconda3/envs/mllm/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/misha/miniconda3/envs/mllm/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
  File "/home/misha/miniconda3/envs/mllm/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 879, in exec_module
  File "<frozen importlib._bootstrap_external>", line 1017, in get_code
  File "<frozen importlib._bootstrap_external>", line 947, in source_to_code
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/misha/prog/mllm/notebooks/../mllm/train/embgen_bert.py", line 477
    loss = qna_l

tensor([ 1,  7, 22, -1])

In [29]:
n = len(at)
atn = at.repeat(n, 1)
atn = torch.tril(atn)
atn

tensor([[ 1,  0,  0,  0],
        [ 1,  7,  0,  0],
        [ 1,  7, 22,  0],
        [ 1,  7, 22, -1]])

In [157]:
maskt = torch.tensor(mask)
maskt

tensor([[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False],
        [False, False, False,  True]])

In [158]:
atn[maskt] = tkz.mask_token_id
atn

tensor([[103,   0,   0,   0],
        [ 28, 103,   0,   0],
        [ 28,  90, 103,   0],
        [ 28,  90,   9, 103]])