## Install Dependencies & Imports

In [None]:
!pip install --upgrade pip --quiet
!pip install torch --quiet
!pip install transformers==4.49.* --quiet
!pip install datasets --quiet
!pip install scikit-learn --quiet
!pip install ir_measures tqdm --quiet
!pip install tira ir_datasets_longeval --quiet
!pip install ragatouille==0.0.9 --quiet
!pip install faiss-gpu-cu12 --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m78.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m162.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m182.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m50.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m142.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m151.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

Thu May 15 04:55:37 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   31C    P0             43W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import os
import json
from pathlib import Path
import collections
from typing import Dict, List, Tuple, Any
import sys
import time

import torch
from google.colab import drive
import ir_measures
from tqdm import tqdm
from ragatouille import RAGPretrainedModel

COLAB_DRIVE_ROOT_PATH = "/content/drive/MyDrive/AIR_Project/"
DATA_DIR_NAME = "longeval_sci_training_2025_abstract"
DATA_DIR = Path(COLAB_DRIVE_ROOT_PATH) / DATA_DIR_NAME
QUERIES_FILE = DATA_DIR / "queries.txt"
QRELS_FILE = DATA_DIR / "qrels.txt"
DOCUMENTS_DIR = DATA_DIR / "documents"
OUTPUT_DIR = Path("./longeval_output_colbert_abstract_final")
RAGATOUILLE_INDEXES_ROOT_SUBDIR = "colbert_indexes"
INDEX_NAME = "longeval_abstract_final_idx"

COLBERT_MODEL_NAME = "colbert-ir/colbertv2.0"
TREC_RUN_NAME = "CLEF-ColBERT-Abstract-Final-v1"

MAX_SEQ_LENGTH = 512

def mount_drive_and_verify_paths(data_dir_path, queries_file_path, qrels_file_path, docs_dir_path):
    drive.mount('/content/drive', force_remount=True)
    paths_to_check = {
        "Dataset directory": data_dir_path,
        "Queries file": queries_file_path,
        "Qrels file": qrels_file_path,
        "Documents directory": docs_dir_path
    }
    all_exist = True
    for name, path_val in paths_to_check.items():
        if (name == "Documents directory" and not path_val.is_dir()) or \
           (name != "Documents directory" and not path_val.exists()):
            print(f"ERROR: {name} not found at: {path_val}")
            all_exist = False
    if all_exist:
        print("All required paths verified successfully.")
    return all_exist

def load_queries(file_path):
    queries = {}
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t', 1)
                if len(parts) == 2:
                    query_id, query_text = parts
                    queries[query_id] = query_text
                else:
                    print(f"WARNING: Skipping malformed line in queries file: {line.strip()}")
        print(f"Loaded {len(queries)} queries.")
    except Exception as e:
        print(f"ERROR: Error loading queries: {e}")
    return queries

def load_qrels_for_ir_measures(file_path):
    qrels_dict = collections.defaultdict(dict)
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 4:
                    query_id, _, doc_id, relevance_score_str = parts
                    qrels_dict[query_id][doc_id] = int(float(relevance_score_str))
                else:
                    print(f"WARNING: Skipping malformed line in qrels file: {line.strip()}")
        print(f"Loaded qrels for {len(qrels_dict)} queries for evaluation (scores as int).")
    except Exception as e:
        print(f"ERROR: Error loading qrels for evaluation: {e}")
    return qrels_dict

def load_raw_qrels_data(file_path):
    raw_qrels = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 4:
                    query_id, _, doc_id, relevance_score = parts
                    raw_qrels.append((query_id, doc_id, float(relevance_score)))
                else:
                    print(f"WARNING: Skipping malformed line in qrels file: {line.strip()}")
        print(f"Loaded {len(raw_qrels)} raw relevance judgments (scores as float).")
    except Exception as e:
        print(f"ERROR: Error loading raw qrels: {e}")
    return raw_qrels

def load_and_prepare_documents_for_indexing(docs_dir, relevant_doc_ids):
    document_contents_map = {}
    loaded_count = 0

    jsonl_files = list(docs_dir.glob('*.jsonl'))
    total_files = len(jsonl_files)
    print(f"Scanning {total_files} document files to find {len(relevant_doc_ids)} relevant documents...")

    for jsonl_file in tqdm(jsonl_files, total=total_files, desc="Scanning document files"):
        try:
            with open(jsonl_file, 'r', encoding='utf-8') as f:
                for line in f:
                    try:
                        doc_data = json.loads(line)
                        doc_id = str(doc_data.get("id"))
                        if doc_id in relevant_doc_ids and doc_id not in document_contents_map:
                            title = doc_data.get("title", "")
                            abstract = doc_data.get("abstract", "")

                            authors_list = doc_data.get("authors", [])
                            author_names_str = ""
                            if authors_list:
                                author_names = [author.get("name", "") for author in authors_list[:3] if author.get("name")]
                                author_names_str = "; ".join(author_names)

                            doc_parts = [title, author_names_str, abstract]
                            document_text_input = " [SEP] ".join(part for part in doc_parts if part and part.strip())

                            document_contents_map[doc_id] = document_text_input.strip()
                            loaded_count += 1
                            if loaded_count == len(relevant_doc_ids):
                                print(f"Successfully loaded all {len(document_contents_map)} required documents for indexing.")
                                break
                    except json.JSONDecodeError:
                        continue
                    except Exception as e_doc:
                        print(f"WARNING: Error processing a document line in {jsonl_file}: {str(e_doc)}")
                if loaded_count == len(relevant_doc_ids):
                    break
        except Exception as e_file:
            print(f"WARNING: Error opening or reading file {jsonl_file}: {str(e_file)}")

    if len(document_contents_map) < len(relevant_doc_ids):
        print(f"WARNING: Could only load {len(document_contents_map)} out of {len(relevant_doc_ids)} required documents for indexing.")

    doc_ids_list = list(document_contents_map.keys())
    doc_texts_list = [document_contents_map[doc_id] for doc_id in doc_ids_list]

    print(f"Prepared {len(doc_texts_list)} documents for ColBERT indexing.")
    return doc_texts_list, doc_ids_list

def generate_trec_run_file(run_data, output_file, run_name):
    with open(output_file, 'w') as f_out:
        for q_id, doc_scores in run_data.items():
            sorted_docs = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)
            for rank, (doc_id, score) in enumerate(sorted_docs, 1):
                f_out.write(f"{q_id} Q0 {doc_id} {rank} {score:.6f} {run_name}\n")
    print(f"TREC run file saved to {output_file}")

def main_colbert_pipeline():
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    ragatouille_indexes_root_path = OUTPUT_DIR / RAGATOUILLE_INDEXES_ROOT_SUBDIR
    ragatouille_indexes_root_path.mkdir(parents=True, exist_ok=True)

    specific_index_path = ragatouille_indexes_root_path / INDEX_NAME

    if not mount_drive_and_verify_paths(DATA_DIR, QUERIES_FILE, QRELS_FILE, DOCUMENTS_DIR):
        return

    print("Loading queries and qrels...")
    queries = load_queries(QUERIES_FILE)
    qrels_for_eval_metrics = load_qrels_for_ir_measures(QRELS_FILE)
    raw_qrels = load_raw_qrels_data(QRELS_FILE)

    if not queries or not qrels_for_eval_metrics or not raw_qrels:
        print("ERROR: Failed to load queries or qrels. Halting.")
        return

    unique_doc_ids_in_qrels = set(doc_id for _, doc_id, _ in raw_qrels)
    if not unique_doc_ids_in_qrels:
        print("ERROR: No document IDs found in qrels to index. Halting.")
        return

    document_texts_list, document_ids_list = load_and_prepare_documents_for_indexing(DOCUMENTS_DIR, unique_doc_ids_in_qrels)

    if not document_texts_list:
        print("ERROR: No document texts could be loaded for indexing. Halting.")
        return

    print(f"\nInitializing ColBERT model: {COLBERT_MODEL_NAME}")

    rag_model = None
    if specific_index_path.exists() and os.listdir(specific_index_path):
        print(f"Attempting to load model and index from: {specific_index_path}")
        try:
            rag_model = RAGPretrainedModel.from_index(index_path=str(specific_index_path))
            print("Model and index loaded successfully.")
        except Exception as e:
            print(f"Error loading existing index from {specific_index_path} (Reason: {e}). Will re-index.")
            rag_model = None

    if rag_model is None:
        print(f"No valid existing index found or loading failed. Initializing base model for new index: {INDEX_NAME}")
        rag_model = RAGPretrainedModel.from_pretrained(COLBERT_MODEL_NAME, index_root=str(ragatouille_indexes_root_path))

        print(f"Creating new index named: {INDEX_NAME} within root: {ragatouille_indexes_root_path}.")
        print(f"Indexing {len(document_texts_list)} documents. This may take some time...")
        start_time = time.time()

        actual_created_index_path = rag_model.index(
            collection=document_texts_list,
            document_ids=document_ids_list,
            index_name=INDEX_NAME,
            max_document_length=MAX_SEQ_LENGTH - 30,
            split_documents=True,
            overwrite_index=True
        )
        end_time = time.time()
        print(f"Indexing completed in {end_time - start_time:.2f} seconds.")
        print(f"Index created at: {actual_created_index_path}")

    print("\nPerforming search for evaluation queries...")
    run_for_eval_metrics = collections.defaultdict(dict)

    eval_query_ids_present_in_qrels = set(qrels_for_eval_metrics.keys())

    for query_id in tqdm(eval_query_ids_present_in_qrels, desc="Searching queries"):
        query_text = queries.get(query_id)
        if not query_text:
            print(f"WARNING: Query text for ID {query_id} not found. Skipping.")
            continue

        try:
            results = rag_model.search(query=query_text, k=100)
        except Exception as e:
            print(f"ERROR during search for query_id {query_id}: {e}")
            import traceback
            traceback.print_exc()
            continue

        for result_item in results:
            doc_id = None
            score = None
            if isinstance(result_item, dict):
                doc_id = result_item.get('document_id')
                if not doc_id: doc_id = result_item.get('docid')
                score = result_item.get('score')
            elif hasattr(result_item, 'metadata') and hasattr(result_item, 'score'):
                 doc_id = result_item.metadata.get('document_id') if result_item.metadata else None
                 score = result_item.score

            if doc_id is not None and score is not None:
                run_for_eval_metrics[str(query_id)][str(doc_id)] = float(score)
            else:
                print(f"WARNING: Could not parse result item for query {query_id}: {result_item}")

    print("\nCalculating IR evaluation metrics...")
    if run_for_eval_metrics and qrels_for_eval_metrics:
        qrels_to_evaluate_with = {
            qid: docs for qid, docs in qrels_for_eval_metrics.items()
            if qid in run_for_eval_metrics
        }

        if not qrels_to_evaluate_with:
            print("WARNING: No overlapping queries between run file and qrels after processing. Cannot evaluate.")
        else:
            measures = [
                ir_measures.nDCG@5, ir_measures.nDCG@10, ir_measures.nDCG@20,
                ir_measures.P@5, ir_measures.P@10, ir_measures.P@20,
                ir_measures.Recall@10, ir_measures.Recall@20, ir_measures.Recall@100,
                ir_measures.MRR, ir_measures.MAP
            ]
            eval_results = ir_measures.calc_aggregate(measures, qrels_to_evaluate_with, run_for_eval_metrics)

            metrics_file_path = OUTPUT_DIR / "colbert_evaluation_metrics.txt"
            with open(metrics_file_path, 'w') as f:
                f.write("ColBERT IR EVALUATION METRICS (Abstracts)\n=========================================\n\n")
                for measure_obj, value in eval_results.items():
                    f.write(f"{str(measure_obj)}: {value:.4f}\n")
                    print(f"{str(measure_obj)}: {value:.4f}")
            print(f"Metrics saved to {metrics_file_path}")

            trec_run_file_path = OUTPUT_DIR / f"{TREC_RUN_NAME}.txt"
            generate_trec_run_file(run_for_eval_metrics, trec_run_file_path, TREC_RUN_NAME)
    else:
        print("WARNING: Not enough data for IR metric calculation (run or qrels empty/mismatched).")

    print(f"\nAll ColBERT processing completed! Output directory: {OUTPUT_DIR}")
    if specific_index_path.exists():
         print(f"Index location: {specific_index_path}")
    else:
        print(f"Index may be in RAGatouille's default cache if not found at {specific_index_path}.")

if __name__ == '__main__':
    if torch.cuda.is_available():
        print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("WARNING: CUDA not available. ColBERT indexing and search will be very slow on CPU.")
        print("Please enable a GPU runtime in Colab (Runtime > Change runtime type > GPU).")

    main_colbert_pipeline()

CUDA is available. Using GPU: Tesla T4
Mounted at /content/drive
All required paths verified successfully.
Loading queries and qrels...
Loaded 393 queries.
Loaded qrels for 393 queries for evaluation (scores as int).
Loaded 4262 raw relevance judgments (scores as float).
Scanning 21 document files to find 4238 relevant documents...


Scanning document files:  90%|█████████ | 19/21 [01:30<00:09,  4.74s/it]

Successfully loaded all 4238 required documents for indexing.
Prepared 4238 documents for ColBERT indexing.

Initializing ColBERT model: colbert-ir/colbertv2.0
No valid existing index found or loading failed. Initializing base model for new index: longeval_abstract_final_idx



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


artifact.metadata:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/405 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

  self.scaler = torch.cuda.amp.GradScaler()


Creating new index named: longeval_abstract_final_idx within root: longeval_output_colbert_abstract_final/colbert_indexes.
Indexing 4238 documents. This may take some time...
This is a behaviour change from RAGatouille 0.8.0 onwards.
This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations.
If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour.
--------------------


[May 14, 19:47:31] #> Creating directory longeval_output_colbert_abstract_final/colbert_indexes/colbert/indexes/longeval_abstract_final_idx 


[May 14, 19:47:32] [0] 		 #> Encoding 4714 passages..


  return torch.cuda.amp.autocast() if self.activated else NullContextManager()


[May 14, 19:48:09] [0] 		 avg_doclen_est = 222.47369384765625 	 len(local_sample) = 4,714
[May 14, 19:48:09] [0] 		 Creating 16,384 partitions.
[May 14, 19:48:09] [0] 		 *Estimated* 1,048,740 embeddings.
[May 14, 19:48:09] [0] 		 #> Saving the indexing plan to longeval_output_colbert_abstract_final/colbert_indexes/colbert/indexes/longeval_abstract_final_idx/plan.json ..
used 20 iterations (2.1495s) to cluster 998741 items into 16384 clusters
[May 14, 19:48:13] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


[May 14, 19:49:31] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


[0.039, 0.04, 0.039, 0.036, 0.037, 0.041, 0.04, 0.036, 0.036, 0.039, 0.038, 0.038, 0.038, 0.041, 0.037, 0.04, 0.035, 0.038, 0.037, 0.038, 0.04, 0.04, 0.038, 0.04, 0.037, 0.037, 0.04, 0.039, 0.037, 0.041, 0.037, 0.043, 0.038, 0.038, 0.039, 0.035, 0.04, 0.039, 0.038, 0.045, 0.039, 0.038, 0.038, 0.039, 0.037, 0.037, 0.04, 0.043, 0.039, 0.037, 0.036, 0.04, 0.039, 0.04, 0.038, 0.039, 0.047, 0.039, 0.045, 0.038, 0.036, 0.042, 0.04, 0.041, 0.04, 0.039, 0.04, 0.039, 0.036, 0.038, 0.04, 0.036, 0.039, 0.041, 0.038, 0.04, 0.04, 0.039, 0.04, 0.042, 0.04, 0.039, 0.04, 0.04, 0.037, 0.038, 0.039, 0.04, 0.036, 0.041, 0.038, 0.042, 0.038, 0.039, 0.039, 0.04, 0.042, 0.038, 0.039, 0.039, 0.041, 0.043, 0.04, 0.038, 0.04, 0.038, 0.038, 0.038, 0.039, 0.037, 0.039, 0.041, 0.041, 0.036, 0.041, 0.037, 0.039, 0.039, 0.038, 0.039, 0.038, 0.038, 0.039, 0.041, 0.036, 0.042, 0.039, 0.037]


0it [00:00, ?it/s]

[May 14, 19:50:48] [0] 		 #> Encoding 4714 passages..


  return torch.cuda.amp.autocast() if self.activated else NullContextManager()
1it [00:40, 40.46s/it]
100%|██████████| 1/1 [00:00<00:00, 115.64it/s]


[May 14, 19:51:29] #> Optimizing IVF to store map from centroids to list of pids..
[May 14, 19:51:29] #> Building the emb2pid mapping..
[May 14, 19:51:29] len(emb2pid) = 1048741


100%|██████████| 16384/16384 [00:00<00:00, 49516.42it/s]

[May 14, 19:51:29] #> Saved optimized IVF to longeval_output_colbert_abstract_final/colbert_indexes/colbert/indexes/longeval_abstract_final_idx/ivf.pid.pt





Done indexing!
Indexing completed in 243.87 seconds.
Index created at: longeval_output_colbert_abstract_final/colbert_indexes/colbert/indexes/longeval_abstract_final_idx

Performing search for evaluation queries...


Searching queries:   0%|          | 0/393 [00:00<?, ?it/s]

Loading searcher for index longeval_abstract_final_idx for the first time... This may take a few seconds
[May 14, 19:51:31] #> Loading codec...
[May 14, 19:51:31] #> Loading IVF...
[May 14, 19:51:31] #> Loading doclens...


  self.scaler = torch.cuda.amp.GradScaler()

100%|██████████| 1/1 [00:00<00:00, 1476.87it/s]

[May 14, 19:51:31] #> Loading codes and residuals...




100%|██████████| 1/1 [00:00<00:00, 11.45it/s]


Searcher loaded!

#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: cuticle, 		 True, 		 None
#> Output IDs: torch.Size([32]), tensor([  101,     1,  3013, 25128,   102,   103,   103,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103], device='cuda:0')
#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')



Searching queries: 100%|██████████| 393/393 [00:16<00:00, 23.47it/s]



Calculating IR evaluation metrics...
nDCG@20: 0.6670
R@10: 0.7268
RR: 0.7928
nDCG@5: 0.5499
P@20: 0.3188
R@20: 0.8067
R@100: 0.8706
AP: 0.6458
P@5: 0.6575
P@10: 0.5410
nDCG@10: 0.6363
Metrics saved to longeval_output_colbert_abstract_final/colbert_evaluation_metrics.txt
TREC run file saved to longeval_output_colbert_abstract_final/CLEF-ColBERT-Abstract-Final-v1.txt

All ColBERT processing completed! Output directory: longeval_output_colbert_abstract_final
Index may be in RAGatouille's default cache if not found at longeval_output_colbert_abstract_final/colbert_indexes/longeval_abstract_final_idx.
