In [2]:
import json
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from lss_func.coil import Coil
from beir.reranking.models import CrossEncoder
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval import models

In [None]:
@dataclass
class ModelArg:
    token_dim: int = 768
    cls_dim: int = 768
    pooler_mode: str ="ave"
    window_size: int = 5
    cls_norm_after: bool = False
    token_norm_after: bool = False
    token_rep_relu: bool = False

In [None]:
query_path = ""
queries = []

with open(query_path) as f:
    for line in f:
        jquery = json.loads(line)["text"]
        queries.append(jquery)
        
sample_query = queries[0]

In [None]:
model_name_or_path = ""

In [None]:
model_args = ModelArg()
model = Coil(model_name_or_path, model_args)
model.eval()

In [None]:
shards_dir = Path("").glob("**/")
shards_ivl = []
shards_shard_map = []
shards_tok_id_2_reps = []
shards_cls_ex_ids = []
for doc_shard in shards_dir:
    all_ivl_scatter_maps = torch.load(os.path.join(doc_shard, "ivl_scatter_maps.pt"))
    all_shard_scatter_maps = torch.load(os.path.join(doc_shard, "shard_scatter_maps.pt"))
    tok_id_2_reps = torch.load(os.path.join(doc_shard, "tok_reps.pt"))
    cls_ex_ids = torch.load(os.path.join(doc_shard, "cls_ex_ids.pt"))
    tok_id_2_reps = dict_2_float(tok_id_2_reps)
    shards_ivl.append(all_ivl_scatter_maps)
    shards_shard_map.append(all_shard_scatter_maps)
    shards_tok_id_2_reps.append(tok_id_2_reps)
    shards_cls_ex_ids.append(cls_ex_ids)


In [None]:
top_k = 100
all_query_match_scores = defaultdict(float)

In [None]:
%time

q_input = model.tokenizer(sample_query, retern_tensor="pt")
qtok_ids = q_input["input_ids"]
q_tok_reps = model.encode_corpus_raw(q_input)

for all_ivl_scatter_maps, all_shard_scatter_maps, tok_id_2_reps, cls_ex_ids in zip(shards_ivl, shards_shard_map, shards_tok_id_2_reps, shards_cls_ex_ids):
    match_scores = torch.zeros(len(cls_ex_ids))
    batched_tok_scores = []
    for q_tok_id, q_tok_rep in zip(qtok_ids, q_tok_reps):
        tok_reps = tok_id_2_reps[q_tok_id]
        tok_scores = torch.matmul(q_tok_rep, tok_reps.transpose(0, 1)).relu_()  # Bt * Ds
        batched_tok_scores.append(tok_scores)

    for i, q_tok_id in enumerate(qtok_ids):
        ivl_scatter_map = all_ivl_scatter_maps[q_tok_id]
        shard_scatter_map = all_shard_scatter_maps[q_tok_id]

        tok_scores = batched_tok_scores[i]
        ivl_maxed_scores = torch.empty(len(shard_scatter_map))

        for j in range(tok_scores.size(0)):
            ivl_maxed_scores.zero_()
            c_scatter.scatter_max(tok_scores[j].numpy(), ivl_scatter_map.numpy(), ivl_maxed_scores.numpy())
            match_scores.scatter_add_(0, shard_scatter_map, ivl_maxed_scores)

        top_scores, top_iids = match_scores.topk(top_k)
        for iid, score in zip(top_iids, top_scores):
            all_query_match_scores[iid] += score        

top_result = sorted(all_query_match_scores.items(), key=lambda x: -x[1])

In [None]:
del shards_ivl
del shards_shard_map
del shards_tok_id_2_reps
del shards_cls_ex_ids

In [None]:
dense_index_path = ""
dense_index = np.load(dense_index_path)

In [None]:
dense_model_path = ""
model = models.SentenceBERT(dense_model_path)

In [None]:
%time

q_input = model.tokenizer(sample_query, retern_tensor="pt")
qtok_ids = q_input["input_ids"]
q_rep = model.encode_queries(q_input)

scores = np.dot(dense_indes, q_rep)
top_iids = np.argsort(-scores)[:top_k]

In [None]:
del dense_index

In [None]:
ce_model = CrossEncoder('cross-encoder/ms-marco-electra-base')

In [None]:
corpus_path = ""
corpus = []
all_ids = []
with open(corpus_path) as f:
    for line in f:
        jline = json.loads(line)
        corpus.append(" ".join(jline["title"], jline["text"]))
        all_ids.append(jline["_iod"])

In [None]:
% time

input_ce = []
for c in corpus:
    input_ce.append((sample_query, c))
    
scores = ce_model.predict(inpput_ce)
top_iids = np.argsort(-scores)[:top_k]