# Load XLM-RoBERTa Ensemble and Predict

This notebook loads the best models saved during cross-validation from a specific training run (`run_id`) and uses them as a weighted ensemble to predict on the test set.

In [None]:
import sys
import os
import json
import numpy as np
import torch
from tqdm.auto import tqdm

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from src.utils import load_cleaned_data, load_config, metrics, plot_confusion_matrix
from src.weighted_ensemble_predict import weighted_ensemble_predict

from transformers import (
    XLMRobertaTokenizerFast,
    XLMRobertaForSequenceClassification,
    XLMRobertaConfig,
    set_seed,
)

import matplotlib.pyplot as plt

## Configuration and Setup

In [None]:
RUN_ID_TO_LOAD = ""
CONFIG_PATH = "../cfg/xlm_roberta.json"

config = load_config(CONFIG_PATH)

set_seed(config["seed"])
torch.backends.cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device == torch.device("cuda"): torch.cuda.empty_cache()

model_name = config["model"]["base_model"]
tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name)
print(f"Tokenizer '{model_name}' loaded.")

xlm_roberta_config = XLMRobertaConfig.from_pretrained(model_name)
xlm_roberta_config.classifier_dropout = config["model"]["classifier_dropout"]
xlm_roberta_config.num_labels = config["model"]["num_labels"] 

model_base_path = config['training']['output_dir']
run_output_dir = os.path.join(model_base_path, RUN_ID_TO_LOAD)
metrics_path = os.path.join(run_output_dir, "fold_metrics.json")
n_splits = config["cross_validation"]["n_splits"]

## Load Fold Metrics and Models

In [None]:
fold_f1_scores = []
with open(metrics_path, 'r') as f:
    fold_results = json.load(f)
metric_key = f"eval_{config['training']['metric_for_best_model']}"
fold_f1_scores = [result[metric_key] for result in fold_results]
print(f"Loaded metrics for {len(fold_f1_scores)} folds from {metrics_path}")
print(f"Fold F1 scores: {fold_f1_scores}")

fold_models = []
print(f"\nLoading {n_splits} fold models...")
for fold in range(1, n_splits + 1):
    model_path = os.path.join(run_output_dir, f"xlm_roberta_fold_{fold}")
    print(f"Attempting to load model from: {model_path}")
    try:
        # Pass the specific config if needed, though often it's saved with the model
        model = XLMRobertaForSequenceClassification.from_pretrained(
            model_path,
            config=xlm_roberta_config # Pass config just in case
        )
        model.to(device)
        model.eval()  # Set to evaluation mode
        fold_models.append(model)
        print(f"Loaded model for fold {fold}.")
    except OSError as e:
        print(f"Error loading model for fold {fold} from {model_path}: {e}")
        print("Ensemble will be incomplete.")
        # Optionally break or handle the missing model

if len(fold_models) != n_splits:
    print(f"Warning: Expected {n_splits} models, but only loaded {len(fold_models)}. Ensemble results may be affected.")
    # Adjust fold_f1_scores if models are missing? Or handle in ensemble_predict?
    # For simplicity now, we'll assume the ensemble function can handle potentially fewer models if needed.

## Load Test Data and Perform Predictions

In [None]:
test_data = None
X_test = None
y_test = None

test_data = load_cleaned_data(config["data"]["test_data_path"])
X_test = test_data["full_text"]
y_test = test_data["label"]
print(f"Loaded test data: {len(X_test)} samples.")

y_pred = []
y_pred_proba = []

print("\nStarting ensemble predictions on test data...")
for text in tqdm(X_test, desc="Predicting"): 
    inputs = tokenizer(
        text,
        max_length=config["data"]["max_length"],
        truncation=config["tokenizer"]["truncation"],
        padding=config["tokenizer"]["padding"], # Should match config, often 'max_length'
        add_special_tokens=config["tokenizer"]["add_special_tokens"],
        return_tensors=config["tokenizer"]["return_tensors"]
    ).to(device)
    
    weighted_probs = weighted_ensemble_predict(inputs, fold_models, fold_f1_scores, device)
    pred_label = torch.argmax(weighted_probs, dim=1).item()
    confidence = weighted_probs[0][1].item()
    
    y_pred.append(pred_label)
    y_pred_proba.append(confidence)
    

torch.cuda.empty_cache()
y_pred = np.array(y_pred)
y_pred_proba = np.array(y_pred_proba)
print("Ensemble predictions finished.")

## Evaluate Ensemble Performance

In [None]:
print("\nEnsemble Model Evaluation on Test Set:")

metrics(y_test, y_pred, y_pred_proba, print_metrics=True)

plot_confusion_matrix(y_test, y_pred)
plt.suptitle(f'Ensemble Confusion Matrix (Run ID: {RUN_ID_TO_LOAD})', y=1.02) # Add title
plt.show()