In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
##%% md
## Install Dependencies & Imports
##%%
!pip install --upgrade pip --quiet
!pip install torch --quiet
!pip install transformers==4.49.* --quiet
!pip install datasets --quiet
!pip install scikit-learn --quiet
!pip install ir_measures tqdm --quiet
!pip install ragatouille==0.0.9 --quiet
!pip install rank-bm25 nltk --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m61.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m148.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m155.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m47.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m45.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m79.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m142.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from pipeline.task_1 import config
from pipeline.task_1.pipeline import RetrievalPipeline
import os
import json
from pathlib import Path
import collections
from typing import Dict, List, Tuple, Any
import sys
import time
import gzip
import shutil
import torch
import pickle
from tqdm import tqdm
from ragatouille import RAGPretrainedModel
from rank_bm25 import BM25Okapi
import re
import nltk
import ir_measures

nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
nltk.download('punkt_tab', quiet=True) # Note: 'punkt_tab' is not a standard NLTK package.

USE_GOOGLE_DRIVE = True

DATA_DIR_NAME = "longeval_sci_training_2025_abstract"
RAGATOUILLE_COLBERT_MODEL_CACHE_SUBDIR = "colbert_model_cache"

BM25_K1 = 1.5
BM25_B = 0.75
BM25_REMOVE_STOPWORDS = True
BM25_ENABLE_STEMMING = True
BM25_TOP_K_CANDIDATES = 200

COLBERT_MODEL_NAME = "colbert-ir/colbertv2.0"
MAX_SEQ_LENGTH = 512
MAX_AUTHORS_IN_INPUT = 3
K_RERANK_RETRIEVAL = 100

nltk_data_path_user = os.path.join(os.path.expanduser("~"), "nltk_data")
if nltk_data_path_user not in nltk.data.path:
    nltk.data.path.append(nltk_data_path_user)
os.makedirs(nltk_data_path_user, exist_ok=True)

def ensure_nltk_resource(resource_name_in_find: str, download_package_name: str):
    try:
        nltk.data.find(resource_name_in_find)
    except (LookupError, nltk.downloader.DownloadError) as e:
        try:
            nltk.download(download_package_name, quiet=True)
            nltk.data.find(resource_name_in_find)
        except Exception as e_download:
            print(f"ERROR: Failed to download or verify NLTK resource '{download_package_name}' for '{resource_name_in_find}': {e_download}")
            sys.exit(1)

ensure_nltk_resource('tokenizers/punkt', 'punkt')
if BM25_REMOVE_STOPWORDS:
    ensure_nltk_resource('corpora/stopwords', 'stopwords')

english_stopwords_global = []
if BM25_REMOVE_STOPWORDS:
    from nltk.corpus import stopwords as nltk_stopwords_import
    english_stopwords_global = nltk_stopwords_import.words('english')

stemmer_global = None
if BM25_ENABLE_STEMMING:
    from nltk.stem import PorterStemmer
    stemmer_global = PorterStemmer()

if USE_GOOGLE_DRIVE:
    try:
        from google.colab import drive
    except ImportError:
        print("ERROR: google.colab.drive module not found. Set USE_GOOGLE_DRIVE to False or run in a Colab environment.")
        sys.exit(1)

if USE_GOOGLE_DRIVE:
    GOOGLE_DRIVE_PROJECT_ROOT = Path("/content/drive/MyDrive/AIR_Project/")
    DATA_DIR = GOOGLE_DRIVE_PROJECT_ROOT / DATA_DIR_NAME
    EVALUATION_OUTPUT_ROOT = GOOGLE_DRIVE_PROJECT_ROOT / "reranking_evaluation_output"
    COLBERT_MODEL_CACHE_PATH = EVALUATION_OUTPUT_ROOT / RAGATOUILLE_COLBERT_MODEL_CACHE_SUBDIR
else:
    SCRIPT_DIR = Path.cwd()
    DATA_DIR = SCRIPT_DIR / DATA_DIR_NAME
    EVALUATION_OUTPUT_ROOT = Path("./reranking_evaluation_output")
    COLBERT_MODEL_CACHE_PATH = EVALUATION_OUTPUT_ROOT / RAGATOUILLE_COLBERT_MODEL_CACHE_SUBDIR

QUERIES_FILE = DATA_DIR / "queries.txt"
QRELS_FILE = DATA_DIR / "qrels.txt"
DOCUMENTS_DIR = DATA_DIR / "documents"


def bm25_tokenizer(text: str) -> List[str]:
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    tokens = nltk.word_tokenize(text)
    if BM25_REMOVE_STOPWORDS:
        tokens = [token for token in tokens if token not in english_stopwords_global]
    if BM25_ENABLE_STEMMING and stemmer_global:
        tokens = [stemmer_global.stem(token) for token in tokens]
    return tokens

def mount_drive_and_verify_paths(data_dir_path: Path, queries_file_path: Path,
                                 qrels_file_path: Path, docs_dir_path: Path) -> bool:
    try:
        drive.mount('/content/drive', force_remount=True)
        print("INFO: Google Drive mounted successfully.")
    except Exception as e:
        print(f"ERROR: Failed to mount Google Drive: {e}")
        return False
    return verify_common_paths("Google Drive", data_dir_path, queries_file_path, qrels_file_path, docs_dir_path)

def verify_local_paths(data_dir_path: Path, queries_file_path: Path,
                       qrels_file_path: Path, docs_dir_path: Path) -> bool:
    print("INFO: Verifying local paths...")
    return verify_common_paths("Local", data_dir_path, queries_file_path, qrels_file_path, docs_dir_path)

def verify_common_paths(env_type: str, data_dir_path: Path, queries_file_path: Path,
                        qrels_file_path: Path, docs_dir_path: Path) -> bool:
    paths_to_check = {
        f"{env_type} Dataset directory": data_dir_path,
        f"{env_type} Queries file": queries_file_path,
        f"{env_type} Qrels file": qrels_file_path,
        f"{env_type} Documents directory": docs_dir_path
    }
    all_exist = True
    for name, path_val in paths_to_check.items():
        display_path = path_val.resolve() if not str(path_val).startswith("/content/drive") else path_val
        is_dir_check = name.endswith("Documents directory") or name.endswith("Dataset directory")

        if is_dir_check and not path_val.is_dir():
            print(f"ERROR: {name} directory not found at: {display_path}")
            all_exist = False
        elif not is_dir_check and not path_val.exists():
            print(f"ERROR: {name} file not found at: {display_path}")
            all_exist = False

    if all_exist:
        print(f"INFO: All required {env_type.lower()} paths verified successfully.")
    else:
        print(f"ERROR: One or more required {env_type.lower()} paths are missing. Base data directory expected at: {data_dir_path}")
    return all_exist

def load_queries(file_path: Path) -> Dict[str, str]:
    queries = {}
    display_path = file_path.resolve() if not str(file_path).startswith("/content/drive") else file_path
    print(f"INFO: Attempting to load queries from {display_path}...")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                parts = line.strip().split('\t', 1)
                if len(parts) == 2:
                    query_id, query_text = parts
                    queries[query_id] = query_text
                else:
                    print(f"WARNING: Skipping malformed line #{i+1} in queries file ({file_path.name}): {line.strip()}")
        print(f"INFO: Successfully loaded {len(queries)} queries from {file_path.name}.")
    except FileNotFoundError:
        print(f"ERROR: Queries file not found: {display_path}")
    except Exception as e:
        print(f"ERROR: Error loading queries from {display_path}: {e}")
    return queries

def load_qrels_for_ir_measures(file_path: Path) -> Dict[str, Dict[str, int]]:
    qrels_dict = collections.defaultdict(dict)
    print(f"INFO: Attempting to load qrels from {file_path}...")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 4:
                    query_id, _, doc_id, relevance_score_str = parts
                    qrels_dict[query_id][doc_id] = int(float(relevance_score_str))
                else:
                    print(f"WARNING: Skipping malformed line in qrels file: {line.strip()}")
        print(f"INFO: Loaded qrels for {len(qrels_dict)} queries for evaluation (scores as int).")
    except Exception as e:
        print(f"ERROR: Error loading qrels for evaluation: {e}")
    return qrels_dict

def load_and_prepare_documents(docs_dir_path: Path, batch_size_info_log: int = 10000) -> Tuple[Dict[str, str], List[str], List[str]]:
    doc_id_to_text_map = {}
    corpus_texts_for_bm25 = []
    corpus_doc_ids_for_bm25 = []

    display_docs_dir_path = docs_dir_path.resolve() if not str(docs_dir_path).startswith("/content/drive") else docs_dir_path
    if not docs_dir_path.is_dir():
        print(f"ERROR: Documents directory not found: {display_docs_dir_path}")
        return {}, [], []

    jsonl_files = list(docs_dir_path.glob('*.jsonl'))
    if not jsonl_files:
        jsonl_files = list(docs_dir_path.glob('*.jsonl.gz'))
        if jsonl_files:
            print(f"INFO: Found .jsonl.gz files, will decompress on the fly.")
        else:
            print(f"WARNING: No .jsonl or .jsonl.gz files found in document directory: {display_docs_dir_path}")
            return {}, [], []

    print(f"INFO: Preparing documents from {len(jsonl_files)} files in {display_docs_dir_path}...")
    total_docs_processed_in_files = 0

    for file_idx, data_file in enumerate(tqdm(jsonl_files, desc="Loading document files", unit="file")):
        open_func = open
        if str(data_file).endswith(".gz"):
            open_func = gzip.open
        try:
            with open_func(data_file, 'rt', encoding='utf-8') as f:
                for line_idx, line in enumerate(f):
                    try:
                        doc_data = json.loads(line)
                        doc_id = str(doc_data.get("id"))
                        if not doc_id:
                            print(f"WARNING: Document in {data_file.name} line {line_idx+1} has no ID. Skipping.")
                            continue

                        title = doc_data.get("title", "")
                        abstract = doc_data.get("abstract", "")

                        authors_list = doc_data.get("authors", [])
                        author_names_str = ""
                        if authors_list:
                            author_names = [author.get("name", "") for author in authors_list[:MAX_AUTHORS_IN_INPUT] if author.get("name")]
                            author_names_str = "; ".join(author_names)

                        doc_parts = [title, author_names_str, abstract]
                        document_text_input = " [SEP] ".join(part for part in doc_parts if part and part.strip()).strip()

                        if document_text_input:
                            if doc_id not in doc_id_to_text_map:
                                doc_id_to_text_map[doc_id] = document_text_input
                                corpus_texts_for_bm25.append(document_text_input)
                                corpus_doc_ids_for_bm25.append(doc_id)
                                total_docs_processed_in_files +=1
                                if total_docs_processed_in_files % batch_size_info_log == 0:
                                    print(f"INFO: Loaded and prepared {total_docs_processed_in_files} documents so far...")
                        else:
                            print(f"WARNING: Document ID {doc_id} in {data_file.name} has no content after processing. Skipping.")

                    except json.JSONDecodeError:
                        print(f"WARNING: Skipping malformed JSON line in {data_file.name} (line {line_idx+1})")
                        continue
                    except Exception as e_doc:
                        print(f"WARNING: Error processing a document in {data_file.name} (line {line_idx+1}): {e_doc}")
        except Exception as e_file:
            print(f"ERROR: Error reading or processing file {data_file}: {e_file}")

    if corpus_texts_for_bm25:
        print(f"INFO: Successfully loaded and prepared a total of {len(corpus_texts_for_bm25)} unique documents.")
    else:
        print("WARNING: No documents were loaded. The corpus is empty.")
    return doc_id_to_text_map, corpus_texts_for_bm25, corpus_doc_ids_for_bm25

def main_bm25_colbert_rerank_pipeline():
    print(f"INFO: Starting BM25 + ColBERT Re-ranking pipeline. Using Google Drive: {USE_GOOGLE_DRIVE}")
    display_data_dir = DATA_DIR.resolve() if not str(DATA_DIR).startswith("/content/drive") else DATA_DIR
    display_output_dir = EVALUATION_OUTPUT_ROOT.resolve() if not str(EVALUATION_OUTPUT_ROOT).startswith("/content/drive") else EVALUATION_OUTPUT_ROOT
    print(f"INFO: Data is expected in: {display_data_dir}")
    print(f"INFO: Evaluation output (like model cache) will use: {display_output_dir}")

    EVALUATION_OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
    COLBERT_MODEL_CACHE_PATH.mkdir(parents=True, exist_ok=True)
    print(f"INFO: ColBERT model cache path: {COLBERT_MODEL_CACHE_PATH}")

    if USE_GOOGLE_DRIVE:
        if 'drive' not in globals():
            print("CRITICAL: Google Drive module ('drive') not available. Halting pipeline.")
            return
        if not mount_drive_and_verify_paths(DATA_DIR, QUERIES_FILE, QRELS_FILE, DOCUMENTS_DIR):
            print("CRITICAL: Google Drive Path verification or mount failed. Halting pipeline.")
            return
    else:
        if not verify_local_paths(DATA_DIR, QUERIES_FILE, QRELS_FILE, DOCUMENTS_DIR):
            print("CRITICAL: Local Path verification failed. Halting pipeline.")
            return

    print("INFO: Loading queries and qrels...")
    queries = load_queries(QUERIES_FILE)
    qrels_for_eval_metrics = load_qrels_for_ir_measures(QRELS_FILE)

    if not queries:
        print("ERROR: Failed to load queries. Halting.")
        return
    if not qrels_for_eval_metrics:
        print("ERROR: Failed to load qrels. Halting.")
        return

    print("INFO: Loading and preparing all documents...")
    doc_id_to_text_map, corpus_texts_for_bm25, corpus_doc_ids_for_bm25 = load_and_prepare_documents(DOCUMENTS_DIR, batch_size_info_log=50000)

    #TODO
    pipeline = RetrievalPipeline()
    pipeline.load_data()            # Load documents, queries, and qrels
    pipeline.setup_preprocessing()  # Initialize the preprocessor
    pipeline.setup_bm25_rank()
    #TODO

    print(f"INFO: Initializing ColBERT model for re-ranking: {COLBERT_MODEL_NAME}")
    colbert_reranker = None
    try:
        colbert_reranker = RAGPretrainedModel.from_pretrained(
            COLBERT_MODEL_NAME,
            index_root=str(COLBERT_MODEL_CACHE_PATH)
        )
        if torch.cuda.is_available():
            device = torch.device("cuda")
            if hasattr(colbert_reranker, 'model') and hasattr(colbert_reranker.model, 'model') and isinstance(colbert_reranker.model.model, torch.nn.Module):
                colbert_reranker.model.model.to(device)
                print("INFO: ColBERT re-ranker model moved to CUDA.")
            else:
                print("WARNING: Could not move ColBERT model to CUDA. Structure might have changed or model not loaded correctly.")
        else:
            print("INFO: CUDA not available, ColBERT model will run on CPU.")

    except Exception as e:
        print(f"ERROR: Failed to load ColBERT model for re-ranking: {e}")
        import traceback
        traceback.print_exc()
        return
    if colbert_reranker is None:
        print("ERROR: colbert_reranker is None after attempting to load. Halting.")
        return
    print(f"INFO: ColBERT model {COLBERT_MODEL_NAME} initialized for re-ranking.")

    run_for_eval_metrics = collections.defaultdict(dict)
    num_processed_queries = 0
    total_search_time_ms = 0

    print(f"\nINFO: Processing {len(queries)} queries for re-ranking and evaluation...")
    for q_id, q_text in tqdm(queries.items(), desc=f"Processing queries", unit="query"):
        query_start_time = time.time()
        try:
            bm25_scores = pipeline.bm25_rank_retriever.search(q_text, config.BM25_TOP_K)

            bm25_candidate_original_ids = []
            bm25_candidate_texts = []

            for doc_idx, score in bm25_scores.items():
                bm25_candidate_original_ids.append(doc_idx)
                bm25_candidate_texts.append(doc_id_to_text_map[doc_idx])

            if not bm25_candidate_texts:
                run_for_eval_metrics[str(q_id)] = {}
                num_processed_queries += 1
                continue

            effective_k_for_rerank = min(K_RERANK_RETRIEVAL, len(bm25_candidate_texts))

            if effective_k_for_rerank == 0 :
                run_for_eval_metrics[str(q_id)] = {}
                num_processed_queries +=1
                continue

            colbert_reranked_results = colbert_reranker.rerank(
                query=q_text,
                documents=bm25_candidate_texts,
                k=effective_k_for_rerank
            )

            if colbert_reranked_results is None:
                print(f"WARNING: colbert_reranker.rerank returned None for QID {q_id} with k={effective_k_for_rerank} and {len(bm25_candidate_texts)} candidates. Skipping.")
                run_for_eval_metrics[str(q_id)] = {}
                num_processed_queries += 1
                continue

            for res_idx, res in enumerate(colbert_reranked_results):
                # RAGatouille rerank result_index is the index in the input 'documents' list
                original_doc_id_for_this_res = bm25_candidate_original_ids[res['result_index']]
                run_for_eval_metrics[str(q_id)][str(original_doc_id_for_this_res)] = float(res['score'])

            num_processed_queries += 1
            query_end_time = time.time()
            total_search_time_ms += (query_end_time - query_start_time) * 1000

            if num_processed_queries % 20 == 0 and num_processed_queries < len(queries):
                avg_time_per_query = total_search_time_ms / num_processed_queries if num_processed_queries > 0 else 0
                print(f"INFO: Processed {num_processed_queries}/{len(queries)} queries. Avg time/query: {avg_time_per_query:.2f} ms.")

        except Exception as e:
            print(f"ERROR: Error processing query ID {q_id} ('{q_text[:50]}...'): {e}")
            import traceback
            traceback.print_exc()

    avg_time_per_query_final = total_search_time_ms / num_processed_queries if num_processed_queries > 0 else 0
    print(f"INFO: Finished processing queries. Processed {num_processed_queries} queries. Avg time/query: {avg_time_per_query_final:.2f} ms.")


    print("\nCalculating IR evaluation metrics...")
    if run_for_eval_metrics and qrels_for_eval_metrics:
        qrels_to_evaluate_with = {
            qid: docs for qid, docs in qrels_for_eval_metrics.items()
            if qid in run_for_eval_metrics
        }

        if not qrels_to_evaluate_with:
            print("WARNING: No overlapping queries between run results and qrels after processing. Cannot evaluate.")
        else:
            measures = [
                ir_measures.nDCG@5, ir_measures.nDCG@10, ir_measures.nDCG@20,
                ir_measures.P@5, ir_measures.P@10, ir_measures.P@20,
                ir_measures.Recall@10, ir_measures.Recall@20, ir_measures.Recall@100,
                ir_measures.MRR, ir_measures.MAP
            ]
            eval_results = ir_measures.calc_aggregate(measures, qrels_to_evaluate_with, run_for_eval_metrics)

            print("\nBM25 + ColBERT Re-ranking IR EVALUATION METRICS (Abstracts)\n======================================================")
            for measure_obj, value in eval_results.items():
                print(f"{str(measure_obj)}: {value:.4f}")
    else:
        print("WARNING: Not enough data for IR metric calculation (run or qrels empty/mismatched).")

    print(f"\nINFO: Two-stage re-ranking and evaluation pipeline finished.")
    print(f"INFO: Model cache (if used by RAGatouille) is at: {COLBERT_MODEL_CACHE_PATH}")


if __name__ == '__main__':
    if USE_GOOGLE_DRIVE:
        if 'google.colab.drive' not in sys.modules and 'drive' not in globals():
            print("ERROR: USE_GOOGLE_DRIVE is True, but the 'google.colab.drive' module is not available.")
            print("INFO: Please ensure you are in a Colab environment or set USE_GOOGLE_DRIVE to False.")
            sys.exit(1)

    if not USE_GOOGLE_DRIVE:
        if not (Path.cwd() / DATA_DIR_NAME).exists():
            print(f"WARNING: The data directory '{DATA_DIR_NAME}' was not found in the current working directory: {Path.cwd()}")
            print(f"WARNING: Please ensure the script is run from the directory containing '{DATA_DIR_NAME}' or adjust DATA_DIR_NAME path if it's elsewhere.")

    if torch.cuda.is_available():
        print(f"INFO: CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("WARNING: CUDA not available. ColBERT re-ranking will be very slow on CPU.")
        print("Consider enabling a GPU runtime in Colab (Runtime > Change runtime type > GPU).")

    pipeline_start_time = time.time()
    main_bm25_colbert_rerank_pipeline()
    pipeline_end_time = time.time()
    print(f"INFO: Total pipeline execution time: {pipeline_end_time - pipeline_start_time:.2f} seconds.")