In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRankerLevel
from mllm.config.model import TokenizerCfg, MllmEncdecCfg, MllmRankerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, ChunkTokenizer, tokenizer_from_config

In [4]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

TRAIN_ENCDEC_0_PATH = DATA_PATH / 'train_mllm_encdec_0'
TRAIN_ENCDEC_1_PATH = DATA_PATH / 'train_mllm_encdec_1'
# encdec_0_subdir = 'encdec-lvl0-20241028_093737-wiki_20200501_en-ch_100_fixed-enc-lrs3-embmatFalse-d256-h8-dec-lrs3-seqlen100-d256-h8-vocdecTrue'
# encdec_0_subdir = 'encdec-lvl0-20241028_212210-wiki_20200501_en-ch_100_fixed-enc-lrs2-embmatFalse-d256-h8-dec-lrs2-seqlen100-d256-h8-vocdecTrue'
# encdec_0_subdir = 'encdec-lvl0-20241029_140645-wiki_20200501_en-ch_100_fixed-enc-lrs3-embmatFalse-d256-h8-dec-lrs3-seqlen100-d256-h8-vocdecTrue'
encdec_0_subdir = 'encdec-lvl0-20241030_090802-wiki_20200501_en-ch_100_fixed-enc-lrs4-embmatFalse-d256-h8-dec-lrs4-seqlen100-d256-h8-vocdecTrue'
encdec_1_subdir = 'encdec-lvl1-20241022_224217-msmarco-fever-enc-lrs2-embmatTrue-d256-h8-dec-lrs2-seqlen100-d256-h8'

encdec_0_train_path = TRAIN_ENCDEC_0_PATH / encdec_0_subdir
encdec_1_train_path = TRAIN_ENCDEC_1_PATH / encdec_1_subdir
encdec_0_snapshot_fpath = encdec_0_train_path / 'last.pth'
encdec_1_snapshot_fpath = encdec_1_train_path / 'best.pth'
encdec_0_tkz_cfg_fpath = encdec_0_train_path / TOKENIZER_CFG_FNAME
encdec_0_model_cfg_fpath = encdec_0_train_path / ENCDEC_MODEL_CFG_FNAME
encdec_1_model_cfg_fpath = encdec_1_train_path / ENCDEC_MODEL_CFG_FNAME

In [5]:
encdec_tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_0_tkz_cfg_fpath)
tokenizer = tokenizer_from_config(encdec_tkz_cfg)
tok_dict = encdec_tkz_cfg.custom_tokens
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind

In [6]:
docs_batch_size = 5
max_chunks_per_doc = 3
model_level = 0
device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [7]:
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [8]:
def tokens_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [9]:
model_encdec_0_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_0_model_cfg_fpath)
model_encdec_0 = MllmEncdecLevel(model_encdec_0_cfg, 0).to(device)
checkpoint_encdec_0 = torch.load(encdec_0_snapshot_fpath)
model_encdec_0.load_state_dict(checkpoint_encdec_0['model'], strict=False)
model_encdec_0.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 1.916379e-06 0.010897173
vocab_encoder.layer_norm.weight (256,) -0.09989282 -0.001876442 0.09918808
vocab_encoder.layer_norm.bias (256,) -0.09993392 -0.0039835246 0.09995335
vocab_decoder.word_prj.weight (50271, 256) -0.010897173 2.8820418e-06 0.01089717
encoder.a_em () 0.070763186 0.070763186 0.070763186
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108251855 0.00016167221 0.10825132
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.108241715 -0.00020803144 0.108245455
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825175 -0.0003009955 0.108245976
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825185 -3.9326944e-05 0.10825074
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09912782 -0.001627505 0.09945027
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09983226 0.002700204 0.099729896
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846522 1.5313053e-05 0.06

In [10]:
model_encdec_1_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_1_model_cfg_fpath)
model_encdec_1 = MllmEncdecLevel(model_encdec_1_cfg, 1).to(device)
checkpoint_encdec_1 = torch.load(encdec_1_snapshot_fpath)
model_encdec_1.load_state_dict(checkpoint_encdec_1['model'], strict=False)
model_encdec_1.eval()
None

ValidationError: 1 validation error for ParsingModel[MllmEncdecCfg]
__root__ -> with_vocab_decoder
  field required (type=value_error.missing)

In [10]:
i = 10
batch = ds_loader.get_batch(i, train=False)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False,  True,  True,  True, False, False, False,
         False, False, False]))

In [11]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|> San Juan de Dios ParkSan Juan de Dios Park  (in Spanish Colonias San Juan de Dios) (also known as San Juan) is a zone situated in the south of Mexico City, in the delegation Tlalpan. Has his origin in the Inland revenue of San Juan of God "The Big". The zone is conformed by the colonies Hacienda de San Juan, Villa Lázaro Cárdenas,Ex Hacienda San Juan de Dios, Arboledas Del Sur, Hacienda de San Juan 2nd Section, Chimalli, The Colorines, Guadeloupe Tlalpan and the colony AMSA. The zone houses big number of parks scattered in all his colonies as well as it also has commercial squares like the shopping centre Paseo Acoxpa. Also the zone of San Juan basin with different urban services like transport, educational and of health. San Juan of God is a residential zone mostly, has different private residentials and also private or private streets. Many of his colonies also belong to the zone of Coapa.\n\nIt is one of the south exclusive zones of the Mexico City. It limits with t

In [13]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:600].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 6053886 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> List of vice-admirals of Sussex <|doc_title_end|> <|doc_body_begin|> This is a list of people who have served as Vice-Admiral of Sussex.\n\nSir William More 1559–1600\nvacant\nCharles Howard, 2nd Earl of Nottingham 1608–1642\nFrancis Lennard, 14th Baron Dacre bef. 1647–1650\nAnthony Stapley 1651–1655 (Parliamentary)
<|doc_begin|> <|doc_id_begin|> 6053886 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> \nvacant\nSir John Pelham, 3rd Baronet 1660–1703\nCharles Goring 1703–1705\nThomas Pelham, 1st Baron Pelham 1705–1712\nJohn Ashburnham, 3rd Baron Ashburnham 1712–1715\nThomas Pelham-Holles, 1st Duke of Newcastle 1715–1768\nvacant\nJohn Ashburnham, 2nd
<|doc_begin|> <|doc_id_begin|> 6053886 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  Earl of Ashburnham 1770–1812\nCharles Lennox, 4th Duke of Richmond 1812–1819\nGeorge Wyndham, 3rd Earl of Egremont 182

In [14]:
docs_chunks_pred = model_encdec_0(docs_chunks)
# docs_chunks_pred = torch.sigmoid(docs_chunks_pred)
docs_chunks_pred = torch.softmax(docs_chunks_pred, dim=-1)
docs_chunks_pred.shape, docs_chunks_pred.dtype

(torch.Size([13, 100, 50271]), torch.float32)

In [15]:
dc_toks_pred = torch.argmax(docs_chunks_pred, dim=-1)
dc_toks_pred.shape

torch.Size([13, 100])

In [17]:
for toks in dc_toks_pred:
    s = tokens_to_text(toks)
    s = s.replace('\n', '\\n')
    print(len(toks), len(s), s)


100 429 <|doc_begin|> <|doc_id_begin|> 172775 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> List of Member-Wreatals of Lords <|doc_title_end|> <|doc_body_begin|> This is a list of people who have held as Vice-Admiral of Norfolk.\n\nSir William Baronst90–1600\nHambell\nCharles III, 2nd Earl of Hunting 1648–1642\nFrancis Castard, 14th Baron D Baron with d. 1641–1670\nWilliam Lennen 1606–1655 (Adliamentus)
100 340 <|doc_begin|> <|doc_id_begin|> 172449 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|> \nHumbons\nSir John Bingham, 3rd Baron 14 1660–1797\nCharles A99 1703–1705\nThomas Dunham, 1st Baron Egham 1705–1714\nJohn Dundeham, 3rd Baron Dundehamb29–1722\nThomas Dunham-Hollage, 1st Duke of Newcastle 17et–1768\nHumbons\nJohn Dundeham, 2nd
100 395 <|doc_begin|> <|doc_id_begin|> 172799 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  Earl of Edfordham 1775–1814\nCharles Fitzay, 4th Duke of Richmond 1829–18)\nGeorge Wynders, 3rd Earl of Cam

In [18]:
txts = [
    '<|doc_begin|> 20 <|doc_id_begin|> 733860 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>Hello, my name is Mikhail<|doc_end|>',
    'Malaga is a city in Spain',
    'LLM stands for Large <|mask|> Model',
    'You\'d better learn new modeling approaches first, Mikhail from Malaga, Spain!',
]

In [19]:
chunks = []
for txt in txts:
    toks = text_to_tokens(txt)
    print(toks.shape)
    chunks.append(toks)

chunks = torch.concat(chunks)
for toks in chunks:
    s = tokens_to_text(toks)
    print(s)


torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 100])
<|doc_begin|>  20  <|doc_id_begin|>  733860  <|doc_id_end|>   <|doc_offset_begin|>  91  <|doc_offset_end|> Hello, my name is Mikhail <|doc_end|>
Malaga is a city in Spain
LLM stands for Large  <|mask|>  Model
You'd better learn new modeling approaches first, Mikhail from Malaga, Spain!


In [20]:
chunks_pred = model_encdec_0(chunks)
# chunks_pred = torch.sigmoid(chunks_pred)
chunks_pred = torch.softmax(chunks_pred, dim=-1)
print(chunks_pred.shape, chunks_pred.dtype)
toks_pred = torch.argmax(chunks_pred, dim=-1)
print(toks_pred.shape)

torch.Size([4, 100, 50271]) torch.float32
torch.Size([4, 100])


In [21]:
for toks in toks_pred:
    s = tokens_to_text(toks)
    s = s.replace('\n', '\\n')
    print(s)


alusedis <|doc_id_begin|>  airlines ISO airlinesusedulated <|doc_offset_begin|> 0ountulated"is <|doc_id_end|>  linkssk generic defunct also0 also is was <|doc_id_end|> <|doc_id_end|> ile generic A0 Comm design <|doc_id_begin|>  also also defunct airlines <|doc_id_end|> <|doc_id_end|> <|doc_id_end|> <|query_end|>  the <|doc_offset_begin|> <|doc_id_end|> <|query_end|> <|query_end|> <|doc_id_end|> <|doc_id_end|>  It distribution to <|doc_begin|>  The, <|doc_id_end|> <|doc_offset_begin|> <|doc_offset_end|>  generic scrapped <|doc_offset_begin|>  airlines generic 2is no, 2 launch a <|doc_id_end|> \n sub design\n Black was <|doc_offset_begin|> <|doc_id_end|> <|doc_id_end|>  refer <|doc_offset_begin|> <|doc_id_end|>  airlines", a and,, <|doc_offset_begin|> <|doc_id_end|>  also <|doc_id_end|> <|doc_id_end|>,
or the, the Transportation the theulated links  <|doc_id_begin|>   \' marketers <|doc_offset_begin|> 0 <|doc_id_begin|>  plays, alsoard"is is CommItized Rep designis alsois, firstulatedume

In [22]:
probs = chunks_pred[0][8]
inds = torch.arange(len(probs))
prob_thres = 0.95
prob_thres = 0.1
probs_mask = probs >= prob_thres
print(probs[probs_mask])
toks = inds[probs_mask]
strs = []
for tok in toks:
    toks_ = torch.Tensor([tok])
    s = tokens_to_text(toks_)
    strs.append(s)
print(strs)


tensor([], grad_fn=<IndexBackward0>)
[]


In [23]:
model_encdec_0

MllmEncdecLevel(
  (vocab_encoder): VocabEncoder(
    (src_word_emb): Embedding(50271, 256, padding_idx=50267)
    (position_enc): PositionalEncoding()
    (dropout): Dropout(p=0.0, inplace=False)
    (layer_norm): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
  )
  (vocab_decoder): VocabDecoder(
    (word_prj): Linear(in_features=256, out_features=50271, bias=False)
  )
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layer_stack): ModuleList(
      (0-1): 2 x EncoderLayer(
        (slf_attn): MultiHeadAttention(
          (w_qs): Linear(in_features=256, out_features=256, bias=False)
          (w_ks): Linear(in_features=256, out_features=256, bias=False)
          (w_vs): Linear(in_features=256, out_features=256, bias=False)
          (fc): Linear(in_features=256, out_features=256, bias=False)
          (attention): ScaledDotProductAttention(
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (dropout): Dropout(p=0.0, inplace=Fals

In [24]:
model_encdec_0.encoder.n_layers

2

In [12]:
import re
DT_PAT_RE = r'\d{8}_\d{6}'
pat = re.compile(r'^[\w-]+?\-(%s)-.+$' % DT_PAT_RE)
print(DT_PAT_RE, pat)

\d{8}_\d{6} re.compile('^[\\w-]+?\\-(\\d{8}_\\d{6})-.+$')


In [13]:
paths = [
    'encdec-20241018_092135-wiki_20200501_en-ch_100_fixed',
    'encdec-lvl0-20241026_120743-wiki_20200501_en-ch_100_fixed-enc-lrs2-embmatFalse-d256-h8-dec-lrs2-seqlen100-d256-h8',
    'encdec-33337128_001122-20241018_092135-wiki_20200501_en-ch_100_fixed',
]
for p in paths:
    m = pat.match(p)
    dt = None
    if m:
        dt = m.group(1)
    print(p[:30], dt)

encdec-20241018_092135-wiki_20 20241018_092135
encdec-lvl0-20241026_120743-wi 20241026_120743
encdec-33337128_001122-2024101 33337128_001122
