In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import re
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.config.model import GenmixTrainDsType, TokensAggType, GenmixembCfg, copy_override_genmixemb_cfg, \
    gen_prefpostfix_genmixemb
from mllm.exp.args import GENMIXEMB_BERT_MODEL_CFG_FNAME, create_bool_str_field, is_arg_true
from mllm.model.genmixemb import Genmixemb
from mllm.train.mask_utils import MaskCfg
from mllm.train.utils import find_create_train_path, log_weights_grads_stats, SumTuple, QnaTuple
from mllm.data.wiki.itwiki import WikiItem, get_wiki_batch_iterators, WikiBatch
from mllm.utils.utils import rethrow
from mllm.data.utils import get_squadv2_df


# BERT Generator model inference
## Configs and paths

In [27]:
from mllm.config.model import SelfSuperviseType
from mllm.model import genmixemb


DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

model_name = 'bert-base-uncased'
random_seed = 111
train_genmixemb_path = DATA_PATH / 'train_mllm_genmixemb_wki'
genmixemb_subdir = 'genmixemb-20250713_202718-bertbaseuncased-d768-mxo50-aggBrt-sub0-dsWki-tmax100-tragF'
genmixemb_subdir = 'genmixemb-20250715_035750-bertbaseuncased-d768-mxo50-aggBrt-sub2-dsWki-tmax100-tragT'
genmixemb_subdir = 'genmixemb-20250716_213252-bertbaseuncased-d768-mxo50-aggBrt-sub0-dsWki-tmax100-msk_sep_0.5/0.15_seq_0.5/0.2/20-tragF'
genmixemb_subdir = 'genmixemb-20250718_220105-bertbaseuncased-d768-mxo50-aggBrt-sub2-dsWki-tmax100-tragT'
genmixemb_subdir = 'genmixemb-20250721_083250-bertbaseuncased-d768-mxo50-aggPyr-agtDecim-stp0-lvl1-lrs2-dsWki-tmax256-tragF-nxtsnt'
genmixemb_subdir = 'genmixemb-20250721_212402-bertbaseuncased-d768-mxo50-aggPyr-agtDecim-stp2-lvl2-lrs2-dsWki-tmax512-tragT-nxtsnt'
genmixemb_subdir = 'genmixemb-20250722_213424-bertbaseuncased-d768-mxo50-aggPyr-agtDecim-stp2-lvl3-lrs2-dsWki-tmax512-tragT-nxtsnt'
genmixemb_subdir = 'genmixemb-20250820_093126-bertbaseuncased-d768-mxi384-mxo50-dsWki-nxtsnt'
model_name = 'gpt2'
genmixemb_subdir = 'genmixemb-20250825_213542-gpt2-d768-mxi448-mxo36-dsWki-sstNxttok'

# train_genmixemb_path = DATA_PATH / 'train_mllm_genmixemb_qna'
# genmixemb_subdir = 'genmixemb-20250726_122548-bertbaseuncased-d768-mxi384-mxo50-dsQna'
# genmixemb_subdir = 'genmixemb-20250809_234548-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggPyr-agtTopdot-stp2-lvl1-lrs2-dsQna-tragT-shemT-ttidF-jcqF'
# genmixemb_subdir = 'genmixemb-20250810_125920-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggBrt-sub2-agtTopdot-dsQna-tragT-shemT-ttidF-jcqF'
# genmixemb_subdir = 'genmixemb-20250813_234929-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggPyr-agtTopdot-stp2-lvl1-lrs2-dsQna-tragT-shemT-ttidF-cqprCq'
# genmixemb_subdir = 'genmixemb-20250815_220237-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggPyr-agtMxpl-stp2-lvl1-lrs2-dsQna-tragT-shemT-ttidF-cqprCq'
# genmixemb_subdir = 'genmixemb-20250817_201509-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggCnv-lvl1-lrs1-cksz3-pksz2-pst2-dsQna-tragT-ttidF-cqprCq'

genmixemb_train_path = train_genmixemb_path / genmixemb_subdir
genmixemb_snapshot_fpath = genmixemb_train_path / 'best.pth'
# genmixemb_snapshot_fpath = genmixemb_train_path / 'last.pth'

# device_name = 'cpu'
device_name = 'cuda'

device = torch.device(device_name)
print('Device:', device)

parts = genmixemb_subdir.split('-')
n_toks_max = 0
mask_cfg = None
pred_next_sent = False
self_supervise_type = SelfSuperviseType.Input
for part in parts:
    if part.startswith('mxi'):
        n_toks_max = int(part[3:])
    elif part.startswith('msk_'):
        subparts = part.split('_')
        # postfix_parts.append(f'msk_sep_{sep_freq}/{sep_frac}_seq_{seq_freq}/{seq_max_frac}/{mask_cfg.seq_max_len}')
        sep_part, seq_part = subparts[2], subparts[4]
        sep_freq, sep_frac = sep_part.split('/')
        seq_freq, seq_max_frac, seq_max_len = seq_part.split('/')
        sep_freq, sep_frac = float(sep_freq), float(sep_frac)
        seq_freq, seq_max_frac, seq_max_len = float(seq_freq), float(seq_max_frac), int(seq_max_len)
        mask_cfg = MaskCfg(
            sep_freq=sep_freq, sep_frac=sep_frac, seq_freq=seq_freq, seq_max_frac=seq_max_frac,
            seq_max_len=seq_max_len,
        )
    elif part.startswith('sst'):
        self_supervise_type = SelfSuperviseType(part[3:].lower())
        

print(f'n_toks_max = {n_toks_max}')
print('Mask cfg:', mask_cfg)
print(f'pred_next_sent = {pred_next_sent}')
print(f'self_supervise_type = {self_supervise_type}')

batch_size = 5


Device: cuda
n_toks_max = 448
Mask cfg: None
pred_next_sent = False
self_supervise_type = nxttok


In [28]:
model_cfg = parse_yaml_file_as(GenmixembCfg, genmixemb_train_path / GENMIXEMB_BERT_MODEL_CFG_FNAME)
pprint(model_cfg.dict())

{'add_token_type_ids': False,
 'bert_agg_n_subseq_toks': 0,
 'bert_agg_type': <BertAggType.Topdot: 'topdot'>,
 'cnv_conv_kernel_size': 3,
 'cnv_n_layers_per_level': 1,
 'cnv_n_levels': 0,
 'cnv_pool_kernel_size': 2,
 'cnv_pool_stride': 2,
 'cnv_share_layer_weights': False,
 'ctx_que_prompt_type': <CtxQuePromptType.Cq: 'cq'>,
 'd_model': 768,
 'join_ctx_que_agg': False,
 'max_inp_toks': 448,
 'max_out_toks': 36,
 'model_name': 'gpt2',
 'pyr_agg_n_layers_per_level': 2,
 'pyr_agg_n_levels': 2,
 'pyr_agg_step': 2,
 'pyr_agg_type': <HgReductType.MaxPool: 'mxpl'>,
 'pyr_share_layer_weights': True,
 'share_agg_enc_token_embeds': True,
 'toks_agg_type': <TokensAggType.Conv: 'cnv'>,
 'train_agg_model': True}


/tmp/ipykernel_104703/449381411.py:2: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  pprint(model_cfg.dict())


## Load models and dataset
### Model

In [29]:
model = Genmixemb(model_cfg, device=device)
tkz = model.tkz

In [30]:
print(f'Load {genmixemb_snapshot_fpath}')
checkpoint = torch.load(genmixemb_snapshot_fpath, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
del checkpoint
model.eval()
None

Load /home/misha/data/train_mllm_genmixemb_wki/genmixemb-20250825_213542-gpt2-d768-mxi448-mxo36-dsWki-sstNxttok/best.pth


## Wiki dataset
### Loading

In [8]:
wiki_ds_name, wiki_ds_subdir = '20220301.en', 'wikipedia'
dss_wiki = load_dataset(wiki_ds_subdir, wiki_ds_name, cache_dir=str(DATA_PATH))
ds_wiki = dss_wiki['train']
n_docs = len(ds_wiki)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

Wikipedia 20220301.en docs: 6458670


In [9]:
def get_wiki_batch(i_batch: int, batch_size: int = batch_size) -> WikiBatch:
    i1 = i_batch * batch_size
    i2 = i1 + batch_size
    items = []
    for i in range(i1, i2):
        row = ds_wiki[i]
        wiki_item = WikiItem(
            tkz=tkz, ind=i, title=row['title'], text=row['text'], max_len=n_toks_max, mask_cfg=mask_cfg,
            self_supervise_type=self_supervise_type, max_pred_len=model_cfg.max_out_toks,
        )
        items.append(wiki_item)
    batch = WikiBatch(items=items, device=device)
    return batch

### Wiki inference

In [10]:
i_batch = 0
batch = get_wiki_batch(i_batch)
# [n_batch, max_len]
b_toks, b_masked_toks, b_mask, b_tgt_toks = batch.get_tensors()

Token indices sequence length is longer than the specified maximum sequence length for this model (8287 > 1024). Running this sequence through the model will result in indexing errors


In [60]:
with_mask = False
# with_mask = True
for i in range(batch_size):
    toks, masked_toks, mask = b_toks[i], b_masked_toks[i], b_mask[i]
    tgt_toks = None
    if b_tgt_toks is not None:
        tgt_toks = b_tgt_toks[i]
    toks_inp = masked_toks if with_mask else toks
    # toks_inp = toks_inp[:100]
    toks_out = model.gen_on_wiki(toks=toks_inp)
    toks_out = toks_out.squeeze(0)
    toks_out = toks_out[len(toks_inp)-2:]  # remove prompt
    s_inp = tkz.decode(toks_inp)
    s_out = tkz.decode(toks_out)
    if with_mask:
        s_src = tkz.decode(toks)
        print(f'{i:03d}. Msk: {s_src}')
    s_inp = s_inp.replace('\n', '\\n')
    print(f'{i:03d}. Inp: {s_inp}')
    if tgt_toks is not None:
        s_tgt = tkz.decode(tgt_toks)
        s_tgt = s_tgt.replace('\n', '\\n')
        print(f'{i:03d}. Tgt: {s_tgt}')
    s_out = s_out.replace('\n', '\\n')
    print(f'{i:03d}. Out: {s_out}')
    # break

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


000. Inp: oy and Herbert Read stated that the border between the artist and the non-artist, what separates art from a daily act, is a construct produced by the alienation<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


001. Inp:  based on observation of children. The Diagnostic interview for social and communication disorders (DISCO) may also be used.\n\nA pediatrician commonly performs a preliminary investigation by taking developmental history and physically examining the child. If warranted, diagnosis and evaluations are conducted with help from ASD specialists, observing and assessing cognitive, communication, family, and other factors using standardized tools, and taking into account any associated medical conditions. A pediatric neuropsychologist is often asked to assess behavior and cognitive skills, both to aid diagnosis and to help recommend educational interventions. A differential diagnosis for ASD at this stage might also consider intellectual disability, hearing impairment, and a specific language impairment such as Landau–Kleffner syndrome. The presence of autism can make it harder to diagnose coexisting psychiatric disorders such as depression.\n\nClinical genetics evaluations are ofte

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


002. Inp:  Earth's surface. These factors vary with atmospheric composition, geographic location, and time (see position of the Sun). While bi-hem<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endofte

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


003. Inp:  form resembling the Greek letter tau in the hands of medieval Irish and English writers. The Roman form is used in most printed material; it consists of a small loop with an arc over it ("a"). Both derive from the majuscule (capital) form. In Greek handwriting, it was common to join the left leg and horizontal stroke into a single loop, as demonstrated by the uncial version shown. Many fonts then made the right leg vertical. In some of these, the serif that began the right leg stroke developed into an arc, resulting in the printed form, while in others it was dropped, resulting in the modern handwritten form. Graphic designers refer to the Italic and Roman forms as "single decker a" and "double decker a" respectively.\n\nItalic type is commonly used to mark emphasis or more generally to distinguish one part of a text from the rest (set in Roman type). There are some other cases aside from italic type where script a ("ɑ"), also called Latin alpha, is used in contrast with Lat

In [64]:
from torch import mode, norm
def norm_str(s: str) -> str:
    s = s.replace('\n', ' ').replace('\r', ' ')
    return s

for i in range(batch_size):
    item = batch.items[i]
    toks_in = item.src_toks[:model_cfg.max_inp_toks]
    s_inp = tkz.decode(toks_in)
    print(f'{i:03d}. Inp: {norm_str(s_inp)}')
    tooks_tgt = item.src_toks[model_cfg.max_inp_toks:model_cfg.max_inp_toks + model_cfg.max_out_toks]
    s_tgt = tkz.decode(tooks_tgt)
    print(f'{i:03d}. Tgt: {norm_str(s_tgt)}')
    toks_in = torch.tensor(toks_in, dtype=torch.long, device=device)
    toks_out = model.gen_on_wiki(toks_in)
    toks_out = toks_out.squeeze()[len(toks_in)-2:]
    s_out = tkz.decode(toks_out)
    print(f'{i:03d}. Out: {norm_str(s_out)}')


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


000. Inp: Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.  Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires. With the rise of organised hierarchical bodies, scepticism toward authority also rose. Although traces of anarchist thought are found throughout history, modern anarchism emerged from the Enlightenment. During the latter half of the 19th and the first decades of the 20th century, the anarchist movement flou

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


000. Out:  The suffix arkhos is a Greek word that means "leader" or "ruler" in Greek and Latin.  In popular culture  The character Arthos, played by
001. Inp: Autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. Parents often notice signs during the first three years of their child's life. These signs often develop gradually, though some autistic children experience regression in their communication and social skills after reaching developmental milestones at a normal pace.  Autism is associated with a combination of genetic and environmental factors. Risk factors during pregnancy include certain infections, such as rubella, toxins including valproic acid, alcohol, cocaine, pesticides, lead, and air pollution, fetal growth restriction, and autoimmune diseases. Controversies surround other proposed environmental causes; for example, the vaccine hypothesis, which has been disproven. Aut

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


001. Out:  condition.  References  Autism Diagnosis Autism spectrum disorder Diagnosis Autism spectrum disorders Autistic personality disorder Autism Autism in
002. Inp: Albedo (; ) is the measure of the diffuse reflection of solar radiation out of the total solar radiation and measured on a scale from 0, corresponding to a black body that absorbs all incident radiation, to 1, corresponding to a body that reflects all incident radiation.  Surface albedo is defined as the ratio of radiosity Je to the irradiance Ee (flux per unit area) received by a surface. The proportion reflected is not only determined by properties of the surface itself, but also by the spectral and angular distribution of solar radiation reaching the Earth's surface. These factors vary with atmospheric composition, geographic location, and time (see position of the Sun). While bi-hemispherical reflectance is calculated for a single angle of incidence (i.e., for a given position of the Sun), albedo is the directional

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


002. Out:  Any surface albedo that can be measured on the scale from 0, 0, 0, and 1 corresponds to a terrestrial albedo, and a terrestrial albedo that is measured
003. Inp: A, or a, is the first letter and the first vowel of the modern English alphabet and the ISO basic Latin alphabet. Its name in English is a (pronounced ), plural aes. It is similar in shape to the Ancient Greek letter alpha, from which it derives. The uppercase version consists of the two slanting sides of a triangle, crossed in the middle by a horizontal bar. The lowercase version can be written in two forms: the double-storey a and single-storey ɑ. The latter is commonly used in handwriting and fonts based on it, especially fonts intended to be read by children, and is also found in italic type.  In the English grammar, "a", and its variant "an", are indefinite articles.  History  The earliest certain ancestor of "A" is aleph (also written 'aleph), the first letter of the Phoenician alphabet, which consisted entire

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


003. Out: can alphabet, but it was replaced by an alternative, the Roman one.  See also  References  External links  A A  Ancient Greek letters Latin letters
004. Inp: Alabama () is a state in the Southeastern region of the United States, bordered by Tennessee to the north; Georgia to the east; Florida and the Gulf of Mexico to the south; and Mississippi to the west. Alabama is the 30th largest by area and the 24th-most populous of the U.S. states. With a total of  of inland waterways, Alabama has among the most of any state.  Alabama is nicknamed the Yellowhammer State, after the state bird. Alabama is also known as the "Heart of Dixie" and the "Cotton State". The state tree is the longleaf pine, and the state flower is the camellia. Alabama's capital is Montgomery, and its largest city by population and area is Huntsville. Its oldest city is Mobile, founded by French colonists in 1702 as the capital of French Louisiana. Greater Birmingham is Alabama's largest metropolitan area and it

### Custom text

In [39]:
context = '''
Lithuania,[b] officially the Republic of Lithuania,[c] is a country in the Baltic region of Europe.[d] It is one of three Baltic states and lies on the eastern shore of the Baltic Sea, bordered by Latvia to the north, Belarus to the east and south, Poland to the south, and the Russian semi-exclave of Kaliningrad Oblast to the southwest, with a maritime border with Sweden to the west. Lithuania covers an area of 65,300 km2 (25,200 sq mi), with a population of 2.89 million. Its capital and largest city is Vilnius; other major cities include Kaunas, Klaipėda, Šiauliai and Panevėžys. Lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of Balts and speak Lithuanian. For millennia, the southeastern shores of the Baltic Sea were inhabited by various Baltic tribes. In the 1230s, Lithuanian lands were united for the first time by Mindaugas, who formed the Kingdom of Lithuania on 6 July 1253. Subsequent expansion and consolidation resulted in the Grand Duchy of Lithuania, which by the 14th century was the largest country in Europe. In 1386, the Grand Duchy entered into a de facto personal union with the Crown of the Kingdom of Poland. The two realms were united into the bi-confederal Polish-Lithuanian Commonwealth in 1569, forming one of the largest and most prosperous states in Europe. The Commonwealth lasted more than two centuries, until neighbouring countries gradually dismantled it between 1772 and 1795, with the Russian Empire annexing most of Lithuania's territory. Towards the end of World War I, Lithuania declared independence in 1918, founding the modern Republic of Lithuania. In World War II, Lithuania was occupied by the Soviet Union, then by Nazi Germany, before being reoccupied by the Soviets in 1944. Lithuanian armed resistance to the Soviet occupation lasted until the early 1950s. On 11 March 1990, a year before the formal dissolution of the Soviet Union, Lithuania became the first Soviet republic to break away when it proclaimed the restoration of its independence. Lithuania is a developed country with a high-income and an advanced economy ranking very high in Human Development Index.'''
context2 = 'Lithuania ranks highly in digital infrastructure,[25][26] press freedom and happiness.[27] Lithuania is a member of the United Nations, the European Union, the Council of Europe, the Council of the Baltic Sea States, the Eurozone, the Nordic Investment Bank, the International Monetary Fund, the Schengen Agreement, NATO, OECD and the World Trade Organization. It also participates in the Nordic-Baltic Eight (NB8) regional co-operation format.'

question = 'When Lithuania declared its independence and from whom?'
# question = 'When Lithuania declared its independence?'
question = 'When Kingdom of Lithuania was formed?'
# question = 'What is the capital of Lithuania?'
text = f'''
Context: {context.strip()}. Question: {question}. Answer: 
'''
# text = f'''
# <context>{context.strip()}</context> <question>{question}</question> <answer> The capital of Lithuania is
# '''
text = text.strip()
title = 'Lithuania'

In [46]:
toks_in = tkz(text).input_ids
print(len(toks_in), toks_in)
toks_in = torch.tensor(toks_in, dtype=torch.long, device=device)
toks_out = model.gen_on_wiki(toks_in)
toks_out = toks_out.squeeze()[len(toks_in)-2:]
s_out = tkz.decode(toks_out)
s_out

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


492 [21947, 25, 31295, 17414, 65, 60, 8720, 262, 2066, 286, 31295, 17414, 66, 60, 318, 257, 1499, 287, 262, 32882, 3814, 286, 2031, 3693, 67, 60, 632, 318, 530, 286, 1115, 32882, 2585, 290, 7363, 319, 262, 10183, 15191, 286, 262, 32882, 6896, 11, 275, 24071, 416, 35794, 284, 262, 5093, 11, 33368, 284, 262, 7627, 290, 5366, 11, 12873, 284, 262, 5366, 11, 290, 262, 3394, 10663, 12, 1069, 44281, 286, 12612, 3191, 6335, 1835, 12957, 284, 262, 26283, 11, 351, 257, 30017, 4865, 351, 10710, 284, 262, 7421, 13, 31295, 8698, 281, 1989, 286, 6135, 11, 6200, 10571, 17, 357, 1495, 11, 2167, 19862, 21504, 828, 351, 257, 3265, 286, 362, 13, 4531, 1510, 13, 6363, 3139, 290, 4387, 1748, 318, 34037, 77, 3754, 26, 584, 1688, 4736, 2291, 509, 1942, 292, 11, 509, 5031, 541, 128, 245, 6814, 11, 25370, 254, 544, 43640, 72, 290, 350, 1531, 85, 128, 245, 129, 122, 893, 13, 49033, 1547, 508, 389, 262, 5259, 934, 3277, 290, 1296, 262, 3741, 286, 262, 1499, 338, 3265, 11, 5594, 284, 262, 33961, 349, 6680, 2569, 

' Answer: In the early 1990s, Lithuania was a member of the European Union.\n\nReferences\n\nExternal links\nOfficial website\n\nLithuania\nLithuania\nLithuania\nLithuania\nLithuania\nLithuania'

In [36]:
print(len(toks_in), len(toks_out))

372 0


In [31]:
tkz.decode(toks_in)

"Context: Lithuania,[b] officially the Republic of Lithuania,[c] is a country in the Baltic region of Europe.[d] It is one of three Baltic states and lies on the eastern shore of the Baltic Sea, bordered by Latvia to the north, Belarus to the east and south, Poland to the south, and the Russian semi-exclave of Kaliningrad Oblast to the southwest, with a maritime border with Sweden to the west. Lithuania covers an area of 65,300 km2 (25,200 sq mi), with a population of 2.89 million. Its capital and largest city is Vilnius; other major cities include Kaunas, Klaipėda, Šiauliai and Panevėžys. Lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of Balts and speak Lithuanian. For millennia, the southeastern shores of the Baltic Sea were inhabited by various Baltic tribes. In the 1230s, Lithuanian lands were united for the first time by Mindaugas, who formed the Kingdom of Lithuania on 6 July 1253. Subsequent expansion

### Try nltk for sentence splitting

In [8]:
import nltk
from nltk.tokenize import sent_tokenize
from timeit import default_timer as timer

nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /home/misha/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [101]:
item = ds_wiki[0]
txt = item['text']
s = txt.split('\n')[0]
print(len(txt), len(s), s)

43985 559 Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.


In [97]:
sents = sent_tokenize(s)
for sent in sents:
    print(sent)

Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.
Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.
As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.


In [100]:
N = 1000
t1 = timer()
for i in range(N):
    sents = sent_tokenize(txt)
t2 = timer()
delta = (t2 - t1) / N
print(f'sent_tokenize avg time: {delta:0.6f}')


sent_tokenize avg time: 0.005308


## Squad v2 dataset
### Loading

In [9]:
df_squad = get_squadv2_df(exclude_empty_answers=True)
print(len(df_squad))
print(df_squad.head())

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749
92749
                         id    title  \
0  56be85543aeaaa14008c9063  Beyoncé   
1  56be85543aeaaa14008c9065  Beyoncé   
2  56be85543aeaaa14008c9066  Beyoncé   
3  56bf6b0f3aeaaa14008c9601  Beyoncé   
4  56bf6b0f3aeaaa14008c9602  Beyoncé   

                                             context  \
0  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   
1  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   
2  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   
3  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   
4  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   

                                            question  \
0           When did Beyonce start becoming popular?   
1  What areas did Beyonce compete in when she was...   
2  When did Beyonce leave Destiny's Child and bec...   
3      In what city and state did Beyonce  grow up?    
4         In which decade did Beyonce become famous?   

            

In [10]:
from mllm.data.itsquadv2 import QnaItemV2, QnaBatchV2


def get_qna_batch(i_batch: int, batch_size: int = batch_size) -> QnaBatchV2:
    i1 = i_batch * batch_size
    i2 = i1 + batch_size
    items = []
    for i in range(i1, i2):
        row = df_squad.iloc[i]
        answer = row.answers['text'][0]
        qna_item = QnaItemV2(
            tkz=tkz, ind=i, context=row.context, question=row.question, answer=answer, max_ctx_toks=n_toks_max, max_que_toks=model_cfg.max_inp_toks,
            max_ans_toks=model_cfg.max_out_toks,
        )
        items.append(qna_item)
    batch = QnaBatchV2(
        items=items, max_inp_len=n_toks_max, max_out_len=model_cfg.max_out_toks, device=device,
    )
    return batch

### QnA inference

In [15]:
i_batch = 33
batch = get_qna_batch(i_batch)
# [n_batch, max_len]
b_ctx_toks, b_que_toks, b_ans_toks, b_cq_toks = batch.get_tensors()
b_ctx_toks.shape, b_que_toks.shape, b_ans_toks.shape, b_cq_toks.shape

(torch.Size([5, 198]),
 torch.Size([5, 18]),
 torch.Size([5, 6]),
 torch.Size([5, 219]))

In [16]:
for i in range(batch_size):
    ctx_toks, que_toks, ans_toks, cq_toks = b_ctx_toks[i], b_que_toks[i], b_ans_toks[i], b_cq_toks[i]
    tgt_toks = ans_toks
    toks_inp = cq_toks
    toks_out = model.gen_on_qna(ctx_toks=ctx_toks, que_toks=que_toks, cq_toks=cq_toks)
    toks_out = toks_out.squeeze(0)
    s_inp = tkz.decode(toks_inp)
    s_out = tkz.decode(toks_out)
    s_tgt = tkz.decode(tgt_toks)
    print(f'{i:03d}. Inp: {s_inp}')
    print(f'{i:03d}. Tgt: {s_tgt}')
    print(f'{i:03d}. Out: {s_out}')
    # break

000. Inp: [CLS] her first acting role of 2006 was in the comedy film the pink panther starring opposite steve martin, grossing $ 158. 8 million at the box office worldwide. her second film dreamgirls, the film version of the 1981 broadway musical loosely based on the supremes, received acclaim from critics and grossed $ 154 million internationally. in it, she starred opposite jennifer hudson, jamie foxx, and eddie murphy playing a pop singer based on diana ross. to promote the film, beyonce released " listen " as the lead single from the soundtrack album. in april 2007, beyonce embarked on the beyonce experience, her first worldwide concert tour, visiting 97 venues and grossed over $ 24 million. [ note 1 ] beyonce conducted pre - concert food donation drives during six major stops in conjunction with her pastor at st. john's and america's second harvest. at the same time, b'day was re - released with five additional songs, including her duet with shakira " beautiful liar ". [SEP] how m

### Custom text

In [20]:
context = '''
Lithuania,[b] officially the Republic of Lithuania,[c] is a country in the Baltic region of Europe.[d] It is one of three Baltic states and lies on the eastern shore of the Baltic Sea, bordered by Latvia to the north, Belarus to the east and south, Poland to the south, and the Russian semi-exclave of Kaliningrad Oblast to the southwest, with a maritime border with Sweden to the west. Lithuania covers an area of 65,300 km2 (25,200 sq mi), with a population of 2.89 million. Its capital and largest city is Vilnius; other major cities include Kaunas, Klaipėda, Šiauliai and Panevėžys. Lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of Balts and speak Lithuanian. For millennia, the southeastern shores of the Baltic Sea were inhabited by various Baltic tribes. In the 1230s, Lithuanian lands were united for the first time by Mindaugas, who formed the Kingdom of Lithuania on 6 July 1253. Subsequent expansion and consolidation resulted in the Grand Duchy of Lithuania, which by the 14th century was the largest country in Europe. In 1386, the Grand Duchy entered into a de facto personal union with the Crown of the Kingdom of Poland. The two realms were united into the bi-confederal Polish-Lithuanian Commonwealth in 1569, forming one of the largest and most prosperous states in Europe. The Commonwealth lasted more than two centuries, until neighbouring countries gradually dismantled it between 1772 and 1795, with the Russian Empire annexing most of Lithuania's territory. Towards the end of World War I, Lithuania declared independence in 1918, founding the modern Republic of Lithuania. In World War II, Lithuania was occupied by the Soviet Union, then by Nazi Germany, before being reoccupied by the Soviets in 1944. Lithuanian armed resistance to the Soviet occupation lasted until the early 1950s. On 11 March 1990, a year before the formal dissolution of the Soviet Union, Lithuania became the first Soviet republic to break away when it proclaimed the restoration of its independence. Lithuania is a developed country with a high-income and an advanced economy ranking very high in Human Development Index.'''
context2 = 'Lithuania ranks highly in digital infrastructure,[25][26] press freedom and happiness.[27] Lithuania is a member of the United Nations, the European Union, the Council of Europe, the Council of the Baltic Sea States, the Eurozone, the Nordic Investment Bank, the International Monetary Fund, the Schengen Agreement, NATO, OECD and the World Trade Organization. It also participates in the Nordic-Baltic Eight (NB8) regional co-operation format.'

question = 'When Lithuania declared its independence?'
question = 'What is an area of Lithuania?'
# question = 'What is a population of Lithuania?'
question = 'What is a capital of Lithuania?'
text = f'''
[CLS] {context.strip()} [SEP] {question} [SEP]: 
'''
text = text.strip()

In [None]:
toks_in = tkz(text).input_ids
toks_in = torch.tensor(toks_in, dtype=torch.long, device=device)
ctx_toks = tkz(context, return_tensors='pt', add_special_tokens=False).input_ids
que_toks = tkz(question, return_tensors='pt', add_special_tokens=False).input_ids
toks_out = model.gen_on_qna(ctx_toks=ctx_toks, que_toks=que_toks)
s_out = tkz.decode(toks_out.squeeze())
s_out

In [20]:
len(toks_in), len(toks_out)

(472, 1)

In [21]:
torch.rand((3, 2))

tensor([[0.0345, 0.3600],
        [0.7277, 0.2924],
        [0.9627, 0.5792]])