In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
import torch
from torch import nn
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.model.embgen_bert import EncoderEmbDecoderModel
from mllm.data.qna import get_hotpotqa
from mllm.train.embgen_bert import get_sq_batch, get_sq_df



# BERT Generator model inference
## Configs and paths

In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

bert_model_name = 'bert-base-uncased'
random_seed = 111
inp_len = 128
train_eed_bert_path = DATA_PATH / 'train_mllm_eed_bert_qna'
# eed_subdir = 'eedbert-20250303_220645-bert_base_uncased-d768'
# eed_subdir = 'eedbert-20250306_220643-bert_base_uncased-d768'
# eed_subdir = 'eedbert-20250307_232430-bert_base_uncased-d768'
eed_subdir = 'eedbert-20250308_130751-bert_base_uncased-d768'
eed_subdir = 'eedbert-20250309_224040-bert_base_uncased-d768'

eed_train_path = train_eed_bert_path / eed_subdir
eed_snapshot_fpath = eed_train_path / 'best.pth'

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


## Load models and dataset
### Model

In [4]:
tkz = BertTokenizer.from_pretrained(bert_model_name)
print(tkz)
enc_model: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(bert_model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
dec_model: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    bert_model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
model = EncoderEmbDecoderModel(encoder=enc_model, decoder=dec_model).to(device)

BertTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}


You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossatte

In [5]:
print(f'Load {eed_snapshot_fpath}')
checkpoint = torch.load(eed_snapshot_fpath)
model.load_state_dict(checkpoint['model'], strict=True)

Load /home/misha/data/train_mllm_eed_bert_qna/eedbert-20250309_224040-bert_base_uncased-d768/best.pth


<All keys matched successfully>

### Squad v2 Qna dataset

In [6]:
np.random.seed(random_seed)
# exclude_empty_answers = False
exclude_empty_answers = True
df_sq = get_sq_df(exclude_empty_answers=True)

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


## Inference

In [24]:
batch_size = 5
inds1 = np.arange(batch_size)
inds1 += batch_size * 2
batch1 = get_sq_batch(tkz=tkz, df_sq=df_sq, inds=inds1, inp_len=inp_len, device=device)
for ctx in batch1.contexts:
    print(ctx[:400])

Contexts: [1 1 1 1 1]. (5, 128)
QAs: [18 34 19]. 71. 1841
Qs: [15 18 16]. 49. 805
As: [ 2 15  2]. 19. 233
Context1. Older than The Game by 23 years, the Harvard-Yale Regatta was the original source of the athletic rivalry between the two schools. It is held annually in June on the Thames River in eastern Connecticut. The Harvard crew is typically considered to be one of the top teams in the country in rowing. Today, Harvard fields top teams in several other sports, such as the Harvard Crimson men's ic
Context2. When aspirated consonants are doubled or geminated, the stop is held longer and then has an aspirated release. An aspirated affricate consists of a stop, fricative, and aspirated release. A doubled aspirated affricate has a longer hold in the stop portion and then has a release consisting of the fricative and aspiration.
Context3. After being lit at the birthplace of the Olympic Games in Olympia, Greece on March 24, the torch traveled to the Panathinaiko Stadium in Athens, and t

In [25]:
for q, a in batch1.qas:
    print(f'Q: {q}. A: {a}')

Q: Context5. How many Khitan Tumens were there?. A: three
Q: Context2. What happens when an aspirated consonant is doubled or geminated?. A: the stop is held longer and then has an aspirated release.
Q: Context3. Where did the Olympics originate?. A: Olympia, Greece
Q: Context1. Who is the primary rival of the Harvard Crimson hockey team?. A: Cornell
Q: Context1. Who is the primary rival of the Harvard Crimson hockey team?. A: strong rivalry against Cornell


In [26]:
qas_toks, qa_att_masks, qa_tgt_masks, ctxs_toks = batch1.gen_tensors()
ctxs_mask = (ctxs_toks > 0).to(device)
enc_out: BaseModelOutputWithPastAndCrossAttentions = model.encoder(input_ids=ctxs_toks, attention_mask=ctxs_mask)
enc_emb = enc_out.last_hidden_state[:, 0].unsqueeze(0)

In [32]:
qa_ind = 2
qa_toks, qa_att_mask, qa_tgt_mask = qas_toks[qa_ind].unsqueeze(0), qa_att_masks[qa_ind], qa_tgt_masks[qa_ind]
qa_toks = qa_toks.repeat(len(qa_att_mask), 1)
qa_toks_inp = qa_toks * qa_att_mask
dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
    input_ids=qa_toks_inp, attention_mask=qa_att_mask, encoder_hidden_states=enc_emb
)

n = 0
for i in range(qa_toks.shape[1]):
    if qa_att_mask[0, i] == 0:
        n = i
        break
q_toks = qa_toks[0, :n + 1].clone()
q_toks[-1] = 0

tkz.decode(q_toks)
# q_toks = qa_toks[0, :n].clone()


'[CLS] context1. who is the primary rival of the harvard crimson hockey team? [SEP] [PAD]'

In [33]:
def predict(model: EncoderEmbDecoderModel, enc_emb: torch.Tensor, toks: torch.Tensor, max_len: int = 10) -> list[int]:
    i, toks_cur, toks_out = 0, toks.tolist(), []
    inp_ids = toks.unsqueeze(0)
    while i < max_len:
        att_mask = inp_ids > 0
        dec_out: CausalLMOutputWithCrossAttentions = model.decoder(
            input_ids=inp_ids, attention_mask=att_mask, encoder_hidden_states=enc_emb,
        )
        print(dec_out.logits.shape)
        probs_pred = torch.softmax(dec_out.logits[0, -1], dim=-1)
        print(probs_pred.shape)
        tok_out = torch.argmax(probs_pred, dim=-1)
        print(tok_out.item())
        tok = tok_out.item()
        if tok == 102:
            break
        toks_cur[-1] = tok
        toks_cur.append(0)
        toks_out.append(tok)
        inp_ids = torch.tensor(toks_cur, dtype=toks.dtype, device=toks.device).unsqueeze(0)
        i =+ 1
    return toks_out

print(tkz.decode(q_toks.tolist()))
toks_out = predict(model, enc_emb, q_toks)
print(tkz.decode(toks_out))

[CLS] context1. who is the primary rival of the harvard crimson hockey team? [SEP] [PAD]
torch.Size([1, 18, 30522])
torch.Size([30522])
5765
torch.Size([1, 19, 30522])
torch.Size([30522])
102
harvard


In [14]:
l = [17, 22, 30, 21, 16]
l = [16, 23, 21, 26, 30]
a = np.array(l)
(a * a).sum()

2802

In [35]:
a = [1, 2, 3]
b = [-1, -2, -3]
list(zip(a, b))

[(1, -1), (2, -2), (3, -3)]