In [9]:
import heapq

import numpy as np
import pandas as pd
from src.preprocess.tf_idf_vectors import load_vectorizer as load_tfidf
from src.preprocess.tf_idf_vectors import read_jsonl
from src.preprocess.fasttext_vectors import load_fasttext_vectors, tokenize
from src.config.paths import (
    poleval2022_questions_path,
    poleval2022_subdataset_dir,
)

In [10]:
tf = load_tfidf("wiki-trivia")

In [11]:
def topk_cosine_sparse_rows(q_vec, X, topk=5, chunk_size=50_000, assume_l2_normalized=True, eps=1e-12):
    """
    q_vec: (1, D) sparse
    X:     (N, D) sparse
    Returns: list[(score, row_index)] sorted desc
    """
    heap = []

    # only needed if not normalized
    if not assume_l2_normalized:
        q_norm = np.sqrt(q_vec.multiply(q_vec).sum()) + eps

    N = X.shape[0]
    for start in range(0, N, chunk_size):
        end = min(start + chunk_size, N)
        Xb = X[start:end]

        dots = Xb @ q_vec.T
        sims = dots.toarray().flatten()

        if not assume_l2_normalized:
            Xb_norm = np.sqrt(Xb.multiply(Xb).sum(axis=1)).A1 + eps
            sims = sims / (Xb_norm * q_norm)

        k_local = min(topk, sims.size)
        if k_local == 0:
            continue

        idx_local = np.argpartition(-sims, k_local - 1)[:k_local]
        for i in idx_local:
            item = (float(sims[i]), start + int(i))
            if len(heap) < topk:
                heapq.heappush(heap, item)
            else:
                heapq.heappushpop(heap, item)

    heap.sort(reverse=True)
    return heap


def retrieve_topk_tfidf(question_text, vectorizer, passages_matrix, k=5, chunk_size=50_000):
    q_vec = vectorizer.transform([question_text])
    assume_norm = (getattr(vectorizer, "norm", None) == "l2")  # default True for TfidfVectorizer

    top = topk_cosine_sparse_rows(
        q_vec=q_vec,
        X=passages_matrix,
        topk=k,
        chunk_size=chunk_size,
        assume_l2_normalized=assume_norm,
    )

    scores = [s for s, _ in top]
    idx = [i for _, i in top]

    results = pd.Series(scores, index=idx)
    return results

In [12]:
dataset_id = "piotr-rybak__poleval2022-passage-retrieval-dataset"
subdataset = "wiki-trivia"
split = "train"  # "train" or "test"

subdataset_dir = poleval2022_subdataset_dir(dataset_id, subdataset)
questions_path = poleval2022_questions_path(dataset_id, subdataset, split)
questions_df = read_jsonl(questions_path).set_index('id')

In [13]:
k = 10
qid = 12
question_text = questions_df.loc[qid, "text"]
scores_by_row = retrieve_topk_tfidf(
    question_text=question_text,
    vectorizer=tf["vectorizer"],
    passages_matrix=tf["matrix"],
    k=k,
    chunk_size=50_000,
)

In [14]:
top_passage_ids = tf["passage_ids"][scores_by_row.index.to_numpy()]

In [15]:
pd.Series(scores_by_row.to_numpy(dtype=np.float32), index=top_passage_ids)

5406-19       0.412167
2224431-0     0.406479
445532-3      0.390036
1352589-29    0.335980
3265537-6     0.334461
3031599-30    0.331957
4468-0        0.322450
3873232-7     0.312494
5045784-0     0.311857
3031599-29    0.307448
dtype: float32

In [None]:
# Vectorized TF-IDF retrieval + evaluation helper (reusable for FAISS later)
from src.eval.retrieval_eval import evaluate_and_write_submission, retrieve_tfidf_topk

dataset_id = "piotr-rybak__poleval2022-passage-retrieval-dataset"
subdataset = "wiki-trivia"
split = "test"
k = 10

result = evaluate_and_write_submission(
    dataset_id=dataset_id,
    subdataset=subdataset,
    questions_split=split,
    pairs_split=split,  # set None if you want "submission only"
    k=k,
    retriever=lambda texts, k: retrieve_tfidf_topk(
        vectorizer=tf["vectorizer"],
        passages_matrix=tf["matrix"],
        passage_ids=tf["passage_ids"],
        query_texts=texts,
        k=k,
        chunk_size=10_000,
    ),
)

print("Wrote:", result.out_path)
if result.hits_at_k is not None:
    print(f"Hits@{k}: {result.hits_at_k:.4f}")
    print(f"MRR@{k}:  {result.mrr_at_k:.4f}")

result

TF-IDF test retrieval chunks: 6639839/6639839
Wrote: /home/mateusz/dev/inl_pjatk_project/.cache/submissions/tfidf_wiki-trivia_questions-test.tsv
Hits@10: 0.3757 (485/1291)
MRR@10:  0.1866


PosixPath('/home/mateusz/dev/inl_pjatk_project/.cache/submissions/tfidf_wiki-trivia_questions-test.tsv')