In [2]:
"""
DeBERTa v3 Comprehensive Model Testing Script
Tests saved model on test set with extensive metrics and visualizations
"""

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import DebertaV2Model, DebertaV2Config
from sklearn.metrics import (
    f1_score, precision_score, recall_score, accuracy_score,
    confusion_matrix, classification_report, roc_auc_score,
    matthews_corrcoef, cohen_kappa_score, balanced_accuracy_score,
    roc_curve, auc, precision_recall_curve, average_precision_score
)
from tqdm.auto import tqdm
import numpy as np
import json
import time
import warnings
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import safetensors.torch
from datetime import datetime

warnings.filterwarnings('ignore')

# ==================== CONFIGURATION ====================
DATA_DIR = '/content/drive/MyDrive/SuperEmotion/'
MODEL_DIR = '/content/drive/MyDrive/BestModelSave/best_model/'
RESULTS_DIR = '/content/drive/MyDrive/ModelTestResults/'
TEST_DATA_PATH = '/content/drive/MyDrive/SuperEmotion/tokenized_test.pt'
METADATA_PATH = '/content/drive/MyDrive/SuperEmotion/metadata.json'

BATCH_SIZE = 128
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
USE_MIXED_PRECISION = True

print("="*90)
print(" "*20 + "üî¨ DeBERTa v3 Comprehensive Model Testing üî¨")
print(" "*25 + "SuperEmotion - 7 Emotions")
print("="*90)
print(f"\nüñ•Ô∏è  Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

# ==================== MOUNT DRIVE ====================
try:
    from google.colab import drive
    if not os.path.exists('/content/drive'):
        print("\nüîó Mounting Google Drive...")
        drive.mount('/content/drive')
        print("‚úÖ Drive mounted!")
    else:
        print("\n‚úÖ Drive already mounted!")
except:
    print("\n‚ö†Ô∏è  Not in Colab or Drive mounted")

# Create results directory with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
RESULTS_DIR = os.path.join(RESULTS_DIR, f'test_run_{timestamp}')
os.makedirs(RESULTS_DIR, exist_ok=True)
print(f"üìÅ Results will be saved to: {RESULTS_DIR}")

# ==================== MODEL CLASS ====================
class DeBERTaEmotionClassifier(nn.Module):
    """DeBERTa v3 with classification head"""
    def __init__(self, config, num_labels):
        super(DeBERTaEmotionClassifier, self).__init__()
        self.deberta = DebertaV2Model(config)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        cls_output = sequence_output[:, 0, :]  # [CLS] token
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

# ==================== LOAD METADATA ====================
def load_metadata():
    """Load dataset metadata"""
    print("\n" + "="*90)
    print("STEP 1: LOADING METADATA")
    print("="*90)

    try:
        print(f"üìÇ Loading from: {METADATA_PATH}")
        with open(METADATA_PATH, 'r') as f:
            metadata = json.load(f)

        emotion_classes = metadata['emotion_classes']
        num_classes = metadata['num_classes']

        print(f"‚úÖ Metadata loaded!")
        print(f"   Emotions: {', '.join(emotion_classes)}")
        print(f"   Classes: {num_classes}")

        return metadata, emotion_classes, num_classes

    except Exception as e:
        print(f"‚ùå ERROR loading metadata: {e}")
        raise

# ==================== LOAD MODEL ====================
def load_model(num_classes):
    """Load DeBERTa model with classifier from saved files"""
    print("\n" + "="*90)
    print("STEP 2: LOADING MODEL")
    print("="*90)

    try:
        # Check if all required files exist
        config_path = os.path.join(MODEL_DIR, 'config.json')
        safetensors_path = os.path.join(MODEL_DIR, 'model.safetensors')
        classifier_path = os.path.join(MODEL_DIR, 'classifier.pt')

        for path, name in [(config_path, 'config.json'),
                           (safetensors_path, 'model.safetensors'),
                           (classifier_path, 'classifier.pt')]:
            if not os.path.exists(path):
                raise FileNotFoundError(f"{name} not found at {path}")

        print(f"‚úÖ All required files found in: {MODEL_DIR}")

        # Load config
        print(f"\nüìÇ Loading config...")
        with open(config_path, 'r') as f:
            config_dict = json.load(f)
        config = DebertaV2Config(**config_dict)
        print(f"‚úÖ Config loaded")
        print(f"   Hidden size: {config.hidden_size}")
        print(f"   Num layers: {config.num_hidden_layers}")
        print(f"   Num heads: {config.num_attention_heads}")

        # Initialize model
        print(f"\nü§ñ Initializing model architecture...")
        model = DeBERTaEmotionClassifier(config, num_classes)

        # Load DeBERTa weights from safetensors
        print(f"\nüìÇ Loading DeBERTa weights from safetensors...")
        state_dict = safetensors.torch.load_file(safetensors_path)

        # Load weights into deberta module
        model.deberta.load_state_dict(state_dict, strict=True)
        print(f"‚úÖ DeBERTa weights loaded successfully")

        # Load classifier weights
        print(f"\nüìÇ Loading classifier head...")
        classifier_checkpoint = torch.load(classifier_path, map_location='cpu', weights_only=False)

        # Load classifier and dropout states
        model.classifier.load_state_dict(classifier_checkpoint['classifier_state_dict'])
        model.dropout.load_state_dict(classifier_checkpoint['dropout_state_dict'])
        print(f"‚úÖ Classifier head loaded successfully")

        # Move to device and set to eval mode
        model.to(DEVICE)
        model.eval()

        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

        print(f"\n‚úÖ Model ready!")
        print(f"   Total parameters: {total_params:,}")
        print(f"   Trainable parameters: {trainable_params:,}")
        print(f"   Model size: ~{total_params * 4 / 1024**2:.2f} MB")
        print(f"   Device: {DEVICE}")

        # Load training metrics if available
        metrics_path = os.path.join(MODEL_DIR, 'metrics.json')
        training_metrics = None
        if os.path.exists(metrics_path):
            with open(metrics_path, 'r') as f:
                training_metrics = json.load(f)
            print(f"\nüìä Training metrics found:")
            if 'accuracy' in training_metrics:
                print(f"   Val Accuracy: {training_metrics['accuracy']:.4f}")
            if 'f1_macro' in training_metrics:
                print(f"   Val F1 Macro: {training_metrics['f1_macro']:.4f}")

        return model, training_metrics

    except Exception as e:
        print(f"\n‚ùå ERROR loading model: {e}")
        import traceback
        traceback.print_exc()
        raise

# ==================== LOAD TEST DATA ====================
def load_test_data():
    """Load test dataset"""
    print("\n" + "="*90)
    print("STEP 3: LOADING TEST DATA")
    print("="*90)

    try:
        print(f"üìÇ Loading from: {TEST_DATA_PATH}")
        test_data = torch.load(TEST_DATA_PATH, weights_only=False)

        print(f"‚úÖ Test data loaded!")
        print(f"   Samples: {test_data['input_ids'].shape[0]:,}")
        print(f"   Max length: {test_data['input_ids'].shape[1]}")

        # Create dataset
        test_dataset = TensorDataset(
            test_data['input_ids'],
            test_data['attention_mask'],
            test_data['labels']
        )

        # Create dataloader
        test_loader = DataLoader(
            test_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            pin_memory=True,
            num_workers=0
        )

        print(f"‚úÖ DataLoader created!")
        print(f"   Batches: {len(test_loader):,}")
        print(f"   Batch size: {BATCH_SIZE}")

        return test_loader, test_data

    except Exception as e:
        print(f"‚ùå ERROR loading test data: {e}")
        raise

# ==================== COMPREHENSIVE EVALUATION ====================
def evaluate_model(model, test_loader, emotion_classes):
    """Comprehensive model evaluation with all metrics"""
    print("\n" + "="*90)
    print("STEP 4: EVALUATING MODEL")
    print("="*90)

    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    total_loss = 0

    criterion = nn.CrossEntropyLoss()

    print("\nüîç Running inference on test set...")
    progress_bar = tqdm(test_loader, desc="Testing", unit="batch", colour="blue")

    with torch.no_grad():
        for batch in progress_bar:
            try:
                input_ids = batch[0].to(DEVICE)
                attention_mask = batch[1].to(DEVICE)
                labels = batch[2].to(DEVICE)

                with torch.cuda.amp.autocast(enabled=USE_MIXED_PRECISION):
                    logits = model(input_ids, attention_mask)
                    loss = criterion(logits, labels)
                    probs = torch.softmax(logits, dim=-1)

                total_loss += loss.item()
                preds = torch.argmax(logits, dim=1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

                progress_bar.set_postfix({'loss': f'{total_loss / len(all_preds) * BATCH_SIZE:.4f}'})

            except Exception as e:
                print(f"\n‚ö†Ô∏è  Error during evaluation: {e}")
                continue

    # Convert to numpy
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    avg_loss = total_loss / len(test_loader)

    print(f"\n‚úÖ Inference complete!")
    print(f"   Total samples: {len(all_preds):,}")
    print(f"   Average loss: {avg_loss:.4f}")

    # Calculate all metrics
    print("\nüìä Computing metrics...")

    # Basic metrics
    accuracy = accuracy_score(all_labels, all_preds)
    balanced_acc = balanced_accuracy_score(all_labels, all_preds)

    # F1 scores
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1_micro = f1_score(all_labels, all_preds, average='micro', zero_division=0)
    f1_per_class = f1_score(all_labels, all_preds, average=None, zero_division=0)

    # Precision
    precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    precision_weighted = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    precision_micro = precision_score(all_labels, all_preds, average='micro', zero_division=0)
    precision_per_class = precision_score(all_labels, all_preds, average=None, zero_division=0)

    # Recall
    recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    recall_weighted = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall_micro = recall_score(all_labels, all_preds, average='micro', zero_division=0)
    recall_per_class = recall_score(all_labels, all_preds, average=None, zero_division=0)

    # Advanced metrics
    mcc = matthews_corrcoef(all_labels, all_preds)
    kappa = cohen_kappa_score(all_labels, all_preds)

    # AUC-ROC
    try:
        auc_roc_macro = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro')
        auc_roc_weighted = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='weighted')
        auc_roc_per_class = roc_auc_score(all_labels, all_probs, multi_class='ovr', average=None)
    except:
        auc_roc_macro = 0.0
        auc_roc_weighted = 0.0
        auc_roc_per_class = np.zeros(len(emotion_classes))

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    # Classification report
    class_report = classification_report(
        all_labels, all_preds,
        target_names=emotion_classes,
        zero_division=0,
        output_dict=True
    )

    print("‚úÖ All metrics computed!")

    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'f1_micro': f1_micro,
        'f1_per_class': f1_per_class,
        'precision_macro': precision_macro,
        'precision_weighted': precision_weighted,
        'precision_micro': precision_micro,
        'precision_per_class': precision_per_class,
        'recall_macro': recall_macro,
        'recall_weighted': recall_weighted,
        'recall_micro': recall_micro,
        'recall_per_class': recall_per_class,
        'mcc': mcc,
        'kappa': kappa,
        'auc_roc_macro': auc_roc_macro,
        'auc_roc_weighted': auc_roc_weighted,
        'auc_roc_per_class': auc_roc_per_class,
        'confusion_matrix': cm,
        'confusion_matrix_normalized': cm_normalized,
        'classification_report': class_report,
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs
    }

# ==================== PRINT RESULTS ====================
def print_results(results, emotion_classes):
    """Print comprehensive results"""
    print("\n" + "="*90)
    print("TEST RESULTS")
    print("="*90)

    print(f"\nüìä Overall Metrics:")
    print(f"   Loss:               {results['loss']:.4f}")
    print(f"   Accuracy:           {results['accuracy']:.4f}")
    print(f"   Balanced Accuracy:  {results['balanced_accuracy']:.4f}")
    print(f"   F1 Macro:           {results['f1_macro']:.4f}")
    print(f"   F1 Weighted:        {results['f1_weighted']:.4f}")
    print(f"   F1 Micro:           {results['f1_micro']:.4f}")
    print(f"   Precision Macro:    {results['precision_macro']:.4f}")
    print(f"   Precision Weighted: {results['precision_weighted']:.4f}")
    print(f"   Recall Macro:       {results['recall_macro']:.4f}")
    print(f"   Recall Weighted:    {results['recall_weighted']:.4f}")
    print(f"   MCC:                {results['mcc']:.4f}")
    print(f"   Cohen's Kappa:      {results['kappa']:.4f}")
    print(f"   AUC-ROC Macro:      {results['auc_roc_macro']:.4f}")
    print(f"   AUC-ROC Weighted:   {results['auc_roc_weighted']:.4f}")

    print(f"\nüìä Per-Class Metrics:")
    print(f"{'Emotion':<12} {'F1':>6} {'Precision':>10} {'Recall':>8} {'AUC-ROC':>8}")
    print("-" * 50)
    for i, emotion in enumerate(emotion_classes):
        print(f"{emotion:<12} {results['f1_per_class'][i]:>6.4f} "
              f"{results['precision_per_class'][i]:>10.4f} "
              f"{results['recall_per_class'][i]:>8.4f} "
              f"{results['auc_roc_per_class'][i]:>8.4f}")

    print(f"\nüìä Confusion Matrix:")
    print(results['confusion_matrix'])

# ==================== SAVE RESULTS ====================
def save_results(results, emotion_classes, training_metrics=None):
    """Save all results and create visualizations"""
    print("\n" + "="*90)
    print("STEP 5: SAVING RESULTS")
    print("="*90)

    # Save metrics as JSON
    metrics_dict = {
        'test_metrics': {
            'loss': float(results['loss']),
            'accuracy': float(results['accuracy']),
            'balanced_accuracy': float(results['balanced_accuracy']),
            'f1_macro': float(results['f1_macro']),
            'f1_weighted': float(results['f1_weighted']),
            'f1_micro': float(results['f1_micro']),
            'precision_macro': float(results['precision_macro']),
            'precision_weighted': float(results['precision_weighted']),
            'precision_micro': float(results['precision_micro']),
            'recall_macro': float(results['recall_macro']),
            'recall_weighted': float(results['recall_weighted']),
            'recall_micro': float(results['recall_micro']),
            'mcc': float(results['mcc']),
            'kappa': float(results['kappa']),
            'auc_roc_macro': float(results['auc_roc_macro']),
            'auc_roc_weighted': float(results['auc_roc_weighted']),
        },
        'per_class_metrics': {
            emotion: {
                'f1': float(results['f1_per_class'][i]),
                'precision': float(results['precision_per_class'][i]),
                'recall': float(results['recall_per_class'][i]),
                'auc_roc': float(results['auc_roc_per_class'][i])
            }
            for i, emotion in enumerate(emotion_classes)
        },
        'confusion_matrix': results['confusion_matrix'].tolist(),
        'confusion_matrix_normalized': results['confusion_matrix_normalized'].tolist(),
        'classification_report': results['classification_report'],
        'training_metrics': training_metrics
    }

    metrics_path = os.path.join(RESULTS_DIR, 'test_metrics.json')
    with open(metrics_path, 'w') as f:
        json.dump(metrics_dict, f, indent=2)
    print(f"‚úÖ Metrics saved: {metrics_path}")

    # Save detailed CSV
    csv_data = []
    for i, emotion in enumerate(emotion_classes):
        csv_data.append({
            'emotion': emotion,
            'f1_score': results['f1_per_class'][i],
            'precision': results['precision_per_class'][i],
            'recall': results['recall_per_class'][i],
            'auc_roc': results['auc_roc_per_class'][i],
            'support': results['confusion_matrix'][i].sum()
        })

    df = pd.DataFrame(csv_data)
    csv_path = os.path.join(RESULTS_DIR, 'per_class_metrics.csv')
    df.to_csv(csv_path, index=False)
    print(f"‚úÖ CSV saved: {csv_path}")

    # Create visualizations
    print("\nüìä Creating visualizations...")
    create_visualizations(results, emotion_classes, RESULTS_DIR)
    print("‚úÖ All visualizations saved!")

# ==================== CREATE VISUALIZATIONS ====================
def create_visualizations(results, emotion_classes, save_dir):
    """Create comprehensive visualizations"""

    # Figure 1: Confusion Matrices
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    # Raw confusion matrix
    sns.heatmap(results['confusion_matrix'], annot=True, fmt='d', cmap='Blues',
                xticklabels=emotion_classes, yticklabels=emotion_classes, ax=axes[0])
    axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
    axes[0].set_ylabel('True Label', fontsize=12)
    axes[0].set_xlabel('Predicted Label', fontsize=12)

    # Normalized confusion matrix
    sns.heatmap(results['confusion_matrix_normalized'], annot=True, fmt='.2%', cmap='Blues',
                xticklabels=emotion_classes, yticklabels=emotion_classes, ax=axes[1])
    axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
    axes[1].set_ylabel('True Label', fontsize=12)
    axes[1].set_xlabel('Predicted Label', fontsize=12)

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'confusion_matrices.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Figure 2: Per-class metrics
    fig, ax = plt.subplots(figsize=(12, 6))
    x = np.arange(len(emotion_classes))
    width = 0.2

    ax.bar(x - 1.5*width, results['f1_per_class'], width, label='F1 Score', alpha=0.8)
    ax.bar(x - 0.5*width, results['precision_per_class'], width, label='Precision', alpha=0.8)
    ax.bar(x + 0.5*width, results['recall_per_class'], width, label='Recall', alpha=0.8)
    ax.bar(x + 1.5*width, results['auc_roc_per_class'], width, label='AUC-ROC', alpha=0.8)

    ax.set_xlabel('Emotion', fontsize=12, fontweight='bold')
    ax.set_ylabel('Score', fontsize=12, fontweight='bold')
    ax.set_title('Per-Class Metrics Comparison', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(emotion_classes, rotation=45, ha='right')
    ax.legend(loc='lower right')
    ax.grid(axis='y', alpha=0.3)
    ax.set_ylim([0, 1.05])

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'per_class_metrics.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Figure 3: Overall metrics
    fig, ax = plt.subplots(figsize=(10, 6))
    metrics = ['Accuracy', 'Balanced\nAccuracy', 'F1\nMacro', 'F1\nWeighted',
               'Precision\nMacro', 'Recall\nMacro', 'MCC', 'Kappa', 'AUC-ROC\nMacro']
    values = [results['accuracy'], results['balanced_accuracy'], results['f1_macro'],
              results['f1_weighted'], results['precision_macro'], results['recall_macro'],
              results['mcc'], results['kappa'], results['auc_roc_macro']]

    bars = ax.bar(metrics, values, color='skyblue', alpha=0.8, edgecolor='navy')
    ax.set_ylabel('Score', fontsize=12, fontweight='bold')
    ax.set_title('Overall Model Performance', fontsize=14, fontweight='bold')
    ax.set_ylim([0, 1.05])
    ax.grid(axis='y', alpha=0.3)

    # Add value labels
    for bar, value in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{value:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')

    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'overall_metrics.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Figure 4: ROC Curves (one vs rest for each class)
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.ravel()

    labels_bin = np.eye(len(emotion_classes))[results['labels']]

    for i, emotion in enumerate(emotion_classes):
        fpr, tpr, _ = roc_curve(labels_bin[:, i], results['probabilities'][:, i])
        roc_auc = auc(fpr, tpr)

        axes[i].plot(fpr, tpr, color='darkorange', lw=2,
                     label=f'ROC curve (AUC = {roc_auc:.3f})')
        axes[i].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
        axes[i].set_xlim([0.0, 1.0])
        axes[i].set_ylim([0.0, 1.05])
        axes[i].set_xlabel('False Positive Rate')
        axes[i].set_ylabel('True Positive Rate')
        axes[i].set_title(f'ROC - {emotion}')
        axes[i].legend(loc="lower right")
        axes[i].grid(alpha=0.3)

    # Hide the last subplot if odd number of classes
    if len(emotion_classes) < 8:
        axes[-1].axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'roc_curves.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Figure 5: Precision-Recall Curves
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.ravel()

    for i, emotion in enumerate(emotion_classes):
        precision, recall, _ = precision_recall_curve(labels_bin[:, i], results['probabilities'][:, i])
        avg_precision = average_precision_score(labels_bin[:, i], results['probabilities'][:, i])

        axes[i].plot(recall, precision, color='darkblue', lw=2,
                     label=f'AP = {avg_precision:.3f}')
        axes[i].set_xlim([0.0, 1.0])
        axes[i].set_ylim([0.0, 1.05])
        axes[i].set_xlabel('Recall')
        axes[i].set_ylabel('Precision')
        axes[i].set_title(f'Precision-Recall - {emotion}')
        axes[i].legend(loc="lower left")
        axes[i].grid(alpha=0.3)

    # Hide the last subplot if odd number of classes
    if len(emotion_classes) < 8:
        axes[-1].axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'precision_recall_curves.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Figure 6: Error Analysis - Misclassification patterns
    fig, ax = plt.subplots(figsize=(12, 10))

    # Create a matrix showing where errors occur
    error_matrix = results['confusion_matrix'].copy()
    np.fill_diagonal(error_matrix, 0)  # Remove correct predictions

    sns.heatmap(error_matrix, annot=True, fmt='d', cmap='Reds',
                xticklabels=emotion_classes, yticklabels=emotion_classes, ax=ax)
    ax.set_title('Misclassification Pattern (Errors Only)', fontsize=14, fontweight='bold')
    ax.set_ylabel('True Label', fontsize=12)
    ax.set_xlabel('Predicted Label', fontsize=12)

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'error_analysis.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Figure 7: Class-wise accuracy
    fig, ax = plt.subplots(figsize=(10, 6))

    class_accuracies = np.diag(results['confusion_matrix_normalized'])
    colors = ['green' if acc > 0.9 else 'orange' if acc > 0.8 else 'red' for acc in class_accuracies]

    bars = ax.barh(emotion_classes, class_accuracies, color=colors, alpha=0.7)
    ax.set_xlabel('Accuracy', fontsize=12, fontweight='bold')
    ax.set_title('Per-Class Accuracy', fontsize=14, fontweight='bold')
    ax.set_xlim([0, 1.0])
    ax.grid(axis='x', alpha=0.3)

    # Add value labels
    for i, (bar, acc) in enumerate(zip(bars, class_accuracies)):
        ax.text(acc + 0.01, bar.get_y() + bar.get_height()/2,
                f'{acc:.3f}', va='center', fontsize=10, fontweight='bold')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'class_accuracies.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Figure 8: Comparison with training metrics (if available)
    # This will be handled in the main function if training metrics exist

# ==================== CREATE COMPARISON PLOT ====================
def create_train_test_comparison(test_results, training_metrics, emotion_classes, save_dir):
    """Create comparison between training validation and test results"""
    if training_metrics is None:
        return

    print("\nüìä Creating train/test comparison...")

    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # Overall metrics comparison
    ax = axes[0, 0]
    metrics = ['Accuracy', 'F1 Macro', 'F1 Weighted', 'Precision\nMacro', 'Recall\nMacro']

    val_values = [
        training_metrics.get('accuracy', 0),
        training_metrics.get('f1_macro', 0),
        training_metrics.get('f1_weighted', 0),
        training_metrics.get('precision_macro', 0),
        training_metrics.get('recall_macro', 0)
    ]

    test_values = [
        test_results['accuracy'],
        test_results['f1_macro'],
        test_results['f1_weighted'],
        test_results['precision_macro'],
        test_results['recall_macro']
    ]

    x = np.arange(len(metrics))
    width = 0.35

    ax.bar(x - width/2, val_values, width, label='Validation', alpha=0.8)
    ax.bar(x + width/2, test_values, width, label='Test', alpha=0.8)

    ax.set_ylabel('Score', fontsize=12, fontweight='bold')
    ax.set_title('Validation vs Test Performance', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics, rotation=45, ha='right')
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    ax.set_ylim([0, 1.05])

    # Per-class F1 comparison
    ax = axes[0, 1]

    val_f1 = training_metrics.get('f1_per_class', [0] * len(emotion_classes))
    test_f1 = test_results['f1_per_class']

    x = np.arange(len(emotion_classes))
    width = 0.35

    ax.bar(x - width/2, val_f1, width, label='Validation', alpha=0.8)
    ax.bar(x + width/2, test_f1, width, label='Test', alpha=0.8)

    ax.set_ylabel('F1 Score', fontsize=12, fontweight='bold')
    ax.set_title('Per-Class F1: Validation vs Test', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(emotion_classes, rotation=45, ha='right')
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    ax.set_ylim([0, 1.05])

    # Performance difference
    ax = axes[1, 0]

    diff = np.array(test_f1) - np.array(val_f1)
    colors = ['green' if d >= 0 else 'red' for d in diff]

    bars = ax.bar(emotion_classes, diff, color=colors, alpha=0.7)
    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    ax.set_ylabel('F1 Difference (Test - Val)', fontsize=12, fontweight='bold')
    ax.set_title('Performance Change: Validation to Test', fontsize=14, fontweight='bold')
    ax.grid(axis='y', alpha=0.3)
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')

    # Add value labels
    for bar, value in zip(bars, diff):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{value:+.3f}', ha='center',
                va='bottom' if height > 0 else 'top', fontsize=9, fontweight='bold')

    # Summary text
    ax = axes[1, 1]
    ax.axis('off')

    summary_text = f"""
    VALIDATION vs TEST SUMMARY
    {'='*40}

    Overall Metrics:
    ‚Ä¢ Accuracy:     Val={val_values[0]:.4f}  Test={test_values[0]:.4f}  Œî={test_values[0]-val_values[0]:+.4f}
    ‚Ä¢ F1 Macro:     Val={val_values[1]:.4f}  Test={test_values[1]:.4f}  Œî={test_values[1]-val_values[1]:+.4f}
    ‚Ä¢ F1 Weighted:  Val={val_values[2]:.4f}  Test={test_values[2]:.4f}  Œî={test_values[2]-val_values[2]:+.4f}

    Best Performing Classes (Test):
    """

    # Add top 3 classes
    top_3_idx = np.argsort(test_f1)[-3:][::-1]
    for idx in top_3_idx:
        summary_text += f"    ‚Ä¢ {emotion_classes[idx]}: {test_f1[idx]:.4f}\n"

    summary_text += "\n    Most Improved Classes (Val‚ÜíTest):\n"
    top_improved = np.argsort(diff)[-3:][::-1]
    for idx in top_improved:
        if diff[idx] > 0:
            summary_text += f"    ‚Ä¢ {emotion_classes[idx]}: {diff[idx]:+.4f}\n"

    ax.text(0.1, 0.9, summary_text, fontsize=11, family='monospace',
            verticalalignment='top', transform=ax.transAxes)

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'train_test_comparison.png'), dpi=300, bbox_inches='tight')
    plt.close()

    print("‚úÖ Comparison plot saved!")

# ==================== CREATE SUMMARY REPORT ====================
def create_summary_report(results, emotion_classes, training_metrics, save_dir):
    """Create a comprehensive text summary report"""

    report_path = os.path.join(save_dir, 'test_report.txt')

    with open(report_path, 'w') as f:
        f.write("="*90 + "\n")
        f.write(" "*20 + "DeBERTa v3 MODEL TEST REPORT\n")
        f.write(" "*25 + "SuperEmotion Dataset\n")
        f.write("="*90 + "\n\n")

        f.write(f"Test Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Model: DeBERTa v3 Base + Classification Head\n")
        f.write(f"Dataset: 7 Emotion Classes\n")
        f.write(f"Test Samples: {len(results['labels']):,}\n\n")

        f.write("="*90 + "\n")
        f.write("OVERALL PERFORMANCE\n")
        f.write("="*90 + "\n\n")

        f.write(f"Loss:                 {results['loss']:.6f}\n")
        f.write(f"Accuracy:             {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)\n")
        f.write(f"Balanced Accuracy:    {results['balanced_accuracy']:.4f} ({results['balanced_accuracy']*100:.2f}%)\n\n")

        f.write("F1 Scores:\n")
        f.write(f"  ‚Ä¢ Macro:            {results['f1_macro']:.4f}\n")
        f.write(f"  ‚Ä¢ Weighted:         {results['f1_weighted']:.4f}\n")
        f.write(f"  ‚Ä¢ Micro:            {results['f1_micro']:.4f}\n\n")

        f.write("Precision:\n")
        f.write(f"  ‚Ä¢ Macro:            {results['precision_macro']:.4f}\n")
        f.write(f"  ‚Ä¢ Weighted:         {results['precision_weighted']:.4f}\n")
        f.write(f"  ‚Ä¢ Micro:            {results['precision_micro']:.4f}\n\n")

        f.write("Recall:\n")
        f.write(f"  ‚Ä¢ Macro:            {results['recall_macro']:.4f}\n")
        f.write(f"  ‚Ä¢ Weighted:         {results['recall_weighted']:.4f}\n")
        f.write(f"  ‚Ä¢ Micro:            {results['recall_micro']:.4f}\n\n")

        f.write("Advanced Metrics:\n")
        f.write(f"  ‚Ä¢ Matthews Correlation Coefficient: {results['mcc']:.4f}\n")
        f.write(f"  ‚Ä¢ Cohen's Kappa:                    {results['kappa']:.4f}\n")
        f.write(f"  ‚Ä¢ AUC-ROC (Macro):                  {results['auc_roc_macro']:.4f}\n")
        f.write(f"  ‚Ä¢ AUC-ROC (Weighted):               {results['auc_roc_weighted']:.4f}\n\n")

        f.write("="*90 + "\n")
        f.write("PER-CLASS PERFORMANCE\n")
        f.write("="*90 + "\n\n")

        f.write(f"{'Emotion':<12} {'F1':>8} {'Precision':>10} {'Recall':>8} {'AUC-ROC':>8} {'Support':>8}\n")
        f.write("-"*70 + "\n")

        for i, emotion in enumerate(emotion_classes):
            support = results['confusion_matrix'][i].sum()
            f.write(f"{emotion:<12} "
                   f"{results['f1_per_class'][i]:>8.4f} "
                   f"{results['precision_per_class'][i]:>10.4f} "
                   f"{results['recall_per_class'][i]:>8.4f} "
                   f"{results['auc_roc_per_class'][i]:>8.4f} "
                   f"{support:>8}\n")

        f.write("\n" + "="*90 + "\n")
        f.write("BEST & WORST PERFORMING CLASSES\n")
        f.write("="*90 + "\n\n")

        # Best performing
        best_idx = np.argsort(results['f1_per_class'])[-3:][::-1]
        f.write("Top 3 Classes by F1 Score:\n")
        for rank, idx in enumerate(best_idx, 1):
            f.write(f"  {rank}. {emotion_classes[idx]:<12} F1={results['f1_per_class'][idx]:.4f}\n")

        # Worst performing
        worst_idx = np.argsort(results['f1_per_class'])[:3]
        f.write("\nBottom 3 Classes by F1 Score:\n")
        for rank, idx in enumerate(worst_idx, 1):
            f.write(f"  {rank}. {emotion_classes[idx]:<12} F1={results['f1_per_class'][idx]:.4f}\n")

        f.write("\n" + "="*90 + "\n")
        f.write("CONFUSION MATRIX\n")
        f.write("="*90 + "\n\n")

        # Print confusion matrix
        f.write("Raw Counts:\n")
        f.write(f"{'':>12} ")
        for emotion in emotion_classes:
            f.write(f"{emotion[:8]:>8} ")
        f.write("\n")

        for i, emotion in enumerate(emotion_classes):
            f.write(f"{emotion[:12]:>12} ")
            for j in range(len(emotion_classes)):
                f.write(f"{results['confusion_matrix'][i][j]:>8} ")
            f.write("\n")

        f.write("\nNormalized (%):\n")
        f.write(f"{'':>12} ")
        for emotion in emotion_classes:
            f.write(f"{emotion[:8]:>8} ")
        f.write("\n")

        for i, emotion in enumerate(emotion_classes):
            f.write(f"{emotion[:12]:>12} ")
            for j in range(len(emotion_classes)):
                f.write(f"{results['confusion_matrix_normalized'][i][j]*100:>7.1f}% ")
            f.write("\n")

        # Comparison with validation if available
        if training_metrics:
            f.write("\n" + "="*90 + "\n")
            f.write("VALIDATION vs TEST COMPARISON\n")
            f.write("="*90 + "\n\n")

            val_acc = training_metrics.get('accuracy', 0)
            val_f1 = training_metrics.get('f1_macro', 0)

            f.write(f"Accuracy:     Val={val_acc:.4f}  Test={results['accuracy']:.4f}  "
                   f"Œî={results['accuracy']-val_acc:+.4f}\n")
            f.write(f"F1 Macro:     Val={val_f1:.4f}  Test={results['f1_macro']:.4f}  "
                   f"Œî={results['f1_macro']-val_f1:+.4f}\n\n")

            if 'f1_per_class' in training_metrics:
                val_f1_per_class = training_metrics['f1_per_class']
                f.write("Per-Class F1 Changes:\n")
                for i, emotion in enumerate(emotion_classes):
                    diff = results['f1_per_class'][i] - val_f1_per_class[i]
                    symbol = "‚Üë" if diff > 0 else "‚Üì" if diff < 0 else "="
                    f.write(f"  {emotion:<12} Val={val_f1_per_class[i]:.4f}  "
                           f"Test={results['f1_per_class'][i]:.4f}  {symbol} {abs(diff):.4f}\n")

        f.write("\n" + "="*90 + "\n")
        f.write("ERROR ANALYSIS\n")
        f.write("="*90 + "\n\n")

        # Most common misclassifications
        error_matrix = results['confusion_matrix'].copy()
        np.fill_diagonal(error_matrix, 0)

        f.write("Top 5 Most Common Misclassifications:\n")
        flat_errors = []
        for i in range(len(emotion_classes)):
            for j in range(len(emotion_classes)):
                if i != j and error_matrix[i][j] > 0:
                    flat_errors.append((error_matrix[i][j], i, j))

        flat_errors.sort(reverse=True)
        for rank, (count, true_idx, pred_idx) in enumerate(flat_errors[:5], 1):
            total = results['confusion_matrix'][true_idx].sum()
            pct = (count / total) * 100
            f.write(f"  {rank}. {emotion_classes[true_idx]} ‚Üí {emotion_classes[pred_idx]}: "
                   f"{count} errors ({pct:.1f}%)\n")

        f.write("\n" + "="*90 + "\n")
        f.write("END OF REPORT\n")
        f.write("="*90 + "\n")

    print(f"‚úÖ Summary report saved: {report_path}")

# ==================== MAIN ====================
def main():
    """Main evaluation pipeline"""

    print("\n" + "="*90)
    print("üöÄ STARTING COMPREHENSIVE MODEL TESTING")
    print("="*90)

    start_time = time.time()

    try:
        # Load metadata
        metadata, emotion_classes, num_classes = load_metadata()

        # Load model
        model, training_metrics = load_model(num_classes)

        # Load test data
        test_loader, test_data = load_test_data()

        # Evaluate model
        results = evaluate_model(model, test_loader, emotion_classes)

        # Print results to console
        print_results(results, emotion_classes)

        # Save all results
        save_results(results, emotion_classes, training_metrics)

        # Create comparison with training if available
        if training_metrics:
            create_train_test_comparison(results, training_metrics, emotion_classes, RESULTS_DIR)

        # Create summary report
        create_summary_report(results, emotion_classes, training_metrics, RESULTS_DIR)

        # Save predictions for further analysis
        predictions_df = pd.DataFrame({
            'true_label': [emotion_classes[i] for i in results['labels']],
            'predicted_label': [emotion_classes[i] for i in results['predictions']],
            'true_label_id': results['labels'],
            'predicted_label_id': results['predictions'],
            'correct': results['labels'] == results['predictions']
        })

        # Add probability for each class
        for i, emotion in enumerate(emotion_classes):
            predictions_df[f'prob_{emotion}'] = results['probabilities'][:, i]

        predictions_path = os.path.join(RESULTS_DIR, 'predictions.csv')
        predictions_df.to_csv(predictions_path, index=False)
        print(f"\n‚úÖ Predictions saved: {predictions_path}")

        # Calculate and print execution time
        total_time = time.time() - start_time

        print("\n" + "="*90)
        print("üéâ TESTING COMPLETE!")
        print("="*90)
        print(f"\n‚è±Ô∏è  Total time: {total_time:.2f}s ({total_time/60:.2f} minutes)")
        print(f"\nüìä Key Results:")
        print(f"   ‚Ä¢ Accuracy: {results['accuracy']:.4f}")
        print(f"   ‚Ä¢ F1 Macro: {results['f1_macro']:.4f}")
        print(f"   ‚Ä¢ AUC-ROC:  {results['auc_roc_macro']:.4f}")
        print(f"\nüìÅ All results saved to: {RESULTS_DIR}")
        print(f"\nüìÑ Generated files:")
        print(f"   ‚Ä¢ test_metrics.json - All metrics in JSON format")
        print(f"   ‚Ä¢ per_class_metrics.csv - Detailed per-class metrics")
        print(f"   ‚Ä¢ predictions.csv - All predictions with probabilities")
        print(f"   ‚Ä¢ test_report.txt - Comprehensive text report")
        print(f"   ‚Ä¢ confusion_matrices.png - Raw and normalized confusion matrices")
        print(f"   ‚Ä¢ per_class_metrics.png - Per-class performance visualization")
        print(f"   ‚Ä¢ overall_metrics.png - Overall model performance")
        print(f"   ‚Ä¢ roc_curves.png - ROC curves for each class")
        print(f"   ‚Ä¢ precision_recall_curves.png - PR curves for each class")
        print(f"   ‚Ä¢ error_analysis.png - Misclassification patterns")
        print(f"   ‚Ä¢ class_accuracies.png - Per-class accuracy breakdown")
        if training_metrics:
            print(f"   ‚Ä¢ train_test_comparison.png - Validation vs Test comparison")

        print("\n" + "="*90)

    except Exception as e:
        print(f"\n‚ùå CRITICAL ERROR: {e}")
        import traceback
        traceback.print_exc()
        raise

if __name__ == "__main__":
    main()

                    üî¨ DeBERTa v3 Comprehensive Model Testing üî¨
                         SuperEmotion - 7 Emotions

üñ•Ô∏è  Device: cuda
   GPU: Tesla T4
   Memory: 14.74 GB

üîó Mounting Google Drive...
Mounted at /content/drive
‚úÖ Drive mounted!
üìÅ Results will be saved to: /content/drive/MyDrive/ModelTestResults/test_run_20251114_233004

üöÄ STARTING COMPREHENSIVE MODEL TESTING

STEP 1: LOADING METADATA
üìÇ Loading from: /content/drive/MyDrive/SuperEmotion/metadata.json
‚úÖ Metadata loaded!
   Emotions: anger, fear, joy, love, neutral, sadness, surprise
   Classes: 7

STEP 2: LOADING MODEL
‚úÖ All required files found in: /content/drive/MyDrive/BestModelSave/best_model/

üìÇ Loading config...
‚úÖ Config loaded
   Hidden size: 768
   Num layers: 12
   Num heads: 12

ü§ñ Initializing model architecture...

üìÇ Loading DeBERTa weights from safetensors...
‚úÖ DeBERTa weights loaded successfully

üìÇ Loading classifier head...
‚úÖ Classifier head loaded successfully

‚úÖ 

Testing:   0%|          | 0/149 [00:00<?, ?batch/s]


‚úÖ Inference complete!
   Total samples: 19,072
   Average loss: 0.2549

üìä Computing metrics...
‚úÖ All metrics computed!

TEST RESULTS

üìä Overall Metrics:
   Loss:               0.2549
   Accuracy:           0.9054
   Balanced Accuracy:  0.9022
   F1 Macro:           0.8992
   F1 Weighted:        0.9064
   F1 Micro:           0.9054
   Precision Macro:    0.8996
   Precision Weighted: 0.9107
   Recall Macro:       0.9022
   Recall Weighted:    0.9054
   MCC:                0.8897
   Cohen's Kappa:      0.8892
   AUC-ROC Macro:      0.9937
   AUC-ROC Weighted:   0.9942

üìä Per-Class Metrics:
Emotion          F1  Precision   Recall  AUC-ROC
--------------------------------------------------
anger        0.9195     0.8814   0.9610   0.9963
fear         0.9045     0.9538   0.8600   0.9959
joy          0.9401     0.9933   0.8923   0.9949
love         0.9261     0.9070   0.9460   0.9952
neutral      0.8402     0.8086   0.8744   0.9883
sadness      0.9409     0.9709   0.9127   0.99