In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.config.model import GenmixBertCfg
from mllm.model.inference import BeamSearch
from mllm.exp.args import GENMIX_BERT_MODEL_CFG_FNAME
from mllm.model.genmix import GenmixBert
from mllm.train.utils import get_squadv2_df, get_squadv2_batch, QnaQuesInp, get_billsum_df, WordToks

# BERT Generator model inference
## Configs and paths

In [5]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

bert_model_name = 'bert-base-uncased'
random_seed = 111
inp_len = 128
train_genmix_bert_path = DATA_PATH / 'train_mllm_genmix_bert'
# genmix_subdir = 'genmixbert-20250510_112004-bert-base-uncased-d768-inp128'
# genmix_subdir = 'genmixbert-20250514_214424-bert-base-uncased-d768-inp128'
# genmix_subdir = 'genmixbert-20250515_223449-bert-base-uncased-d768-inp128-ds_sum-maxi10-maxo50'
genmix_subdir = 'genmixbert-20250517_105055-bert-base-uncased-d768-inp128-ds_sum-maxi10-maxo50'
genmix_subdir = 'genmixbert-20250624_184929-bert-base-uncased-d768-inp128-dsWki-msktgtT-msklf0.2-mskl10-mxi1-mxo50-nfem1-nsem128-emagFst-emexMat'

genmix_train_path = train_genmix_bert_path / genmix_subdir
genmix_snapshot_fpath = genmix_train_path / 'best.pth'

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [23]:
model_cfg = parse_yaml_file_as(GenmixBertCfg, genmix_train_path / GENMIX_BERT_MODEL_CFG_FNAME)
model_cfg

GenmixBertCfg(inp_len=128, d_model=768, pretrained_model_name='bert-base-uncased', tokenizer_name='bert-base-uncased', max_inp_chunks=1, max_out_toks=50, n_first_embs=1, n_second_embs=128, emb_agg_type=<GenmixEmbAggType.Fst: 'fst'>, emb_exp_type=<GenmixEmbExpType.Mat: 'mat'>)

## Load models and dataset
### Model

In [24]:
model = GenmixBert(model_cfg, device=device)
tkz = model.tkz

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossatte

In [25]:
print(f'Load {genmix_snapshot_fpath}')
checkpoint = torch.load(genmix_snapshot_fpath, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
del checkpoint
model.eval()
None

Load /home/misha/data/train_mllm_genmix_bert/genmixbert-20250624_184929-bert-base-uncased-d768-inp128-dsWki-msktgtT-msklf0.2-mskl10-mxi1-mxo50-nfem1-nsem128-emagFst-emexMat/best.pth


### Squad v2 Qna dataset

In [49]:
np.random.seed(random_seed)
# exclude_empty_answers = False
exclude_empty_answers = True
df_sq = get_squadv2_df(exclude_empty_answers=True)

README.md:   0%|          | 0.00/8.92k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/16.4M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/130319 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11873 [00:00<?, ? examples/s]

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749


## Inference

In [50]:
def predict_beam(model: GenmixBert, enc_emb: torch.Tensor, num_beams: int = 5, max_len: int = 10,
                 temperature: float = 1) -> list[int]:
    beam_search = BeamSearch(
        num_beams=num_beams, max_len=max_len, temperature=temperature, next_token_id=tkz.cls_token_id,
        last_token_id=tkz.sep_token_id, device=device, append_next_token_id=False,
    )
    # toks_inp: [n_active_beams, beam_seq_len] -> [n_active_beams, vocab_size]
    def run_inference(beam_seq_batch: torch.Tensor) -> torch.Tensor:
        n_active_beams = beam_seq_batch.shape[0]
        dec_out: CausalLMOutputWithCrossAttentions = model(
            inputs_embeds=enc_emb, decoder_input_ids=beam_seq_batch,
        )
        return dec_out.logits[:, -1, :]

    beams = beam_search.run(run_inference)
    for beam in beams:
        print(tkz.decode(beam.tokens_cur))
    return beams[0].tokens_cur


In [51]:
i = 5
row = df_sq.iloc[i]
context, question, answers = row['context'], row['question'], row['answers']['text']
print(f'Context: {context}')
print(f'Q: {question}')
for answer in answers:
    print(f'A: {answer}')

Context: Traditionally a carnival feast was the last opportunity to eat well before the time of food shortage at the end of the winter during which one was limited to the minimum necessary. On what nowadays is called vastenavond (the days before fasting) all the remaining winter stores of lard, butter and meat which were left would be eaten, for it would soon start to rot and decay. The selected livestock had in fact already been slaughtered in November and the meat would be no longer preservable. All the food that had survived the winter had to be eaten to assure that everyone was fed enough to survive until the coming spring would provide new food sources.
Q: What was one limited to during the winter?
A: the minimum necessary


In [52]:
# [1, n_cq, d_model]
emb = model.context_question_to_emb(context, question)
target_ids = torch.tensor([[tkz.cls_token_id]], device=device)
# target_ids = torch.tensor([[2491]], device=device)
gen_out: Seq2SeqLMOutput = model.gen(inputs_embeds=emb, decoder_input_ids=target_ids, use_cache=False)
# [1, tgt_len, n_vocab]
gen_logits = gen_out.logits

# [tgt_len, n_vocab]
logits = gen_logits.view(-1, model.gen.decoder.config.vocab_size)
probs = torch.softmax(logits[-1], dim=-1)
out_tok = torch.argmax(probs)
print(out_tok)
print(tkz.decode([out_tok]))


tensor(4493)
existing


In [53]:
tkz(answer)

{'input_ids': [101, 1996, 6263, 4072, 102], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}

In [54]:
out_toks = model.gen_on_qna_txt(context, question)
out_ans = tkz.decode(out_toks.flatten())
print(out_ans)

[CLS] existing law provides for the licensure and regulation of certain persons who are convicted of a crime.


In [55]:
enc_emb = model.context_question_to_emb(context=context, question=question)
out_toks = predict_beam(model.gen, enc_emb)
out_str = tkz.decode(out_toks)
print(out_str)

[CLS] table of contents : title i : amendments to
[CLS] table of contents : title i : miscellaneous provisions
[CLS] table of contents : title i : supplemental appropriations
[CLS] directs the secretary of the interior to provide for
[CLS] table of contents : title i : supplemental provisions
[CLS] table of contents : title i : amendments to


### Bilsum Summarization dataset

In [56]:
df_bs = get_billsum_df()
df_bs

Unnamed: 0,text,summary,title
0,SECTION 1. LIABILITY OF BUSINESS ENTITIES PROV...,Shields a business entity from civil liability...,A bill to limit the civil liability of busines...
1,SECTION 1. SHORT TITLE.\n\n This Act may be...,Human Rights Information Act - Requires certai...,Human Rights Information Act
2,SECTION 1. SHORT TITLE.\n\n This Act may be...,Jackie Robinson Commemorative Coin Act - Direc...,Jackie Robinson Commemorative Coin Act
3,SECTION 1. NONRECOGNITION OF GAIN WHERE ROLLOV...,Amends the Internal Revenue Code to provide (t...,To amend the Internal Revenue Code to provide ...
4,SECTION 1. SHORT TITLE.\n\n This Act may be...,Native American Energy Act - (Sec. 3) Amends t...,Native American Energy Act
...,...,...,...
1232,The people of the State of California do enact...,"Existing law, the Carpenter-Presley-Tanner Haz...",An act to amend Sections 25173.7 and 25205.6 o...
1233,The people of the State of California do enact...,(1) The Hazardous Waste Control Law authorizes...,"An act to amend Sections 25185.6, 25358.1, 253..."
1234,The people of the State of California do enact...,"Under existing law, any employer or other pers...",An act to amend Section 1197.1 of the Labor Co...
1235,The people of the State of California do enact...,Existing law requires the Director of the Depa...,An act to amend Section 14838 of the Governmen...


In [62]:
i = 10
row = df_bs.iloc[i]
text, summary, title = row['text'], row['summary'], row['title']
print('Title:', title)
print('Text:', text[:400])
print('Summary:', summary)

Title: A bill to amend title XVIII of the Social Security Act to provide coverage for kidney disease education services under the medicare program, and for other purposes.
Text: SECTION 1. SHORT TITLE.

    This Act may be cited as the ``Kidney Disease Educational Benefits 
Act of 2002''.

SEC. 2. MEDICARE COVERAGE OF KIDNEY DISEASE EDUCATION SERVICES.

    (a) Coverage of Kidney Disease Education Services.--
            (1) In general.--Section 1861 of the Social Security Act 
        (42 U.S.C. 1395x), as amended by section 105 of the Medicare, 
        Medicaid, and SC
Summary: Kidney Disease Educational Benefits Act of 2002 - Amends title XVIII (Medicare) of the Social Security Act, as amended by the Medicare, Medicaid, and SCHIP Benefits Improvement and Protection Act of 2000, to provide coverage for kidney disease education services furnished, upon the managing physician's referral, to an individual with kidney disease who will require dialysis or a kidney transplant. Requires su

In [63]:
out_toks = model.gen_on_sum_txt(text=text, title=title)
out_ans = tkz.decode(out_toks.flatten())
print(out_ans)

[CLS] medicare medicare prescription drug access act of 2003 - amends title xviii ( medicare ) of the social


In [64]:
enc_emb = model.text_title_to_emb(text=text, title=title)
out_toks = predict_beam(model.gen, enc_emb, max_len=20)
out_str = tkz.decode(out_toks)
print(out_str)

[CLS] amends title xviii ( medicare ) of the social security act to require the secretary of health
[CLS] amends title xviii ( medicare ) of the social security act ( ssa ) to provide
[CLS] amends title xviii ( medicare ) of the social security act ( ssa ) to require
[CLS] amends title xix ( medicaid ) of the social security act to require the secretary
[CLS] amends title xviii ( medicare ) of the social security act to : ( 1 ) provide
[CLS] amends title xviii ( medicare ) of the social security act to require the secretary of health


In [38]:
df_bs['text'].str.len().mean()

10243.284331698998

### Wiki dataset

In [41]:
tkz = model.tkz
max_len = 10

In [27]:
def parse_subdir(subdir: str) -> tuple[bool, float, int]:
    mask_tgt, max_tgt_len_freq, max_tgt_len = None, None, None
    parts = subdir.split('-')
    for part in parts:
        if part.startswith('msktgt'):
            c = part[-1]
            assert c in ('T', 'F')
            mask_tgt = c == 'T'
        elif part.startswith('msklf'):
            max_tgt_len_freq = float(part[5:])
        elif part.startswith('mskl'):
            max_tgt_len = int(part[4:])
    assert mask_tgt is not None
    assert max_tgt_len_freq is not None
    assert max_tgt_len is not None
    return mask_tgt, max_tgt_len_freq, max_tgt_len


mask_tgt, max_tgt_len_freq, max_tgt_len = parse_subdir(genmix_subdir)
print(f'mask_tgt = {mask_tgt}. max_tgt_len_freq = {max_tgt_len_freq}. max_tgt_len = {max_tgt_len}')

mask_tgt = True. max_tgt_len_freq = 0.2. max_tgt_len = 10


In [12]:
wiki_ds_name, wiki_ds_subdir = '20220301.en', 'wikipedia'
dss = load_dataset(wiki_ds_subdir, wiki_ds_name, cache_dir=str(DATA_PATH))
ds = dss['train']
n_docs = len(ds)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')


Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

Wikipedia 20220301.en docs: 6458670


In [43]:
i = 1
item = ds[i]
item

{'id': '25',
 'url': 'https://en.wikipedia.org/wiki/Autism',
 'title': 'Autism',
 'text': 'Autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. Parents often notice signs during the first three years of their child\'s life. These signs often develop gradually, though some autistic children experience regression in their communication and social skills after reaching developmental milestones at a normal pace.\n\nAutism is associated with a combination of genetic and environmental factors. Risk factors during pregnancy include certain infections, such as rubella, toxins including valproic acid, alcohol, cocaine, pesticides, lead, and air pollution, fetal growth restriction, and autoimmune diseases. Controversies surround other proposed environmental causes; for example, the vaccine hypothesis, which has been disproven. Autism affects information processing in the brain and how nerve cel

In [44]:
text = item['text']
len(text)

46773

In [45]:
max_toks = 0
if model.cfg.max_inp_chunks > 0:
    max_toks = (model.cfg.inp_len - 2) * model.cfg.max_inp_chunks - 17 - 14 - 5
wt = WordToks(
    tkz=tkz, s=text, max_tgt_len_freq=max_tgt_len_freq, max_tgt_len=max_tgt_len, max_toks=max_toks,
)
# tags_list_str = ', '.join(wt.tags_names)
tags_list_str = ', '.join(wt.tags_dict.values())
inp_str = wt.inp_masked_str if mask_tgt else wt.inp_str
prompt = f'Cite the text between the tags: {tags_list_str}. Text: {inp_str}'
emb = model.prompt_to_emb(prompt=prompt)
print(emb.shape)

torch.Size([1, 128, 768])


In [46]:
print(prompt)
print(wt.inp_str)
print(wt.inp_masked_str)
print(wt.tgt_str)

Cite the text between the tags: <|cite_begin|>, <|cite_end|>. Text: autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. parents often notice signs during the first three years of their child's life. these signs often < | cite _ begin | > [MASK] [MASK] [MASK] [MASK] [MASK] < | cite _ end | > autistic children experience regression in their communication and social skills after reaching developmental milestones at a normal pace. autism is associated with a combination of genetic and environmental factors. risk factors during pregnancy include certain
autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. parents often notice signs during the first three years of their child's life. these signs often < | cite _ begin | > develop gradually, though some < | cite _ end | > autistic children

In [47]:
out_toks = model.gen.generate(inputs_embeds=emb, decoder_start_token_id=tkz.cls_token_id, max_length=max_len)

out_ans = tkz.decode(out_toks.flatten())
print(out_ans)

[CLS] the first [SEP] the [SEP] the [SEP] the [SEP]
