# 1. Setup

In [2]:
!pip install faiss-gpu -qq

In [4]:
!git -C ColBERT/ pull || git clone https://github.com/stanford-futuredata/ColBERT.git
!git clone https://github.com/ghanahmada/tp4-tbi-be.git

fatal: cannot change to 'ColBERT/': No such file or directory
Cloning into 'ColBERT'...
remote: Enumerating objects: 2835, done.[K
remote: Counting objects: 100% (1169/1169), done.[K
remote: Compressing objects: 100% (353/353), done.[K
remote: Total 2835 (delta 948), reused 817 (delta 816), pack-reused 1666 (from 2)[K
Receiving objects: 100% (2835/2835), 2.07 MiB | 24.09 MiB/s, done.
Resolving deltas: 100% (1785/1785), done.
Cloning into 'tp4-tbi-be'...
remote: Enumerating objects: 33, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 33 (delta 0), reused 4 (delta 0), pack-reused 28 (from 1)[K
Unpacking objects: 100% (33/33), 114.31 MiB | 14.14 MiB/s, done.
Updating files: 100% (17/17), done.


In [113]:
import os
import sys
import json
import gdown
import random
import zipfile
import logging
from tqdm import tqdm
from typing import List, Any, Dict, Tuple
from genericpath import isdir

import numpy as np
import pandas as pd

sys.path.insert(0, 'ColBERT/')
import colbert
from colbert import Trainer
from colbert.data import Queries
from colbert import Indexer, Searcher
from colbert.data import Queries, Collection
from colbert.utils.utils import print_message
from colbert.data.collection import Collection
from colbert.modeling.checkpoint import Checkpoint
from colbert.indexing.index_saver import IndexSaver
from colbert.search.index_storage import IndexScorer
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.infra.launcher import Launcher, print_memory_stats
from colbert.indexing.collection_encoder import CollectionEncoder
from colbert.indexing.collection_indexer import CollectionIndexer

import torch

import faiss
assert faiss.get_num_gpus() > 0

import os
from langdetect import detect
import torch.multiprocessing as mp

In [6]:
try:
    import google.colab
    !pip install -U pip
    !pip install -e ColBERT/['faiss-gpu','torch']
except Exception:
  import sys; sys.path.insert(0, 'ColBERT/')
  try:
    from colbert import Indexer, Searcher
  except Exception:
    print("If you're running outside Colab, please make sure you install ColBERT in conda following the instructions in our README. You can also install (as above) with pip but it may install slower or less stable faiss or torch dependencies. Conda is recommended.")
    assert False

In [7]:
train_config = {
    'triples_path': '/kaggle/working/triples.train.small.jsonl',
    'queries_path': '/kaggle/working/queries.train.small.tsv',
    'collection_path':  '/kaggle/working/collection.tsv',
    'root_path': '/kaggle/working/experiments',
    'experiment_name': 'msmarco',
    'model_checkpoint': 'google-bert/bert-base-cased',
    'checkpoint_path': None,
    'nranks': 1
}

In [8]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed = 42
set_seed(seed)

# 2. Dataset

In [10]:
def escape_tsv(value):
    """Escape tabs and newlines in a string for safe TSV writing."""
    return value.replace('\t', '\\t').replace('\n', '\\n')

def unescape_tsv(value):
    """Unescape tabs and newlines in a string after reading from TSV."""
    return value.replace('\\t', '\t').replace('\\n', '\n')

def get_unique_ordered_list(original_list):
    seen = set()
    unique_list = []
    for item in original_list:
        if item not in seen:
            unique_list.append(item)
            seen.add(item)
    return unique_list

def get_example_data(n_sample: int):
    qrels = pd.read_parquet('/kaggle/working/tp4-tbi-be/experiment/data/train_qrels.parquet')
    qrels.rename(columns={"document": "answer"}, inplace=True)
    
    if "is_train" in qrels.columns:
        sampled_data = qrels[qrels["is_train"] == True][["answer", "query", "label"]].copy()
    else:
        sampled_data = qrels[["answer", "query", "label"]]
    sampled_data["answer"] = sampled_data["answer"].apply(lambda row: row.replace("\t", ""))
    sampled_data["answer"] = sampled_data["answer"].apply(lambda row: " ".join(row.split()[:500]))
    display(sampled_data)
    print(f"loaded data with {len(sampled_data)} rows")

    questions = [{"qid":i, "query": item} for i, item in 
                 enumerate(get_unique_ordered_list(sampled_data["query"].tolist()))]
    inv_questions = {item:i for i, item in 
                 enumerate(get_unique_ordered_list(sampled_data["query"].tolist()))}
    passages = [{"pid":i, "passage": item} for i, item in 
                 enumerate(get_unique_ordered_list(sampled_data["answer"].tolist()))]
    inv_passages = {item:i for i, item in 
                 enumerate(get_unique_ordered_list(sampled_data["answer"].tolist()))}
    
    triples = []
    for q in sampled_data["query"].unique():
        curr_df = sampled_data[sampled_data["query"] == q]

        for idx in range(1, 10):
            tup = [inv_questions[q], 
                   inv_passages[curr_df["answer"].iloc[0]], 
                   inv_passages[curr_df["answer"].iloc[idx]]]
            triples.append(tup)
    
    labels = sampled_data["label"].tolist()
    
    return questions, passages, triples

In [11]:
def setup_training(n_sample, triples_path, queries_path, collection_path, root_path):
    questions, passages, triples = get_example_data(n_sample=n_sample)
    
    with open('/kaggle/working/triples.train.small.jsonl', 'w') as f:
        for item in triples:
            f.write(json.dumps(item) + '\n')

    with open('/kaggle/working/queries.train.small.tsv', 'w') as f:
        for item in questions:
            f.write(f"{item['qid']}\t{escape_tsv(item['query'])}\n")

    with open('/kaggle/working/collection.tsv', 'w') as f:
        for item in passages:
            f.write(f"{str(item['pid'])}\t{escape_tsv(item['passage'])}\n")

# 3. Train ColBERT

In [15]:
!mkdir -p {root_path}/checkpoint

def train_colbert(triples_path, 
                  queries_path, 
                  collection_path, 
                  root_path, 
                  experiment_name,
                  model_checkpoint,
                  n_sample=10_000):
    
    setup_training(n_sample, 
                   triples_path, 
                   queries_path, 
                   collection_path,
                   root_path)
    
    with Run().context(RunConfig(nranks=1, experiment=experiment_name)):

        colbert_config = ColBERTConfig(
            bsize=8,
            query_maxlen=64,
            doc_maxlen=512, 
            dim=256, 
            root=root_path,
        )
        
        trainer = Trainer(
            triples=triples_path,
            queries=queries_path,
            collection=collection_path,
            config=colbert_config,
        )

        checkpoint_path = trainer.train(checkpoint=model_checkpoint)

        print(f"Saved checkpoint to {checkpoint_path}...")
        
    
train_colbert(triples_path=train_config['triples_path'], 
              queries_path=train_config['queries_path'], 
              collection_path=train_config['collection_path'], 
              root_path=train_config['root_path'], 
              experiment_name=train_config['experiment_name'],
              model_checkpoint=train_config['model_checkpoint'],
              n_sample=60_000)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  sampled_data["answer"] = sampled_data["answer"].apply(lambda row: row.replace("\t", ""))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  sampled_data["answer"] = sampled_data["answer"].apply(lambda row: " ".join(row.split()[:500]))


Unnamed: 0,answer,query,label
0,"When deciding whether to buy, sell, or hold a ...",How does Zacks Rank compare to average brokera...,1
1,Looking today at week-over-week shares outstan...,How does Zacks Rank compare to average brokera...,0
2,"For a Reuters live blog on U.S., UK and Europe...",How does Zacks Rank compare to average brokera...,0
3,(RTTNews) - Reata Pharmaceuticals Inc. (RETA) ...,How does Zacks Rank compare to average brokera...,0
4,"CHENNAI/BENGALURU, Oct 31 (Reuters) - India's ...",How does Zacks Rank compare to average brokera...,0
...,...,...,...
59995,"Updates with closing prices TOKYO, Dec 5 (Reut...",How are institutional investors changing their...,0
59996,By Jeffrey Dastin Nov 20 (Reuters) - OpenAI's ...,How are institutional investors changing their...,0
59997,(RTTNews) - Paychex Inc. (PAYX) will host a co...,How are institutional investors changing their...,0
59998,(RTTNews) - Adobe Inc. (ADBE) will host a conf...,How are institutional investors changing their...,0


loaded data with 60000 rows
#> Starting...




nranks = 1 	 num_gpus = 1 	 device=0
{
    "query_token_id": "[unused0]",
    "doc_token_id": "[unused1]",
    "query_token": "[Q]",
    "doc_token": "[D]",
    "ncells": null,
    "centroid_score_threshold": null,
    "ndocs": null,
    "load_index_with_mmap": false,
    "index_path": null,
    "index_bsize": 64,
    "nbits": 1,
    "kmeans_niters": 4,
    "resume": false,
    "pool_factor": 1,
    "clustering_mode": "hierarchical",
    "protected_tokens": 0,
    "similarity": "cosine",
    "bsize": 8,
    "accumsteps": 1,
    "lr": 3e-6,
    "maxsteps": 500000,
    "save_every": null,
    "warmup": null,
    "warmup_bert": null,
    "relu": false,
    "nway": 2,
    "use_ib_negatives": false,
    "reranker": false,
    "distillation_alpha": 1.0,
    "ignore_scores": false,
    "model_name": null,
    "query_maxlen": 64,
    "attend_to_mask_tokens": false,
    "interaction": "colbert",
    "dim": 256,
    "doc_maxlen": 512,
    "mask_punctuation": true,
    "checkpoint": "google-bert\

Some weights of HF_ColBERT were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: How does Zacks Rank compare to average brokerage recommendations for stocks?, 		 True, 		 None
#> Output IDs: torch.Size([64]), tensor([  101,   100,  1731,  1674, 14064,  1116, 25949, 14133,  1106,  1903,
        24535,  2553, 11859,  1111, 17901,   136,   102,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103,   103,   103], device='cuda:0')
#> Output Mask: torch.Size([64]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:

# 4. Indexing

In [18]:
index_config = {
    'triples_path': '/kaggle/working/triples.train.small.jsonl',
    'queries_path': '/kaggle/working/queries.train.small.tsv',
    'collection_path':  '/kaggle/working/collection.tsv',
    'root_path': '/kaggle/working/experiments',
    'experiment_name': 'msmarco',
    'model_checkpoint': 'google-bert/bert-base-cased',
    'checkpoint_path': 'ghanahmada/stock-colbert', # trained colbert
    'index_name': 'ir',
    'nranks': 1
}

In [55]:
def load_collection():
    doc = pd.read_parquet("/kaggle/working/tp4-tbi-be/experiment/data/document.parquet")
    non_duplicate_doc = doc.drop_duplicates(subset="Article")["Article"].tolist()
    return non_duplicate_doc

def init_index(indexer, collection, config, verbose=3):
    with Run().context(RunConfig(nranks=config["nranks"], experiment=config["experiment_name"])):
        checkpoint_path = config["checkpoint_path"]

        colbert_config = ColBERTConfig(
            nbits=2,
            root=config["root_path"],
        )
        indexer = Indexer(checkpoint=checkpoint_path, config=colbert_config, verbose=verbose)
        indexer.index(name=f"{config['index_name']}", 
                      collection=collection, overwrite=True)

In [58]:
init_index(indexer=Indexer, 
           collection=load_collection(),
           config=index_config,
           verbose=1)



[Dec 22, 10:02:09] #> Note: Output directory /kaggle/working/experiments/msmarco/indexes/ir already exists


#> Starting...




nranks = 1 	 num_gpus = 1 	 device=0
[Dec 22, 10:02:15] [0] 		 #> Encoding 16918 passages..
[Dec 22, 10:07:04] [0] 		 avg_doclen_est = 410.19140625 	 len(local_sample) = 16,918
[Dec 22, 10:07:08] [0] 		 Creating 32,768 partitions.
[Dec 22, 10:07:08] [0] 		 *Estimated* 6,939,618 embeddings.
[Dec 22, 10:07:08] [0] 		 #> Saving the indexing plan to /kaggle/working/experiments/msmarco/indexes/ir/plan.json ..
Clustering 6889618 points in 256D to 32768 clusters, redo 1 times, 4 iterations
  Preprocessing in 1.77 s
[Dec 22, 10:08:36] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Dec 22, 10:08:36] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[0.024, 0.022, 0.022, 0.021, 0.022, 0.023, 0.022, 0.025, 0.022, 0.023, 0.024, 0.022, 0.022, 0.023, 0.023, 0.024, 0.022, 0.023, 0.022, 0.022, 0.024, 0.024, 0.021, 0.024, 0.023, 0.023, 0.022, 0.023, 0.023, 0.025, 0.023, 0.024, 0.022, 0.023, 0.025

1it [05:10, 310.37s/it]
100%|██████████| 1/1 [00:00<00:00, 86.58it/s]


[Dec 22, 10:13:48] #> Optimizing IVF to store map from centroids to list of pids..
[Dec 22, 10:13:48] #> Building the emb2pid mapping..
[Dec 22, 10:13:48] len(emb2pid) = 6939618


100%|██████████| 32768/32768 [00:00<00:00, 33411.67it/s]


[Dec 22, 10:13:49] #> Saved optimized IVF to /kaggle/working/experiments/msmarco/indexes/ir/ivf.pid.pt

#> Joined...


# 5. Inference on Trained ColBERT

In [129]:
class ColBERTSearcher:
    def __init__(self, config, verbose, searcher=None):
        self.config = config
        self.verbose = verbose
        self.collection = self.load_collection()
        self.searcher = self.init_searcher(searcher)
        
    def load_collection(self):
        doc = pd.read_parquet("/kaggle/working/tp4-tbi-be/experiment/data/document.parquet")
        passages = [{"pid":i, "passage": item} for i, item in 
                     enumerate(get_unique_ordered_list(doc["Article"].tolist()))]
    
        return passages
    
    def init_searcher(self, searcher=None):
        if searcher is None:
            searcher = Searcher
            
        with Run().context(RunConfig(experiment=self.config["experiment_name"])):
            searcher = searcher(index=f"{self.config['index_name']}", 
                                collection=self.collection,
                                verbose=self.verbose)
        return searcher

    def infer(self, query, k=3):
        results = self.searcher.search(query, k=k)
        
        data = []
        for passage_id, passage_rank, passage_score in zip(*results):
            passage_text = self.searcher.collection[passage_id]
            if self.verbose > 0:
                print(f"{passage_rank:<5} {passage_score:<10.4f} {passage_id:<15} {passage_text}")
                
            data.append([passage_rank, passage_score, passage_id, passage_text, passage_score])

        df = pd.DataFrame(data, columns=["rank", "score", "passage_id", "passage_text", "passage_score"])        
        return df
    
    
def add_inference(_eval_df, searcher, k=10):
    K = k
    eval_df = _eval_df.copy()

    for i in range(1, K + 1):
        eval_df[f"context_{i}"] = None
        eval_df[f"id_context_{i}"] = None
        eval_df[f"score_context_{i}"] = None
    
    query_df = eval_df.iloc[[i for i in range(len(eval_df)) if i % 10 == 0]]

    for idx, row in tqdm(query_df.iterrows(), total=len(query_df)):
        query = row["query"]
        result = searcher.infer(query, k=K)
        passage_text = result["passage_text"].tolist()
        passage_id = result["passage_id"].tolist()
        passage_score = result["passage_score"].tolist()

        
        for i in range(K):
            if i < len(passage_text):   
                eval_df.at[idx, f"context_{i+1}"] = passage_text[i]
                eval_df.at[idx, f"id_context_{i+1}"] = passage_id[i]
                eval_df.at[idx, f"score_context_{i+1}"] = passage_score[i]
            else:  
                eval_df.at[idx, f"context_{i+1}"] = "-"
                eval_df.at[idx, f"id_context_{i+1}"] = -1
                eval_df.at[idx, f"score_context_{i+1}"] = -1

    
    eval_df = eval_df.fillna(method="ffill")
    return eval_df

In [93]:
test_df = pd.read_parquet("/kaggle/working/tp4-tbi-be/experiment/data/test_qrels.parquet")

searcher = ColBERTSearcher(index_config, -1)
result_df = add_inference(test_df, searcher, k=30)



[Dec 22, 10:54:13] #> Loading codec...
[Dec 22, 10:54:13] #> Loading IVF...
[Dec 22, 10:54:13] #> Loading doclens...


100%|██████████| 1/1 [00:00<00:00, 839.20it/s]

[Dec 22, 10:54:13] #> Loading codes and residuals...



100%|██████████| 1/1 [00:00<00:00,  2.06it/s]
100%|██████████| 2000/2000 [01:22<00:00, 24.34it/s]
  eval_df = eval_df.fillna(method="ffill")
  eval_df = eval_df.fillna(method="ffill")


In [146]:
result_df.head()

Unnamed: 0,qid,query,docno,document,label,context_1,id_context_1,score_context_1,context_2,id_context_2,...,score_context_27,context_28,id_context_28,score_context_28,context_29,id_context_29,score_context_29,context_30,id_context_30,score_context_30
0,2,What lessons can investors learn from holding ...,0,"After an absolute disaster of a year in 2022, ...",1,"{'pid': 13092, 'passage': 'From 1977 to 1990, ...",13092,41.21875,"{'pid': 160, 'passage': 'Electronics giant App...",160,...,39.40625,"{'pid': 192, 'passage': 'Apple (NASDAQ:AAPL) s...",192,39.40625,"{'pid': 9024, 'passage': 'Apple (NASDAQ: AAPL)...",9024,39.375,"{'pid': 326, 'passage': 'Investing always carr...",326,39.375
1,2,What lessons can investors learn from holding ...,8188,"As the saying goes, there are many possible re...",0,"{'pid': 13092, 'passage': 'From 1977 to 1990, ...",13092,41.21875,"{'pid': 160, 'passage': 'Electronics giant App...",160,...,39.40625,"{'pid': 192, 'passage': 'Apple (NASDAQ:AAPL) s...",192,39.40625,"{'pid': 9024, 'passage': 'Apple (NASDAQ: AAPL)...",9024,39.375,"{'pid': 326, 'passage': 'Investing always carr...",326,39.375
2,2,What lessons can investors learn from holding ...,12951,"By Caroline Valetkevitch\nNEW YORK, Nov 14 (Re...",0,"{'pid': 13092, 'passage': 'From 1977 to 1990, ...",13092,41.21875,"{'pid': 160, 'passage': 'Electronics giant App...",160,...,39.40625,"{'pid': 192, 'passage': 'Apple (NASDAQ:AAPL) s...",192,39.40625,"{'pid': 9024, 'passage': 'Apple (NASDAQ: AAPL)...",9024,39.375,"{'pid': 326, 'passage': 'Investing always carr...",326,39.375
3,2,What lessons can investors learn from holding ...,4160,By Shashwat Chauhan\nAug 22 (Reuters) - Europe...,0,"{'pid': 13092, 'passage': 'From 1977 to 1990, ...",13092,41.21875,"{'pid': 160, 'passage': 'Electronics giant App...",160,...,39.40625,"{'pid': 192, 'passage': 'Apple (NASDAQ:AAPL) s...",192,39.40625,"{'pid': 9024, 'passage': 'Apple (NASDAQ: AAPL)...",9024,39.375,"{'pid': 326, 'passage': 'Investing always carr...",326,39.375
4,2,What lessons can investors learn from holding ...,18700,Tech stocks were higher Thursday afternoon wit...,0,"{'pid': 13092, 'passage': 'From 1977 to 1990, ...",13092,41.21875,"{'pid': 160, 'passage': 'Electronics giant App...",160,...,39.40625,"{'pid': 192, 'passage': 'Apple (NASDAQ:AAPL) s...",192,39.40625,"{'pid': 9024, 'passage': 'Apple (NASDAQ: AAPL)...",9024,39.375,"{'pid': 326, 'passage': 'Investing always carr...",326,39.375


# 6. Evaluation

In [141]:
def mean_reciprocal_rank(true_labels: List[int], predicted_lists: List[List[int]]) -> float:
    def reciprocal_rank(true_label, predicted_list):
        for i, predicted_id in enumerate(predicted_list, start=1):
            if predicted_id == true_label:
                return 1 / i
        return 0

    total_queries = len(true_labels)
    if total_queries == 0:
        return 0 
    
    rr_sum = sum(reciprocal_rank(true_label, predicted_list) 
                 for true_label, predicted_list in zip(true_labels, predicted_lists))
    
    return rr_sum / total_queries

def mean_average_precision(true_labels: List[int], predicted_lists: List[List[int]]) -> float:
    def average_precision(true_label, predicted_list):
        precision_sum = 0
        relevant_count = 0
        
        for i, predicted_id in enumerate(predicted_list, start=1):
            if predicted_id == true_label:
                relevant_count += 1
                precision_sum += relevant_count / i  # Precision at rank i
                
        return precision_sum / relevant_count if relevant_count > 0 else 0

    total_queries = len(true_labels)
    if total_queries == 0:
        return 0

    ap_sum = sum(average_precision(true_label, predicted_list)
                 for true_label, predicted_list in zip(true_labels, predicted_lists))
    
    return ap_sum / total_queries

def precision_at_k(true_labels: List[int], predicted_lists: List[List[int]], k: int) -> float:
    correct = 0
    total = len(predicted_lists) * k

    for true_label, predicted_list in zip(true_labels, predicted_lists):
        correct += sum(1 for pred in predicted_list[:k] if pred == true_label)

    return correct / total if total > 0 else 0

def recall_at_k(true_labels: List[int], predicted_lists: List[List[int]], k: int) -> float:
    correct = 0
    total_relevant = len(true_labels)

    for true_label, predicted_list in zip(true_labels, predicted_lists):
        if true_label in predicted_list[:k]:
            correct += 1

    return correct / total_relevant if total_relevant > 0 else 0

def load_prediction(_eval_df: pd.DataFrame, collection, topk: int):
    eval_df = _eval_df.copy()
    mapping = {}

    for _, row in eval_df.iterrows():
        raw_idx = collection[row["document"]]
        processed_idx = row["docno"]

        mapping[processed_idx] = raw_idx

    true_labels = [mapping[i] for i in eval_df["docno"].values]

    prediction = eval_df[[f"id_context_{i}" for i in range(1, topk+1)]].astype(int).values[::10].tolist()
    eval_df["label"] = true_labels
    
    return true_labels, prediction

def run_evaluation(result_df, collection, topk: int):
    print(f"=====TOP-{topk}=====")
    true_labels, pred = load_prediction(result_df, collection, topk=topk)

    mrr = mean_reciprocal_rank(np.array(true_labels)[::10], pred)
    map_score = mean_average_precision(np.array(true_labels)[::10], pred)
    precision = precision_at_k(np.array(true_labels)[::10], pred, k=topk)
    recall = recall_at_k(np.array(true_labels)[::10], pred, k=topk)

    print(f"MRR@{topk}: {mrr:.4f}")
    print(f"MAP@{topk}: {map_score:.4f}")
    print(f"Precision@{topk}: {precision:.4f}")
    print(f"Recall@{topk}: {recall:.4f}")


In [142]:
DOC_PATH = "/kaggle/working/tp4-tbi-be/experiment/data/document.parquet"
QREL_PATH = "/kaggle/working/tp4-tbi-be/experiment/data/test_qrels.parquet"

document = pd.read_parquet(DOC_PATH)
qrels = pd.read_parquet(QREL_PATH)

In [143]:
passages = load_collection()
    
collection = {item:i for i, item in enumerate(passages)}

In [145]:
for topk in [10, 20, 30]:
    run_evaluation(result_df, collection, topk=topk)

=====TOP-10=====
MRR@10: 0.3826
MAP@10: 0.3826
Precision@10: 0.0587
Recall@10: 0.5870
=====TOP-20=====
MRR@20: 0.3896
MAP@20: 0.3896
Precision@20: 0.0344
Recall@20: 0.6885
=====TOP-30=====
MRR@30: 0.3913
MAP@30: 0.3913
Precision@30: 0.0244
Recall@30: 0.7305
