In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import re
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import numpy as np
import pandas as pd
from pydantic_yaml import parse_yaml_file_as, to_yaml_file
import torch
from torch import nn
import torch.nn.functional as F
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.config.model import GenmixTrainDsType, TokensAggType, GenmixembBertCfg, copy_override_genmixemb_bert_cfg, \
    gen_prefpostfix_genmixemb_bert
from mllm.exp.args import GENMIXEMB_BERT_MODEL_CFG_FNAME, create_bool_str_field, is_arg_true
from mllm.model.genmixemb_bert import GenmixembBert
from mllm.train.mask_utils import MaskCfg
from mllm.train.utils import find_create_train_path, log_weights_grads_stats, SumTuple, QnaTuple
from mllm.data.wiki.itwiki import WikiItem, get_wiki_batch_iterators, WikiBatch
from mllm.utils.utils import rethrow
from mllm.data.utils import get_squadv2_df


# BERT Generator model inference
## Configs and paths

In [28]:
from mllm.model import genmixemb_bert


DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

bert_model_name = 'bert-base-uncased'
random_seed = 111
train_genmixemb_bert_path = DATA_PATH / 'train_mllm_genmixembbert_wki'
genmixemb_subdir = 'genmixemb-20250713_202718-bertbaseuncased-d768-mxo50-aggBrt-sub0-dsWki-tmax100-tragF'
genmixemb_subdir = 'genmixemb-20250715_035750-bertbaseuncased-d768-mxo50-aggBrt-sub2-dsWki-tmax100-tragT'
genmixemb_subdir = 'genmixemb-20250716_213252-bertbaseuncased-d768-mxo50-aggBrt-sub0-dsWki-tmax100-msk_sep_0.5/0.15_seq_0.5/0.2/20-tragF'
genmixemb_subdir = 'genmixemb-20250718_220105-bertbaseuncased-d768-mxo50-aggBrt-sub2-dsWki-tmax100-tragT'
genmixemb_subdir = 'genmixemb-20250721_083250-bertbaseuncased-d768-mxo50-aggPyr-agtDecim-stp0-lvl1-lrs2-dsWki-tmax256-tragF-nxtsnt'
genmixemb_subdir = 'genmixemb-20250721_212402-bertbaseuncased-d768-mxo50-aggPyr-agtDecim-stp2-lvl2-lrs2-dsWki-tmax512-tragT-nxtsnt'
genmixemb_subdir = 'genmixemb-20250722_213424-bertbaseuncased-d768-mxo50-aggPyr-agtDecim-stp2-lvl3-lrs2-dsWki-tmax512-tragT-nxtsnt'
genmixemb_subdir = 'genmixemb-20250820_093126-bertbaseuncased-d768-mxi384-mxo50-dsWki-nxtsnt'

# train_genmixemb_bert_path = DATA_PATH / 'train_mllm_genmixembbert_qna'
# genmixemb_subdir = 'genmixemb-20250726_122548-bertbaseuncased-d768-mxi384-mxo50-dsQna'
# genmixemb_subdir = 'genmixemb-20250809_234548-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggPyr-agtTopdot-stp2-lvl1-lrs2-dsQna-tragT-shemT-ttidF-jcqF'
# genmixemb_subdir = 'genmixemb-20250810_125920-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggBrt-sub2-agtTopdot-dsQna-tragT-shemT-ttidF-jcqF'
# genmixemb_subdir = 'genmixemb-20250813_234929-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggPyr-agtTopdot-stp2-lvl1-lrs2-dsQna-tragT-shemT-ttidF-cqprCq'
# genmixemb_subdir = 'genmixemb-20250815_220237-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggPyr-agtMxpl-stp2-lvl1-lrs2-dsQna-tragT-shemT-ttidF-cqprCq'
# genmixemb_subdir = 'genmixemb-20250817_201509-pre_genmixemb20250726122548-bertbaseuncased-d768-mxi384-mxo50-aggCnv-lvl1-lrs1-cksz3-pksz2-pst2-dsQna-tragT-ttidF-cqprCq'

genmixemb_train_path = train_genmixemb_bert_path / genmixemb_subdir
genmixemb_snapshot_fpath = genmixemb_train_path / 'best.pth'
# genmixemb_snapshot_fpath = genmixemb_train_path / 'last.pth'

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print('Device:', device)

parts = genmixemb_subdir.split('-')
n_toks_max = 0
mask_cfg = None
pred_next_sent = False
for part in parts:
    if part.startswith('mxi'):
        n_toks_max = int(part[3:])
    elif part.startswith('msk_'):
        subparts = part.split('_')
        # postfix_parts.append(f'msk_sep_{sep_freq}/{sep_frac}_seq_{seq_freq}/{seq_max_frac}/{mask_cfg.seq_max_len}')
        sep_part, seq_part = subparts[2], subparts[4]
        sep_freq, sep_frac = sep_part.split('/')
        seq_freq, seq_max_frac, seq_max_len = seq_part.split('/')
        sep_freq, sep_frac = float(sep_freq), float(sep_frac)
        seq_freq, seq_max_frac, seq_max_len = float(seq_freq), float(seq_max_frac), int(seq_max_len)
        mask_cfg = MaskCfg(
            sep_freq=sep_freq, sep_frac=sep_frac, seq_freq=seq_freq, seq_max_frac=seq_max_frac,
            seq_max_len=seq_max_len,
        )
    elif part == 'nxtsnt':
        pred_next_sent = True

print(f'n_toks_max = {n_toks_max}')
print('Mask cfg:', mask_cfg)
print(f'pred_next_sent = {pred_next_sent}')

batch_size = 5


Device: cpu
n_toks_max = 384
Mask cfg: None
pred_next_sent = True


In [29]:
model_cfg = parse_yaml_file_as(GenmixembBertCfg, genmixemb_train_path / GENMIXEMB_BERT_MODEL_CFG_FNAME)
pprint(model_cfg.dict())

{'add_token_type_ids': False,
 'bert_agg_n_subseq_toks': 8,
 'bert_agg_type': <BertAggType.Topdot: 'topdot'>,
 'bert_model_name': 'bert-base-uncased',
 'cnv_conv_kernel_size': 3,
 'cnv_n_layers_per_level': 1,
 'cnv_n_levels': 0,
 'cnv_pool_kernel_size': 2,
 'cnv_pool_stride': 2,
 'cnv_share_layer_weights': False,
 'ctx_que_prompt_type': <CtxQuePromptType.Cq: 'cq'>,
 'd_model': 768,
 'join_ctx_que_agg': False,
 'max_inp_toks': 384,
 'max_out_toks': 50,
 'pyr_agg_n_layers_per_level': 2,
 'pyr_agg_n_levels': 2,
 'pyr_agg_step': 2,
 'pyr_agg_type': <HgReductType.MaxPool: 'mxpl'>,
 'pyr_share_layer_weights': True,
 'share_agg_enc_token_embeds': True,
 'toks_agg_type': <TokensAggType.Conv: 'cnv'>,
 'train_agg_model': True}


## Load models and dataset
### Model

In [30]:
model = GenmixembBert(model_cfg, device=device)
tkz = model.tkz

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossatte

In [31]:
print(f'Load {genmixemb_snapshot_fpath}')
checkpoint = torch.load(genmixemb_snapshot_fpath, map_location=device)
model.load_state_dict(checkpoint['model'], strict=True)
del checkpoint
model.eval()
None

Load /home/misha/data/train_mllm_genmixembbert_wki/genmixemb-20250820_093126-bertbaseuncased-d768-mxi384-mxo50-dsWki-nxtsnt/best.pth


## Wiki dataset
### Loading

In [32]:
wiki_ds_name, wiki_ds_subdir = '20220301.en', 'wikipedia'
dss_wiki = load_dataset(wiki_ds_subdir, wiki_ds_name, cache_dir=str(DATA_PATH))
ds_wiki = dss_wiki['train']
n_docs = len(ds_wiki)
print(f'Wikipedia {wiki_ds_name} docs: {n_docs}')

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

Wikipedia 20220301.en docs: 6458670


In [33]:
def get_wiki_batch(i_batch: int, batch_size: int = batch_size) -> WikiBatch:
    i1 = i_batch * batch_size
    i2 = i1 + batch_size
    items = []
    for i in range(i1, i2):
        row = ds_wiki[i]
        wiki_item = WikiItem(
            tkz=tkz, ind=i, title=row['title'], text=row['text'], max_len=n_toks_max, mask_cfg=mask_cfg, pred_next_sent=pred_next_sent,
            max_pred_len=model_cfg.max_out_toks,
        )
        items.append(wiki_item)
    batch = WikiBatch(items=items, device=device)
    return batch

### Wiki inference

In [34]:
i_batch = 0
batch = get_wiki_batch(i_batch)
# [n_batch, max_len]
b_toks, b_masked_toks, b_mask, b_tgt_toks = batch.get_tensors()

In [35]:
with_mask = False
# with_mask = True
for i in range(batch_size):
    toks, masked_toks, mask = b_toks[i], b_masked_toks[i], b_mask[i]
    tgt_toks = None
    if b_tgt_toks is not None:
        tgt_toks = b_tgt_toks[i]
    toks_inp = masked_toks if with_mask else toks
    toks_out = model.gen_on_wiki(toks=toks_inp)
    toks_out = toks_out.squeeze(0)
    s_inp = tkz.decode(toks_inp)
    s_out = tkz.decode(toks_out)
    if with_mask:
        s_src = tkz.decode(toks)
        print(f'{i:03d}. Msk: {s_src}')
    print(f'{i:03d}. Inp: {s_inp}')
    if tgt_toks is not None:
        s_tgt = tkz.decode(tgt_toks)
        print(f'{i:03d}. Tgt: {s_tgt}')
    print(f'{i:03d}. Out: {s_out}')
    # break

000. Inp: its citizens without the presence of a state. in medieval europe, there was no anarchistic activity except some ascetic religious movements. these, and other muslim movements, later gave birth to religious anarchism. in the sasanian empire, mazdak called for an egalitarian society and the abolition of monarchy, only to be soon executed by emperor kavad i. in basra, religious sects preached against the state. in europe, various sects developed anti - state and libertarian tendencies. renewed interest in antiquity during the renaissance and in private judgment during the reformation restored elements of anti - authoritarian secularism, particularly in france. enlightenment challenges to intellectual authority ( secular and religious ) and the revolutions of the 1790s and 1848 all spurred the ideological development of what became the era of classical anarchism. modern era during the french revolution, partisan groups such as the enrages and the saw a turning point in the fermen

In [37]:
print(batch.items[4].text)

Alabama () is a state in the Southeastern region of the United States, bordered by Tennessee to the north; Georgia to the east; Florida and the Gulf of Mexico to the south; and Mississippi to the west. Alabama is the 30th largest by area and the 24th-most populous of the U.S. states. With a total of  of inland waterways, Alabama has among the most of any state.

Alabama is nicknamed the Yellowhammer State, after the state bird. Alabama is also known as the "Heart of Dixie" and the "Cotton State". The state tree is the longleaf pine, and the state flower is the camellia. Alabama's capital is Montgomery, and its largest city by population and area is Huntsville. Its oldest city is Mobile, founded by French colonists in 1702 as the capital of French Louisiana. Greater Birmingham is Alabama's largest metropolitan area and its economic center.

Originally home to many native tribes, present-day Alabama was a Spanish territory beginning in the sixteenth century until the French acquired it i

In [11]:
(b_masked_toks == tkz.mask_token_id).sum()

tensor(0)

### Custom text

In [172]:
context = '''
Lithuania,[b] officially the Republic of Lithuania,[c] is a country in the Baltic region of Europe.[d] It is one of three Baltic states and lies on the eastern shore of the Baltic Sea, bordered by Latvia to the north, Belarus to the east and south, Poland to the south, and the Russian semi-exclave of Kaliningrad Oblast to the southwest, with a maritime border with Sweden to the west. Lithuania covers an area of 65,300 km2 (25,200 sq mi), with a population of 2.89 million. Its capital and largest city is Vilnius; other major cities include Kaunas, Klaipėda, Šiauliai and Panevėžys. Lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of Balts and speak Lithuanian. For millennia, the southeastern shores of the Baltic Sea were inhabited by various Baltic tribes. In the 1230s, Lithuanian lands were united for the first time by Mindaugas, who formed the Kingdom of Lithuania on 6 July 1253. Subsequent expansion and consolidation resulted in the Grand Duchy of Lithuania, which by the 14th century was the largest country in Europe. In 1386, the Grand Duchy entered into a de facto personal union with the Crown of the Kingdom of Poland. The two realms were united into the bi-confederal Polish-Lithuanian Commonwealth in 1569, forming one of the largest and most prosperous states in Europe. The Commonwealth lasted more than two centuries, until neighbouring countries gradually dismantled it between 1772 and 1795, with the Russian Empire annexing most of Lithuania's territory. Towards the end of World War I, Lithuania declared independence in 1918, founding the modern Republic of Lithuania. In World War II, Lithuania was occupied by the Soviet Union, then by Nazi Germany, before being reoccupied by the Soviets in 1944. Lithuanian armed resistance to the Soviet occupation lasted until the early 1950s. On 11 March 1990, a year before the formal dissolution of the Soviet Union, Lithuania became the first Soviet republic to break away when it proclaimed the restoration of its independence. Lithuania is a developed country with a high-income and an advanced economy ranking very high in Human Development Index.'''
context2 = 'Lithuania ranks highly in digital infrastructure,[25][26] press freedom and happiness.[27] Lithuania is a member of the United Nations, the European Union, the Council of Europe, the Council of the Baltic Sea States, the Eurozone, the Nordic Investment Bank, the International Monetary Fund, the Schengen Agreement, NATO, OECD and the World Trade Organization. It also participates in the Nordic-Baltic Eight (NB8) regional co-operation format.'

question = 'When Lithuania declared its independence and from whom?'
question = 'What is a capital of Lithuania?'
text = f'''
Context: {context.strip()}. Question: {question}. Answer: 
'''
text = f'''
<context>{context.strip()}</context> <question>{question}</question> <answer> The capital of Lithuania is
'''
text = text.strip()
title = 'Lithuania'

In [173]:
toks_in = tkz(text).input_ids
toks_in = torch.tensor(toks_in, dtype=torch.long, device=device)
toks_out = model.gen_on_wiki(toks_in)
s_out = tkz.decode(toks_out.squeeze())
s_out

'[CLS] in addition, the central economic development commission ( cei ) of the european union, which was established in 1946, is a separate independent state. these countries are not a separate state, but an entirely administrative unit of the voivodeship, the administrative region and'

In [171]:
tkz.decode(toks_in)

"[CLS] < context > lithuania, [ b ] officially the republic of lithuania, [ c ] is a country in the baltic region of europe. [ d ] it is one of three baltic states and lies on the eastern shore of the baltic sea, bordered by latvia to the north, belarus to the east and south, poland to the south, and the russian semi - exclave of kaliningrad oblast to the southwest, with a maritime border with sweden to the west. lithuania covers an area of 65, 300 km2 ( 25, 200 sq mi ), with a population of 2. 89 million. its capital and largest city is vilnius ; other major cities include kaunas, klaipeda, siauliai and panevezys. lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of balts and speak lithuanian. for millennia, the southeastern shores of the baltic sea were inhabited by various baltic tribes. in the 1230s, lithuanian lands were united for the first time by mindaugas, who formed the kingdom of lithuania on 6 july 

### Try nltk for sentence splitting

In [8]:
import nltk
from nltk.tokenize import sent_tokenize
from timeit import default_timer as timer

nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /home/misha/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [101]:
item = ds_wiki[0]
txt = item['text']
s = txt.split('\n')[0]
print(len(txt), len(s), s)

43985 559 Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.


In [97]:
sents = sent_tokenize(s)
for sent in sents:
    print(sent)

Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.
Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.
As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.


In [100]:
N = 1000
t1 = timer()
for i in range(N):
    sents = sent_tokenize(txt)
t2 = timer()
delta = (t2 - t1) / N
print(f'sent_tokenize avg time: {delta:0.6f}')


sent_tokenize avg time: 0.005308


## Squad v2 dataset
### Loading

In [9]:
df_squad = get_squadv2_df(exclude_empty_answers=True)
print(len(df_squad))
print(df_squad.head())

Remove empty answers from dataset squad_v2. Size: 142192 --> 92749
92749
                         id    title  \
0  56be85543aeaaa14008c9063  Beyoncé   
1  56be85543aeaaa14008c9065  Beyoncé   
2  56be85543aeaaa14008c9066  Beyoncé   
3  56bf6b0f3aeaaa14008c9601  Beyoncé   
4  56bf6b0f3aeaaa14008c9602  Beyoncé   

                                             context  \
0  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   
1  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   
2  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   
3  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   
4  Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...   

                                            question  \
0           When did Beyonce start becoming popular?   
1  What areas did Beyonce compete in when she was...   
2  When did Beyonce leave Destiny's Child and bec...   
3      In what city and state did Beyonce  grow up?    
4         In which decade did Beyonce become famous?   

            

In [10]:
from mllm.data.itsquadv2 import QnaItemV2, QnaBatchV2


def get_qna_batch(i_batch: int, batch_size: int = batch_size) -> QnaBatchV2:
    i1 = i_batch * batch_size
    i2 = i1 + batch_size
    items = []
    for i in range(i1, i2):
        row = df_squad.iloc[i]
        answer = row.answers['text'][0]
        qna_item = QnaItemV2(
            tkz=tkz, ind=i, context=row.context, question=row.question, answer=answer, max_ctx_toks=n_toks_max, max_que_toks=model_cfg.max_inp_toks,
            max_ans_toks=model_cfg.max_out_toks,
        )
        items.append(qna_item)
    batch = QnaBatchV2(
        items=items, max_inp_len=n_toks_max, max_out_len=model_cfg.max_out_toks, device=device,
    )
    return batch

### QnA inference

In [15]:
i_batch = 33
batch = get_qna_batch(i_batch)
# [n_batch, max_len]
b_ctx_toks, b_que_toks, b_ans_toks, b_cq_toks = batch.get_tensors()
b_ctx_toks.shape, b_que_toks.shape, b_ans_toks.shape, b_cq_toks.shape

(torch.Size([5, 198]),
 torch.Size([5, 18]),
 torch.Size([5, 6]),
 torch.Size([5, 219]))

In [16]:
for i in range(batch_size):
    ctx_toks, que_toks, ans_toks, cq_toks = b_ctx_toks[i], b_que_toks[i], b_ans_toks[i], b_cq_toks[i]
    tgt_toks = ans_toks
    toks_inp = cq_toks
    toks_out = model.gen_on_qna(ctx_toks=ctx_toks, que_toks=que_toks, cq_toks=cq_toks)
    toks_out = toks_out.squeeze(0)
    s_inp = tkz.decode(toks_inp)
    s_out = tkz.decode(toks_out)
    s_tgt = tkz.decode(tgt_toks)
    print(f'{i:03d}. Inp: {s_inp}')
    print(f'{i:03d}. Tgt: {s_tgt}')
    print(f'{i:03d}. Out: {s_out}')
    # break

000. Inp: [CLS] her first acting role of 2006 was in the comedy film the pink panther starring opposite steve martin, grossing $ 158. 8 million at the box office worldwide. her second film dreamgirls, the film version of the 1981 broadway musical loosely based on the supremes, received acclaim from critics and grossed $ 154 million internationally. in it, she starred opposite jennifer hudson, jamie foxx, and eddie murphy playing a pop singer based on diana ross. to promote the film, beyonce released " listen " as the lead single from the soundtrack album. in april 2007, beyonce embarked on the beyonce experience, her first worldwide concert tour, visiting 97 venues and grossed over $ 24 million. [ note 1 ] beyonce conducted pre - concert food donation drives during six major stops in conjunction with her pastor at st. john's and america's second harvest. at the same time, b'day was re - released with five additional songs, including her duet with shakira " beautiful liar ". [SEP] how m

### Custom text

In [22]:
context = '''
Lithuania,[b] officially the Republic of Lithuania,[c] is a country in the Baltic region of Europe.[d] It is one of three Baltic states and lies on the eastern shore of the Baltic Sea, bordered by Latvia to the north, Belarus to the east and south, Poland to the south, and the Russian semi-exclave of Kaliningrad Oblast to the southwest, with a maritime border with Sweden to the west. Lithuania covers an area of 65,300 km2 (25,200 sq mi), with a population of 2.89 million. Its capital and largest city is Vilnius; other major cities include Kaunas, Klaipėda, Šiauliai and Panevėžys. Lithuanians who are the titular nation and form the majority of the country's population, belong to the ethnolinguistic group of Balts and speak Lithuanian. For millennia, the southeastern shores of the Baltic Sea were inhabited by various Baltic tribes. In the 1230s, Lithuanian lands were united for the first time by Mindaugas, who formed the Kingdom of Lithuania on 6 July 1253. Subsequent expansion and consolidation resulted in the Grand Duchy of Lithuania, which by the 14th century was the largest country in Europe. In 1386, the Grand Duchy entered into a de facto personal union with the Crown of the Kingdom of Poland. The two realms were united into the bi-confederal Polish-Lithuanian Commonwealth in 1569, forming one of the largest and most prosperous states in Europe. The Commonwealth lasted more than two centuries, until neighbouring countries gradually dismantled it between 1772 and 1795, with the Russian Empire annexing most of Lithuania's territory. Towards the end of World War I, Lithuania declared independence in 1918, founding the modern Republic of Lithuania. In World War II, Lithuania was occupied by the Soviet Union, then by Nazi Germany, before being reoccupied by the Soviets in 1944. Lithuanian armed resistance to the Soviet occupation lasted until the early 1950s. On 11 March 1990, a year before the formal dissolution of the Soviet Union, Lithuania became the first Soviet republic to break away when it proclaimed the restoration of its independence. Lithuania is a developed country with a high-income and an advanced economy ranking very high in Human Development Index.'''
context2 = 'Lithuania ranks highly in digital infrastructure,[25][26] press freedom and happiness.[27] Lithuania is a member of the United Nations, the European Union, the Council of Europe, the Council of the Baltic Sea States, the Eurozone, the Nordic Investment Bank, the International Monetary Fund, the Schengen Agreement, NATO, OECD and the World Trade Organization. It also participates in the Nordic-Baltic Eight (NB8) regional co-operation format.'

question = 'When Lithuania declared its independence?'
question = 'What is an area of Lithuania?'
# question = 'What is a population of Lithuania?'
question = 'What is a capital of Lithuania?'
text = f'''
[CLS] {context.strip()} [SEP] {question} [SEP]: 
'''
text = text.strip()

In [23]:
toks = tkz('hi', add_special_tokens=True).input_ids
toks, toks[0] == tkz.cls_token_id, toks[-1] == tkz.sep_token_id

([101, 7632, 102], True, True)

In [24]:
toks_in = tkz(text).input_ids
toks_in = torch.tensor(toks_in, dtype=torch.long, device=device)
ctx_toks = tkz(context, return_tensors='pt', add_special_tokens=False).input_ids
que_toks = tkz(question, return_tensors='pt', add_special_tokens=False).input_ids
toks_out = model.gen_on_qna(ctx_toks=ctx_toks, que_toks=que_toks)
s_out = tkz.decode(toks_out.squeeze())
s_out

'[CLS] vilnius, [SEP]'

In [20]:
len(toks_in), len(toks_out)

(472, 1)

In [21]:
torch.rand((3, 2))

tensor([[0.0345, 0.3600],
        [0.7277, 0.2924],
        [0.9627, 0.5792]])