In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
from pydantic_yaml import parse_yaml_file_as
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.exp.args import TOKENIZER_CFG_FNAME, ENCDEC_MODEL_CFG_FNAME, RANKER_MODEL_CFG_FNAME
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRankerLevel
from mllm.config.model import TokenizerCfg, MllmEncdecCfg, MllmRankerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, ChunkTokenizer, tokenizer_from_config

In [None]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

TRAIN_ENCDEC_0_PATH = DATA_PATH / 'train_mllm_encdec_0'
TRAIN_ENCDEC_1_PATH = DATA_PATH / 'train_mllm_encdec_1'
encdec_0_subdir = 'encdec-lvl0-20241028_093737-wiki_20200501_en-ch_100_fixed-enc-lrs3-embmatFalse-d256-h8-dec-lrs3-seqlen100-d256-h8-vocdecTrue'
encdec_1_subdir = 'encdec-lvl1-20241022_224217-msmarco-fever-enc-lrs2-embmatTrue-d256-h8-dec-lrs2-seqlen100-d256-h8'

encdec_0_train_path = TRAIN_ENCDEC_0_PATH / encdec_0_subdir
encdec_1_train_path = TRAIN_ENCDEC_1_PATH / encdec_1_subdir
encdec_0_snapshot_fpath = encdec_0_train_path / 'best.pth'
encdec_1_snapshot_fpath = encdec_1_train_path / 'best.pth'
encdec_0_tkz_cfg_fpath = encdec_0_train_path / TOKENIZER_CFG_FNAME
encdec_0_model_cfg_fpath = encdec_0_train_path / ENCDEC_MODEL_CFG_FNAME
encdec_1_model_cfg_fpath = encdec_1_train_path / ENCDEC_MODEL_CFG_FNAME

In [7]:
encdec_tkz_cfg = parse_yaml_file_as(TokenizerCfg, encdec_0_tkz_cfg_fpath)
tokenizer = tokenizer_from_config(encdec_tkz_cfg)
tok_dict = encdec_tkz_cfg.custom_tokens
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind

In [8]:
docs_batch_size = 5
max_chunks_per_doc = 3
model_level = 0
device_name = 'cpu'
# device_name = 'cuda'

device = torch.device(device_name)
print(device)

cpu


In [9]:
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [10]:
def tokens_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [11]:
model_encdec_0_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_0_model_cfg_fpath)
model_encdec_0 = MllmEncdecLevel(model_encdec_0_cfg, 0).to(device)
checkpoint_encdec_0 = torch.load(encdec_0_snapshot_fpath)
model_encdec_0.load_state_dict(checkpoint_encdec_0['model'], strict=False)
model_encdec_0.eval()
None

vocab_encoder.src_word_emb.weight (50271, 256) -0.010897174 -8.7604326e-07 0.010897173
vocab_encoder.layer_norm.weight (256,) -0.09980767 0.0057643503 0.09948851
vocab_encoder.layer_norm.bias (256,) -0.09884226 0.0011206635 0.098379485
vocab_decoder.word_prj.weight (50271, 256) -0.010897174 4.4291858e-07 0.010897167
encoder.a_em () -0.04112512 -0.04112512 -0.04112512
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10824676 0.00018859553 0.108249635
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825236 -9.765843e-05 0.10825259
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825161 -8.1947765e-05 0.108248875
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108251005 1.5149781e-05 0.10825272
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.099994875 0.00016797625 0.09996452
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09980128 -0.0029777153 0.09829161
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846519 7.760219e-05 0

In [12]:
model_encdec_1_cfg = parse_yaml_file_as(MllmEncdecCfg, encdec_1_model_cfg_fpath)
model_encdec_1 = MllmEncdecLevel(model_encdec_1_cfg, 1).to(device)
checkpoint_encdec_1 = torch.load(encdec_1_snapshot_fpath)
model_encdec_1.load_state_dict(checkpoint_encdec_1['model'], strict=False)
model_encdec_1.eval()
None

ValidationError: 1 validation error for ParsingModel[MllmEncdecCfg]
__root__ -> with_vocab_decoder
  field required (type=value_error.missing)

In [13]:
i = 10
batch = ds_loader.get_batch(i, train=False)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([11, 100]),
 torch.Size([3, 100]),
 tensor([False,  True,  True,  True, False, False, False, False, False, False,
         False]))

In [14]:
s_target = tokens_to_text(target_chunks)
s_target

'<|query_begin|> Benoni, GautengBenoni is a town in Ekurhuleni municipality, Gauteng, South Africa.\n\nBenoni was also the setting for the MTV-inspired movie Crazy Monkey: Straight Outta Benoni, released internationally in 2005.\n\nPeople from Benoni\n\nCharlene, Princess of Monaco, (née Charlene Wittstock), swimmer, and consort of Prince Albert II of Monaco\nBryan Habana, former Springboks rugby player\nPhilip Holiday, IBF World Champion Boxer\nMorris Kahn (born 1930), Israeli billionaire, founder and chairman of Aurec Group\nMildred Mangxola, singer and member of the Mahotella Queens\n Frith van der Merwe, schoolteacher at Benoni High and the most prolific female runner in the history of the Comrades Marathon\nGenevieve Morton, top model \nGrace Mugabe, former First Lady of Zimbabwe\n Bradley Player, cricketer\n Oliver Reginald Tambo, ANC, ANCYL and SACP hero during the Apartheid regime.\nCharlize Theron, Oscar-winning actress (Academy Awards: Best Actress Monster (2003 film))\nVic T

In [15]:
for toks in docs_chunks:
    s = tokens_to_text(toks)
    print(s[:600].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 5853387 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Marivirga atlantica <|doc_title_end|> <|doc_body_begin|> Marivirga atlantica is a Gram-negative, aerobic and rod-shaped bacterium from the genus of Marivirga which has been isolated from seawater from the Atlantic Ocean.\n\nReferences\n\nCategory:Sphingobacteriia\nCategory:Bacteria described in 2015 <|doc_body_end|> <|doc_end|>
<|doc_begin|> <|doc_id_begin|> 6064729 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> Benoni, Gauteng <|doc_title_end|> <|doc_body_begin|> Benoni is a town in Ekurhuleni municipality, Gauteng, South Africa.\n\nBenoni was also the setting for the MTV-inspired movie Crazy Monkey: Straight Outta Benoni, released internationally in 2005.\n\nPeople from Benoni\n\nCharlene, Princess of Monaco, (née Charlene Wittstock), swimmer, and consort of Prince Albert II of
<|doc_begin|> <|doc_id_begin|> 6064729 <|doc_id_end|> <|doc_off

In [16]:
docs_chunks_pred = model_encdec_0(docs_chunks)
# docs_chunks_pred = torch.sigmoid(docs_chunks_pred)
docs_chunks_pred = torch.softmax(docs_chunks_pred, dim=-1)
docs_chunks_pred.shape, docs_chunks_pred.dtype

(torch.Size([11, 100, 50271]), torch.float32)

In [17]:
dc_toks_pred = torch.argmax(docs_chunks_pred, dim=-1)
dc_toks_pred.shape

torch.Size([11, 100])

In [18]:
for toks in dc_toks_pred:
    s = tokens_to_text(toks)
    s = s.replace('\n', '\\n')
    print(len(toks), len(s), s)


100 729 Debug <|doc_end|> Runtime JPM <|doc_body_end|>  EthereumforeseenCategory <|query_end|> Args Ibid Ibidusterityusterity <|query_end|> <|query_end|> HopefullyHopefully"} falsehood falsehood falsehood,",","," NEED NEED wedd],"foreseenforeseenforeseenforeseenforeseenforeseen!=!= unfocusedforeseenforeseenforeseenforeseenforeseenforeseenforeseenforeseenforeseenforeseen weddSettingsvPforeseenforeseenCONTforeseenSettings)\SettingsclaimerforeseenArgsArgsArgsArgsArgsArgsArgsArgsArgsArgsArgsArgsArgsArgsArgsArgsoths��� <|doc_id_begin|> <|doc_id_begin|> <|doc_id_begin|> <|doc_id_begin|> <|doc_id_begin|> <|doc_id_begin|> <|doc_id_begin|> <|doc_id_begin|> <|doc_id_begin|> <|doc_id_begin|> <|doc_begin|> <|doc_id_begin|> <|doc_id_begin|>
100 813 Debug <|doc_end|>  SLIforeseen <|doc_end|> <|query_end|> <|query_end|> <|query_end|> <|query_end|> Args"}"}"}"}"}"}"}"}"}"}"},","ombies,"],"],"],"],"],""} <|doc_end|> <|doc_end|> <|doc_end|> <|doc_end|> <|doc_end|> <|doc_end|> ]," <|doc_end|> <|doc_end|>

In [34]:
txts = [
    '<|doc_begin|> 20 <|doc_id_begin|> 733860 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>Hello, my name is Mikhail<|doc_end|>',
    'Malaga is a city in Spain',
    'LLM stands for Large <|mask|> Model',
    'You\'d better learn new modeling approaches first, Mikhail from Malaga, Spain!',
]

In [35]:
chunks = []
for txt in txts:
    toks = text_to_tokens(txt)
    print(toks.shape)
    chunks.append(toks)

chunks = torch.concat(chunks)
for toks in chunks:
    s = tokens_to_text(toks)
    print(s)


torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 100])
torch.Size([1, 100])
<|doc_begin|>  20  <|doc_id_begin|>  733860  <|doc_id_end|>   <|doc_offset_begin|>  91  <|doc_offset_end|> Hello, my name is Mikhail <|doc_end|>
Malaga is a city in Spain
LLM stands for Large  <|mask|>  Model
You'd better learn new modeling approaches first, Mikhail from Malaga, Spain!


In [36]:
chunks_pred = model_encdec_0(chunks)
# chunks_pred = torch.sigmoid(chunks_pred)
chunks_pred = torch.softmax(chunks_pred, dim=-1)
print(chunks_pred.shape, chunks_pred.dtype)
toks_pred = torch.argmax(chunks_pred, dim=-1)
print(toks_pred.shape)

torch.Size([4, 100, 50271]) torch.float32
torch.Size([4, 100])


In [37]:
for toks in toks_pred:
    s = tokens_to_text(toks)
    s = s.replace('\n', '\\n')
    print(s)


<|doc_begin|> <|doc_id_begin|> 111181, <|doc_offset_begin|> <|doc_offset_begin|> . ..\n., my name is <|doc_body_end|>  "
Moo of Santao, <|query_end|>
:Clms of by <|query_end|> <|query_end|>
,,, a the,, TV, singer of Maraga, Serbia <|query_end|>


In [41]:
probs = chunks_pred[0][8]
inds = torch.arange(len(probs))
prob_thres = 0.95
prob_thres = 0.1
probs_mask = probs >= prob_thres
print(probs[probs_mask])
toks = inds[probs_mask]
strs = []
for tok in toks:
    toks_ = torch.Tensor([tok])
    s = tokens_to_text(toks_)
    strs.append(s)
print(strs)


tensor([0.1149, 0.7758], grad_fn=<IndexBackward0>)
['.', '.']


In [12]:
import re
DT_PAT_RE = r'\d{8}_\d{6}'
pat = re.compile(r'^[\w-]+?\-(%s)-.+$' % DT_PAT_RE)
print(DT_PAT_RE, pat)

\d{8}_\d{6} re.compile('^[\\w-]+?\\-(\\d{8}_\\d{6})-.+$')


In [13]:
paths = [
    'encdec-20241018_092135-wiki_20200501_en-ch_100_fixed',
    'encdec-lvl0-20241026_120743-wiki_20200501_en-ch_100_fixed-enc-lrs2-embmatFalse-d256-h8-dec-lrs2-seqlen100-d256-h8',
    'encdec-33337128_001122-20241018_092135-wiki_20200501_en-ch_100_fixed',
]
for p in paths:
    m = pat.match(p)
    dt = None
    if m:
        dt = m.group(1)
    print(p[:30], dt)

encdec-20241018_092135-wiki_20 20241018_092135
encdec-lvl0-20241026_120743-wi 20241026_120743
encdec-33337128_001122-2024101 33337128_001122
