In [None]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import ndcg_score

# Load the trained model and tokenizer
model_path = "/scratch-shared/ir2-less/out/${JOB_NAME}"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load evaluation dataset
evaluation_data_path = "evaluation_dataset.jsonl"  # Path to the dataset
with open(evaluation_data_path, "r") as file:
    evaluation_data = [json.loads(line) for line in file]

# Helper function to compute rankings
def get_rankings(model, tokenizer, prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_length=512, do_sample=False)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract and parse the ranking result (e.g., "[2] > [0] > [4] > [5] > [1] > [3]")
    ranking = generated_text.strip().split(" > ")
    return [int(item.strip("[]")) for item in ranking]

# Evaluation metrics
ndcg_scores = []
mrr_scores = []

# Process each example in the dataset
for example in evaluation_data:
    prompt = example["messages"][0]["content"]
    true_ranking = example["messages"][1]["content"]
    true_ranking_ids = [int(item.strip("[]")) for item in true_ranking.split(" > ")]

    # Get model ranking
    predicted_ranking_ids = get_rankings(model, tokenizer, prompt)

    # Compute relevance scores (1 for first place, 2 for second, etc.)
    relevance = [len(true_ranking_ids) - true_ranking_ids.index(pid) for pid in predicted_ranking_ids]

    # Compute nDCG
    ideal_relevance = sorted(relevance, reverse=True)
    ndcg = ndcg_score([ideal_relevance], [relevance])
    ndcg_scores.append(ndcg)

    # Compute MRR
    mrr = 1 / (predicted_ranking_ids.index(true_ranking_ids[0]) + 1)
    mrr_scores.append(mrr)

# Aggregate metrics
average_ndcg = sum(ndcg_scores) / len(ndcg_scores)
average_mrr = sum(mrr_scores) / len(mrr_scores)

# Print results
print(f"Average nDCG: {average_ndcg:.4f}")
print(f"Average MRR: {average_mrr:.4f}")
