In [4]:
%load_ext autoreload
%autoreload 2

In [19]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import re
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.config.model import GenmixTrainDsType, TokensAggType, GenmixembBertCfg, copy_override_genmixemb_bert_cfg, \
    gen_prefpostfix_genmixemb_bert
from mllm.exp.args import GENMIXEMB_BERT_MODEL_CFG_FNAME, create_bool_str_field, is_arg_true
from mllm.model.genmixemb_bert import GenmixembBert
from mllm.train.mask_utils import MaskCfg
from mllm.train.utils import find_create_train_path, log_weights_grads_stats, SumTuple, QnaTuple
from mllm.data.wiki.itwiki import WikiItem, get_wiki_batch_iterators, WikiBatch
from mllm.utils.utils import rethrow


# BERT Generator model inference
## Configs and paths

In [58]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

bert_model_name = 'bert-base-uncased'
random_seed = 111
train_genmixemb_bert_path = DATA_PATH / 'train_mllm_genmixemb_bert'
genmixemb_subdir = 'genmixemb-20250713_202718-bertbaseuncased-d768-mxo50-aggBrt-sub0-dsWki-tmax100-tragF'
genmixemb_subdir = 'genmixemb-20250715_035750-bertbaseuncased-d768-mxo50-aggBrt-sub2-dsWki-tmax100-tragT'
genmixemb_subdir = 'genmixemb-20250716_213252-bertbaseuncased-d768-mxo50-aggBrt-sub0-dsWki-tmax100-msk_sep_0.5/0.15_seq_0.5/0.2/20-tragF'

genmixemb_train_path = train_genmixemb_bert_path / genmixemb_subdir
genmixemb_snapshot_fpath = genmixemb_train_path / 'best.pth'

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

parts = genmixemb_subdir.split('-')
n_toks_max = 0
mask_cfg = None
for part in parts:
    if part.startswith('tmax'):
        n_toks_max = int(part[4:])
    elif part.startswith('msk_'):
        subparts = part.split('_')
        # postfix_parts.append(f'msk_sep_{sep_freq}/{sep_frac}_seq_{seq_freq}/{seq_max_frac}/{mask_cfg.seq_max_len}')
        sep_part, seq_part = subparts[2], subparts[4]
        sep_freq, sep_frac = sep_part.split('/')
        seq_freq, seq_max_frac, seq_max_len = seq_part.split('/')
        sep_freq, sep_frac = float(sep_freq), float(sep_frac)
        seq_freq, seq_max_frac, seq_max_len = float(seq_freq), float(seq_max_frac), int(seq_max_len)
        mask_cfg = MaskCfg(
            sep_freq=sep_freq, sep_frac=sep_frac, seq_freq=seq_freq, seq_max_frac=seq_max_frac,
            seq_max_len=seq_max_len,
        )

print(f'n_toks_max = {n_toks_max}')
print(mask_cfg)

batch_size = 5


cpu
n_toks_max = 100
MaskCfg(sep_freq=0.5, sep_frac=0.15, seq_freq=0.5, seq_max_frac=0.2, seq_max_len=20)


In [59]:
model_cfg = parse_yaml_file_as(GenmixembBertCfg, genmixemb_train_path / GENMIXEMB_BERT_MODEL_CFG_FNAME)
model_cfg

GenmixembBertCfg(bert_model_name='bert-base-uncased', d_model=768, max_out_toks=50, toks_agg_type=<TokensAggType.Bert: 'brt'>, bert_agg_n_subseq_toks=0, pyr_agg_n_levels=0, pyr_agg_n_layers_per_level=0, train_agg_model=False)

## Load models and dataset
### Model

In [60]:
model = GenmixembBert(model_cfg, device=device)
tkz = model.tkz

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossatte

In [61]:
print(f'Load {genmixemb_snapshot_fpath}')
checkpoint = torch.load(genmixemb_snapshot_fpath, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
del checkpoint
model.eval()
None

Load /home/misha/data/train_mllm_genmixemb_bert/genmixemb-20250716_213252-bertbaseuncased-d768-mxo50-aggBrt-sub0-dsWki-tmax100-msk_sep_0.5/0.15_seq_0.5/0.2/20-tragF/best.pth


## Wiki dataset
### Loading

In [18]:
wiki_ds_name, wiki_ds_subdir = '20220301.en', 'wikipedia'
dss_wiki = load_dataset(wiki_ds_subdir, wiki_ds_name, cache_dir=str(DATA_PATH))
ds_wiki = dss_wiki['train']
n_docs = len(ds_wiki)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

Wikipedia 20220301.en docs: 6458670


In [62]:
def get_wiki_batch(i_batch: int, batch_size: int = batch_size) -> WikiBatch:
    i1 = i_batch * batch_size
    i2 = i1 + batch_size
    items = []
    for i in range(i1, i2):
        row = ds_wiki[i]
        wiki_item = WikiItem(
            tkz=tkz, ind=i, title=row['title'], text=row['text'], max_len=n_toks_max, mask_cfg=mask_cfg,
        )
        items.append(wiki_item)
    batch = WikiBatch(items=items, device=device)
    return batch

### Wiki inference

In [63]:
i_batch = 0
batch = get_wiki_batch(i_batch)
# [n_batch, max_len]
b_toks, b_masked_toks, b_mask = batch.get_tensors()

Token indices sequence length is longer than the specified maximum sequence length for this model (8349 > 512). Running this sequence through the model will result in indexing errors


In [66]:
# with_mask = False
with_mask = True
for i in range(batch_size):
    toks, masked_toks, mask = b_toks[i], b_masked_toks[i], b_mask[i]
    toks_inp = masked_toks if with_mask else toks
    toks_out = model.gen_on_wiki(toks=toks_inp)
    toks_out = toks_out.squeeze(0)
    s_inp = tkz.decode(toks_inp)
    s_out = tkz.decode(toks_out)
    if with_mask:
        s_src = tkz.decode(toks)
        print(f'{i:03d}. {s_src}')
    print(f'{i:03d}. {s_inp}')
    print(f'{i:03d}. {s_out}')
    # break

000. anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. as a historically left - wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian marxism as the libertarian wing ( libertarian socialism ) of the socialist movement, and has a strong
000. anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] to be unnecessary, undesirable, and harmful. as a historically left - wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian marxism as the libertarian 

In [67]:
(b_masked_toks == tkz.mask_token_id).sum()

tensor(100)