In [1]:
%load_ext autoreload
%autoreload 2

In [49]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
from datasets.arrow_dataset import Dataset
import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_HG_MODEL_CFG_FNAME
from mllm.model.encdec_hg import EncdecHg
from mllm.config.model import TokenizerCfg, EncdecHgCfg
from mllm.tokenization.chunk_tokenizer import tokenizer_from_config


In [50]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
WIKI_DS_NAME = '20200501.en'

TRAIN_ENCDEC_HG_PATH = DATA_PATH / 'train_mllm_encdec_hg'
# encdec_subdir = 'encdechg-20241202_225258-ilen256-lrs8-step2-d256-h8'
encdec_subdir = 'encdechg-20241203_233923-ilen256-lrs8-step2-d256-h8'
encdec_subdir = 'encdechg-20241204_214436-ilen256-lrs8x2-step2-d256-h8'
encdec_subdir = 'encdechg-20241205_213010-ilen128-lrs7x1-step2-d256-h8'

encdec_train_path = TRAIN_ENCDEC_HG_PATH / encdec_subdir
encdec_snapshot_fpath = encdec_train_path / 'best.pth'
encdec_model_cfg_fpath = encdec_train_path / ENCDEC_HG_MODEL_CFG_FNAME
encdec_tkz_cfg_fpath = encdec_train_path / TOKENIZER_CFG_FNAME

device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [46]:
dss = load_dataset('wikipedia', WIKI_DS_NAME, beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
ds: Dataset = dss['train']
n_docs = len(ds)
print(f'Wikipedia {WIKI_DS_NAME} docs: {n_docs}')

Reusing dataset wikipedia (/home/misha/data/wikipedia/20200501.en/1.0.0/009f923d9b6dd00c00c8cdc7f408f2b47f45dd4f5fb7982a21f9448f4afbe475)


  0%|          | 0/1 [00:00<?, ?it/s]

Wikipedia 20200501.en docs: 6078422


In [47]:
tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_tkz_cfg_fpath)
tkz = tokenizer_from_config(tkz_cfg)
model_cfg = parse_yaml_file_as(EncdecHgCfg, encdec_model_cfg_fpath)
inp_len = model_cfg.enc_pyr.inp_len
pad_tok = tkz_cfg.custom_tokens['pad'].ind

In [52]:
chkpt = torch.load(encdec_snapshot_fpath, map_location=device)
model = EncdecHg(model_cfg).to(device)
strict = True
# strict = False
model.load_state_dict(chkpt['model'], strict=strict)
model.eval()

EncdecHg(
  (enc_pyr): EncoderPyramid(
    (vocab_encoder): VocabEncoder(
      (src_word_emb): Embedding(50271, 256, padding_idx=50267)
      (position_enc): PositionalEncoding()
      (dropout): Dropout(p=0.0, inplace=False)
      (layer_norm): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
    )
    (enc_layers): ModuleList(
      (0-6): 7 x EncoderLayer(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=False)
          (w_ks): Linear(in_features=256, out_features=256, bias=False)
          (w_vs): Linear(in_features=256, out_features=256, bias=False)
          (fc): Linear(in_features=256, out_features=256, bias=False)
          (attention): ScaledDotProductAttention(
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (layer_norm): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
        )
        (pos_ffn): PositionwiseFeedForward(
        

In [53]:
def get_batch_tokens(doc_inds: list[int], randomize: bool = False) -> torch.Tensor:
    docs_toks = np.full((len(doc_inds), inp_len), pad_tok)
    for i, doc_ind in enumerate(doc_inds):
        doc = ds[doc_ind]
        title, text = doc['title'], doc['text']
        doc_txt = f'{title} {text}'
        doc_toks: list[int] = tkz(doc_txt)['input_ids']
        n_toks = len(doc_toks)
        if n_toks > inp_len:
            i_off = np.random.randint(n_toks - inp_len + 1) if randomize else 0
            doc_toks = doc_toks[i_off:i_off + inp_len]
        docs_toks[i, :len(doc_toks)] = doc_toks
    docs_toks_t = torch.from_numpy(docs_toks).to(device)
    return docs_toks_t


In [54]:
doc_inds = np.arange(5)
doc_inds = [x.item() for x in doc_inds]
for doc_ind in doc_inds:
    doc = ds[doc_ind]
    title, text = doc['title'], doc['text'].replace('\n', '\\n')
    print(f'{doc_ind:03d} "{title}" {text[:300]}')

000 "Yangliuqing" Yangliuqing () is a market town in Xiqing District, in the western suburbs of Tianjin, People's Republic of China. Despite its relatively small size, it has been named since 2006 in the "famous historical and cultural market towns in China".\n\nIt is best known in China for creating nianhua or Yangl
001 "Orana Australia Ltd" Orana Australia Ltd is a not-for-profit organisation that provides a diverse range of training and support services to over 650 people with disabilities and their families in South Australia.\n\nHistory\nThe Mentally Retarded Children’s Society of SA Inc. was established in 1950 by a group of parent
002 "St. Mary's Church, Sønderborg" The St. Mary's Church is a church owned by the Church of Denmark in Sønderborg, Denmark and the church of the parish with the same name. Thanks to its location on a hill, the church building is very iconic for the city.\n\nHistory \nIn the Middle Ages there was a leper colony on a hill just outside 
003 "Kalitta" Kal

In [55]:
docs_toks_in = get_batch_tokens(doc_inds)
logits_pred = model(docs_toks_in)
probs_pred = torch.softmax(logits_pred, dim=-1)
# probs_pred = torch.sigmoid(logits_pred)
print(probs_pred.shape)
docs_toks_out = torch.argmax(probs_pred, dim=-1)
print(docs_toks_out.shape)

torch.Size([5, 128, 50271])
torch.Size([5, 128])


In [56]:
for i, doc_ind in enumerate(doc_inds):
    s = tkz.decode(docs_toks_out[i])
    s = s.replace('\n', '\\n')
    print(f'{doc_ind:03d} {s}')


000 arazi de dechzari () is a small located in Harian District, in the first city of 18ia and University's Republic of India. As her the to time, and has been been be such to the firstth century and other other due for India.\n\nOn was located is in use for anthides of Schaariarii. As the to other there, Aaidi is also as the first area for the time of otherth such in the sames.\n. this works to " dachi to be other but of which the people
001 cals 'cl"", the -to-based, is be a this that of which and be due such to be other to information, and used in 18s.\n\nHistory\n\nAs ofi" (-�s" of 18i.also is to well as a number of " States Statess and California, an services for the time in the United due in a own to two to services in the to as as well.\n\n. "-�ss is been be that and be due for the to others, to " $s, and to other by other\n.
002  St. She, 2011, Kbu County"". As the time of the first season of the University of located in Wban County, California of the end of the end of the same 