In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.config.model import GenmixBertCfg
from mllm.model.inference import BeamSearch
from mllm.exp.args import GENMIX_BERT_MODEL_CFG_FNAME
from mllm.model.genmix import GenmixBert
from mllm.train.utils import get_squadv2_df, get_squadv2_batch, QnaQuesInp
from mllm.train.encmix_bert import get_squadv2_txt_iterator



# BERT Generator model inference
## Configs and paths

In [3]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

bert_model_name = 'bert-base-uncased'
random_seed = 111
inp_len = 128
train_genmix_bert_path = DATA_PATH / 'train_mllm_genmix_bert'
genmix_subdir = 'genmixbert-20250508_224518-bert-base-uncased-d768-inp128'

genmix_train_path = train_genmix_bert_path / genmix_subdir
genmix_snapshot_fpath = genmix_train_path / 'best.pth'

device_name = 'cpu'
device_name = 'cuda'

device = torch.device(device_name)
print(device)

cuda


In [4]:
model_cfg = parse_yaml_file_as(GenmixBertCfg, genmix_train_path / GENMIX_BERT_MODEL_CFG_FNAME)
model_cfg

GenmixBertCfg(inp_len=128, d_model=768, pretrained_model_name='bert-base-uncased', tokenizer_name='bert-base-uncased')

## Load models and dataset
### Model

In [5]:
model = GenmixBert(model_cfg, device=device)
tkz = model.tkz

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossatte

In [6]:
print(f'Load {genmix_snapshot_fpath}')
checkpoint = torch.load(genmix_snapshot_fpath, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
del checkpoint
model.eval()
None

Load /home/misha/data/train_mllm_genmix_bert/genmixbert-20250508_224518-bert-base-uncased-d768-inp128/best.pth


### Squad v2 Qna dataset

In [7]:
np.random.seed(random_seed)
# exclude_empty_answers = False
exclude_empty_answers = True
df_sq = get_squadv2_df(exclude_empty_answers=True)

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


## Inference

In [8]:
def predict_beam(model: GenmixBert, enc_emb: torch.Tensor, num_beams: int = 5, max_len: int = 10,
                 temperature: float = 1) -> list[int]:
    beam_search = BeamSearch(
        num_beams=num_beams, max_len=max_len, temperature=temperature, next_token_id=tkz.cls_token_id,
        last_token_id=tkz.sep_token_id, device=device, append_next_token_id=False,
    )
    # toks_inp: [n_active_beams, beam_seq_len] -> [n_active_beams, vocab_size]
    def run_inference(beam_seq_batch: torch.Tensor) -> torch.Tensor:
        n_active_beams = beam_seq_batch.shape[0]
        dec_out: CausalLMOutputWithCrossAttentions = model(
            inputs_embeds=enc_emb, decoder_input_ids=beam_seq_batch,
        )
        return dec_out.logits[:, -1, :]

    beams = beam_search.run(run_inference)
    for beam in beams:
        print(tkz.decode(beam.tokens_cur))
    return beams[0].tokens_cur


In [20]:
i = 1
row = df_sq.iloc[i]
context, question, answers = row['context'], row['question'], row['answers']['text']
print(f'Context: {context}')
print(f'Q: {question}')
for answer in answers:
    print(f'A: {answer}')

Context: Every dollar ($1) that is spent on pesticides for crops yields four dollars ($4) in crops saved. This means based that, on the amount of money spent per year on pesticides, $10 billion, there is an additional $40 billion savings in crop that would be lost due to damage by insects and weeds. In general, farmers benefit from having an increase in crop yield and from being able to grow a variety of crops throughout the year. Consumers of agricultural products also benefit from being able to afford the vast quantities of produce available year-round. The general public also benefits from the use of pesticides for the control of insect-borne diseases and illnesses, such as malaria. The use of pesticides creates a large job market within the agrichemical sector.
Q: How is the health of the general publis affected by pesticides?
A: control of insect-borne diseases and illnesses


In [27]:
# [1, n_cq, d_model]
emb = model.context_question_to_emb(context, question)
target_ids = torch.tensor([[tkz.cls_token_id]], device=device)
target_ids = torch.tensor([[2491]], device=device)
gen_out: Seq2SeqLMOutput = model.gen(inputs_embeds=emb, decoder_input_ids=target_ids)
# [1, tgt_len, n_vocab]
gen_logits = gen_out.logits

# [tgt_len, n_vocab]
logits = gen_logits.view(-1, model.gen.decoder.config.vocab_size)
probs = torch.softmax(logits[-1], dim=-1)
out_tok = torch.argmax(probs)
print(out_tok)


tensor(2491, device='cuda:0')


In [25]:
tkz(answer)

{'input_ids': [101, 2491, 1997, 14211, 1011, 15356, 7870, 1998, 24757, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [23]:
out_toks = model.gen_on_qna_txt(context, question)
out_ans = tkz.decode(out_toks.flatten())
print(out_ans)



[CLS] [SEP] t [SEP] t [SEP] t [SEP] t [SEP] t [SEP] t [SEP] t [SEP] t [SEP] t [SEP]


In [24]:
enc_emb = model.gen_emb_on_qna_txt(context=context, question=question)
out_toks = predict_beam(model.gen, enc_emb)
out_str = tkz.decode(out_toks)
print(out_str)

[CLS] [SEP]
[CLS] " [SEP]
[CLS]s [SEP]
[CLS] the [SEP]
[CLS] of [SEP]
[CLS] [SEP]
