In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_HG_MODEL_CFG_FNAME
from mllm.model.encdec_ranker_hg import EncdecHg
from mllm.config.model import TokenizerCfg, EncdecHgCfg
from mllm.tokenization.chunk_tokenizer import tokenizer_from_config




In [60]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'

TRAIN_ENCDEC_HG_PATH = DATA_PATH / 'train_mllm_encdec_hg'
encdec_subdir = 'encdechg-20241216_224415-inp128-pos_emb-lrs7x1-rdc_avg-enh_mmbeg-step2-d512-h8-t1'
encdec_subdir = 'encdechg-20250107_232630-inp128-pos_emb-lrs7x1-rdc_avg-enh_mmbeg-step2-d768-h12-dp0-t0'

encdec_train_path = TRAIN_ENCDEC_HG_PATH / encdec_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
encdec_model_cfg_fpath = encdec_train_path / ENCDEC_HG_MODEL_CFG_FNAME
encdec_tkz_cfg_fpath = encdec_train_path / TOKENIZER_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [61]:
dss = load_dataset('wikipedia', WIKI_DS_NAME, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
ds: Dataset = dss['train']
n_docs = len(ds)
print(f'Wikipedia {WIKI_DS_NAME} docs: {n_docs}')

Reusing dataset wikipedia (/home/misha/data/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

Wikipedia 20200501.en docs: 6078422


In [62]:
tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_tkz_cfg_fpath)
tkz = tokenizer_from_config(tkz_cfg)
model_cfg = parse_yaml_file_as(EncdecHgCfg, encdec_model_cfg_fpath)
inp_len = model_cfg.enc_pyr.inp_len
pad_tok = tkz_cfg.custom_tokens['pad'].ind

In [63]:
chkpt = torch.load(encdec_snapshot_fpath, map_location=device)
model = EncdecHg(model_cfg).to(device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
model.eval()

EncdecHg(
  (enc_pyr): EncoderPyramid(
    (vocab_encoder): VocabEncoder(
      (src_word_emb): Embedding(50271, 768, padding_idx=50267)
      (position_enc): Embedding(128, 768)
      (dropout): Dropout(p=0.0, inplace=False)
      (layer_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    )
    (enc_layers): ModuleList(
      (0-6): 7 x EncoderLayer(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=768, out_features=768, bias=False)
          (w_ks): Linear(in_features=768, out_features=768, bias=False)
          (w_vs): Linear(in_features=768, out_features=768, bias=False)
          (fc): Linear(in_features=768, out_features=768, bias=False)
          (attention): ScaledDotProductAttention(
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (layer_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        )
        (pos_ffn): PositionwiseFeedForward(
         

In [64]:
def get_batch_tokens(doc_inds: list[int], randomize: bool = False) -> torch.Tensor:
    docs_toks = np.full((len(doc_inds), inp_len), pad_tok)
    for i, doc_ind in enumerate(doc_inds):
        doc = ds[doc_ind]
        title, text = doc['title'], doc['text']
        doc_txt = f'{title} {text}'
        doc_toks: list[int] = tkz(doc_txt)['input_ids']
        n_toks = len(doc_toks)
        if n_toks > inp_len:
            i_off = np.random.randint(n_toks - inp_len + 1) if randomize else 0
            doc_toks = doc_toks[i_off:i_off + inp_len]
        docs_toks[i, :len(doc_toks)] = doc_toks
    docs_toks_t = torch.from_numpy(docs_toks).to(device)
    return docs_toks_t


In [65]:
doc_inds = np.arange(5)
# doc_inds += 5
doc_inds = [x.item() for x in doc_inds]
for doc_ind in doc_inds:
    doc = ds[doc_ind]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{doc_ind:03d} "{title}" {text[:300]}')

000 "Yangliuqing" Yangliuqing () is a market town in Xiqing District, in the western suburbs of Tianjin, People's Republic of China. Despite its relatively small size, it has been named since 2006 in the "famous historical and cultural market towns in China".\n\nIt is best known in China for creating nianhua or Yangl
001 "Orana Australia Ltd" Orana Australia Ltd is a not-for-profit organisation that provides a diverse range of training and support services to over 650 people with disabilities and their families in South Australia.\n\nHistory\nThe Mentally Retarded Children’s Society of SA Inc. was established in 1950 by a group of parent
002 "St. Mary's Church, Sønderborg" The St. Mary's Church is a church owned by the Church of Denmark in Sønderborg, Denmark and the church of the parish with the same name. Thanks to its location on a hill, the church building is very iconic for the city.\n\nHistory \nIn the Middle Ages there was a leper colony on a hill just outside 
003 "Kalitta" Kal

In [66]:
docs_toks_in = get_batch_tokens(doc_inds)
logits_pred = model(docs_toks_in)
probs_pred = torch.softmax(logits_pred, dim=-1)
# probs_pred = torch.sigmoid(logits_pred)
print(probs_pred.shape)
docs_toks_out = torch.argmax(probs_pred, dim=-1)
print(docs_toks_out.shape)

torch.Size([5, 128, 50271])
torch.Size([5, 128])


In [67]:
for i, doc_ind in enumerate(doc_inds):
    s = tkz.decode(docs_toks_out[i])
    s = s.replace('\n', '\\n')
    print(f'{doc_ind:03d} {s}')


000 Yang Tianuqing Yangliuqing () is a market town in Tianqing District, in the western suburbs of Tianjin, People's Republic of China. Despite its relatively small size, it has been named since 2006 in the "largest historical and cultural cultural newspaper in China".\n\nIt is best known in China for about Lianhua or Yangliuqing Lianhua. For more than 400 years, Yangliuqing has in recent special success in the middle of some wood mines for the New countries.  Modern historical prints using colorful carscunges to support traditional pictures of children's paintings
001  Cookana Australia Singapore Orana Service Singapore is a not-for-profit organisation that provides a diverse range of training and support services to over cultural groups with disabilities and their families in South Australia.\n\nHistory\nThe Visally Outitable Children’s Office of Health Inc. was established in 1950 by a group of parents who provide education, community and cultural opportunities for their children wi

## Training gradients calculation

In [36]:
def encdec_prob_loss_softmax(logits_pred: torch.Tensor, tokens_gt: torch.Tensor) -> torch.Tensor:
    tokens_gt = tokens_gt.to(torch.int64).unsqueeze(-1)
    probs_pred = torch.softmax(logits_pred, dim=-1)
    probs_gt = torch.gather(probs_pred, dim=2, index=tokens_gt)
    loss = -torch.mean(torch.log(probs_gt))
    return loss

loss_fn = encdec_prob_loss_softmax

In [81]:
docs_batch_size = 5


def get_batch_tokens(doc_inds: list[int]) -> torch.Tensor:
    docs_toks = np.full((len(doc_inds), inp_len), pad_tok)
    for i, doc_ind in enumerate(doc_inds):
        doc = ds[doc_ind]
        title, text = doc['title'], doc['text']
        doc_txt = f'{title} {text}'
        doc_toks = tkz(doc_txt)['input_ids']
        n_toks = len(doc_toks)
        if n_toks > inp_len:
            i_off = np.random.randint(n_toks - inp_len + 1)
            doc_toks = doc_toks[i_off:i_off + inp_len]
        docs_toks[i, :len(doc_toks)] = doc_toks
    docs_toks_t = torch.from_numpy(docs_toks).to(device)
    return docs_toks_t

def get_batch(inds: list[int], i_batch: int) -> tuple[torch.Tensor, int]:
    i1 = i_batch * docs_batch_size
    i2 = i1 + docs_batch_size
    batch_inds = inds[i1:i2]
    rest_batch_size = docs_batch_size - len(batch_inds)
    if rest_batch_size > 0:
        batch_inds = batch_inds + inds[:rest_batch_size * docs_batch_size]
    if i2 >= len(batch_inds):
        i_batch = 0
        np.random.shuffle(inds)
    batch_toks = get_batch_tokens(batch_inds)
    return batch_toks, i_batch

docs_inds = list(range(len(ds)))


In [82]:
tokens_inp, _ = get_batch(docs_inds, 0)
print(tokens_inp.dtype, tokens_inp.shape)

torch.int64 torch.Size([5, 128])


In [17]:
model.train()
out_logits = model(tokens_inp)
loss = loss_fn(out_logits, tokens_inp)
loss.backward()

In [23]:
def calc_params_grads_stats(params: torch.nn.Parameter) -> tuple[tuple[float, float], Optional[tuple[float, float]]]:
    gres = None
    pres = params.mean().detach().cpu().item(), params.std().detach().cpu().item()
    if params.grad is not None:
        gres = params.grad.mean().detach().cpu().item(), params.grad.std().detach().cpu().item()
    return pres, gres

for pname, params in model.named_parameters():
    pms, gms = calc_params_grads_stats(params)
    print(pname, pms, gms)

enc_pyr.vocab_encoder.src_word_emb.weight (-2.7574906198424287e-06, 1.0002703666687012) (1.5332831592519735e-15, 2.1249656128929928e-05)
enc_pyr.vocab_encoder.layer_norm.weight (0.9618587493896484, 0.01780407875776291) (8.019902452360839e-06, 0.004944556392729282)
enc_pyr.vocab_encoder.layer_norm.bias (-0.0004760617157444358, 0.009908275678753853) (7.112050661817193e-05, 0.011260643601417542)
enc_pyr.enc_layers.0.slf_attn.w_qs.weight (3.763933273148723e-05, 0.03625566139817238) (4.7372861899930285e-07, 0.0033385928254574537)
enc_pyr.enc_layers.0.slf_attn.w_ks.weight (-0.0001073815074050799, 0.03647030144929886) (3.275928293078323e-07, 0.0034606929402798414)
enc_pyr.enc_layers.0.slf_attn.w_vs.weight (8.182486635632813e-05, 0.04236489534378052) (1.4093603795117815e-07, 0.002324572065845132)
enc_pyr.enc_layers.0.slf_attn.fc.weight (-7.319550786633044e-05, 0.04345853999257088) (2.2737367544323206e-13, 0.0012499691220000386)
enc_pyr.enc_layers.0.slf_attn.layer_norm.weight (1.001289129257202

## Encoder embedding evaluation

In [68]:
def get_tokens(txts: list[str]) -> torch.Tensor:
    batch_toks = np.full((len(txts), inp_len), pad_tok)
    for i, txt in enumerate(txts):
        toks: list[int] = tkz(txt)['input_ids']
        n_toks = len(toks)
        if n_toks > inp_len:
            i_off = np.random.randint(n_toks - inp_len + 1)
            toks = toks[i_off:i_off + inp_len]
        batch_toks[i, :len(toks)] = toks
    batch_toks_t = torch.from_numpy(batch_toks).to(device)
    return batch_toks_t

model.eval()
None

In [81]:
txts = [
    '"Orana Australia Ltd" Orana Australia Ltd is a not-for-profit organisation that provides a diverse range of training and support services to over 650 people with disabilities and their families in South Australia.\n\nHistory\nThe Mentally Retarded Children’s Society of SA Inc. was established in 1950 by a group of parent',
    'Australia',
    'Orana Australia Ltd',
    'Hello Kitty',
]
batch_toks = get_tokens(txts)
embs = model.enc_pyr(batch_toks)
# embs = embs.detach().cpu().numpy()
embs = embs.detach().cpu()
print(embs.shape)

torch.Size([4, 1, 768])


In [82]:
for i in range(1, len(embs)):
    cos_dist = F.cosine_similarity(embs[0], embs[i])
    norm_dist = torch.norm(embs[0] - embs[i])
    print(txts[i], cos_dist.numpy(), norm_dist)

Australia [0.22261377] tensor(32.4214)
Orana Australia Ltd [0.16103858] tensor(33.4885)
Hello Kitty [0.17609353] tensor(33.3606)
