In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import re
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.config.model import GenmixTrainDsType, TokensAggType, GenmixembCfg, copy_override_genmixemb_cfg, \
    gen_prefpostfix_genmixemb
from mllm.exp.args import GENMIXEMB_BERT_MODEL_CFG_FNAME, create_bool_str_field, is_arg_true
from mllm.model.genmixemb import Genmixemb
from mllm.train.mask_utils import MaskCfg
from mllm.train.utils import find_create_train_path, log_weights_grads_stats, SumTuple, QnaTuple
from mllm.data.itsquadv2 import get_squadv2_batch_iterators_v2, QnaBatchV2
from mllm.data.wiki.itwiki import WikiItem, get_wiki_batch_iterators, WikiBatch
from mllm.utils.utils import rethrow
from mllm.data.utils import get_squadv2_df


🚨 `emb_off` is part of GPT2Model.forward's signature, but not documented. Make sure to add it to the docstring of the function in /home/misha/prog/mllm/notebooks/../mllm/model/gpt2/modeling_gpt2.py.
🚨 `emb_off` is part of GPT2LMHeadModel.forward's signature, but not documented. Make sure to add it to the docstring of the function in /home/misha/prog/mllm/notebooks/../mllm/model/gpt2/modeling_gpt2.py.


# BERT Generator model inference
## Configs and paths

In [15]:
from mllm.config.model import SelfSuperviseType
from mllm.model import genmixemb


DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

model_name = 'bert-base-uncased'
random_seed = 111
train_genmixemb_path = DATA_PATH / 'train_mllm_genmixemb_wki'
genmixemb_subdir = 'genmixemb-20250713_202718-bertbaseuncased-d768-mxo50-aggBrt-sub0-dsWki-tmax100-tragF'
genmixemb_subdir = 'genmixemb-20250715_035750-bertbaseuncased-d768-mxo50-aggBrt-sub2-dsWki-tmax100-tragT'
genmixemb_subdir = 'genmixemb-20250716_213252-bertbaseuncased-d768-mxo50-aggBrt-sub0-dsWki-tmax100-msk_sep_0.5/0.15_seq_0.5/0.2/20-tragF'
genmixemb_subdir = 'genmixemb-20250718_220105-bertbaseuncased-d768-mxo50-aggBrt-sub2-dsWki-tmax100-tragT'
genmixemb_subdir = 'genmixemb-20250721_083250-bertbaseuncased-d768-mxo50-aggPyr-agtDecim-stp0-lvl1-lrs2-dsWki-tmax256-tragF-nxtsnt'
genmixemb_subdir = 'genmixemb-20250721_212402-bertbaseuncased-d768-mxo50-aggPyr-agtDecim-stp2-lvl2-lrs2-dsWki-tmax512-tragT-nxtsnt'
genmixemb_subdir = 'genmixemb-20250722_213424-bertbaseuncased-d768-mxo50-aggPyr-agtDecim-stp2-lvl3-lrs2-dsWki-tmax512-tragT-nxtsnt'
genmixemb_subdir = 'genmixemb-20250820_093126-bertbaseuncased-d768-mxi384-mxo50-dsWki-nxtsnt'
model_name = 'gpt2'
genmixemb_subdir = 'genmixemb-20250825_213542-gpt2-d768-mxi448-mxo36-dsWki-sstNxttok'
# genmixemb_subdir = 'genmixemb-20250904_212609-gpt2-d768-mxi1024-mxo256-aggCnv-lvl3-lrs1-cksz3-pksz2-pst2-shlF-dsWki-tragT-sstNxttok'

train_genmixemb_path = DATA_PATH / 'train_mllm_genmixemb_qna'
# genmixemb_subdir = 'genmixemb-20250726_122548-bertbaseuncased-d768-mxi384-mxo50-dsQna'
# genmixemb_subdir = 'genmixemb-20250810_125920-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggBrt-sub2-agtTopdot-dsQna-tragT-shemT-ttidF-jcqF'
# genmixemb_subdir = 'genmixemb-20250914_112502-gpt2-d768-dp0.1-mxi512-mxo128-dsQna-ttidF'
genmixemb_subdir = 'genmixemb-20251013_212430-pre_encdecbert20251004224422-bertbaseuncased-d768-mdlDec-dp0.1-mxi1024-mxo50-aggBrt-sub128-agtSep-dsQna-tragT-shemT-ttidF-cqprCq'


genmixemb_train_path = train_genmixemb_path / genmixemb_subdir
genmixemb_snapshot_fpath = genmixemb_train_path / 'best.pth'
# genmixemb_snapshot_fpath = genmixemb_train_path / 'last.pth'

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print('Device:', device)

parts = genmixemb_subdir.split('-')
n_toks_max = 0
mask_cfg = None
pred_next_sent = False
self_supervise_type = SelfSuperviseType.Input
for part in parts:
    if part.startswith('mxi'):
        n_toks_max = int(part[3:])
    elif part.startswith('msk_'):
        subparts = part.split('_')
        # postfix_parts.append(f'msk_sep_{sep_freq}/{sep_frac}_seq_{seq_freq}/{seq_max_frac}/{mask_cfg.seq_max_len}')
        sep_part, seq_part = subparts[2], subparts[4]
        sep_freq, sep_frac = sep_part.split('/')
        seq_freq, seq_max_frac, seq_max_len = seq_part.split('/')
        sep_freq, sep_frac = float(sep_freq), float(sep_frac)
        seq_freq, seq_max_frac, seq_max_len = float(seq_freq), float(seq_max_frac), int(seq_max_len)
        mask_cfg = MaskCfg(
            sep_freq=sep_freq, sep_frac=sep_frac, seq_freq=seq_freq, seq_max_frac=seq_max_frac,
            seq_max_len=seq_max_len,
        )
    elif part.startswith('sst'):
        self_supervise_type = SelfSuperviseType(part[3:].lower())
        

print(f'n_toks_max = {n_toks_max}')
print('Mask cfg:', mask_cfg)
print(f'pred_next_sent = {pred_next_sent}')
print(f'self_supervise_type = {self_supervise_type}')

batch_size = 5


Device: cpu
n_toks_max = 1024
Mask cfg: None
pred_next_sent = False
self_supervise_type = inp


In [16]:
model_cfg = parse_yaml_file_as(GenmixembCfg, genmixemb_train_path / GENMIXEMB_BERT_MODEL_CFG_FNAME)
pprint(model_cfg.dict())

{'add_token_type_ids': False,
 'bert_agg_model_name': 'bert-base-uncased',
 'bert_agg_n_subseq_toks': 128,
 'bert_agg_type': <BertAggType.Sep: 'sep'>,
 'bert_attention_prob_dropout_prob': 0.1,
 'bert_hidden_dropout_prob': 0.1,
 'bert_model_type': <BertModelType.Dec: 'dec'>,
 'cnv_conv_kernel_size': 3,
 'cnv_n_layers_per_level': 1,
 'cnv_n_levels': 6,
 'cnv_pool_kernel_size': 2,
 'cnv_pool_stride': 2,
 'cnv_share_layer_weights': True,
 'ctx_que_prompt_type': <CtxQuePromptType.Cq: 'cq'>,
 'd_model': 768,
 'dec_expert_type': <DecExpertType.Non: 'non'>,
 'dp_prob': 0.1,
 'gpt2_attn_pdrop': 0.1,
 'gpt2_embd_pdrop': 0.1,
 'gpt2_resid_pdrop': 0.1,
 'join_ctx_que_agg': False,
 'max_inp_toks': 1024,
 'max_out_toks': 50,
 'model_name': 'bert-base-uncased',
 'moe_experts_num': 0,
 'moe_topk': 0,
 'pyr_agg_n_layers_per_level': 2,
 'pyr_agg_n_levels': 2,
 'pyr_agg_step': 2,
 'pyr_agg_type': <HgReductType.MaxPool: 'mxpl'>,
 'pyr_share_layer_weights': True,
 'share_agg_enc_token_embeds': True,
 'toks

## Load models and dataset
### Model

In [17]:
model = Genmixemb(model_cfg, device=device)
tkz = model.tkz

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['lm_head.bias', 'lm_head.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
print(f'Load {genmixemb_snapshot_fpath}')
checkpoint = torch.load(genmixemb_snapshot_fpath, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
del checkpoint
model.eval()
None

Load /home/misha/data/train_mllm_genmixemb_qna/genmixemb-20251013_212430-pre_encdecbert20251004224422-bertbaseuncased-d768-mdlDec-dp0.1-mxi1024-mxo50-aggBrt-sub128-agtSep-dsQna-tragT-shemT-ttidF-cqprCq/best.pth


## Wiki dataset
### Loading

In [7]:
wiki_ds_name, wiki_ds_subdir = '20220301.en', 'wikipedia'
dss_wiki = load_dataset(wiki_ds_subdir, wiki_ds_name, cache_dir=str(DATA_PATH))
ds_wiki = dss_wiki['train']
n_docs = len(ds_wiki)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')

Using the latest cached version of the dataset since wikipedia couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration '20220301.en' at /home/misha/data/wikipedia/20220301.en/2.0.0/d41137e149b2ea90eead07e7e3f805119a8c22dd1d5b61651af8e3e3ee736001 (last modified on Sat Jun 21 20:19:29 2025).


Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

Wikipedia 20220301.en docs: 6458670


In [8]:
def get_wiki_batch(i_batch: int, batch_size: int = batch_size) -> WikiBatch:
    i1 = i_batch * batch_size
    i2 = i1 + batch_size
    items = []
    for i in range(i1, i2):
        row = ds_wiki[i]
        wiki_item = WikiItem(
            tkz=tkz, ind=i, title=row['title'], text=row['text'], max_len=n_toks_max, mask_cfg=mask_cfg,
            self_supervise_type=self_supervise_type, max_pred_len=model_cfg.max_out_toks,
        )
        items.append(wiki_item)
    batch = WikiBatch(items=items, device=device)
    return batch

### Wiki inference

In [9]:
i_batch = 0
batch = get_wiki_batch(i_batch)
# [n_batch, max_len]
b_toks, b_masked_toks, b_mask, b_tgt_toks = batch.get_tensors()

Token indices sequence length is longer than the specified maximum sequence length for this model (8287 > 1024). Running this sequence through the model will result in indexing errors


In [10]:
with_mask = False
# with_mask = True
for i in range(batch_size):
    toks, masked_toks, mask = b_toks[i], b_masked_toks[i], b_mask[i]
    tgt_toks = None
    if b_tgt_toks is not None:
        tgt_toks = b_tgt_toks[i]
    toks_inp = masked_toks if with_mask else toks
    # toks_inp = toks_inp[:100]
    toks_out = model.gen_on_wiki(toks=toks_inp)
    toks_out = toks_out.squeeze(0)
    # toks_out = toks_out[len(toks_inp)-2:]  # remove prompt
    s_inp = tkz.decode(toks_inp)
    s_out = tkz.decode(toks_out)
    print(f'toks_inp: {len(toks_inp)} toks_out: {len(toks_out)}')
    if with_mask:
        s_src = tkz.decode(toks)
        print(f'{i:03d}. Msk: {s_src}')
    s_inp = s_inp.replace('\n', '\\n')
    print(f'{i:03d}. Inp: {s_inp}')
    if tgt_toks is not None:
        s_tgt = tkz.decode(tgt_toks)
        s_tgt = s_tgt.replace('\n', '\\n')
        print(f'{i:03d}. Tgt: {s_tgt}')
    s_out = s_out.replace('\n', '\\n')
    print(f'{i:03d}. Out: {s_out}')
    # break

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


toks_inp: 448 toks_out: 484
000. Inp:  role. Anarchists have employed various methods in order to build a rough consensus among members of their group without the need of a leader or a leading group. One way is for an individual from the group to play the role of facilitator to help achieve a consensus without taking part in the discussion themselves or promoting a specific point. Minorities usually accept rough consensus, except when they feel the proposal contradicts anarchist ethics, goals and values. Anarchists usually form small groups (5–20 individuals) to enhance autonomy and friendships among their members. These kinds of groups more often than not interconnect with each other, forming larger networks. Anarchists still support and participate in strikes, especially wildcat strikes as these are leaderless strikes not organised centrally by a syndicate.\n\nAs in the past, newspapers and journals are used, and anarchists have gone online in the World Wide Web to spread their messa

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


toks_inp: 448 toks_out: 484
001. Inp:  high functioning individuals who may have active but distinctly odd social approaches, narrowly focused interests, and verbose, pedantic communication. Because the behavior spectrum is continuous, boundaries between diagnostic categories are necessarily somewhat arbitrary.\n\nScreening\nAbout half of parents of children with ASD notice their child's unusual behaviors by age 18 months, and about four-fifths notice by age 24 months. According to an article, failure to meet any of the following milestones "is an absolute indication to proceed with further evaluations. Delay in referral for such testing may delay early diagnosis and treatment and affect the long-term outcome".\n No response to name (or eye-to-eye gaze) by 6 months. \n No babbling by 12 months.\n No gesturing (pointing, waving, etc.) by 12 months.\n No single words by 16 months.\n No two-word (spontaneous, not just echolalic) phrases by 24 months.\n Loss of any language or social skill

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


toks_inp: 448 toks_out: 484
002. Inp: being water or ground which is darker color) and reflects less heat back into space. This feedback loop results in a reduced albedo effect.\n\nClimate and weather\nAlbedo affects climate by determining how much radiation a planet absorbs. The uneven heating of Earth from albedo variations between land, ice, or ocean surfaces can drive weather.\n\nAlbedo–temperature feedback\nWhen an area's albedo changes due to snowfall, a snow–temperature feedback results. A layer of snowfall increases local albedo, reflecting away sunlight, leading to local cooling. In principle, if no outside temperature change affects this area (e.g., a warm air mass), the raised albedo and lower temperature would maintain the current snow and invite further snowfall, deepening the snow–temperature feedback. However, because local weather is dynamic due to the change of seasons, eventually warm air masses and a more direct angle of sunlight (higher insolation) cause melting. Wh

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


toks_inp: 448 toks_out: 484
003. Inp:  The lowercase version can be written in two forms: the double-storey a and single-storey ɑ. The latter is commonly used in handwriting and fonts based on it, especially fonts intended to be read by children, and is also found in italic type.\n\nIn the English grammar, "a", and its variant "an", are indefinite articles.\n\nHistory\n\nThe earliest certain ancestor of "A" is aleph (also written 'aleph), the first letter of the Phoenician alphabet, which consisted entirely of consonants (for that reason, it is also called an abjad to distinguish it from a true alphabet). In turn, the ancestor of aleph may have been a pictogram of an ox head in proto-Sinaitic script influenced by Egyptian hieroglyphs, styled as a triangular head with two horns extended.\n\nWhen the ancient Greeks adopted the alphabet, they had no use for a letter to represent the glottal stop—the consonant sound that the letter denoted in Phoenician and other Semitic languages, and tha

In [11]:
from torch import mode, norm
def norm_str(s: str) -> str:
    s = s.replace('\n', ' ').replace('\r', ' ')
    return s

for i in range(batch_size):
    item = batch.items[i]
    toks_in = item.src_toks[:model_cfg.max_inp_toks]
    s_inp = tkz.decode(toks_in)
    print(f'{i:03d}. Inp: {norm_str(s_inp)}')
    tooks_tgt = item.src_toks[model_cfg.max_inp_toks:model_cfg.max_inp_toks + model_cfg.max_out_toks]
    s_tgt = tkz.decode(tooks_tgt)
    print(f'{i:03d}. Tgt: {norm_str(s_tgt)}')
    toks_in = torch.tensor(toks_in, dtype=torch.long, device=device)
    toks_out = model.gen_on_wiki(toks_in)
    toks_out = toks_out.squeeze()
    # toks_out = toks_out[len(toks_in)-2:]
    s_out = tkz.decode(toks_out)
    print(f'{i:03d}. Out: {norm_str(s_out)}')


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


000. Inp: Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.  Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires. With the rise of organised hierarchical bodies, scepticism toward authority also rose. Although traces of anarchist thought are found throughout history, modern anarchism emerged from the Enlightenment. During the latter half of the 19th and the first decades of the 20th century, the anarchist movement flou

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


000. Out: Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.  Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires. With the rise of organised hierarchical bodies, scepticism toward authority also rose. Although traces of anarchist thought are found throughout history, modern anarchism emerged from the Enlightenment. During the latter half of the 19th and the first decades of the 20th century, the anarchist movement flou

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


001. Out: Autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. Parents often notice signs during the first three years of their child's life. These signs often develop gradually, though some autistic children experience regression in their communication and social skills after reaching developmental milestones at a normal pace.  Autism is associated with a combination of genetic and environmental factors. Risk factors during pregnancy include certain infections, such as rubella, toxins including valproic acid, alcohol, cocaine, pesticides, lead, and air pollution, fetal growth restriction, and autoimmune diseases. Controversies surround other proposed environmental causes; for example, the vaccine hypothesis, which has been disproven. Autism affects information processing in the brain and how nerve cells and their synapses connect and organize; how this occurs is not well understood. 

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


002. Out: Albedo (; ) is the measure of the diffuse reflection of solar radiation out of the total solar radiation and measured on a scale from 0, corresponding to a black body that absorbs all incident radiation, to 1, corresponding to a body that reflects all incident radiation.  Surface albedo is defined as the ratio of radiosity Je to the irradiance Ee (flux per unit area) received by a surface. The proportion reflected is not only determined by properties of the surface itself, but also by the spectral and angular distribution of solar radiation reaching the Earth's surface. These factors vary with atmospheric composition, geographic location, and time (see position of the Sun). While bi-hemispherical reflectance is calculated for a single angle of incidence (i.e., for a given position of the Sun), albedo is the directional integration of reflectance over all solar angles in a given period. The temporal resolution may range from seconds (as obtained from flux measurements) to dail

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


003. Out: A, or a, is the first letter and the first vowel of the modern English alphabet and the ISO basic Latin alphabet. Its name in English is a (pronounced ), plural aes. It is similar in shape to the Ancient Greek letter alpha, from which it derives. The uppercase version consists of the two slanting sides of a triangle, crossed in the middle by a horizontal bar. The lowercase version can be written in two forms: the double-storey a and single-storey ɑ. The latter is commonly used in handwriting and fonts based on it, especially fonts intended to be read by children, and is also found in italic type.  In the English grammar, "a", and its variant "an", are indefinite articles.  History  The earliest certain ancestor of "A" is aleph (also written 'aleph), the first letter of the Phoenician alphabet, which consisted entirely of consonants (for that reason, it is also called an abjad to distinguish it from a true alphabet). In turn, the ancestor of aleph may have been a pictogram of 

### Custom text

In [14]:
context = '''
Lithuania,[b] officially the Republic of Lithuania,[c] is a country in the Baltic region of Europe.[d] It is one of three Baltic states and lies on the eastern shore of the Baltic Sea, bordered by Latvia to the north, Belarus to the east and south, Poland to the south, and the Russian semi-exclave of Kaliningrad Oblast to the southwest, with a maritime border with Sweden to the west. Lithuania covers an area of 65,300 km2 (25,200 sq mi), with a population of 2.89 million. Its capital and largest city is Vilnius; other major cities include Kaunas, Klaipėda, Šiauliai and Panevėžys. Lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of Balts and speak Lithuanian. For millennia, the southeastern shores of the Baltic Sea were inhabited by various Baltic tribes. In the 1230s, Lithuanian lands were united for the first time by Mindaugas, who formed the Kingdom of Lithuania on 6 July 1253. Subsequent expansion and consolidation resulted in the Grand Duchy of Lithuania, which by the 14th century was the largest country in Europe. In 1386, the Grand Duchy entered into a de facto personal union with the Crown of the Kingdom of Poland. The two realms were united into the bi-confederal Polish-Lithuanian Commonwealth in 1569, forming one of the largest and most prosperous states in Europe. The Commonwealth lasted more than two centuries, until neighbouring countries gradually dismantled it between 1772 and 1795, with the Russian Empire annexing most of Lithuania's territory. Towards the end of World War I, Lithuania declared independence in 1918, founding the modern Republic of Lithuania. In World War II, Lithuania was occupied by the Soviet Union, then by Nazi Germany, before being reoccupied by the Soviets in 1944. Lithuanian armed resistance to the Soviet occupation lasted until the early 1950s. On 11 March 1990, a year before the formal dissolution of the Soviet Union, Lithuania became the first Soviet republic to break away when it proclaimed the restoration of its independence. Lithuania is a developed country with a high-income and an advanced economy ranking very high in Human Development Index.'''
context2 = 'Lithuania ranks highly in digital infrastructure,[25][26] press freedom and happiness.[27] Lithuania is a member of the United Nations, the European Union, the Council of Europe, the Council of the Baltic Sea States, the Eurozone, the Nordic Investment Bank, the International Monetary Fund, the Schengen Agreement, NATO, OECD and the World Trade Organization. It also participates in the Nordic-Baltic Eight (NB8) regional co-operation format.'

question = 'When Lithuania declared its independence and from whom?'
# question = 'When Lithuania declared its independence?'
question = 'When Kingdom of Lithuania was formed?'
# question = 'What is the capital of Lithuania?'
text = f'''
Context: {context.strip()}. Question: {question}. Answer: 
'''
# text = f'''
# <context>{context.strip()}</context> <question>{question}</question> <answer> The capital of Lithuania is
# '''
text = text.strip()
title = 'Lithuania'

In [15]:
toks_in = tkz(text).input_ids
print(len(toks_in), toks_in)
toks_in = torch.tensor(toks_in, dtype=torch.long, device=device)
toks_out = model.gen_on_wiki(toks_in)
toks_out = toks_out.squeeze()
# toks_out = toks_out[len(toks_in)-2:]
s_out = tkz.decode(toks_out)
s_out

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


492 [21947, 25, 31295, 17414, 65, 60, 8720, 262, 2066, 286, 31295, 17414, 66, 60, 318, 257, 1499, 287, 262, 32882, 3814, 286, 2031, 3693, 67, 60, 632, 318, 530, 286, 1115, 32882, 2585, 290, 7363, 319, 262, 10183, 15191, 286, 262, 32882, 6896, 11, 275, 24071, 416, 35794, 284, 262, 5093, 11, 33368, 284, 262, 7627, 290, 5366, 11, 12873, 284, 262, 5366, 11, 290, 262, 3394, 10663, 12, 1069, 44281, 286, 12612, 3191, 6335, 1835, 12957, 284, 262, 26283, 11, 351, 257, 30017, 4865, 351, 10710, 284, 262, 7421, 13, 31295, 8698, 281, 1989, 286, 6135, 11, 6200, 10571, 17, 357, 1495, 11, 2167, 19862, 21504, 828, 351, 257, 3265, 286, 362, 13, 4531, 1510, 13, 6363, 3139, 290, 4387, 1748, 318, 34037, 77, 3754, 26, 584, 1688, 4736, 2291, 509, 1942, 292, 11, 509, 5031, 541, 128, 245, 6814, 11, 25370, 254, 544, 43640, 72, 290, 350, 1531, 85, 128, 245, 129, 122, 893, 13, 49033, 1547, 508, 389, 262, 5259, 934, 3277, 290, 1296, 262, 3741, 286, 262, 1499, 338, 3265, 11, 5594, 284, 262, 33961, 349, 6680, 2569, 

"Context: Lithuania,[b] officially the Republic of Lithuania,[c] is a country in the Baltic region of Europe.[d] It is one of three Baltic states and lies on the eastern shore of the Baltic Sea, bordered by Latvia to the north, Belarus to the east and south, Poland to the south, and the Russian semi-exclave of Kaliningrad Oblast to the southwest, with a maritime border with Sweden to the west. Lithuania covers an area of 65,300 km2 (25,200 sq mi), with a population of 2.89 million. Its capital and largest city is Vilnius; other major cities include Kaunas, Klaipėda, Šiauliai and Panevėžys. Lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of Balts and speak Lithuanian. For millennia, the southeastern shores of the Baltic Sea were inhabited by various Baltic tribes. In the 1230s, Lithuanian lands were united for the first time by Mindaugas, who formed the Kingdom of Lithuania on 6 July 1253. Subsequent expansion

In [28]:
print(len(toks_in), len(toks_out))

492 256


In [31]:
tkz.decode(toks_in)

"Context: Lithuania,[b] officially the Republic of Lithuania,[c] is a country in the Baltic region of Europe.[d] It is one of three Baltic states and lies on the eastern shore of the Baltic Sea, bordered by Latvia to the north, Belarus to the east and south, Poland to the south, and the Russian semi-exclave of Kaliningrad Oblast to the southwest, with a maritime border with Sweden to the west. Lithuania covers an area of 65,300 km2 (25,200 sq mi), with a population of 2.89 million. Its capital and largest city is Vilnius; other major cities include Kaunas, Klaipėda, Šiauliai and Panevėžys. Lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of Balts and speak Lithuanian. For millennia, the southeastern shores of the Baltic Sea were inhabited by various Baltic tribes. In the 1230s, Lithuanian lands were united for the first time by Mindaugas, who formed the Kingdom of Lithuania on 6 July 1253. Subsequent expansion

### Try nltk for sentence splitting

In [8]:
import nltk
from nltk.tokenize import sent_tokenize
from timeit import default_timer as timer

nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /home/misha/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [101]:
item = ds_wiki[0]
txt = item['text']
s = txt.split('\n')[0]
print(len(txt), len(s), s)

43985 559 Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.


In [97]:
sents = sent_tokenize(s)
for sent in sents:
    print(sent)

Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.
Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.
As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.


In [100]:
N = 1000
t1 = timer()
for i in range(N):
    sents = sent_tokenize(txt)
t2 = timer()
delta = (t2 - t1) / N
print(f'sent_tokenize avg time: {delta:0.6f}')


sent_tokenize avg time: 0.005308


## Squad v2 dataset
### Loading

In [7]:
df_squad = get_squadv2_df(exclude_empty_answers=True)
print(len(df_squad))
df_squad.head()

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749
92749


Unnamed: 0,id,title,context,question,answers
0,56be85543aeaaa14008c9063,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce start becoming popular?,"{'text': ['in the late 1990s'], 'answer_start'..."
1,56be85543aeaaa14008c9065,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,What areas did Beyonce compete in when she was...,"{'text': ['singing and dancing'], 'answer_star..."
2,56be85543aeaaa14008c9066,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce leave Destiny's Child and bec...,"{'text': ['2003'], 'answer_start': [526]}"
3,56bf6b0f3aeaaa14008c9601,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In what city and state did Beyonce grow up?,"{'text': ['Houston, Texas'], 'answer_start': [..."
4,56bf6b0f3aeaaa14008c9602,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In which decade did Beyonce become famous?,"{'text': ['late 1990s'], 'answer_start': [276]}"


In [8]:
from mllm.data.itsquadv2 import QnaItemV2, QnaBatchV2


def get_qna_batch(i_batch: int, batch_size: int = batch_size) -> QnaBatchV2:
    i1 = i_batch * batch_size
    i2 = i1 + batch_size
    items = []
    for i in range(i1, i2):
        row = df_squad.iloc[i]
        answer = row.answers['text'][0]
        qna_item = QnaItemV2(
            tkz=tkz, ind=i, context=row.context, question=row.question, answer=answer, max_ctx_toks=n_toks_max, max_que_toks=model_cfg.max_inp_toks,
            max_ans_toks=model_cfg.max_out_toks,
        )
        items.append(qna_item)
    batch = QnaBatchV2(
        items=items, max_inp_len=n_toks_max, max_out_len=model_cfg.max_out_toks, device=device,
    )
    return batch

### QnA inference

In [19]:
i_batch = 2
batch = get_qna_batch(i_batch)
# [n_batch, max_len]
b_ctx_toks, b_que_toks, b_ans_toks, b_cq_toks = batch.get_tensors()
b_ctx_toks.shape, b_que_toks.shape, b_ans_toks.shape, b_cq_toks.shape

(torch.Size([5, 164]),
 torch.Size([5, 13]),
 torch.Size([5, 5]),
 torch.Size([5, 180]))

In [20]:
train_it, val_it = get_squadv2_batch_iterators_v2(
    batch_size=batch_size, exclude_empty_answers=True, tkz=model.tkz,
    max_inp_len=model_cfg.max_inp_toks, max_out_len=model_cfg.max_out_toks,
    device=device,
)

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749
Squad v2 n_total = 92749. n_train = 88112. n_val = 4637


In [21]:
train_batch, val_batch = next(train_it), next(val_it)

In [22]:
batch = train_batch
batch = val_batch
# [n_batch, max_len]
b_ctx_toks, b_que_toks, b_ans_toks, b_cq_toks = batch.get_tensors()
b_ctx_toks.shape, b_que_toks.shape, b_ans_toks.shape, b_cq_toks.shape

(torch.Size([5, 225]),
 torch.Size([5, 16]),
 torch.Size([5, 14]),
 torch.Size([5, 244]))

In [None]:
for i in range(batch_size):
    ctx_toks, que_toks, ans_toks, cq_toks = b_ctx_toks[i], b_que_toks[i], b_ans_toks[i], b_cq_toks[i]
    tgt_toks = ans_toks
    toks_inp = cq_toks
    toks_out = model.gen_on_qna(ctx_toks=ctx_toks, que_toks=que_toks, cq_toks=cq_toks)
    toks_out = toks_out.squeeze(0)
    s_inp = tkz.decode(toks_inp)
    s_out = tkz.decode(toks_out)
    s_tgt = tkz.decode(tgt_toks)
    print(f'{i:03d}. Inp: {s_inp}')
    print(f'{i:03d}. Tgt: {s_tgt}')
    print(f'{i:03d}. Out: {s_out}')
    # break

### Custom text

In [34]:
context = '''
Lithuania,[b] officially the Republic of Lithuania,[c] is a country in the Baltic region of Europe.[d] It is one of three Baltic states and lies on the eastern shore of the Baltic Sea, bordered by Latvia to the north, Belarus to the east and south, Poland to the south, and the Russian semi-exclave of Kaliningrad Oblast to the southwest, with a maritime border with Sweden to the west. Lithuania covers an area of 65,300 km2 (25,200 sq mi), with a population of 2.89 million. Its capital and largest city is Vilnius; other major cities include Kaunas, Klaipėda, Šiauliai and Panevėžys. Lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of Balts and speak Lithuanian. For millennia, the southeastern shores of the Baltic Sea were inhabited by various Baltic tribes. In the 1230s, Lithuanian lands were united for the first time by Mindaugas, who formed the Kingdom of Lithuania on 6 July 1253. Subsequent expansion and consolidation resulted in the Grand Duchy of Lithuania, which by the 14th century was the largest country in Europe. In 1386, the Grand Duchy entered into a de facto personal union with the Crown of the Kingdom of Poland. The two realms were united into the bi-confederal Polish-Lithuanian Commonwealth in 1569, forming one of the largest and most prosperous states in Europe. The Commonwealth lasted more than two centuries, until neighbouring countries gradually dismantled it between 1772 and 1795, with the Russian Empire annexing most of Lithuania's territory. Towards the end of World War I, Lithuania declared independence in 1918, founding the modern Republic of Lithuania. In World War II, Lithuania was occupied by the Soviet Union, then by Nazi Germany, before being reoccupied by the Soviets in 1944. Lithuanian armed resistance to the Soviet occupation lasted until the early 1950s. On 11 March 1990, a year before the formal dissolution of the Soviet Union, Lithuania became the first Soviet republic to break away when it proclaimed the restoration of its independence. Lithuania is a developed country with a high-income and an advanced economy ranking very high in Human Development Index.'''
context2 = 'Lithuania ranks highly in digital infrastructure,[25][26] press freedom and happiness.[27] Lithuania is a member of the United Nations, the European Union, the Council of Europe, the Council of the Baltic Sea States, the Eurozone, the Nordic Investment Bank, the International Monetary Fund, the Schengen Agreement, NATO, OECD and the World Trade Organization. It also participates in the Nordic-Baltic Eight (NB8) regional co-operation format.'

question = 'When Lithuania declared its independence?'
# question = 'What is an area of Lithuania?'
# question = 'What is a population of Lithuania?'
# question = 'What is a capital of Lithuania?'
text = f'''
[CLS] {context.strip()} [SEP] {question} [SEP]: 
'''
text = text.strip()

In [35]:
toks_in = tkz(text).input_ids
toks_in = torch.tensor(toks_in, dtype=torch.long, device=device)
ctx_toks = tkz(context, return_tensors='pt', add_special_tokens=False).input_ids.to(device)
que_toks = tkz(question, return_tensors='pt', add_special_tokens=False).input_ids.to(device)
toks_out = model.gen_on_qna(ctx_toks=ctx_toks, que_toks=que_toks)
s_out = tkz.decode(toks_out.squeeze())
s_out

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'9 March 1990<|endoftext|>'

In [20]:
len(toks_in), len(toks_out)

(472, 1)

In [21]:
torch.rand((3, 2))

tensor([[0.0345, 0.3600],
        [0.7277, 0.2924],
        [0.9627, 0.5792]])