In [None]:
%load_ext autoreload
%autoreload 2

In [16]:
from collections import defaultdict
import csv
from dataclasses import dataclass
import gzip
import itertools
import os
from pathlib import Path
from pprint import pprint
import re
import shutil
import sys
from typing import Optional
if '..' not in sys.path: sys.path.append('..')

import numpy as np
import pandas as pd
from pydantic_yaml import to_yaml_file, parse_yaml_file_as
from tqdm import tqdm

from mllm.utils.utils import write_tsv, read_tsv
from s_04_run_01_ranker_embs import RunInfo

In [8]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
EMBS_DS_PATH = DATA_PATH / 'ranker_embs_msmarco_fever'
DST_PATH = DATA_PATH / 'ranker_embs_msmarco_fever_part'
DST_PATH.mkdir(exist_ok=True)

In [13]:
run_info_fpath = EMBS_DS_PATH / 'run_info.yaml'
docs_ids_fpath = EMBS_DS_PATH / 'docs_ids.tsv'
qs_ids_fpath = EMBS_DS_PATH / 'qs_ids.tsv'
qrels_fpath = EMBS_DS_PATH / 'qrels.tsv'
docs_embs_fpath = EMBS_DS_PATH / 'docs_embs.npy'
qs_embs_fpath = EMBS_DS_PATH / 'qs_embs.npy'

In [15]:
df_docs_ids = read_tsv(docs_ids_fpath)
n_docs = len(df_docs_ids)
n_docs_part = int(n_docs * 0.05)
df_docs_ids = df_docs_ids.iloc[:n_docs_part].copy()
df_docs_ids

Unnamed: 0,ds_id,ds_doc_id,doc_emb_id
0,1,0,0
1,1,0,1
2,1,0,2
3,1,0,3
4,1,0,4
...,...,...,...
3488391,1,188306,3488391
3488392,1,188306,3488392
3488393,1,188306,3488393
3488394,1,188306,3488394


In [20]:
ds_docs_ids = np.unique(df_docs_ids['ds_doc_id'])

In [19]:
df_qrels = read_tsv(qrels_fpath)
df_qrels.set_index('dsqid', inplace=True)
df_qrels = df_qrels.loc[ds_doc_ids].copy()
df_qrels

Unnamed: 0_level_0,qid,did,dsid,dsdid
dsqid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,1185869,59219,1,2376038
1,1185868,59235,1,100673
2,1183785,576811,1,1201976
3,645590,576840,1,1232741
4,186154,114789,1,403227
...,...,...,...,...
188302,848127,1053563,1,2589733
188303,950430,2127827,1,1115121
188304,255570,2127829,1,938911
188305,986614,183106,1,79644


In [21]:
df_qrels.reset_index(drop=False, inplace=True)
ds_qs_ids = np.unique(df_qrels['dsqid'])

In [29]:
df_qs_ids = read_tsv(qs_ids_fpath)
df_qs_ids.set_index('ds_query_id', inplace=True)
df_qs_ids = df_qs_ids.loc[ds_qs_ids].copy()
df_qs_ids.reset_index(drop=False, inplace=True)
df_qs_ids

Unnamed: 0,ds_query_id,ds_id,query_emb_id
0,0,1,0
1,1,1,1
2,2,1,2
3,3,1,3
4,4,1,4
...,...,...,...
188302,188302,1,188302
188303,188303,1,188303
188304,188304,1,188304
188305,188305,1,188305


In [30]:
run_info = parse_yaml_file_as(RunInfo, run_info_fpath)
pprint(run_info.dict())

{'ds_dir_paths': [PosixPath('/home/misha/data/msmarco'),
                  PosixPath('/home/misha/data/fever')],
 'emb_chunk_size': 100,
 'model_fpath': PosixPath('/home/misha/data/train_mllm_ranker_qrels/ranker-20240903_215749-msmarco-fever/best.pth'),
 'n_docs': 8630403,
 'n_docs_chunks': 69767925,
 'n_qs': 495348,
 'n_qs_chunks': 495373}


In [31]:
run_info.n_docs = len(df_docs_ids['ds_doc_id'].unique())
run_info.n_docs_chunks = len(df_docs_ids)
run_info.n_qs = len(df_qs_ids['ds_query_id'].unique())
run_info.n_qs_chunks = len(df_qs_ids)
pprint(run_info.dict())

{'ds_dir_paths': [PosixPath('/home/misha/data/msmarco'),
                  PosixPath('/home/misha/data/fever')],
 'emb_chunk_size': 100,
 'model_fpath': PosixPath('/home/misha/data/train_mllm_ranker_qrels/ranker-20240903_215749-msmarco-fever/best.pth'),
 'n_docs': 188307,
 'n_docs_chunks': 3488396,
 'n_qs': 188307,
 'n_qs_chunks': 188307}


In [26]:
len(df_docs_ids['ds_doc_id'].unique())

188307

In [32]:
write_tsv(df_docs_ids, DST_PATH / docs_ids_fpath.name)
write_tsv(df_qs_ids, DST_PATH / qs_ids_fpath.name)
write_tsv(df_qrels, DST_PATH / qrels_fpath.name)

In [33]:
shutil.copy(qs_embs_fpath, DST_PATH)

'/home/misha/data/ranker_embs_msmarco_fever_part/qs_embs.npy'

In [34]:
to_yaml_file(DST_PATH / run_info_fpath.name, run_info)

In [37]:
n_bytes = len(df_docs_ids) * 256 * 4
n_chunk = 1024 * 1024
n_read = 0
with open(docs_embs_fpath, 'rb') as f_src, open(DST_PATH / docs_embs_fpath.name, 'wb') as f_dst:
    n_to_read = n_chunk
    while n_to_read > 0:
        n_to_read = min(n_bytes - n_read, n_chunk)
        bts = f_src.read(n_to_read)
        f_dst.write(bts)
        n_read += n_to_read
