In [28]:
import os
import json
import numpy as np
import torch
from tqdm.notebook import tqdm
from collections import defaultdict, Counter
from pathlib import Path
from dataclasses import dataclass
from lss_func.coil import Coil
from beir.reranking.models import CrossEncoder
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from retriever.retriever_ext import scatter as c_scatter

In [2]:
@dataclass
class ModelArg:
    token_dim: int = 768
    cls_dim: int = 768
    pooler_mode: str ="ave"
    window_size: int = 5
    cls_norm_after: bool = False
    token_norm_after: bool = False
    token_rep_relu: bool = False

In [3]:
query_path = "/groups/gcb50243/iida.h/BEIR/dataset/trec-robust04-title/queries.jsonl"
queries = []

with open(query_path) as f:
    for line in f:
        jquery = json.loads(line)["text"]
        queries.append(jquery)
        
sample_query = queries[0]

In [4]:
model_name_or_path = "/groups/gcb50243/iida.h/BEIR/model/output/microsoft/mpnet-base-v3-msmarco"

In [5]:
model_args = ModelArg()
model = Coil(model_name_or_path, model_args)
model.eval()

In [6]:
shards_dir = sorted(list(Path("/groups/gcb50243/iida.h/TREC/robust04/lss/index/doc").glob("shard*")))
share_dist_shards_dir = shards_dir[-2:]
share_dist_shards_dir

[PosixPath('/groups/gcb50243/iida.h/TREC/robust04/lss/index/doc/shard_08'),
 PosixPath('/groups/gcb50243/iida.h/TREC/robust04/lss/index/doc/shard_09')]

In [13]:
shards_dir = sorted(list(Path("/local/9229643.1.mem/doc").glob("shard*")))
ssd_shards_dir = shards_dir[:-2]
ssd_shards_dir

[PosixPath('/local/9229643.1.mem/doc/shard_00'),
 PosixPath('/local/9229643.1.mem/doc/shard_01'),
 PosixPath('/local/9229643.1.mem/doc/shard_02'),
 PosixPath('/local/9229643.1.mem/doc/shard_03'),
 PosixPath('/local/9229643.1.mem/doc/shard_04'),
 PosixPath('/local/9229643.1.mem/doc/shard_05'),
 PosixPath('/local/9229643.1.mem/doc/shard_06'),
 PosixPath('/local/9229643.1.mem/doc/shard_07')]

In [8]:
def dict_2_float(dd):
    for k in dd:
        dd[k] = dd[k].float()

In [20]:
def calc_idf_and_doclen(corpus, tokenizer, sep):
    doc_lens = []
    df = Counter()
    for cid in tqdm(corpus.keys()):
        text = corpus[cid]["title"] + sep + corpus[cid]["text"]
        input_ids = tokenizer(text)["input_ids"]
        doc_lens.append(len(input_ids))
        df.update(list(set(input_ids)))

    idf = defaultdict(float)
    N = len(corpus)
    for w, v in df.items():
        idf[w] = np.log(N / v)

    doc_len_ave = np.mean(doc_lens)
    return idf, doc_len_ave, doc_lens

In [30]:
data_path = "/groups/gcb50243/iida.h/BEIR/dataset/trec-robust04-title/"
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")

  0%|          | 0/528155 [00:00<?, ?it/s]

In [31]:
tokenizer = model.q_tokenizer
idf, doc_len_ave, doc_lens = calc_idf_and_doclen(corpus, tokenizer, sep)

  0%|          | 0/528155 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1340 > 512). Running this sequence through the model will result in indexing errors


In [14]:
shards_ivl = []
shards_shard_map = []
shards_tok_id_2_reps = []
shards_cls_ex_ids = []
for doc_shard in tqdm(ssd_shards_dir):
    all_ivl_scatter_maps = torch.load(os.path.join(doc_shard, "ivl_scatter_maps.pt"))
    all_shard_scatter_maps = torch.load(os.path.join(doc_shard, "shard_scatter_maps.pt"))
    tok_id_2_reps = torch.load(os.path.join(doc_shard, "tok_reps.pt"))
    cls_ex_ids = torch.load(os.path.join(doc_shard, "cls_ex_ids.pt"))
    dict_2_float(tok_id_2_reps)
    shards_ivl.append(all_ivl_scatter_maps)
    shards_shard_map.append(all_shard_scatter_maps)
    shards_tok_id_2_reps.append(tok_id_2_reps)
    shards_cls_ex_ids.append(cls_ex_ids)

for doc_shard in tqdm(share_dist_shards_dir):
    all_ivl_scatter_maps = torch.load(os.path.join(doc_shard, "ivl_scatter_maps.pt"))
    all_shard_scatter_maps = torch.load(os.path.join(doc_shard, "shard_scatter_maps.pt"))
    tok_id_2_reps = torch.load(os.path.join(doc_shard, "tok_reps.pt"))
    cls_ex_ids = torch.load(os.path.join(doc_shard, "cls_ex_ids.pt"))
    dict_2_float(tok_id_2_reps)
    shards_ivl.append(all_ivl_scatter_maps)
    shards_shard_map.append(all_shard_scatter_maps)
    shards_tok_id_2_reps.append(tok_id_2_reps)
    shards_cls_ex_ids.append(cls_ex_ids)


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

In [33]:
len(shards_cls_ex_ids[0])

53330

In [46]:
shard_idx = {}
for i_shard, cls_ex_ids in enumerate(shards_cls_ex_ids):
    for idx_shard, ex_id in enumerate(cls_ex_ids):
        shard_idx[ex_id] = (i_shard, idx_shard)

In [61]:
cls_ex_ids

array(['LA082390-0058', 'LA082390-0059', 'LA082390-0060', ...,
       'LA123190-0132', 'LA123190-0133', 'LA123190-0134'], dtype='<U13')

In [60]:
doc_lens = []
df = Counter()
shard_tf_idx = {}
shard_idx_doclen = {}
for i in range(len(shards_cls_ex_ids)):
    shard_tf_idx[i] = defaultdict(list)
    shard_idx_doclen[i] = {}

for cid in tqdm(corpus.keys()):
    try:
        i_shard, idx_shard = shard_idx[cid]
    except KeyError:
        continue
    text = corpus[cid]["title"] + sep + corpus[cid]["text"]
    input_ids = tokenizer(text)["input_ids"]
    doc_lens.append(len(input_ids))
    df.update(list(set(input_ids)))
    tf_d = Counter(input_ids)
    shard_idx_doclen[i_shard][idx_shard] =len(input_ids)
    for tok, freq in tf_d.items():
        shard_tf_idx[i_shard][tok].append((idx_shard, freq))
        

idf = defaultdict(float)
N = len(corpus)
for w, v in df.items():
    idf[w] = np.log(N / v)

doc_len_ave = np.mean(doc_lens)

  0%|          | 0/528155 [00:00<?, ?it/s]

In [72]:
# def bm25_score(docs_tf, idf, doc_lens, doc_len_ave):
#     ivl_bm25_scores = torch.zeros(len(doc_lens))
#     k1 = 0.9
#     b = 0.4
#     for idx, tf in docs_tf:
# #         ivl_bm25_scores[idx] = tf * (1 + k1) / (tf + k1 * (1 - b + b * doc_lens[idx] / doc_len_ave))
        
#     return ivl_bm25_scores

In [None]:
k1 = 0.9
b = 0.4

shard_bm25_shard_idx = []
shard_bm25 = []
for i in range(len(shard_tf_idx)):
    bm25_shard_idx = defaultdict(list)
    bm25_index = defaultdict(list)
    doc_lens = shard_idx_doclen[i]
    for tok, infos in shard_tf_idx[i].items():
        for idx, tf in infos:
            bm25_shard_idx[tok].append(idx)
            bm25_index[tok].append(tf * (1 + k1) / (tf + k1 * (1 - b + b * doc_lens[idx] / doc_len_ave)) * idf[tok])

    for tok in bm25_shard_idx:
        bm25_shard_idx[tok] = torch.tensor(bm25_shard_idx[tok])
        bm25_index[tok] = torch.tensor(bm25_index[tok]).to(torch.float)
        
    shard_bm25_shard_idx.append(bm25_shard_idx)
    shard_bm25.append(bm25_index)
    

In [122]:
top_k = 100
all_query_match_scores = defaultdict(float)

In [123]:
%%time

q_input = model.q_tokenizer(sample_query, return_tensors="pt")
qtok_ids = q_input["input_ids"][0].tolist()[1:-1]
with torch.no_grad():
    q_tok_reps = model.encode_corpus_raw(q_input)[1].squeeze()
    
    
for i_shard, (all_ivl_scatter_maps, all_shard_scatter_maps, tok_id_2_reps, cls_ex_ids, bm25_shard_idx, bm25_index) in enumerate(zip(shards_ivl, shards_shard_map, shards_tok_id_2_reps, shards_cls_ex_ids, shard_bm25_shard_idx, shard_bm25)):
    match_scores = torch.zeros(len(cls_ex_ids))
    tok_match_scores = torch.empty(len(cls_ex_ids))
    bm25_match_scores = torch.empty(len(cls_ex_ids))
    batched_tok_scores = []
    for q_tok_id, q_tok_rep in zip(qtok_ids, q_tok_reps):
        if q_tok_id in tok_id_2_reps:
            tok_reps = tok_id_2_reps[q_tok_id]
        else:
            continue
        tok_scores = torch.matmul(q_tok_rep, tok_reps.transpose(0, 1)).relu_()  # Bt * Ds
        batched_tok_scores.append((q_tok_id, tok_scores))

    for q_tok_id, tok_scores in batched_tok_scores:
        tok_match_scores.zero_()
        bm25_match_scores.zero_()
        
        ivl_scatter_map = all_ivl_scatter_maps[q_tok_id]
        shard_scatter_map = all_shard_scatter_maps[q_tok_id]
        
        
        bm25_scores = bm25_index[q_tok_id]
        shard_bm25_scatter_map = bm25_shard_idx[q_tok_id]
        
        ivl_maxed_scores = torch.empty(len(shard_scatter_map))
        ivl_maxed_scores.zero_()
        c_scatter.scatter_max(tok_scores.numpy(), ivl_scatter_map.numpy(), ivl_maxed_scores.numpy())
        tok_match_scores.scatter_add_(0, shard_scatter_map, ivl_maxed_scores)
        bm25_match_scores.scatter_add_(0, shard_bm25_scatter_map, bm25_scores)
        
        match_scores += tok_match_scores * bm25_match_scores

        top_scores, top_iids = match_scores.topk(top_k)
        for iid, score in zip(top_iids, top_scores):
            all_query_match_scores[cls_ex_ids[iid.item()]] += float(score.item())

top_result = sorted(all_query_match_scores.items(), key=lambda x: -x[1])

CPU times: user 1.1 s, sys: 0 ns, total: 1.1 s
Wall time: 70.4 ms


In [85]:
# %%time
# for i_shard, (all_ivl_scatter_maps, all_shard_scatter_maps, tok_id_2_reps, cls_ex_ids) in enumerate(zip(shards_ivl, shards_shard_map, shards_tok_id_2_reps, shards_cls_ex_ids)):
#     match_scores = torch.zeros(len(cls_ex_ids))
#     tok_match_scores = torch.empty(len(cls_ex_ids))
#     batched_q_tok_id = []
#     for q_tok_id, q_tok_rep in zip(qtok_ids, q_tok_reps):
#         if q_tok_id in tok_id_2_reps:
#             batched_q_tok_id.append(q_tok_id)
        
#     for q_tok_id in batched_q_tok_id:
#         tok_match_scores.zero_()
        
#         ivl_scatter_map = all_ivl_scatter_maps[q_tok_id]
#         shard_scatter_map = all_shard_scatter_maps[q_tok_id]
        
#         docs_tf = shard_tf_idx[i_shard][q_tok_id]
#         doc_lens = shard_idx_doclen[i_shard]
#         tok_idf = idf[q_tok_id]
#         ivl_bm25_score = bm25_score(docs_tf, tok_idf, doc_lens, doc_len_ave)
#         tok_match_scores *= ivl_bm25_score
#         match_scores += tok_match_scores

CPU times: user 5.89 s, sys: 2.94 ms, total: 5.89 s
Wall time: 825 ms


In [94]:
top_result

[(19195, 3.170194685459137),
 (26355, 3.01444411277771),
 (26410, 2.990734577178955),
 (41239, 2.9897618293762207),
 (374, 2.803445965051651),
 (52029, 2.793766051530838),
 (50219, 2.7628308534622192),
 (22009, 2.7500077188014984),
 (9800, 2.74575611948967),
 (52193, 2.653989404439926),
 (21765, 2.59348201751709),
 (41379, 2.5853949189186096),
 (31261, 2.5679527521133423),
 (19642, 2.5425510108470917),
 (21956, 2.53891122341156),
 (5409, 2.51525154709816),
 (41390, 2.4984427094459534),
 (48409, 2.4830699265003204),
 (46593, 2.4389002323150635),
 (23981, 2.4345113337039948),
 (19416, 2.4258944392204285),
 (366, 2.4141756296157837),
 (39875, 2.391512632369995),
 (27767, 2.3912524580955505),
 (1878, 2.374084383249283),
 (27455, 2.3670108914375305),
 (21760, 2.3480354845523834),
 (22716, 2.3474926948547363),
 (9825, 2.3443615436553955),
 (24320, 2.3373753428459167),
 (17752, 2.297724574804306),
 (21774, 2.2869341373443604),
 (1631, 2.2856321930885315),
 (51780, 2.237980991601944),
 (18503,

In [None]:
del shards_ivl
del shards_shard_map
del shards_tok_id_2_reps
del shards_cls_ex_ids

In [124]:
dense_index_path = "/groups/gcb50243/iida.h/TREC/robust04/lss/index/dense_index.npy"
dense_index = np.load(dense_index_path)

In [125]:
dense_model_path = "/groups/gcb50243/iida.h/BEIR/model/output/microsoft/mpnet-base-v3-msmarco/"
model = models.SentenceBERT(dense_model_path)

In [128]:
%%timeit

q_rep = model.encode_queries(sample_query)

scores = np.dot(dense_index, q_rep)
top_iids = np.argsort(-scores)[:top_k]

136 ms ± 9.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
del dense_index

In [129]:
ce_model = CrossEncoder('cross-encoder/ms-marco-electra-base')

In [135]:
list_corpus = []
all_ids = []
with open(corpus_path) as f:
    for line in f:
        jline = json.loads(line)
        list_corpus.append(jline["title"] + " " + jline["text"])
        all_ids.append(jline["_id"])

In [None]:
%%time

input_ce = []
for c in tqdm(corpus):
    input_ce.append((sample_query, c))
    
scores = ce_model.predict(input_ce)
top_iids = np.argsort(-scores)[:top_k]

  0%|          | 0/528155 [00:00<?, ?it/s]

Batches:   0%|          | 0/16505 [00:00<?, ?it/s]