In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdec
from mllm.model.mllm_ranker import MllmRanker
from mllm.config.model import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens



In [30]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'

encdec_subdir = 'encdec-20240718_221554-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240724_230827-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240726_232850-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240722_225232-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240730_213328-wiki_20200501_en-ch_100_fixed'
# ranker_subdir = 'ranker-20240806_221913-msmarco'
# ranker_subdir = 'ranker-20240815_180317-msmarco'
ranker_subdir = 'ranker-20240903_215749-msmarco-fever'
# ranker_subdir = 'ranker-20240905_213413-msmarco-fever'
encdec_train_path = TRAIN_ENCDEC_PATH / encdec_subdir
# ranker_train_path = TRAIN_RANKER_PATH / ranker_subdir
ranker_train_path = DATA_PATH / 'train_mllm_ranker_qrels' / ranker_subdir
encdec_snapshot_path = encdec_train_path / 'best.pth'
ranker_snapshot_path = ranker_train_path / 'best.pth'

In [5]:
docs_batch_size = 5
max_chunks_per_doc = 3
device = 'cpu'
# device = 'cuda'

device = torch.device(device)
print(device)

cpu


In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=100000)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind
ds_loader = WikiDsLoader(
    ds_path=DS_DIR_PATH, docs_batch_size=docs_batch_size, max_chunks_per_doc=max_chunks_per_doc,
    pad_tok=pad_tok, qbeg_tok=qbeg_tok, qend_tok=qend_tok, device=device,
)
ds_loader.shuffle(train=True)
ds_loader.shuffle(train=False)
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)

Loading cache from /home/misha/data/wiki_20200501_en/ch_100_fixed/.mllm/ds.csv
Loaded dataset size: 50989207


In [8]:
def tokten_to_text(tokens: torch.Tensor) -> str:
    tokens = tokens.flatten()
    tokens = tokens[tokens != pad_tok]
    tokens = list(tokens)
    s = tokenizer.decode(tokens)
    return s

def distance(x: np.ndarray, y: np.ndarray, cosine: bool = False):
    if not cosine:
        return np.linalg.norm(x - y)
    x_norm, y_norm = np.linalg.norm(x), np.linalg.norm(y)
    return np.sum(x * y) / (x_norm * y_norm)

def text_to_tokens(s: str, qbeg_tok: Optional[int] = None, qend_tok: Optional[int] = None) -> torch.Tensor:
    tokens = tokenizer(s)['input_ids']
    if qbeg_tok is not None:
        assert qend_tok is not None
        tokens = [qbeg_tok, *tokens, qend_tok]
    n_tokens = len(tokens)
    n_padded = n_tokens // inp_len + (n_tokens % inp_len > 0)
    res = np.full((n_padded * inp_len, ), pad_tok, dtype=np.int32)
    res[:n_tokens] = tokens
    res = torch.from_numpy(res).to(device)
    res = res.reshape(n_padded, inp_len)
    return res

def print_dist(target_embs: torch.Tensor, docs_embs: torch.Tensor, target_mask: torch.Tensor, cosine: bool = True):
    for i, docs_emb in enumerate(docs_embs.detach().numpy()):
        for target_emb in target_embs.detach().numpy():
            dist = distance(target_emb, docs_emb, cosine)
            print(f'{dist:0.6f} ', end='')
        sfx = 'T' if target_mask[i] else 'F'
        print(sfx)


In [9]:
model_encdec_cfg = create_mllm_encdec_cfg(
    n_vocab=len(tokenizer), d_word_wec=256, inp_len=inp_len,
    enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)
model_encdec = MllmEncdec(model_encdec_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897278 -1.335788e-06 0.010897281
vocab_encoder.layer_norm.weight (256,) -0.09893657 0.0066219117 0.09997964
vocab_encoder.layer_norm.bias (256,) -0.099779785 0.0021610437 0.099599935
encoder.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.10825171 -7.75688e-05 0.10825041
encoder.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825286 6.4800486e-05 0.10825026
encoder.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10825027 0.00010499777 0.10825214
encoder.layer_stack.0.slf_attn.fc.weight (256, 256) -0.10825227 -6.173113e-05 0.108251885
encoder.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.098802604 -0.0031716428 0.09854559
encoder.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09966463 0.005480206 0.099370494
encoder.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846521 2.8875238e-05 0.06846477
encoder.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09987471 -0.000680316 0.09982319
encoder.layer_stack.0.pos_ffn.w_2.weight (256

In [10]:
checkpoint_encdec = torch.load(encdec_snapshot_path)
model_encdec.load_state_dict(checkpoint_encdec['model'], strict=False)
model_encdec.eval()
del checkpoint_encdec

In [11]:
text_to_tokens('Hello there')

tensor([[15496,   612, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267,
         50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267, 50267]],
       dtype=torch.int32)

In [12]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([3, 100]),
 tensor([False, False, False, False, False, False, False, False,  True,  True,
          True, False, False]))

In [13]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> Suden uniSuden uni ("Wolf\'s Dream") is the first full-length album by Finnish pagan metal band Moonsorrow. It was originally released in 2001, and then re-released in 2003 with one bonus track (a Finnish lyrics version of the traditional Swedish song "Kom nu gubbar"), different cover art, and a 40-minute DVD.\n\nTrack listing\n\nPersonnel\n Ville Sorvali - vocals, bass, handclaps, choir\n Marko Tarvonen - drums, timpani, 12-string, vocals (backing), handclaps, choir\n Henri Sorvali - choir, guitars, keyboards, vocals (clean), accordion, mouth harp, handclaps\n\nGuest musicians\n Robert Lejon - handclaps on "Tulkaapa äijät!"\n Stefan Lejon - handclaps on "Tulkaapa äijät!"\n Blastmor - handclaps\n Avather - handclaps\n Janne Perttilä - choir, handclaps\n\nProduction\n Mika Jussila - remastering\n Ahti "Pirtu" Kortelainen - recording, mixing, mastering\n Niklas Sundin - reissue cover art\n\nCategory: <|query_end|>'

In [14]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 4043219 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> The Hamilton Spectator <|doc_title_end|> <|doc_body_begin|> The Hamilton Spectator, 
<|doc_begin|> <|doc_id_begin|> 4043219 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  the Spectator the first of the chain. The Southam chain was sold in 1998 to Conrad Black, who in turn 
<|doc_begin|> <|doc_id_begin|> 4043219 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  the surrounding communities of Grimsby and Beamsville. It also serves Brant County and Haldimand Coun
<|doc_begin|> <|doc_id_begin|> 5623618 <|doc_id_end|> <|doc_offset_begin|> 4330 <|doc_offset_end|>.\n Sport Available starting from the 2017 model year, the inline-4 is replaced by the 2.7 Ecoboost 325
<|doc_begin|> <|doc_id_begin|> 5623618 <|doc_id_end|> <|doc_offset_begin|> 4420 <|doc_offset_end|>  Fancy Sport Chrome Mesh Grille, Nappa Leather Seating Surfaces, and Heated and Cooled Front Sea

In [15]:
target_embs = model_encdec.run_enc_emb(target_chunks)
docs_embs = model_encdec.run_enc_emb(docs_chunks)

In [16]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.586205 0.576020 0.477831 F
0.576322 0.482000 0.543113 F
0.520205 0.474481 0.514628 F
0.614104 0.664908 0.666830 F
0.597907 0.575063 0.665326 F
0.590845 0.581053 0.620674 F
0.558637 0.512605 0.530244 F
0.597204 0.632336 0.617147 F
0.700274 0.638536 0.582114 T
0.698899 0.835854 0.719015 T
0.671124 0.757178 0.785849 T
0.604881 0.599266 0.567953 F
0.538754 0.573610 0.576608 F


In [31]:
inp_len = ds_loader.emb_chunk_size if ds_loader.fixed_size else calc_max_inp_size(ds_loader.emb_chunk_size)
model_ranker_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=inp_len, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=pad_tok, dropout_rate=0.1, enc_with_emb_mat=True,
)

model_ranker = MllmRanker(model_ranker_cfg).to(device)

vocab_encoder.src_word_emb.weight (50270, 256) -0.010897282 -9.1012805e-07 0.010897279
vocab_encoder.layer_norm.weight (256,) -0.09958132 0.00054555107 0.099090874
vocab_encoder.layer_norm.bias (256,) -0.099580586 -0.0005646632 0.09982468
encoders.0.layer_stack.0.slf_attn.w_qs.weight (256, 256) -0.108248346 -4.838856e-05 0.10825213
encoders.0.layer_stack.0.slf_attn.w_ks.weight (256, 256) -0.10825298 -0.00034150513 0.10825191
encoders.0.layer_stack.0.slf_attn.w_vs.weight (256, 256) -0.10824595 -0.00038487592 0.10825164
encoders.0.layer_stack.0.slf_attn.fc.weight (256, 256) -0.108249485 3.074905e-05 0.108251706
encoders.0.layer_stack.0.slf_attn.layer_norm.weight (256,) -0.09877449 -0.0065938733 0.099637486
encoders.0.layer_stack.0.slf_attn.layer_norm.bias (256,) -0.09863682 0.009567137 0.09970514
encoders.0.layer_stack.0.pos_ffn.w_1.weight (1024, 256) -0.06846418 1.3598855e-05 0.068465136
encoders.0.layer_stack.0.pos_ffn.w_1.bias (1024,) -0.09990684 -0.0018554593 0.099997476
encoders.0.l

In [32]:
checkpoint_ranker = torch.load(ranker_snapshot_path)
model_ranker.load_state_dict(checkpoint_ranker['model'])
model_ranker.eval()
del checkpoint_ranker

In [19]:
i = 10
batch = ds_loader.get_batch(i, train=True)
docs_chunks, target_chunks, target_mask = batch.gen_tensors()
docs_chunks.shape, target_chunks.shape, target_mask

(torch.Size([13, 100]),
 torch.Size([1, 100]),
 tensor([False, False, False, False, False, False,  True,  True, False, False,
         False, False, False]))

In [20]:
s_target = tokten_to_text(target_chunks)
s_target

'<|query_begin|> Buddy Baumann (American football)Carl Buddy Baumann (August 4, 1900 – April 27, 1951) was an American football player in the National Football League. He played with the Racine Legion during the 1922 NFL season.\n\nReferences\n\nCategory:Sportspeople from Racine, Wisconsin\nCategory:Players of American football from Wisconsin\nCategory:Racine Legion players\nCategory:1900 births\nCategory:1951 deaths <|query_end|>'

In [55]:
for toks in docs_chunks:
    s = tokten_to_text(toks)
    print(s[:200].replace('\n', '\\n'))

<|doc_begin|> <|doc_id_begin|> 4043219 <|doc_id_end|> <|doc_offset_begin|> 0 <|doc_offset_end|> <|doc_title_begin|> The Hamilton Spectator <|doc_title_end|> <|doc_body_begin|> The Hamilton Spectator, 
<|doc_begin|> <|doc_id_begin|> 4043219 <|doc_id_end|> <|doc_offset_begin|> 91 <|doc_offset_end|>  the Spectator the first of the chain. The Southam chain was sold in 1998 to Conrad Black, who in turn 
<|doc_begin|> <|doc_id_begin|> 4043219 <|doc_id_end|> <|doc_offset_begin|> 182 <|doc_offset_end|>  the surrounding communities of Grimsby and Beamsville. It also serves Brant County and Haldimand Coun
<|doc_begin|> <|doc_id_begin|> 5623618 <|doc_id_end|> <|doc_offset_begin|> 5590 <|doc_offset_end|>  is powered by the 2.0L EcoBoost turbocharged I4 gasoline engine. All Fusion Hybrid and Fusion Energi
<|doc_begin|> <|doc_id_begin|> 5623618 <|doc_id_end|> <|doc_offset_begin|> 5680 <|doc_offset_end|>  wheel drive (FWD), while select gasoline-only Fusion models are offered with all wheel drive. \n

In [33]:
target_embs = model_ranker.run_enc_emb(target_chunks)
docs_embs = model_ranker.run_enc_emb(docs_chunks)

In [34]:
cosine = False
cosine = True
print_dist(target_embs, docs_embs, target_mask, cosine=cosine)

0.443827 F
0.403274 F
0.597387 F
0.276130 F
0.164219 F
0.261762 F
0.911145 T
0.207537 T
0.228399 F
0.270754 F
0.173670 F
0.248695 F
0.355535 F


In [60]:
txt = 'Hong Kong 1987'
txt = 'Climate Classification system, Mays Landing has a humid subtropical climate, abbreviated "Cfa"'
# txt = 'Climate Classification system'
txt = 'War and Peace'
txt = 'Bandar Express, Ichhamati Express and Benapole Express'
txt = 'Rick Anderson'
txt = 'Makangarawe Temeke ward'
# txt = 'graph sandwich'
txt = 'james barry'
txt = 'erigeron'
txt = 'Dillon Gabriel america'
txt = 'East Willy B'
txt = 'Fusion Hybrid EcoBoost Gasoline'
# txt = 'The graph sandwich problem for property Π is defined as follows:'
cosine = True
txt_tokens = text_to_tokens(txt, qbeg_tok=qbeg_tok, qend_tok=qend_tok)
# txt_tokens = txt_tokens.repeat(3, 1)
print(txt_tokens.shape)
txt_embs = model_ranker.run_enc_emb(txt_tokens)
print_dist(txt_embs, docs_embs, target_mask, cosine=cosine)


torch.Size([1, 100])
0.374339 F
0.465199 F
0.231861 F
0.682488 F
0.482204 F
0.496427 F
0.243842 T
0.264784 T
0.379053 F
0.345734 F
0.411226 F
0.174130 F
0.137890 F


In [61]:

rank = model_ranker.run_qs_infer(docs_chunks, txt_tokens)
rank_str = [f'{r:.06f}' for r in rank.flatten()]
rank_str

['0.010951',
 '0.058801',
 '0.001228',
 '0.615539',
 '0.022415',
 '0.051561',
 '0.004969',
 '0.230985',
 '0.000338',
 '0.000157',
 '0.001883',
 '0.000541',
 '0.000633']

In [48]:
np.ones(1).un

442500

In [31]:
rank = model_ranker(txt_tokens, docs_chunks)
rank

tensor([[0.8737, 0.8907, 0.7371, 0.5201, 0.4101, 0.8237, 0.7342, 0.8183, 0.8694,
         0.6263, 0.6608, 0.6087, 0.2565, 0.3135, 0.3250]],
       grad_fn=<SigmoidBackward0>)

In [35]:
rank = model_ranker(target_chunks, docs_chunks)
rank

tensor([[2.6538e-21, 3.7294e-11, 2.4412e-19, 3.4020e-23, 7.5630e-20, 3.8663e-25,
         1.3986e-18, 7.2614e-22, 1.0000e+00, 1.0000e+00, 1.0000e+00, 4.7166e-09,
         1.9997e-11, 1.3684e-08]], grad_fn=<SigmoidBackward0>)

In [31]:
np.random.randint(0, 6)

3

In [2]:
import torch

In [4]:
device1 = torch.device('cpu')
device2 = torch.device('cuda')
device3 = torch.device('cpu')

In [5]:
print(device1 == device2)
print(device2 == device3)
print(device1 == device3)


False
False
True


In [7]:
device1.type, type(device1.type)

('cpu', str)

In [11]:
import numpy as np
toks = [
    np.ones(3) * 0.5,
    np.ones(3) * 1.7,
    np.array([-1, 7, 33])
]
np.stack(toks, axis=0)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [12]:
np.array(toks)

array([[ 0.5,  0.5,  0.5],
       [ 1.7,  1.7,  1.7],
       [-1. ,  7. , 33. ]])

In [13]:
np.allclose(np.array(toks), np.stack(toks, axis=0))

True