In [1]:
import pandas as pd
import zarr
from pathlib import Path
import math
from tqdm import tqdm
import numpy as np
from scipy.sparse import csr_matrix, vstack
import gc

In [13]:
emb_path = "result/bert_768_wta_2048_0.05_recover-task_8192_0.0005"
# emb_path = "dense/bert"
data_dir = Path("E:\Data\msmarco-passages")

# Load Data

## Query

In [14]:
q_ids = pd.read_csv(
    data_dir / "runs/queries.dev.small.tsv", sep="\t", names=["id", "query"]
).id.values

In [15]:
z = zarr.open(str(data_dir / "queries.eval.zarr"))
if 'result' in emb_path:
    q_embs = csr_matrix(z[emb_path])
else:
    q_embs = z[emb_path][:]

In [5]:
# z = zarr.open('../msmarco-passages/queries.eval.zarr')
# q_ids_all = z.id[:]
# sorter = np.argsort(q_ids_all)
# indices = sorter[np.searchsorted(q_ids_all, q_ids, sorter=sorter)]
# q_embs = z.dense["bert"].oindex[indices.tolist(), :][:]

## Passages

In [39]:
CHUNK_SIZE = 8192 * 10

z = zarr.open(str(data_dir / "docs.eval.zarr"))
if "result" in emb_path:
    z_emb = z[emb_path]
    n_chunks = math.ceil(z_emb.shape[0] / CHUNK_SIZE)
    p_embs = None

    for i in tqdm(range(n_chunks)):
        next_p_embs = csr_matrix(z[emb_path][i * CHUNK_SIZE : (i + 1) * CHUNK_SIZE])
        if i == 0:
            p_embs = next_p_embs.copy()
        else:
            p_embs = vstack([p_embs, next_p_embs])
    # coo to csr
    p_embs = p_embs.tocsr()
else:
    p_embs = z[emb_path][:]
p_ids = z.id[:]
p_text = z.text

100%|██████████████████████████████████████████████████████████████████████████████████| 47/47 [01:27<00:00,  1.87s/it]


# Retrieve

In [22]:
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import sklearn.preprocessing as pp


def sim_binary(q, docs):
    return (q & docs).sum(axis=1)


def sim_sparse(q, docs):
    q = pp.normalize(q)
    docs = pp.normalize(docs, axis=1)
    return (q * docs.T).toarray()


def sim_dense(q, docs):
    return cosine_similarity(docs, q.reshape(1, -1))

In [None]:
emb_type, emb_name = emb_path.split("/")
sim_func = {
    "binary": sim_binary,
    "sparse": sim_sparse,
    "result": sim_sparse,
    "dense": sim_dense,
}[emb_type]


def search(q_emb, p_embs, p_ids, topk=1000):
    scores = sim_func(p_embs, q_emb)
    indices = np.argsort(scores.squeeze())[::-1][:topk]
    return p_ids[indices]


with open(
    f"./runs/run.msmarco-passage.dev.small.{emb_type}.{emb_name}.tsv",
    "w",
    encoding="utf-8",
) as f:
    for q_id, q_emb in tqdm(zip(q_ids, q_embs)):
        topk = search(q_emb, p_embs, p_ids)
        for i, tk in enumerate(topk):
            f.write(f"{q_id}\t{tk}\t{i+1}\n")

417it [19:39,  2.31s/it]