In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer, AutoTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import ENCDEC_BERT_MODEL_CFG_FNAME
from mllm.model.encdec_ranker_hg import EncdecBert
from mllm.config.model import EncdecBertCfg


In [81]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
# WIKI_DS_NAME = '20200501.en'
WIKI_DS_NAME = '20220301.en'

TRAIN_ENCDEC_BERT_PATH = DATA_PATH / 'train_mllm_encdec_bert'
# encdec_subdir = 'encdecbert-20250131_223521-bert-base-uncased-d768-emb_cls-inp128-lrs7x1-enh_mmbb-step2-h12-dp0-t0.0'
encdec_subdir = 'encdecbert-20250629_222704-bert-base-uncased-d768-emb_cls-inp128-lrs7x1-enh_mmbb-step2-h12-tgt_all-dp0-t0.0'
encdec_subdir = 'encdecbert-20250701_225013-bert-base-uncased-d768-emb_cls-inp128-lrs7x1-enh_mmbb-step2-h12-tgt_allmsk-dp0-t0.0'
encdec_subdir = 'encdecbert-20250703_225845-bert-base-uncased-d768-emb_cls-inp128-lrs7x1-enh_mmbb-step2-h12-tgt_mskseq-dp0-t0.0'
encdec_subdir = 'encdecbert-20250704_213735-bert-base-uncased-d768-emb_cls-inp128-lrs7x1-enh_mmbb-step2-h12-tgt_mskseq-dp0-t0.0'

encdec_train_path = TRAIN_ENCDEC_BERT_PATH / encdec_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
encdec_model_cfg_fpath = encdec_train_path / ENCDEC_BERT_MODEL_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [82]:
# dss = load_dataset('wikipedia', WIKI_DS_NAME, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
dss = load_dataset('wikipedia', WIKI_DS_NAME, cache_dir=str(DATA_PATH))
ds: Dataset = dss['train']
n_docs = len(ds)
print(f'Wikipedia {WIKI_DS_NAME} docs: {n_docs}')

Loading dataset shards:   0%|          | 0/41 [00:02<?, ?it/s]

Wikipedia 20220301.en docs: 6458670


In [83]:
model_cfg = parse_yaml_file_as(EncdecBertCfg, encdec_model_cfg_fpath)
tkz = AutoTokenizer.from_pretrained(model_cfg.enc_bert.pretrained_model_name)
print(model_cfg)
print(tkz)

enc_bert=EncBertCfg(inp_len=128, d_model=768, pad_token_id=0, pretrained_model_name='bert-base-uncased', tokenizer_name='bert-base-uncased', emb_type=<BertEmbType.Cls: 'cls'>, emb2_tok_name='') dec_pyr=DecPyrCfg(d_model=768, n_heads=12, d_k=64, d_v=64, d_inner=3072, inp_len=128, step=2, n_layers=7, dropout_rate=0.0, n_vocab=30522, n_similar_layers=1, enhance_type=<HgEnhanceType.MatmulBeginBias: 'mmbb'>, temperature=0.0)
BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]",

In [84]:
chkpt = torch.load(encdec_snapshot_fpath, map_location=device)
model = EncdecBert(model_cfg).to(device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
del chkpt
model.eval()
None

In [85]:
inp_len = model_cfg.enc_bert.inp_len
print('inp_len:', inp_len)

def get_batch_tokens(doc_inds: list[int], randomize: bool = False, mask_toks_len_max: int = 5) -> torch.Tensor:
    docs_toks = np.full((len(doc_inds), inp_len), tkz.pad_token_id)
    for i, doc_ind in enumerate(doc_inds):
        doc = ds[doc_ind]
        title, text = doc['title'], doc['text']
        doc_txt = f'{title} {text}'
        doc_txt = text
        doc_txt = doc_txt.lower()
        doc_toks: list[int] | np.ndarray = tkz(doc_txt)['input_ids']
        n_toks = len(doc_toks)
        if n_toks > inp_len:
            i_off = np.random.randint(1, n_toks - inp_len + 1) if randomize else 1
            doc_toks = np.concatenate([doc_toks[:1], doc_toks[i_off:i_off + inp_len - 2], doc_toks[-1:]])
            # print(doc_toks)

        mask_toks_len = np.random.randint(1, mask_toks_len_max + 1)
        mask_toks_off = np.random.randint(len(doc_toks) - mask_toks_len + 1)
        masked_toks = doc_toks[mask_toks_off:mask_toks_off + mask_toks_len]
        masked_substr = tkz.decode(masked_toks)
        print(f'{doc_ind:03d}. {masked_substr}')
        doc_toks[mask_toks_off:mask_toks_off + mask_toks_len] = tkz.mask_token_id

        docs_toks[i, :len(doc_toks)] = doc_toks
    docs_toks_t = torch.from_numpy(docs_toks).to(device)
    return docs_toks_t


inp_len: 128


In [86]:
doc_inds = np.arange(5)
# doc_inds += 5
doc_inds = [x.item() for x in doc_inds]
for doc_ind in doc_inds:
    doc = ds[doc_ind]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{doc_ind:03d} "{title}" {text[:400]}')

000 "Anarchism" Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian
001 "Autism" Autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. Parents often notice signs during the first three years of their child's life. These signs often develop gradually, though some autistic children experience regression in their communication and social skills after reaching developmental milest
002 "Albedo" Albedo (; ) is the measure of the diffuse reflection of solar radiation out of the total solar radiation and measured on a scale from 0, corresponding to a 

In [92]:
docs_toks_in = get_batch_tokens(doc_inds)
print(docs_toks_in[0])
print(docs_toks_in.shape)
logits_pred = model(docs_toks_in, docs_toks_in != tkz.pad_token_id)
probs_pred = torch.softmax(logits_pred, dim=-1)
# probs_pred = torch.sigmoid(logits_pred)
print(probs_pred.shape)
docs_toks_out = torch.argmax(probs_pred, dim=-1)
print(docs_toks_out.shape)

000. formal hierarch
001. first
002. albedo (
003. written
004. yellowhammer state,
tensor([  101,  9617, 11140,  2964,  2003,  1037,  2576,  4695,  1998,  2929,
         2008,  2003,  8040, 23606,  7476,  1997,  3691,  1998, 19164,  2035,
        26097,  1010, 24873, 11890,  3512,  3596,  1997, 12571,  1012,  9617,
        11140,  2964,  4455,  2005,  1996, 15766,  1997,  1996,  2110,  1010,
         2029,  2009,  4324,  2000,  2022, 14203,  1010,  6151,  2229,  7895,
         3468,  1010,  1998, 17631,  1012,  2004,  1037,  7145,  2187,  1011,
         3358,  2929,  1010,  2872,  2006,  1996,  2521, 20515,  2187,  1997,
         1996,  2576,  8674,  1010,  2009,  2003,  2788,  2649,  4077, 15029,
         2964,  1998, 19297, 27255,  2004,  1996, 19297,  3358,  1006, 19297,
        14649,  1007,  1997,  1996,  6102,  2929,  1010,  1998,  2038,  1037,
         2844,  3439,  2523,  2007,  3424,  1011, 16498,  1998, 14649,  1012,
         4286,  2973,  1999,  8384,  2302,   103,   103,  

In [93]:
for i, doc_ind in enumerate(doc_inds):
    s = tkz.decode(docs_toks_in[i])
    s = s.replace('\n', '\\n')
    print(f'{doc_ind:03d} {s}')

000 [CLS] anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. as a historically left - wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian marxism as the libertarian wing ( libertarian socialism ) of the socialist movement, and has a strong historical association with anti - capitalism and socialism. humans lived in societies without [MASK] [MASK] [MASK] [MASK]ies long before the establishment of formal states [SEP]
001 [CLS] autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior. parents often notice signs during the [MASK] three years of their child ' s life. these signs often develop gradually, though some autistic 

In [94]:
for i, doc_ind in enumerate(doc_inds):
    s = tkz.decode(docs_toks_out[i])
    s = s.replace('\n', '\\n')
    print(f'{doc_ind:03d} {s}')

000 the, andrel [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
001 past two [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

## Encoder embedding evaluation

In [16]:
def get_tokens(txts: list[str]) -> torch.Tensor:
    batch_toks = np.full((len(txts), inp_len), tkz.pad_token_id)
    for i, txt in enumerate(txts):
        toks: list[int] = tkz(txt)['input_ids']
        n_toks = len(toks)
        if n_toks > inp_len:
            i_off = np.random.randint(n_toks - inp_len + 1)
            toks = toks[i_off:i_off + inp_len]
        batch_toks[i, :len(toks)] = toks
    batch_toks_t = torch.from_numpy(batch_toks).to(device)
    return batch_toks_t

model.eval()
None

In [17]:
txts = [
    '"Orana Australia Ltd" Orana Australia Ltd is a not-for-profit organisation that provides a diverse range of training and support services to over 650 people with disabilities and their families in South Australia.\n\nHistory\nThe Mentally Retarded Children’s Society of SA Inc. was established in 1950 by a group of parent',
    'Australia',
    'Orana Australia Ltd',
    'Hello Kitty',
]
batch_toks = get_tokens(txts)
embs = model.enc_bert(batch_toks, batch_toks != tkz.pad_token_id)
# embs = embs.detach().cpu().numpy()
embs = embs.detach().cpu()
print(embs.shape)

torch.Size([4, 768])


In [18]:
for i in range(1, len(embs)):
    cos_dist = F.cosine_similarity(embs[0:1], embs[i:i + 1])
    norm_dist = torch.norm(embs[0] - embs[i])
    print(txts[i], cos_dist.numpy(), norm_dist)

Australia [0.1348315] tensor(9.7678)
Orana Australia Ltd [0.14778207] tensor(10.4330)
Hello Kitty [0.08750954] tensor(10.7561)


## Answering questions

In [40]:
ctx = 'Context3. Like other American research universities, Northwestern was transformed by World War II. Franklyn B. Snyder led the university from 1939 to 1949, when nearly 50,000 military officers and personnel were trained on the Evanston and Chicago campuses. After the war, surging enrollments under the G.I. Bill drove drastic expansion of both campuses. In 1948 prominent anthropologist Melville J. '
q = 'Question: Between 1939 and 1949, how many military officers and personnel were trained on the Evanston and Chicago campuses? Answer: [MASK] [MASK] [MASK] [MASK]'
s = f'{ctx} {q}'
toks = tkz(s).input_ids
len(toks)

109

In [41]:
toks_in = torch.tensor(toks).to(device).unsqueeze(0)
logits_pred = model(toks_in, toks_in != tkz.pad_token_id)
print(logits_pred.shape)
probs_pred = torch.softmax(logits_pred[0], dim=-1)
toks_out = torch.argmax(probs_pred, dim=-1)
toks_out

torch.Size([1, 128, 30522])


tensor([  101,  6123,  2509,  1012,  2066,  2060,  2137,  2470,  5534,  1010,
         7855,  2001,  8590,  2011,  2088,  2162,  2462,  1012, 19597,  2078,
         1038,  1012, 17840,  2419,  1996,  2118,  2013,  3912,  2000,  4085,
         1010,  2043,  3053,  2753,  1010,  2199,  2510,  3738,  1998,  5073,
         2020,  4738,  2006,  1996,  6473,  2669,  1998,  3190, 13696,  1012,
         2044,  1996,  2162,  1010,  7505,  4726, 10316,  2015,  2104,  1996,
         1043,  1012,  1045,  1012,  3021,  6303, 20851,  4935,  1997,  2119,
        13696,  1012,  1999,  3882,  4069, 21571, 20154,  1046,  1012,  3160,
         1024,  2090,  3912,  1998,  4085,  1010,  2129,  2116,  2510,  3738,
         1998,  5073,  2020,  4738,  2006,  1996,  6473,  2669,  1998,  3190,
        13696,  1029,  3437,  1024,  1000,  1024,  2137,  2015,   102,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0])

In [42]:
tkz.decode(toks_out)

'[CLS] context3. like other american research universities, northwestern was transformed by world war ii. franklyn b. snyder led the university from 1939 to 1949, when nearly 50, 000 military officers and personnel were trained on the evanston and chicago campuses. after the war, surging enrollments under the g. i. bill demanded ample expansion of both campuses. in 1948 prominent anthropologist melville j. question : between 1939 and 1949, how many military officers and personnel were trained on the evanston and chicago campuses? answer : " : americans [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [34]:
probs_pred[:10]

tensor([1.0000e+00, 1.1139e-12, 6.2658e-13, 1.4770e-12, 1.5230e-12, 1.4352e-12,
        1.0220e-12, 1.1165e-12, 1.0452e-12, 1.1000e-12],
       grad_fn=<SliceBackward0>)

In [31]:
probs_pred.shape

torch.Size([5, 128, 30522])

In [19]:
t = torch.rand(2, 3, 4)
t

tensor([[[0.6792, 0.0907, 0.6140, 0.9981],
         [0.2409, 0.8605, 0.0702, 0.2925],
         [0.1105, 0.1167, 0.4092, 0.3401]],

        [[0.3772, 0.2479, 0.2412, 0.0046],
         [0.8316, 0.8350, 0.4125, 0.5759],
         [0.0096, 0.6448, 0.9033, 0.0755]]])

In [22]:
t1 = torch.softmax(t, dim=2)
print(t1)
t1.shape

tensor([[[0.2586, 0.1435, 0.2422, 0.3557],
         [0.2103, 0.3908, 0.1773, 0.2215],
         [0.2168, 0.2182, 0.2923, 0.2728]],

        [[0.2906, 0.2554, 0.2537, 0.2002],
         [0.2911, 0.2921, 0.1914, 0.2254],
         [0.1563, 0.2949, 0.3819, 0.1669]]])


torch.Size([2, 3, 4])

In [27]:
t2 = torch.randint(0, 4, (2, 3, 1))
t2

tensor([[[1],
         [1],
         [3]],

        [[2],
         [2],
         [1]]])

In [28]:
torch.gather(t1, dim=2, index=t2)

tensor([[[0.1435],
         [0.3908],
         [0.2728]],

        [[0.2537],
         [0.1914],
         [0.2949]]])