In [1]:
%load_ext autoreload
%autoreload 2

In [36]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer, AutoTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import ENCDEC_BERT_MODEL_CFG_FNAME
from mllm.model.encdec_ranker_hg import EncdecBert
from mllm.config.model import EncdecBertCfg


In [49]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'

TRAIN_ENCDEC_BERT_PATH = DATA_PATH / 'train_mllm_encdec_bert'
# encdec_subdir = 'encdecbert-20250126_212805-bert-base-uncased-d768-emb_cls-inp128-lrs7x1-enh_mmbb-step2-h12-dp0-t0.0'
encdec_subdir = 'encdecbert-20250131_223521-bert-base-uncased-d768-emb_cls-inp128-lrs7x1-enh_mmbb-step2-h12-dp0-t0.0'

encdec_train_path = TRAIN_ENCDEC_BERT_PATH / encdec_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
encdec_model_cfg_fpath = encdec_train_path / ENCDEC_BERT_MODEL_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [50]:
dss = load_dataset('wikipedia', WIKI_DS_NAME, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
ds: Dataset = dss['train']
n_docs = len(ds)
print(f'Wikipedia {WIKI_DS_NAME} docs: {n_docs}')

Reusing dataset wikipedia (/home/misha/data/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

Wikipedia 20200501.en docs: 6078422


In [51]:
model_cfg = parse_yaml_file_as(EncdecBertCfg, encdec_model_cfg_fpath)
tkz = AutoTokenizer.from_pretrained(model_cfg.enc_bert.pretrained_model_name)
print(model_cfg)
print(tkz)

enc_bert=EncBertCfg(inp_len=128, d_model=768, pretrained_model_name='bert-base-uncased', tokenizer_name='bert-base-uncased', emb_type=<BertEmbType.Cls: 'cls'>) dec_pyr=DecPyrCfg(d_model=768, n_heads=12, d_k=64, d_v=64, d_inner=3072, inp_len=128, step=2, n_layers=7, dropout_rate=0.0, n_vocab=30522, n_similar_layers=1, enhance_type=<HgEnhanceType.MatmulBeginBias: 'mmbb'>, temperature=0.0)
BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, sing

In [52]:
chkpt = torch.load(encdec_snapshot_fpath, map_location=device)
model = EncdecBert(model_cfg).to(device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
model.eval()

EncdecBert(
  (enc_bert): EncoderBert(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=Tru

In [53]:
inp_len = model_cfg.enc_bert.inp_len
print('inp_len:', inp_len)

def get_batch_tokens(doc_inds: list[int], randomize: bool = False) -> torch.Tensor:
    docs_toks = np.full((len(doc_inds), inp_len), tkz.pad_token_id)
    for i, doc_ind in enumerate(doc_inds):
        doc = ds[doc_ind]
        title, text = doc['title'], doc['text']
        doc_txt = f'{title} {text}'
        doc_txt = text
        doc_txt = doc_txt.lower()
        doc_toks: list[int] | np.ndarray = tkz(doc_txt)['input_ids']
        n_toks = len(doc_toks)
        if n_toks > inp_len:
            i_off = np.random.randint(1, n_toks - inp_len + 1) if randomize else 1
            doc_toks = np.concatenate([doc_toks[:1], doc_toks[i_off:i_off + inp_len - 2], doc_toks[-1:]])
            # print(doc_toks)
        docs_toks[i, :len(doc_toks)] = doc_toks
    docs_toks_t = torch.from_numpy(docs_toks).to(device)
    return docs_toks_t


inp_len: 128


In [54]:
doc_inds = np.arange(5)
# doc_inds += 5
doc_inds = [x.item() for x in doc_inds]
for doc_ind in doc_inds:
    doc = ds[doc_ind]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{doc_ind:03d} "{title}" {text[:300]}')

000 "Yangliuqing" Yangliuqing () is a market town in Xiqing District, in the western suburbs of Tianjin, People's Republic of China. Despite its relatively small size, it has been named since 2006 in the "famous historical and cultural market towns in China".\n\nIt is best known in China for creating nianhua or Yangl
001 "Orana Australia Ltd" Orana Australia Ltd is a not-for-profit organisation that provides a diverse range of training and support services to over 650 people with disabilities and their families in South Australia.\n\nHistory\nThe Mentally Retarded Children’s Society of SA Inc. was established in 1950 by a group of parent
002 "St. Mary's Church, Sønderborg" The St. Mary's Church is a church owned by the Church of Denmark in Sønderborg, Denmark and the church of the parish with the same name. Thanks to its location on a hill, the church building is very iconic for the city.\n\nHistory \nIn the Middle Ages there was a leper colony on a hill just outside 
003 "Kalitta" Kal

In [55]:
docs_toks_in = get_batch_tokens(doc_inds)
logits_pred = model(docs_toks_in, docs_toks_in != tkz.pad_token_id)
probs_pred = torch.softmax(logits_pred, dim=-1)
# probs_pred = torch.sigmoid(logits_pred)
print(probs_pred.shape)
docs_toks_out = torch.argmax(probs_pred, dim=-1)
print(docs_toks_out.shape)

Token indices sequence length is longer than the specified maximum sequence length for this model (730 > 512). Running this sequence through the model will result in indexing errors


torch.Size([5, 128, 30522])
torch.Size([5, 128])


In [58]:
for i, doc_ind in enumerate(doc_inds):
    s = tkz.decode(docs_toks_in[i])
    s = s.replace('\n', '\\n')
    print(f'{doc_ind:03d} {s}')

000 [CLS] yangliuqing ( ) is a market town in xiqing district, in the western suburbs of tianjin, people's republic of china. despite its relatively small size, it has been named since 2006 in the " famous historical and cultural market towns in china ". it is best known in china for creating nianhua or yangliuqing nianhua. for more than 400 years, yangliuqing has in effect specialised in the creation of these woodcuts for the new year. wood block prints using vivid colourschemes to portray traditional scenes of children's games often interwoven with auspicious [SEP]
001 [CLS] orana australia ltd is a not - for - profit organisation that provides a diverse range of training and support services to over 650 people with disabilities and their families in south australia. history the mentally retarded children ’ s society of sa inc. was established in 1950 by a group of parents who wanted education, employment and accommodation opportunities for their children within the local community a

In [59]:
for i, doc_ind in enumerate(doc_inds):
    s = tkz.decode(docs_toks_out[i])
    s = s.replace('\n', '\\n')
    print(f'{doc_ind:03d} {s}')


000 [CLS] yangliuqing ( ) is a market town in xiqing district, in the western suburbs of tianjin, people's republic of china. despite its relatively small size, it has been named since 2006 in the " famous historical and cultural market towns in china ". it is best known in china for creating nianhua or yangliuqing nianhua. for more than 400 years, yangliuqing has in effect specialised in the creation of these woodcuts for the new year. wood block prints using vivid colourschemes to portray traditional scenes of children's games often interwoven with auspicious [SEP]
001 [CLS] orana australia ltd is a not - for - profit organisation that provides a diverse range of training and support services to over 650 people with disabilities and their families in south australia. history the mentally retarded children ’ s society of sa inc. was established in 1950 by a group of parents who wanted education, employment and accommodation opportunities for their children within the local community a

## Encoder embedding evaluation

In [46]:
def get_tokens(txts: list[str]) -> torch.Tensor:
    batch_toks = np.full((len(txts), inp_len), tkz.pad_token_id)
    for i, txt in enumerate(txts):
        toks: list[int] = tkz(txt)['input_ids']
        n_toks = len(toks)
        if n_toks > inp_len:
            i_off = np.random.randint(n_toks - inp_len + 1)
            toks = toks[i_off:i_off + inp_len]
        batch_toks[i, :len(toks)] = toks
    batch_toks_t = torch.from_numpy(batch_toks).to(device)
    return batch_toks_t

model.eval()
None

In [47]:
txts = [
    '"Orana Australia Ltd" Orana Australia Ltd is a not-for-profit organisation that provides a diverse range of training and support services to over 650 people with disabilities and their families in South Australia.\n\nHistory\nThe Mentally Retarded Children’s Society of SA Inc. was established in 1950 by a group of parent',
    'Australia',
    'Orana Australia Ltd',
    'Hello Kitty',
]
batch_toks = get_tokens(txts)
embs = model.enc_bert(batch_toks, batch_toks != tkz.pad_token_id)
# embs = embs.detach().cpu().numpy()
embs = embs.detach().cpu()
print(embs.shape)

torch.Size([4, 768])


In [48]:
for i in range(1, len(embs)):
    cos_dist = F.cosine_similarity(embs[0:1], embs[i:i + 1])
    norm_dist = torch.norm(embs[0] - embs[i])
    print(txts[i], cos_dist.numpy(), norm_dist)

Australia [0.13148455] tensor(9.9159)
Orana Australia Ltd [0.14719056] tensor(10.5426)
Hello Kitty [0.08668104] tensor(10.8450)
