## Install Dependencies & Imports

In [None]:
# Install dependencies in correct order to avoid binary incompatibility
!pip install numpy --quiet
!pip install --upgrade pip --quiet
!pip install torch transformers datasets --quiet
!pip install scikit-learn --force-reinstall --no-deps --quiet
!pip install ir_measures tqdm --quiet


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m59.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m140.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m150.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m48.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m40.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m129.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m138.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import json
from pathlib import Path
import collections
from typing import Dict, List, Tuple, Any
import sys

import torch
import numpy as np
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DefaultDataCollator
)
from google.colab import drive
import ir_measures
from tqdm import tqdm

# --- Configuration Constants ---
COLAB_DRIVE_ROOT_PATH = "/content/drive/MyDrive/AIR_Project/"
DATA_DIR_NAME = "longeval_sci_training_2025_abstract"
DATA_DIR = Path(COLAB_DRIVE_ROOT_PATH) / DATA_DIR_NAME
QUERIES_FILE = DATA_DIR / "queries.txt"
QRELS_FILE = DATA_DIR / "qrels.txt"
DOCUMENTS_DIR = DATA_DIR / "documents"
OUTPUT_DIR = Path("./longeval_output_colab_abstract_fp16")
PRETRAINED_MODEL_NAME = "allenai/scibert_scivocab_uncased"
TREC_RUN_NAME = "CLEF-Bert-Run-FP16"
NUM_TRAIN_EPOCHS = 5
PER_DEVICE_TRAIN_BATCH_SIZE = 128
PER_DEVICE_EVAL_BATCH_SIZE = 128
LEARNING_RATE = 3e-5
MAX_SEQ_LENGTH = 512
EVAL_SPLIT_SIZE = 0.1
# --- End Configuration Constants ---

def mount_drive_and_verify_paths(data_dir_path, queries_file_path, qrels_file_path, docs_dir_path):
    drive.mount('/content/drive', force_remount=True)
    paths_to_check = {
        "Dataset directory": data_dir_path,
        "Queries file": queries_file_path,
        "Qrels file": qrels_file_path,
        "Documents directory": docs_dir_path
    }
    all_exist = True
    for name, path_val in paths_to_check.items():
        if (name == "Documents directory" and not path_val.is_dir()) or \
           (name != "Documents directory" and not path_val.exists()):
            print(f"ERROR: {name} not found at: {path_val}")
            all_exist = False
    if all_exist:
        print("All required paths verified successfully.")
    return all_exist

def load_queries(file_path):
    queries = {}
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t', 1)
                if len(parts) == 2:
                    query_id, query_text = parts
                    queries[query_id] = query_text
                else:
                    print(f"WARNING: Skipping malformed line in queries file: {line.strip()}")
        print(f"Loaded {len(queries)} queries.")
    except Exception as e:
        print(f"ERROR: Error loading queries: {e}")
    return queries

def load_qrels_for_ir_measures(file_path):
    qrels_dict = collections.defaultdict(dict)
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 4:
                    query_id, _, doc_id, relevance_score_str = parts
                    qrels_dict[query_id][doc_id] = int(float(relevance_score_str))
                else:
                    print(f"WARNING: Skipping malformed line in qrels file: {line.strip()}")
        print(f"Loaded qrels for {len(qrels_dict)} queries for evaluation (scores as int).")
    except Exception as e:
        print(f"ERROR: Error loading qrels for evaluation: {e}")
    return qrels_dict

def load_raw_qrels_data(file_path):
    raw_qrels = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 4:
                    query_id, _, doc_id, relevance_score = parts
                    raw_qrels.append((query_id, doc_id, float(relevance_score)))
                else:
                    print(f"WARNING: Skipping malformed line in qrels file: {line.strip()}")
        print(f"Loaded {len(raw_qrels)} raw relevance judgments (scores as float).")
    except Exception as e:
        print(f"ERROR: Error loading raw qrels: {e}")
    return raw_qrels

def load_documents_for_ids(docs_dir, required_doc_ids):
    documents = {}
    loaded_count = 0
    if not required_doc_ids:
        return documents

    jsonl_files = list(docs_dir.glob('*.jsonl'))
    total_files = len(jsonl_files)

    for jsonl_file in tqdm(jsonl_files, total=total_files, desc="Scanning document files"):
        try:
            with open(jsonl_file, 'r', encoding='utf-8') as f:
                for line in f:
                    try:
                        doc_data = json.loads(line)
                        doc_id = str(doc_data.get("id"))
                        if doc_id in required_doc_ids and doc_id not in documents:
                            title = doc_data.get("title", "")
                            abstract = doc_data.get("abstract", "")
                            documents[doc_id] = f"{title} [SEP] {abstract}".strip()
                            loaded_count += 1
                            if loaded_count == len(required_doc_ids):
                                print(f"Successfully loaded all {len(documents)} required documents.")
                                return documents
                    except json.JSONDecodeError:
                        continue
                    except Exception as e_doc:
                        print(f"WARNING: Error processing a document line in {jsonl_file}: {str(e_doc)}")
        except Exception as e_file:
            print(f"WARNING: Error opening or reading file {jsonl_file}: {str(e_file)}")

    if len(documents) < len(required_doc_ids):
        print(f"WARNING: Could only load {len(documents)} out of {len(required_doc_ids)} required documents.")
    else:
        print(f"Loaded {len(documents)} documents.")
    return documents

def generate_trec_run_file(run_data, output_file, run_name):
    with open(output_file, 'w') as f_out:
        for q_id, doc_scores in run_data.items():
            sorted_docs = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)
            for rank, (doc_id, score) in enumerate(sorted_docs, 1):
                f_out.write(f"{q_id} Q0 {doc_id} {rank} {score:.6f} {run_name}\n")
    print(f"TREC run file saved to {output_file}")

def load_initial_data(queries_fp, qrels_fp, docs_dp):
    queries = load_queries(queries_fp)
    raw_qrels_list = load_raw_qrels_data(qrels_fp)
    qrels_for_eval_metrics = load_qrels_for_ir_measures(qrels_fp)

    if not queries or not raw_qrels_list or not qrels_for_eval_metrics:
        print("ERROR: Failed to load critical data (queries or qrels). Returning None for all.")
        return None, None, None, None

    all_doc_ids_in_qrels = set(doc_id for _, doc_id, _ in raw_qrels_list)
    documents_content = load_documents_for_ids(docs_dp, all_doc_ids_in_qrels)
    if not documents_content and all_doc_ids_in_qrels:
        print("ERROR: Failed to load required documents. Returning loaded data, but documents_content is None.")
        return queries, raw_qrels_list, qrels_for_eval_metrics, None

    return queries, raw_qrels_list, qrels_for_eval_metrics, documents_content

def prepare_and_split_dataset(
    raw_qrels,
    queries_map,
    docs_map,
    split_size
):

    dataset_items = []
    for query_id, doc_id, relevance_score_float in raw_qrels:
        query_text = queries_map.get(query_id)
        doc_text = docs_map.get(doc_id)
        if query_text and doc_text:
            dataset_items.append({
                "query_id": query_id, "doc_id": doc_id,
                "query_text": query_text, "document_text": doc_text,
                "label": relevance_score_float
            })

    if not dataset_items:
        print("ERROR: No data items prepared after combining sources. Returning empty lists.")
        return [], []
    print(f"Prepared {len(dataset_items)} examples for model training/evaluation.")

    stratify_on = [item['query_id'] for item in dataset_items]
    unique_query_ids = set(stratify_on)
    stratify_param = stratify_on if len(unique_query_ids) > 1 and len(unique_query_ids) < len(dataset_items) else None

    train_items, eval_items = train_test_split(
        dataset_items, test_size=split_size, random_state=42, stratify=stratify_param
    )
    print(f"Split data into {len(train_items)} training and {len(eval_items)} evaluation examples.")
    return train_items, eval_items

def initialize_model_and_tokenizer(model_name):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)
        return tokenizer, model
    except Exception as e:
        print(f"ERROR: Error initializing model/tokenizer: {e}")
        return None, None

def tokenize_datasets(
    train_items,
    eval_items,
    tokenizer
):

    _tokenize_fn = lambda examples: tokenizer(
        examples["query_text"], examples["document_text"],
        padding="max_length", truncation=True, max_length=MAX_SEQ_LENGTH
    )

    train_ds = Dataset.from_list(train_items)
    tokenized_train = train_ds.map(
        _tokenize_fn, batched=True,
        remove_columns=["query_id", "doc_id", "query_text", "document_text"]
    ).rename_column("label", "labels")

    tokenized_eval = None
    current_eval_items = list(eval_items)

    if current_eval_items and len(current_eval_items) > 0:
        eval_ds = Dataset.from_list(current_eval_items)
        tokenized_eval = eval_ds.map(
            _tokenize_fn, batched=True,
            remove_columns=["query_id", "doc_id", "query_text", "document_text"]
        ).rename_column("label", "labels")
    else:
        current_eval_items = []
        print("WARNING: No evaluation data to tokenize for the trainer.")

    return tokenized_train, tokenized_eval, current_eval_items


def train_model(model, train_data, eval_data, output_dp, training_config):
    args = TrainingArguments(
        output_dir=str(output_dp / "training_results"),
        num_train_epochs=training_config['num_train_epochs'],
        per_device_train_batch_size=training_config['per_device_train_batch_size'],
        per_device_eval_batch_size=training_config['per_device_eval_batch_size'],
        learning_rate=training_config['learning_rate'],
        eval_strategy="epoch" if eval_data and len(eval_data) > 0 else "no",
        save_strategy="epoch",
        logging_dir=str(output_dp / 'logs'),
        logging_steps=50,
        load_best_model_at_end=True if eval_data and len(eval_data) > 0 else False,
        metric_for_best_model="loss" if eval_data and len(eval_data) > 0 else None,
        greater_is_better=False,
        report_to="none",
        fp16=True,
        dataloader_num_workers=2,
        gradient_accumulation_steps=2,
        optim="adamw_torch"
    )
    trainer = Trainer(
        model=model, args=args, train_dataset=train_data,
        eval_dataset=eval_data, data_collator=DefaultDataCollator()
    )
    print("Starting training...")
    try:
        trainer.train()
        print("Training finished.")
        model_save_path = output_dp / "final_model"
        trainer.save_model(model_save_path)
        print(f"Model saved to {model_save_path}")
    except Exception as e:
        print(f"ERROR: Error during training: {e}")
        import traceback
        traceback.print_exc()
        print("ERROR: Attempting to continue with evaluation.")
    return trainer

def evaluate_model_ir(
    trainer_instance,
    eval_dataset_tokenized,
    original_eval_items,
    qrels_map,
    output_dp,
    run_name_prefix
):
    print(f"DEBUG PRINT: original_eval_items contains {len(original_eval_items) if original_eval_items is not None else 'None'} items.")
    print(f"DEBUG PRINT: eval_dataset_tokenized is {'present and has items' if eval_dataset_tokenized and len(eval_dataset_tokenized) > 0 else 'None or empty'}")

    if not original_eval_items:
        print("DEBUG PRINT: Skipping IR evaluation because original_eval_items is empty or None.")
        return

    print("DEBUG PRINT: Entered main else block for IR evaluation (original_eval_items is not empty).")
    trainer_instance.model.eval()

    if not eval_dataset_tokenized or len(eval_dataset_tokenized) == 0:
        print("DEBUG PRINT ERROR: Tokenized evaluation data for trainer is None or empty. Skipping IR prediction.")
        return

    print("DEBUG PRINT: Proceeding with predictions using trainer.predict().")
    try:
        predictions_output = trainer_instance.predict(eval_dataset_tokenized)
        scores = predictions_output.predictions.squeeze(-1)

        run_for_eval_metrics = collections.defaultdict(dict)
        for i, item in enumerate(original_eval_items):
            run_for_eval_metrics[item["query_id"]][item["doc_id"]] = scores[i].item()

        eval_query_ids = set(item['query_id'] for item in original_eval_items)
        filtered_qrels_for_eval = {
            qid: docs for qid, docs in qrels_map.items() if qid in eval_query_ids
        }

        print(f"DEBUG PRINT: run_for_eval_metrics has {len(run_for_eval_metrics)} queries.")
        print(f"DEBUG PRINT: filtered_qrels_for_eval has {len(filtered_qrels_for_eval)} queries.")

        if run_for_eval_metrics and filtered_qrels_for_eval:
            measures = [
                ir_measures.nDCG@5, ir_measures.nDCG@10, ir_measures.nDCG@20,
                ir_measures.P@5, ir_measures.P@10, ir_measures.P@20,
                ir_measures.Recall@10, ir_measures.Recall@20, ir_measures.Recall@100,
                ir_measures.MRR, ir_measures.MAP
            ]
            print("Calculating IR evaluation metrics with integer qrels...")
            results = ir_measures.calc_aggregate(measures, filtered_qrels_for_eval, run_for_eval_metrics)

            metrics_file_path = output_dp / "evaluation_metrics.txt"
            with open(metrics_file_path, 'w') as f:
                f.write("IR EVALUATION METRICS\n====================\n\n")
                for measure_obj, value in results.items():
                    f.write(f"{str(measure_obj)}: {value:.4f}\n")
                    print(f"{str(measure_obj)}: {value:.4f}")
            print(f"Metrics saved to {metrics_file_path}")

            trec_run_file_path = output_dp / f"{run_name_prefix}.txt"
            generate_trec_run_file(run_for_eval_metrics, trec_run_file_path, run_name_prefix)
        else:
            print("WARNING: DEBUG: Not enough data for IR metric calculation (run or qrels empty/mismatched after filtering).")
    except Exception as e:
        print(f"ERROR: Error during IR prediction or metric calculation: {e}")
        import traceback
        traceback.print_exc()

def main():
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    if not mount_drive_and_verify_paths(DATA_DIR, QUERIES_FILE, QRELS_FILE, DOCUMENTS_DIR):
        return

    queries, raw_qrels, qrels_eval, documents = load_initial_data(QUERIES_FILE, QRELS_FILE, DOCUMENTS_DIR)
    if not queries or not raw_qrels or not qrels_eval or not documents:
        print("ERROR: Halting due to failure in initial data loading.")
        return

    train_items, eval_items = prepare_and_split_dataset(raw_qrels, queries, documents, EVAL_SPLIT_SIZE)
    if not train_items:
        print("ERROR: No training items after split. Halting.")
        return

    tokenizer, model = initialize_model_and_tokenizer(PRETRAINED_MODEL_NAME)
    if not tokenizer or not model:
        print("ERROR: Halting due to failure in model/tokenizer initialization.")
        return

    tokenized_train, tokenized_eval, final_eval_items = tokenize_datasets(train_items, eval_items, tokenizer)

    if not tokenized_train:
        print("ERROR: Training data tokenization failed. Halting.")
        return

    training_config_params = {
        'num_train_epochs': NUM_TRAIN_EPOCHS,
        'per_device_train_batch_size': PER_DEVICE_TRAIN_BATCH_SIZE,
        'per_device_eval_batch_size': PER_DEVICE_EVAL_BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
    }

    trained_trainer = train_model(model, tokenized_train, tokenized_eval, OUTPUT_DIR, training_config_params)

    if final_eval_items and len(final_eval_items) > 0 and tokenized_eval and len(tokenized_eval) > 0:
         evaluate_model_ir(trained_trainer, tokenized_eval, final_eval_items, qrels_eval, OUTPUT_DIR, TREC_RUN_NAME)
    else:
        print("Skipping final IR evaluation as there are no evaluation items or tokenized evaluation data.")

    print(f"All processing completed! Output directory: {OUTPUT_DIR}")

if __name__ == '__main__':
    main()

Mounted at /content/drive
ERROR: Dataset directory not found at: /content/drive/MyDrive/AIR_Project/longeval_sci_training_2025_fulltext
ERROR: Queries file not found at: /content/drive/MyDrive/AIR_Project/longeval_sci_training_2025_fulltext/queries.txt
ERROR: Qrels file not found at: /content/drive/MyDrive/AIR_Project/longeval_sci_training_2025_fulltext/qrels.txt
ERROR: Documents directory not found at: /content/drive/MyDrive/AIR_Project/longeval_sci_training_2025_fulltext/documents
