In [1]:
from datasets import load_dataset

In [2]:
dataset_name = "trec-covid"
docs = load_dataset("Cohere/beir-embed-english-v3", f"{dataset_name}-corpus", split="train")
queries = load_dataset("Cohere/beir-embed-english-v3", f"{dataset_name}-queries", split="test")
qrels = load_dataset("Cohere/beir-embed-english-v3", f"{dataset_name}-qrels", split="test")

In [None]:
m = len(queries[0]["emb"])
print("dimension:", m)
print("# of docs:", len(docs))
print("# of queries:", len(queries))

dimension: 1024
# of docs: 171332
# of queries: 50


In [None]:
scale = 2**32

In [None]:
int_docs: list[dict[str, str | list[int]]] = []
i = 0
for doc in docs:
    doc_id = doc["_id"]
    emb = doc["emb"]
    int_emb = [int(x * scale) for x in emb]
    int_docs.append({"_id": doc_id, "emb": int_emb})
    i += 1
    if i % 10000 == 0:
        print(f"Processed {i} docs")
print("doc 0:", int_docs[0])

Processed 10000 docs
Processed 20000 docs
Processed 30000 docs
Processed 40000 docs
Processed 50000 docs
Processed 60000 docs
Processed 70000 docs
Processed 80000 docs
Processed 90000 docs
Processed 100000 docs
Processed 110000 docs
Processed 120000 docs
Processed 130000 docs
Processed 140000 docs
Processed 150000 docs
Processed 160000 docs
Processed 170000 docs
doc 0: {'_id': 'ug7v899j', 'emb': [73138176, -13017088, -198443008, -209453056, -2287616, -145227776, -117899264, -65699840, -178388992, 21381120, -101777408, 155713536, -133103616, -84541440, 85917696, -177471488, 73138176, -58228736, 29622272, -35487744, -7749632, -4030464, 88735744, 91095040, -2121728, -115146752, 146145280, -50397184, -63471616, -57049088, 12337152, 105709568, 134742016, 102891520, -116195328, -61177856, 45613056, -13795328, -115408896, 261619712, 19873792, 115277824, 677904384, -225312768, 268042240, 3774464, -3420160, -297795584, -132055040, -191365120, 73662464, -157548544, 120586240, -50626560, 1487872,

In [None]:
int_queries: list[dict[str, str | list[int]]] = []
for query in queries:
    query_id = query["_id"]
    emb = query["emb"]
    int_emb = [int(x * scale) for x in emb]
    int_queries.append({"_id": query_id, "emb": int_emb})
print("query 0:", int_queries[0])

query 0: {'_id': '6', 'emb': [-146276352, -69992448, -50528256, -164757504, -92995584, 36208640, -217841664, 57802752, -132972544, 139853824, 179568640, 87228416, -11403264, -107413504, -60456960, 27885568, 70385664, -6619136, -7901184, -92733440, -40534016, -92602368, 219152384, 226885632, 16842752, -8060928, -85196800, -149684224, -108658688, -101580800, -74252288, 167510016, 198311936, 72613888, -244842496, 194510848, -207618048, 137363456, -61997056, 87162880, 28049408, -80019456, 346292224, -138280960, 220200960, -106102784, 25788416, -448528384, -97124352, 248119296, 85196800, -98500608, 70254592, -9273344, 31162368, -61374464, 41517056, -101842944, 133234688, -116654080, -144441344, -20250624, 165937152, 61210624, -248381440, 38600704, 99680256, -20201472, 295698432, -221642752, -124452864, 301989888, 27017216, -135921664, -147718144, -73138176, 221249536, -1501184, 35684352, -106037248, 164495360, 97124352, 277610496, 80609280, -43581440, 106299392, -80150528, -118685696, 60522

In [7]:
def distance(lhs, rhs):
    dot_prod = 0
    for i in range(m):
        prod = lhs[i] * rhs[i]
        trun_prod = prod >> 32
        dot_prod += trun_prod
    return dot_prod

In [22]:
def query_topk(query_emb, k):
    topk: list[tuple[int, str]] = []  # (distance, doc_id)
    for doc in int_docs:
        doc_id = doc["_id"]
        doc_emb = doc["emb"]
        dist = distance(query_emb, doc_emb)
        if len(topk) < k:
            topk.append((dist, doc_id))
            topk.sort(reverse=True)
        else:
            if dist > topk[-1][0]:
                topk[-1] = (dist, doc_id)
                topk.sort(reverse=True)
    return [doc_id for (_, doc_id) in topk]

In [None]:
print("max score:", max([qrel["score"] for qrel in qrels]))

max score: 2


In [10]:
# qrel_map = {}
# for qrel in qrels:
#     query_id = qrel['query_id']
#     doc_id = qrel['corpus_id']
#     score = qrel['score']

#     if query_id not in qrel_map:
#         qrel_map[query_id] = {}
#     if score > 0 and score not in qrel_map[query_id]:
#         qrel_map[query_id][score] = []
#     if score > 0:
#         qrel_map[query_id][score].append(doc_id)
# print('qrel for query 1:', qrel_map['1'])

In [11]:
# for query_id, info in qrel_map.items():
#     print('q', query_id, 'score 2:', len(info.get(2, [])), 'score 1:', len(info.get(1, [])))

In [None]:
qrel_idx = {}
for qrel in qrels:
    query_id = qrel["query_id"]
    doc_id = qrel["corpus_id"]
    score = qrel["score"]

    if query_id not in qrel_idx:
        qrel_idx[query_id] = {}
    if score > 0:
        qrel_idx[query_id][doc_id] = score
print("qrel idx for query 1 doc 005b2j4b:", qrel_idx["1"]["005b2j4b"])

qrel idx for query 1 doc 005b2j4b: 2


In [23]:
def query_float_topk(query_emb, k):
    topk: list[tuple[float, str]] = []  # (distance, doc_id)
    for doc in docs:
        doc_id = doc["_id"]
        doc_emb = doc["emb"]
        dist = sum(qe * de for qe, de in zip(query_emb, doc_emb))
        if len(topk) < k:
            topk.append((dist, doc_id))
            topk.sort(reverse=True)
        else:
            if dist > topk[-1][0]:
                topk[-1] = (dist, doc_id)
                topk.sort(reverse=True)
    return [doc_id for (_, doc_id) in topk]

In [None]:
from multiprocessing import Pool

args_list = []
for i in range(7):
    for k in [16, 32, 64, 128, 256, 512, 1024]:
        args_list.append((int_queries[i]["emb"], queries[i]["emb"], k))


def f(x):
    int_emb, emb, k = x
    int_res = query_topk(int_emb, k)
    float_res = query_float_topk(emb, k)
    return len(set(int_res) & set(float_res)) / k


with Pool(20) as p:
    recall_res = p.map(f, args_list)

for i in range(7):
    print(f"Query idx {i}: {recall_res[i * 7 : (i + 1) * 7]}")

Query idx 0: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Query idx 1: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Query idx 2: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Query idx 3: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Query idx 4: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Query idx 5: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
Query idx 6: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]


In [None]:
from multiprocessing import Pool

args_list = []
for i in range(7):
    for k in [16, 32, 64, 128, 256, 512, 1024]:
        args_list.append((int_queries[i]["_id"], int_queries[i]["emb"], k))


def f(x):
    query_id, int_emb, k = x
    int_res = query_topk(int_emb, k)
    score_sum = 0
    for doc_id in int_res:
        doc_map = qrel_idx[query_id]
        if doc_id in doc_map:
            score_sum += doc_map[doc_id]
    return score_sum / k


with Pool(20) as p:
    rel_score_res = p.map(f, args_list)

for i in range(7):
    print(f"Query idx {i}: {rel_score_res[i * 7 : (i + 1) * 7]}")

Query idx 0: [1.9375, 1.71875, 1.546875, 1.390625, 1.20703125, 0.982421875, 0.6923828125]
Query idx 1: [1.9375, 1.84375, 1.71875, 1.5859375, 1.27734375, 0.91015625, 0.578125]
Query idx 2: [1.25, 1.28125, 1.15625, 0.8984375, 0.89453125, 0.748046875, 0.505859375]
Query idx 3: [1.875, 1.84375, 1.515625, 1.2578125, 0.84375, 0.552734375, 0.3798828125]
Query idx 4: [1.75, 1.65625, 1.421875, 1.0703125, 0.8359375, 0.533203125, 0.30078125]
Query idx 5: [1.625, 1.75, 1.609375, 1.140625, 0.7890625, 0.486328125, 0.2705078125]
Query idx 6: [2.0, 1.96875, 1.890625, 1.8359375, 1.4921875, 0.92578125, 0.494140625]


In [26]:
import csv

In [29]:
query_ids = [queries[i]["_id"] for i in range(7)]

In [30]:
# Recall
with open("recall.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["query_id", "k", "recall"])
    for i in range(7):
        for j, k in enumerate([16, 32, 64, 128, 256, 512, 1024]):
            recall = recall_res[i * 7 + j]
            writer.writerow([query_ids[i], k, recall])

In [31]:
# Rel scores
with open("rel_scores.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["query_id", "k", "rel_score"])
    for i in range(7):
        for j, k in enumerate([16, 32, 64, 128, 256, 512, 1024]):
            rel_score = rel_score_res[i * 7 + j]
            writer.writerow([query_ids[i], k, rel_score])