In [29]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
import os
from pathlib import Path
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, AddedToken, PreTrainedTokenizer

from mllm.data.wiki.dswiki import WikiDsLoader
from mllm.model.mllm_encdec import MllmEncdecLevel
from mllm.model.mllm_ranker import MllmRanker
from mllm.config.model import create_mllm_encdec_cfg, create_mllm_ranker_cfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens
from mllm.utils.utils import read_tsv, write_tsv

In [17]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
TRAIN_ENCDEC_PATH = DATA_PATH / 'train_mllm_encdec'
# TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker'
TRAIN_RANKER_PATH = DATA_PATH / 'train_mllm_ranker_qs'
DS_DIR_PATH = DATA_PATH / 'wiki_20200501_en' / 'ch_100_fixed'


In [19]:
embs_dpath = DATA_PATH / 'ranker_embs_msmarco_fever'
docs_ids_fpath = embs_dpath / 'docs_ids.tsv'
qs_ids_fpath = embs_dpath / 'qs_ids.tsv'
df_docs_ids = read_tsv(docs_ids_fpath)
df_qs_ids = read_tsv(qs_ids_fpath)
df_docs_ids

Unnamed: 0,ds_ids,ds_doc_ids
0,1,0
1,1,0
2,1,0
3,1,0
4,1,0
...,...,...
69767920,2,8630399
69767921,2,8630400
69767922,2,8630401
69767923,2,8630402


In [21]:
df_docs_ids.rename(columns={'ds_ids': 'ds_id', 'ds_doc_ids': 'ds_doc_id'}, inplace=True)
df_qs_ids.rename(columns={'ds_ids': 'ds_id', 'ds_query_ids': 'ds_query_id'}, inplace=True)
df_qs_ids

Unnamed: 0,ds_id,ds_query_id
0,1,0
1,1,1
2,1,2
3,1,3
4,1,4
...,...,...
495368,2,495343
495369,2,495344
495370,2,495345
495371,2,495346


In [22]:
df_docs_ids['doc_emb_id'] = np.arange(len(df_docs_ids))
df_qs_ids['query_emb_id'] = np.arange(len(df_qs_ids))
df_docs_ids

Unnamed: 0,ds_id,ds_doc_id,doc_emb_id
0,1,0,0
1,1,0,1
2,1,0,2
3,1,0,3
4,1,0,4
...,...,...,...
69767920,2,8630399,69767920
69767921,2,8630400,69767921
69767922,2,8630401,69767922
69767923,2,8630402,69767923


In [26]:
print(f'Save {len(df_docs_ids)} docs ids in {docs_ids_fpath}')
write_tsv(df_docs_ids, docs_ids_fpath)
print(f'Save {len(df_qs_ids)} qs ids in {qs_ids_fpath}')
write_tsv(df_qs_ids, qs_ids_fpath)

Save 69767925 docs ids in /home/misha/data/ranker_embs_msmarco_fever/docs_ids.tsv
Save 495373 qs ids in /home/misha/data/ranker_embs_msmarco_fever/qs_ids.tsv


In [32]:
input1 = torch.randn(10, 12)
input2 = torch.randn(10, 12)
output = F.cosine_similarity(input1, input2).unsqueeze(-1)
print(output)

tensor([[ 0.0160],
        [-0.2776],
        [-0.2787],
        [-0.2542],
        [-0.3560],
        [ 0.0086],
        [-0.3117],
        [-0.1231],
        [ 0.4930],
        [ 0.0494]])


In [33]:
p = 0.3
bernoulli = torch.distributions.Bernoulli(probs=p)
bernoulli

Bernoulli(probs: 0.30000001192092896)

In [43]:
m = bernoulli.sample((2, 3)).to(bool)
m

tensor([[False, False,  True],
        [ True, False, False]])

In [44]:
t = torch.randint(1, 10, (2, 3, 2))
t

tensor([[[2, 3],
         [4, 4],
         [2, 6]],

        [[1, 7],
         [4, 9],
         [4, 9]]])

In [45]:
t[m] = 0
t

tensor([[[2, 3],
         [4, 4],
         [0, 0]],

        [[0, 0],
         [4, 9],
         [4, 9]]])