In [1]:
import torch
import json
import os
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
class ModelEvaluator:
    def __init__(self, model: torch.nn.Module, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        self.model = model.to(device)
        self.device = device
        self.results = {}

    def evaluate(self, dataloader, num_labels=2, label_names=None):
        """
        Evaluates a PyTorch model on a given DataLoader.
        """
        self.model.eval()
        y_true, y_pred = [], []

        with torch.no_grad():
            for batch in dataloader:
                input_ids      = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels         = batch["label"].to(self.device)

                logits = self.model(input_ids, attn_mask=attention_mask)
                preds = torch.argmax(logits, dim=1)

                y_true.extend(labels.cpu().tolist())
                y_pred.extend(preds.cpu().tolist())

        self._compute_metrics(y_true, y_pred, num_labels=num_labels, label_names=label_names)

    def _compute_metrics(self, y_true, y_pred, num_labels=2, label_names=None):
        """
        Compute classification metrics.
        """
        accuracy_val = accuracy_score(y_true, y_pred)
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', labels=list(range(num_labels)))

        self.results = {
            "accuracy": accuracy_val,
            "precision": precision,
            "recall": recall,
            "f1_score": f1,
            "report": classification_report(y_true, y_pred, target_names=label_names, output_dict=True),
            "confusion_matrix": confusion_matrix(y_true, y_pred).tolist()
        }
        
    def plot_auc_roc(self, y_true, y_scores):
        """
        Plots the AUC-ROC curve given true labels and predicted scores.
        """
        fpr, tpr, _ = roc_curve(y_true, y_scores)
        roc_auc = auc(fpr, tpr)

        plt.figure()
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])    
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')  # ✅ fixed
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc='lower right')
        plt.grid(True)
        plt.show()

    def save_results(self, filepath="evaluation_metrics/results.json"):
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        with open(filepath, 'w') as f:
            json.dump(self.results, f, indent=4)
        print(f"Saved evaluation results to {filepath}")

    def print_summary(self):
        print("=== Evaluation Summary ===")
        print(f"Accuracy:  {self.results.get('accuracy', 0):.4f}")
        print(f"Precision: {self.results.get('precision', 0):.4f}")
        print(f"Recall:    {self.results.get('recall', 0):.4f}")
        print(f"F1 Score:  {self.results.get('f1_score', 0):.4f}")
        print("==========================")

    def plot_confusion_matrix(self, label_names=None):
        cm = self.results.get("confusion_matrix")
        if cm is None:
            print("Confusion matrix not available.")
            return
        plt.figure(figsize=(8,6))
        sns.heatmap(cm, annot=True, fmt="d", xticklabels=label_names, yticklabels=label_names, cmap="Blues")
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.title("Confusion Matrix")
        plt.show()
