In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Results Analysis\n",
    "\n",
    "This notebook provides tools for analyzing the results of deepfake detection models. It includes:\n",
    "\n",
    "1. Loading and visualizing evaluation metrics\n",
    "2. Comparative analysis of different models\n",
    "3. Error analysis and identification of challenging cases\n",
    "4. Cross-dataset generalization analysis\n",
    "5. ROC curves and precision-recall analysis\n",
    "6. Ensembling and fusion performance assessment\n",
    "\n",
    "These analyses help understand the strengths and weaknesses of different approaches and guide future improvements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import sys\n",
    "import json\n",
    "import yaml\n",
    "import glob\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from PIL import Image\n",
    "import cv2\n",
    "from pathlib import Path\n",
    "from sklearn.metrics import roc_curve, precision_recall_curve, auc\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# Add parent directory to path to enable imports from project\n",
    "sys.path.append(os.path.abspath('..'))\n",
    "\n",
    "# Set plot style\n",
    "plt.style.use('fivethirtyeight')\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Evaluation Results\n",
    "\n",
    "Load evaluation results from the evaluation output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Set the evaluation results directory - update to your local path\n",
    "EVAL_DIR = \"../evaluation_results\"\n",
    "\n",
    "# Check if directory exists\n",
    "if not os.path.exists(EVAL_DIR):\n",
    "    print(f\"Evaluation directory {EVAL_DIR} does not exist. Please update the path.\")\n",
    "else:\n",
    "    print(f\"Found evaluation directory at {EVAL_DIR}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_evaluation_results(eval_dir):\n",
    "    \"\"\"Load evaluation results from directory\"\"\"\n",
    "    if not os.path.exists(eval_dir):\n",
    "        return None\n",
    "    \n",
    "    results = {}\n",
    "    \n",
    "    # Find all model directories\n",
    "    model_dirs = [d for d in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, d))]\n",
    "    \n",
    "    for model_name in model_dirs:\n",
    "        model_path = os.path.join(eval_dir, model_name)\n",
    "        results_file = os.path.join(model_path, \"all_results.json\")\n",
    "        \n",
    "        if os.path.exists(results_file):\n",
    "            with open(results_file, 'r') as f:\n",
    "                model_results = json.load(f)\n",
    "            \n",
    "            # Store results\n",
    "            results[model_name] = model_results\n",
    "    \n",
    "    return results\n",
    "\n",
    "# Load results\n",
    "evaluation_results = load_evaluation_results(EVAL_DIR)\n",
    "\n",
    "if evaluation_results is None:\n",
    "    print(\"No evaluation results found.\")\n",
    "elif not evaluation_results:\n",
    "    print(\"No models found in evaluation results.\")\n",
    "else:\n",
    "    print(f\"Loaded evaluation results for {len(evaluation_results)} models:\")\n",
    "    for model_name in evaluation_results.keys():\n",
    "        print(f\"  - {model_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_raw_predictions(eval_dir):\n",
    "    \"\"\"Load raw predictions for detailed analysis\"\"\"\n",
    "    if not os.path.exists(eval_dir):\n",
    "        return None\n",
    "    \n",
    "    predictions = {}\n",
    "    \n",
    "    # Find all model directories\n",
    "    model_dirs = [d for d in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, d))]\n",
    "    \n",
    "    for model_name in model_dirs:\n",
    "        model_path = os.path.join(eval_dir, model_name)\n",
    "        model_predictions = {}\n",
    "        \n",
    "        # Find dataset directories\n",
    "        dataset_dirs = [d for d in os.listdir(model_path) \n",
    "                       if os.path.isdir(os.path.join(model_path, d))]\n",
    "        \n",
    "        for dataset_name in dataset_dirs:\n",
    "            dataset_path = os.path.join(model_path, dataset_name)\n",
    "            pred_file = os

In [None]:
## 5. Cross-Dataset Generalization Analysis

def analyze_cross_dataset_generalization(error_df):
    \"\"\"Analyze how well models generalize across datasets\"\"\"
    if error_df is None or error_df.empty:
        print("No error data to analyze.")
        return
    
    # Check if we have multiple datasets
    if len(error_df['dataset'].unique()) < 2:
        print("Need at least 2 datasets for cross-dataset analysis.")
        return
    
    # Create a list to store generalization data
    generalization_data = []
    
    # For each model
    for model in error_df['model'].unique():
        model_data = error_df[error_df['model'] == model]
        
        # For each pair of datasets
        for train_dataset in model_data['dataset'].unique():
            train_perf = model_data[model_data['dataset'] == train_dataset]
            
            for test_dataset in model_data['dataset'].unique():
                if train_dataset != test_dataset:
                    test_perf = model_data[model_data['dataset'] == test_dataset]
                    
                    # Skip if missing data
                    if train_perf.empty or test_perf.empty:
                        continue
                    
                    # Calculate performance drops
                    accuracy_drop = train_perf['accuracy'].values[0] - test_perf['accuracy'].values[0]
                    real_acc_drop = train_perf['real_accuracy'].values[0] - test_perf['real_accuracy'].values[0]
                    fake_acc_drop = train_perf['fake_accuracy'].values[0] - test_perf['fake_accuracy'].values[0]
                    
                    # Store results
                    generalization_data.append({
                        'model': model,
                        'train_dataset': train_dataset,
                        'test_dataset': test_dataset,
                        'train_accuracy': train_perf['accuracy'].values[0],
                        'test_accuracy': test_perf['accuracy'].values[0],
                        'accuracy_drop': accuracy_drop,
                        'real_acc_drop': real_acc_drop,
                        'fake_acc_drop': fake_acc_drop
                    })
    
    # Create DataFrame
    if generalization_data:
        gen_df = pd.DataFrame(generalization_data)
        return gen_df
    
    return None

def visualize_cross_dataset_generalization(gen_df):
    \"\"\"Visualize cross-dataset generalization results\"\"\"
    if gen_df is None or gen_df.empty:
        print("No generalization data to visualize.")
        return
    
    # Set up the figure
    fig = plt.figure(figsize=(20, 15))
    
    # 1. Accuracy drop across datasets
    ax1 = fig.add_subplot(2, 2, 1)
    
    # Plot
    sns.barplot(x='model', y='accuracy_drop', hue='test_dataset', data=gen_df, ax=ax1)
    ax1.set_title('Accuracy Drop when Testing on Different Datasets')
    ax1.set_xlabel('Model')
    ax1.set_ylabel('Accuracy Drop')
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, ha='right')
    
    # Add horizontal line at 0
    ax1.axhline(y=0, color='r', linestyle='--')
    
    # 2. Per-class accuracy drop
    ax2 = fig.add_subplot(2, 2, 2)
    
    # Melt the DataFrame
    plot_df = pd.melt(gen_df, 
                      id_vars=['model', 'test_dataset'], 
                      value_vars=['real_acc_drop', 'fake_acc_drop'],
                      var_name='Class', value_name='Accuracy Drop')
    
    # Plot
    sns.barplot(x='model', y='Accuracy Drop', hue='Class', data=plot_df, ax=ax2)
    ax2.set_title('Per-class Accuracy Drop')
    ax2.set_xlabel('Model')
    ax2.set_ylabel('Accuracy Drop')
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha='right')
    
    # Add horizontal line at 0
    ax2.axhline(y=0, color='r', linestyle='--')
    
    # 3. Train vs test accuracy scatter plot
    ax3 = fig.add_subplot(2, 2, 3)
    
    # Plot
    sns.scatterplot(x='train_accuracy', y='test_accuracy', hue='model', 
                   style='test_dataset', s=100, data=gen_df, ax=ax3)
    
    # Plot diagonal line (x=y)
    min_val = min(gen_df['train_accuracy'].min(), gen_df['test_accuracy'].min())
    max_val = max(gen_df['train_accuracy'].max(), gen_df['test_accuracy'].max())
    ax3.plot([min_val, max_val], [min_val, max_val], 'k--')
    
    ax3.set_title('Train vs Test Accuracy')
    ax3.set_xlabel('Train Accuracy')
    ax3.set_ylabel('Test Accuracy')
    ax3.grid(True, linestyle='--', alpha=0.7)
    
    # 4. Generalization ranking
    ax4 = fig.add_subplot(2, 2, 4)
    
    # Calculate average accuracy drop for each model
    model_drops = gen_df.groupby('model')['accuracy_drop'].mean().reset_index()
    model_drops = model_drops.sort_values('accuracy_drop')
    
    # Plot
    sns.barplot(x='model', y='accuracy_drop', data=model_drops, ax=ax4)
    ax4.set_title('Average Accuracy Drop (Lower is Better)')
    ax4.set_xlabel('Model')
    ax4.set_ylabel('Average Accuracy Drop')
    ax4.set_xticklabels(ax4.get_xticklabels(), rotation=45, ha='right')
    
    # Add horizontal line at 0
    ax4.axhline(y=0, color='r', linestyle='--')
    
    # Add overall title
    plt.suptitle('Cross-Dataset Generalization Analysis', fontsize=16)
    
    # Adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
    
    return model_drops

# Analyze cross-dataset generalization
if 'error_df' in locals() and error_df is not None and not error_df.empty:
    gen_df = analyze_cross_dataset_generalization(error_df)## 4. Error Analysis

def analyze_errors(raw_predictions):
    \"\"\"Analyze error patterns in model predictions\"\"\"
    if raw_predictions is None or not raw_predictions:
        print("No predictions to analyze.")
        return
    
    # Create a DataFrame to store error analysis results
    error_analysis = []
    
    for model_name, model_preds in raw_predictions.items():
        for dataset_name, data in model_preds.items():
            # Check if we have the necessary data
            if 'labels' in data and 'predictions' in data:
                labels = data['labels']
                predictions = data['predictions']
                
                # Calculate error statistics
                total = len(labels)
                correct = (predictions == labels).sum()
                incorrect = total - correct
                accuracy = correct / total
                
                # Per-class statistics
                real_samples = (labels == 0).sum()
                fake_samples = (labels == 1).sum()
                
                real_correct = ((predictions == 0) & (labels == 0)).sum()
                fake_correct = ((predictions == 1) & (labels == 1)).sum()
                
                real_to_fake = ((predictions == 1) & (labels == 0)).sum()
                fake_to_real = ((predictions == 0) & (labels == 1)).sum()
                
                real_accuracy = real_correct / real_samples if real_samples > 0 else 0
                fake_accuracy = fake_correct / fake_samples if fake_samples > 0 else 0
                
                # Store results
                error_analysis.append({
                    'model': model_name,
                    'dataset': dataset_name,
                    'total_samples': total,
                    'accuracy': accuracy,
                    'correct': correct,
                    'incorrect': incorrect,
                    'real_samples': real_samples,
                    'fake_samples': fake_samples,
                    'real_accuracy': real_accuracy,
                    'fake_accuracy': fake_accuracy,
                    'real_to_fake': real_to_fake,
                    'fake_to_real': fake_to_real
                })
    
    # Create DataFrame
    if error_analysis:
        error_df = pd.DataFrame(error_analysis)
        return error_df
    
    return None

def visualize_error_patterns(error_df):
    \"\"\"Visualize error patterns from error analysis\"\"\"
    if error_df is None or error_df.empty:
        print("No error data to visualize.")
        return
    
    # Set up the figure
    fig = plt.figure(figsize=(20, 15))
    
    # 1. Overall accuracy vs per-class accuracy
    ax1 = fig.add_subplot(2, 2, 1)
    
    # Melt the DataFrame to get it in the right format for seaborn
    plot_df = pd.melt(error_df, 
                      id_vars=['model', 'dataset'], 
                      value_vars=['accuracy', 'real_accuracy', 'fake_accuracy'],
                      var_name='Metric', value_name='Value')
    
    # Plot
    sns.barplot(x='model', y='Value', hue='Metric', data=plot_df, ax=ax1)
    ax1.set_title('Overall vs Per-Class Accuracy')
    ax1.set_xlabel('Model')
    ax1.set_ylabel('Accuracy')
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, ha='right')
    
    # 2. Error distribution
    ax2 = fig.add_subplot(2, 2, 2)
    
    # Melt the DataFrame
    plot_df = pd.melt(error_df, 
                      id_vars=['model', 'dataset'], 
                      value_vars=['real_to_fake', 'fake_to_real'],
                      var_name='Error Type', value_name='Count')
    
    # Plot
    sns.barplot(x='model', y='Count', hue='Error Type', data=plot_df, ax=ax2)
    ax2.set_title('Error Distribution')
    ax2.set_xlabel('Model')
    ax2.set_ylabel('Error Count')
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha='right')
    
    # 3. Class balance
    ax3 = fig.add_subplot(2, 2, 3)
    
    # Calculate class balance
    error_df['real_percent'] = error_df['real_samples'] / error_df['total_samples'] * 100
    error_df['fake_percent'] = 100 - error_df['real_percent']
    
    # Melt the DataFrame
    plot_df = pd.melt(error_df, 
                      id_vars=['dataset'], 
                      value_vars=['real_percent', 'fake_percent'],
                      var_name='Class', value_name='Percentage')
    
    # Plot
    sns.barplot(x='dataset', y='Percentage', hue='Class', data=plot_df, ax=ax3)
    ax3.set_title('Class Balance')
    ax3.set_xlabel('Dataset')
    ax3.set_ylabel('Percentage')
    
    # 4. Error rate vs class balance
    ax4 = fig.add_subplot(2, 2, 4)
    
    # Calculate error rates
    error_df['real_error_rate'] = error_df['real_to_fake'] / error_df['real_samples']
    error_df['fake_error_rate'] = error_df['fake_to_real'] / error_df['fake_samples']
    
    # Melt the DataFrame
    plot_df = pd.melt(error_df, 
                      id_vars=['model', 'dataset', 'real_percent', 'fake_percent'], 
                      value_vars=['real_error_rate', 'fake_error_rate'],
                      var_name='Error Type', value_name='Error Rate')
    
    # Add class percentage information
    plot_df['class_percent'] = plot_df.apply(
        lambda row: row['real_percent'] if row['Error Type'] == 'real_error_rate' else row['fake_percent'], 
        axis=1
    )
    
    # Plot
    for model in plot_df['model'].unique():
        model_df = plot_df[plot_df['model'] == model]
        ax4.scatter(model_df['class_percent'], model_df['Error Rate'], label=model)
    
    ax4.set_title('Error Rate vs Class Percentage')
    ax4.set_xlabel('Class Percentage')
    ax4.set_ylabel('Error Rate')
    ax4.legend()
    ax4.grid(True, linestyle='--', alpha=0.7)
    
    # Add overall title
    plt.suptitle('Error Analysis', fontsize=16)
    
    # Adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

# Analyze errors
if 'raw_predictions' in locals() and raw_predictions:
    error_df = analyze_errors(raw_predictions)
    
    if error_df is not None:
        print("\nError Analysis:")
        display(error_df)
        
        # Visualize error patterns
        visualize_error_patterns(error_df)
    else:
        print("Could not perform error analysis.")
else:
    print("No raw predictions available for error analysis.")## 3. Visualize ROC Curves and PR Curves

def plot_roc_curves(raw_predictions, dataset_name=None):
    \"\"\"Plot ROC curves for all models on a specific dataset\"\"\"
    if raw_predictions is None or not raw_predictions:
        print("No predictions to plot.")
        return
    
    # Set up the figure
    plt.figure(figsize=(12, 8))
    
    # Keep track of datasets plotted
    datasets_plotted = set()
    
    # Plot ROC curve for each model
    for model_name, model_preds in raw_predictions.items():
        # If dataset_name is specified, only plot for that dataset
        if dataset_name:
            if dataset_name in model_preds:
                data = model_preds[dataset_name]
                
                # Check if we have the necessary data
                if 'labels' in data and 'probabilities' in data and data['probabilities'] is not None:
                    # Compute ROC curve
                    fpr, tpr, _ = roc_curve(data['labels'], data['probabilities'])
                    roc_auc = auc(fpr, tpr)
                    
                    # Plot ROC curve
                    plt.plot(fpr, tpr, lw=2, label=f'{model_name} (AUC = {roc_auc:.3f})')
                    datasets_plotted.add(dataset_name)
        else:
            # Plot for all datasets
            for ds_name, data in model_preds.items():
                # Check if we have the necessary data
                if 'labels' in data and 'probabilities' in data and data['probabilities'] is not None:
                    # Compute ROC curve
                    fpr, tpr, _ = roc_curve(data['labels'], data['probabilities'])
                    roc_auc = auc(fpr, tpr)
                    
                    # Plot ROC curve
                    plt.plot(fpr, tpr, lw=2, label=f'{model_name} - {ds_name} (AUC = {roc_auc:.3f})')
                    datasets_plotted.add(ds_name)
    
    # Plot random classifier
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    
    # Set title and labels
    if dataset_name:
        plt.title(f'ROC Curves - {dataset_name}' if dataset_name in datasets_plotted else 'ROC Curves')
    else:
        plt.title('ROC Curves')
    
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(loc="lower right")
    
    plt.show()

def plot_pr_curves(raw_predictions, dataset_name=None):
    \"\"\"Plot Precision-Recall curves for all models on a specific dataset\"\"\"
    if raw_predictions is None or not raw_predictions:
        print("No predictions to plot.")
        return
    
    # Set up the figure
    plt.figure(figsize=(12, 8))
    
    # Keep track of datasets plotted
    datasets_plotted = set()
    
    # Plot PR curve for each model
    for model_name, model_preds in raw_predictions.items():
        # If dataset_name is specified, only plot for that dataset
        if dataset_name:
            if dataset_name in model_preds:
                data = model_preds[dataset_name]
                
                # Check if we have the necessary data
                if 'labels' in data and 'probabilities' in data and data['probabilities'] is not None:
                    # Compute PR curve
                    precision, recall, _ = precision_recall_curve(data['labels'], data['probabilities'])
                    pr_auc = auc(recall, precision)
                    
                    # Plot PR curve
                    plt.plot(recall, precision, lw=2, label=f'{model_name} (AUC = {pr_auc:.3f})')
                    datasets_plotted.add(dataset_name)
        else:
            # Plot for all datasets
            for ds_name, data in model_preds.items():
                # Check if we have the necessary data
                if 'labels' in data and 'probabilities' in data and data['probabilities'] is not None:
                    # Compute PR curve
                    precision, recall, _ = precision_recall_curve(data['labels'], data['probabilities'])
                    pr_auc = auc(recall, precision)
                    
                    # Plot PR curve
                    plt.plot(recall, precision, lw=2, label=f'{model_name} - {ds_name} (AUC = {pr_auc:.3f})')
                    datasets_plotted.add(ds_name)
    
    # Set title and labels
    if dataset_name:
        plt.title(f'Precision-Recall Curves - {dataset_name}' if dataset_name in datasets_plotted else 'Precision-Recall Curves')
    else:
        plt.title('Precision-Recall Curves')
    
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(loc="lower left")
    
    plt.show()

# Plot ROC and PR curves
if 'raw_predictions' in locals() and raw_predictions:
    # Get list of available datasets
    available_datasets = set()
    for model_preds in raw_predictions.values():
        available_datasets.update(model_preds.keys())
    
    print(f"Available datasets: {', '.join(available_datasets)}")
    
    # Plot curves for each dataset
    for dataset in available_datasets:
        print(f"\nROC and PR curves for {dataset} dataset:")
        
        # Plot ROC curve
        plot_roc_curves(raw_predictions, dataset)
        
        # Plot PR curve
        plot_pr_curves(raw_predictions, dataset)
else:
    print("No raw predictions available for plotting curves.")def rank_models(performance_df, metrics=None):
    \"\"\"Rank models based on performance metrics\"\"\"
    if performance_df is None or performance_df.empty:
        print("No performance data to rank.")
        return
    
    if metrics is None:
        metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'auc']
        if 'eer' in performance_df.columns:
            metrics.append('eer')
    
    # Ensure all metrics are in DataFrame
    metrics = [m for m in metrics if m in performance_df.columns]
    
    if not metrics:
        print("No valid metrics to rank.")
        return
    
    # Create a copy of DataFrame
    df = performance_df.copy()
    
    # Create rank columns for each metric
    for metric in metrics:
        # For EER, lower is better
        if metric == 'eer':
            df[f'{metric}_rank'] = df.groupby('dataset')[metric].rank(ascending=True)
        else:
            df[f'{metric}_rank'] = df.groupby('dataset')[metric].rank(ascending=False)
    
    # Calculate average rank across metrics
    rank_cols = [f'{metric}_rank' for metric in metrics]
    df['avg_rank'] = df[rank_cols].mean(axis=1)
    
    # Sort by average rank
    df = df.sort_values(['dataset', 'avg_rank'])
    
    # Select columns to display
    display_cols = ['model', 'dataset'] + metrics + ['avg_rank']
    
    # Reset index for better display
    df_display = df[display_cols].reset_index(drop=True)
    
    return df_display

# Rank models
if 'performance_df' in locals() and performance_df is not None and not performance_df.empty:
    ranked_df = rank_models(performance_df)
    
    if ranked_df is not None:
        print("\nModel Rankings:")
        display(ranked_df)
    else:
        print("Could not rank models.")
else:
    print("No performance data to rank.")def visualize_performance_metrics(performance_df):
    \"\"\"Visualize performance metrics across models and datasets\"\"\"
    if performance_df is None or performance_df.empty:
        print("No performance data to visualize.")
        return
    
    # Set up the figure
    fig = plt.figure(figsize=(20, 15))
    
    # Define metrics to plot
    metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'auc']
    if 'eer' in performance_df.columns:
        metrics.append('eer')
    
    # Number of rows and columns
    n_metrics = len(metrics)
    n_rows = (n_metrics + 1) // 2
    n_cols = 2
    
    # Plot each metric
    for i, metric in enumerate(metrics):
        if metric in performance_df.columns:
            ax = fig.add_subplot(n_rows, n_cols, i+1)
            
            # Create bar plot
            sns.barplot(x='model', y=metric, hue='dataset', data=performance_df, ax=ax)
            
            # Set title and labels
            ax.set_title(f'{metric.replace("_", " ").title()}')
            ax.set_xlabel('Model')
            ax.set_ylabel(metric)
            
            # Rotate x-labels for better readability
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
            
            # Set y-axis limits
            if metric != 'eer':  # For EER, lower is better
                ax.set_ylim(0.5, 1.05)
            else:
                ax.set_ylim(0, 0.5)
            
            # Add legend
            ax.legend(title='Dataset')
    
    # Add overall title
    plt.suptitle('Model Performance Comparison', fontsize=16)
    
    # Adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

# Visualize performance metrics
if 'performance_df' in locals() and performance_df is not None and not performance_df.empty:
    visualize_performance_metrics(performance_df)
else:
    print("No performance data to visualize.")## 2. Analyze Model Performance Metrics

def create_performance_summary(evaluation_results):
    \"\"\"Create a summary table of model performance metrics\"\"\"
    if not evaluation_results:
        return None
    
    # Create a list to store performance data
    performance_data = []
    
    for model_name, model_results in evaluation_results.items():
        for dataset_name, dataset_results in model_results.items():
            # Extract metrics
            metrics = {
                'model': model_name,
                'dataset': dataset_name,
                'accuracy': dataset_results.get('accuracy', None),
                'precision': dataset_results.get('precision', None),
                'recall': dataset_results.get('recall', None),
                'f1_score': dataset_results.get('f1_score', None),
                'auc': dataset_results.get('auc', None),
                'eer': dataset_results.get('eer', None)
            }
            
            performance_data.append(metrics)
    
    # Create DataFrame
    if performance_data:
        df = pd.DataFrame(performance_data)
        
        # Ensure all metric columns are float
        for col in ['accuracy', 'precision', 'recall', 'f1_score', 'auc', 'eer']:
            if col in df:
                df[col] = df[col].astype(float)
        
        return df
    
    return None

# Create performance summary
if evaluation_results:
    performance_df = create_performance_summary(evaluation_results)
    
    if performance_df is not None:
        print("Performance Summary:")
        display(performance_df)
    else:
        print("Could not create performance summary.")
else:
    print("No evaluation results to analyze."){
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Results Analysis\n",
    "\n",
    "This notebook provides tools for analyzing the results of deepfake detection models. It includes:\n",
    "\n",
    "1. Loading and visualizing evaluation metrics\n",
    "2. Comparative analysis of different models\n",
    "3. Error analysis and identification of challenging cases\n",
    "4. Cross-dataset generalization analysis\n",
    "5. ROC curves and precision-recall analysis\n",
    "6. Ensembling and fusion performance assessment\n",
    "\n",
    "These analyses help understand the strengths and weaknesses of different approaches and guide future improvements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import sys\n",
    "import json\n",
    "import yaml\n",
    "import glob\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from PIL import Image\n",
    "import cv2\n",
    "from pathlib import Path\n",
    "from sklearn.metrics import roc_curve, precision_recall_curve, auc\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# Add parent directory to path to enable imports from project\n",
    "sys.path.append(os.path.abspath('..'))\n",
    "\n",
    "# Set plot style\n",
    "plt.style.use('fivethirtyeight')\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Evaluation Results\n",
    "\n",
    "Load evaluation results from the evaluation output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Set the evaluation results directory - update to your local path\n",
    "EVAL_DIR = \"../evaluation_results\"\n",
    "\n",
    "# Check if directory exists\n",
    "if not os.path.exists(EVAL_DIR):\n",
    "    print(f\"Evaluation directory {EVAL_DIR} does not exist. Please update the path.\")\n",
    "else:\n",
    "    print(f\"Found evaluation directory at {EVAL_DIR}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_evaluation_results(eval_dir):\n",
    "    \"\"\"Load evaluation results from directory\"\"\"\n",
    "    if not os.path.exists(eval_dir):\n",
    "        return None\n",
    "    \n",
    "    results = {}\n",
    "    \n",
    "    # Find all model directories\n",
    "    model_dirs = [d for d in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, d))]\n",
    "    \n",
    "    for model_name in model_dirs:\n",
    "        model_path = os.path.join(eval_dir, model_name)\n",
    "        results_file = os.path.join(model_path, \"all_results.json\")\n",
    "        \n",
    "        if os.path.exists(results_file):\n",
    "            with open(results_file, 'r') as f:\n",
    "                model_results = json.load(f)\n",
    "            \n",
    "            # Store results\n",
    "            results[model_name] = model_results\n",
    "    \n",
    "    return results\n",
    "\n",
    "# Load results\n",
    "evaluation_results = load_evaluation_results(EVAL_DIR)\n",
    "\n",
    "if evaluation_results is None:\n",
    "    print(\"No evaluation results found.\")\n",
    "elif not evaluation_results:\n",
    "    print(\"No models found in evaluation results.\")\n",
    "else:\n",
    "    print(f\"Loaded evaluation results for {len(evaluation_results)} models:\")\n",
    "    for model_name in evaluation_results.keys():\n",
    "        print(f\"  - {model_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_raw_predictions(eval_dir):\n",
    "    \"\"\"Load raw predictions for detailed analysis\"\"\"\n",
    "    if not os.path.exists(eval_dir):\n",
    "        return None\n",
    "    \n",
    "    predictions = {}\n",
    "    \n",
    "    # Find all model directories\n",
    "    model_dirs = [d for d in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, d))]\n",
    "    \n",
    "    for model_name in model_dirs:\n",
    "        model_path = os.path.join(eval_dir, model_name)\n",
    "        model_predictions = {}\n",
    "        \n",
    "        # Find dataset directories\n",
    "        dataset_dirs = [d for d in os.listdir(model_path) \n",
    "                       if os.path.isdir(os.path.join(model_path, d))]\n",
    "        \n",
    "        for dataset_name in dataset_dirs:\n",
    "            dataset_path = os.path.join(model_path, dataset_name)\n",
    "            pred_file = os

In [None]:
## 8. Summary and Recommendations

def generate_summary_and_recommendations():
    \"\"\"Generate a summary of findings and recommendations\"\"\"
    # Summary
    print("## Summary of Findings\n")
    
    # Check which analyses were performed
    has_performance = 'performance_df' in globals() and performance_df is not None
    has_ranking = 'ranked_df' in globals() and ranked_df is not None
    has_errors = 'error_df' in globals() and error_df is not None
    has_generalization = 'gen_df' in globals() and gen_df is not None
    has_confidence = 'conf_df' in globals() and conf_df is not None
    has_complementarity = 'comp_df' in globals() and comp_df is not None
    
    # 1. Overall Performance
    print("### 1. Overall Performance\n")
    if has_performance:
        best_model = performance_df.loc[performance_df['accuracy'].idxmax()]
        print(f"- The best overall performing model is **{best_model['model']}** on the {best_model['dataset']} dataset, with an accuracy of {best_model['accuracy']:.4f}.")
        
        # Dataset differences
        datasets = performance_df['dataset'].unique()
        if len(datasets) > 1:
            print("- Performance across datasets:")
            for dataset in datasets:
                dataset_avg = performance_df[performance_df['dataset'] == dataset]['accuracy'].mean()
                print(f"  - {dataset}: Average accuracy of {dataset_avg:.4f}")
    else:
        print("- No performance data available.")
    
    # 2. Model Ranking
    print("\n### 2. Model Ranking\n")
    if has_ranking:
        print("- Models ranked by overall performance:")
        for dataset in ranked_df['dataset'].unique():
            dataset_ranks = ranked_df[ranked_df['dataset'] == dataset].sort_values('avg_rank')
            print(f"  - **{dataset}**: {', '.join(dataset_ranks['model'])}")
    else:
        print("- No ranking data available.")
    
    # 3. Error Analysis
    print("\n### 3. Error Analysis\n")
    if has_errors:
        # Find common error patterns
        has_real_fake_imbalance = error_df['real_accuracy'].mean() - error_df['fake_accuracy'].mean()
        
        if abs(has_real_fake_imbalance) > 0.05:
            if has_real_fake_imbalance > 0:
                print(f"- Models generally perform better on **real** samples (by {has_real_fake_imbalance:.4f} on average).")
            else:
                print(f"- Models generally perform better on **fake** samples (by {-has_real_fake_imbalance:.4f} on average).")
        else:
            print("- Models show relatively balanced performance between real and fake samples.")
        
        # Look for dataset-specific patterns
        for dataset in error_df['dataset'].unique():
            dataset_error = error_df[error_df['dataset'] == dataset]
            real_fake_diff = dataset_error['real_accuracy'].mean() - dataset_error['fake_accuracy'].mean()
            
            if abs(real_fake_diff) > 0.1:
                print(f"  - On {dataset}, there's a significant imbalance: " +
                     f"{'real' if real_fake_diff > 0 else 'fake'} samples are easier to classify " +
                     f"(by {abs(real_fake_diff):.4f} accuracy difference).")
    else:
        print("- No error analysis data available.")
    
    # 4. Cross-Dataset Generalization
    print("\n### 4. Cross-Dataset Generalization\n")
    if has_generalization:
        # Sort models by generalization performance
        model_drops = gen_df.groupby('model')['accuracy_drop'].mean().reset_index()
        model_drops = model_drops.sort_values('accuracy_drop')
        
        print(f"- **{model_drops.iloc[0]['model']}** shows the best cross-dataset generalization with an average accuracy drop of {model_drops.iloc[0]['accuracy_drop']:.4f}.")
        
        # Look for challenging pairs
        worst_pair = gen_df.sort_values('accuracy_drop', ascending=False).iloc[0]
        print(f"- The most challenging dataset transfer is from **{worst_pair['train_dataset']}** to **{worst_pair['test_dataset']}**, with an accuracy drop of {worst_pair['accuracy_drop']:.4f}.")
    else:
        print("- No cross-dataset generalization data available.")
    
    # 5. Confidence Analysis
    print("\n### 5. Confidence Analysis\n")
    if has_confidence:
        # Get models sorted by calibration error
        calibration_errors = []
        for model_name in conf_df['model'].unique():
            model_data = conf_df[conf_df['model'] == model_name]
            
            # Calculate weighted average for each bin
            weighted_data = model_data.groupby('bin_confidence').apply(
                lambda x: pd.Series({
                    'accuracy': np.average(x['bin_accuracy'], weights=x['bin_samples']),
                    'samples': x['bin_samples'].sum()
                })
            ).reset_index()
            
            # Calculate expected accuracy for each bin (assuming perfect calibration)
            weighted_data['expected_accuracy'] = 0.5 + weighted_data['bin_confidence'] / 2
            
            # Calculate calibration error
            error = np.average(
                np.abs(weighted_data['accuracy'] - weighted_data['expected_accuracy']),
                weights=weighted_data['samples']
            )
            
            calibration_errors.append({
                'model': model_name,
                'calibration_error': error
            })
        
        # Create DataFrame and sort
        ce_df = pd.DataFrame(calibration_errors).sort_values('calibration_error')
        
        print(f"- **{ce_df.iloc[0]['model']}** has the best calibration with an error of {ce_df.iloc[0]['calibration_error']:.4f}.")
        print(f"- **{ce_df.iloc[-1]['model']}** has the worst calibration with an error of {ce_df.iloc[-1]['calibration_error']:.4f}.")
        
        # Look for overall calibration trends
        avg_error = ce_df['calibration_error'].mean()
        if avg_error < 0.05:
            print("- Models generally show good calibration between confidence and accuracy.")
        elif avg_error < 0.1:
            print("- Models show moderate calibration issues between confidence and accuracy.")
        else:
            print("- Models show significant calibration issues between confidence and accuracy.")
    else:
        print("- No confidence analysis data available.")
    
    # 6. Model Complementarity
    print("\n### 6. Model Complementarity\n")
    if has_complementarity:
        # Sort by complementarity
        top_pair = comp_df.sort_values('complementarity', ascending=False).iloc[0]
        print(f"- The most complementary model pair is **{top_pair['model1']}** and **{top_pair['model2']}** on {top_pair['dataset']} dataset, with a complementarity score of {top_pair['complementarity']:.4f}.")
        
        # Look for ensemble improvements
        top_ensemble = comp_df.sort_values('ensemble_improvement', ascending=False).iloc[0]
        print(f"- The best ensemble improvement is achieved by combining **{top_ensemble['model1']}** and **{top_ensemble['model2']}** on {top_ensemble['dataset']} dataset, with an accuracy improvement of {top_ensemble['ensemble_improvement']:.4f}.")
        
        # Oracle performance
        avg_oracle = comp_df['oracle_acc'].mean()
        avg_ensemble = comp_df['ensemble_acc'].mean()
        print(f"- On average, ensembles achieve {avg_ensemble:.4f} accuracy, while the theoretical maximum (oracle) is {avg_oracle:.4f}, showing room for improvement of {avg_oracle - avg_ensemble:.4f}.")
    else:
        print("- No model complementarity data available.")
    
    # Recommendations
    print("\n## Recommendations\n")
    
    # 1. Model Selection
    print("### 1. Model Selection\n")
    if has_ranking:
        best_model = ranked_df.sort_values('avg_rank').iloc[0]['model']
        print(f"- **Primary Model**: Use **{best_model}** as the primary model for most scenarios, as it shows the best overall performance.")
    
    if has_generalization:
        best_gen_model = model_drops.iloc[0]['model']
        print(f"- **Cross-Dataset Scenarios**: Use **{best_gen_model}** for scenarios where the test data may differ significantly from the training data, as it shows the best generalization capabilities.")
    
    if has_confidence and 'ce_df' in locals():
        best_calibrated = ce_df.iloc[0]['model']
        print(f"- **Confidence-Critical Applications**: Use **{best_calibrated}** for applications where accurate confidence estimation is crucial, as it shows the best calibration.")
    
    # 2. Ensemble Strategy
    print("\n### 2. Ensemble Strategy\n")
    if has_complementarity:
        # Get top complementary pairs
        top_pairs = comp_df.sort_values('complementarity', ascending=False).head(3)
        
        print("- **Recommended Ensembles**:")
        for i, row in top_pairs.iterrows():
            print(f"  - Combine **{row['model1']}** and **{row['model2']}** for {row['dataset']} dataset (complementarity: {row['complementarity']:.4f}, ensemble accuracy: {row['ensemble_acc']:.4f}).")
    else:
        print("- No specific ensemble recommendations can be made without complementarity analysis.")
    
    # 3. Future Improvements
    print("\n### 3. Future Improvements\n")
    
    recommendations = []
    
    if has_errors:
        # Check for class imbalance issues
        has_real_fake_imbalance = abs(error_df['real_accuracy'].mean() - error_df['fake_accuracy'].mean()) > 0.05
        if has_real_fake_imbalance:
            weaker_class = "real" if error_df['real_accuracy'].mean() < error_df['fake_accuracy'].mean() else "fake"
            recommendations.append(f"- **Balanced Training**: Focus on improving performance on **{weaker_class}** samples, which are more challenging across models.")
    
    if has_generalization:
        if gen_df['accuracy_drop'].mean() > 0.05:
            recommendations.append("- **Domain Adaptation**: Implement domain adaptation techniques to improve cross-dataset generalization.")
    
    if has_confidence:
        if 'ce_df' in locals() and ce_df['calibration_error'].mean() > 0.05:
            recommendations.append("- **Calibration**: Apply post-hoc calibration methods (like Platt scaling or temperature scaling) to improve confidence calibration.")
    
    if has_complementarity:
        if comp_df['oracle_acc'].mean() - comp_df['ensemble_acc'].mean() > 0.05:
            recommendations.append("- **Advanced Ensembling**: Explore more sophisticated ensemble methods (like stacking or feature-level fusion) to better leverage model complementarity.")
    
    # If no specific recommendations, add general ones
    if not recommendations:
        recommendations.append("- **Data Augmentation**: Expand the training dataset with more diverse examples to improve robustness.")
        recommendations.append("- **Model Architecture**: Experiment with different model architectures or hybrids to capture different aspects of deepfakes.")
        recommendations.append("- **Feature Engineering**: Investigate what features are most predictive of deepfakes and enhance those aspects in the models.")
    
    # Print all recommendations
    for recommendation in recommendations:
        print(recommendation)
    
    # 4. Deployment Considerations
    print("\n### 4. Deployment Considerations\n")
    print("- **Threshold Tuning**: Adjust decision thresholds based on the specific application needs (higher precision vs. higher recall).")
    print("- **Monitoring**: Implement monitoring for detecting performance degradation as new deepfake techniques emerge.")
    print("- **Human Oversight**: For high-stakes decisions, maintain human oversight to verify model predictions, especially for cases with moderate confidence.")
    print("- **Explainability**: Include visualizations like attention maps or Grad-CAM to help users understand model decisions.")

# Generate summary and recommendations
print("\n\n")
print("=" * 80)
print("                      SUMMARY AND RECOMMENDATIONS                      ")
print("=" * 80)
print("\n")

generate_summary_and_recommendations()

print("\n\nResults Analysis Notebook Complete!")
## 7. Model Complementarity Analysis

def analyze_model_complementarity(raw_predictions):
    \"\"\"Analyze how models complement each other\"\"\"
    if raw_predictions is None or not raw_predictions:
        print("No predictions to analyze.")
        return
    
    # Check if we have at least two models
    if len(raw_predictions) < 2:
        print("Need at least 2 models for complementarity analysis.")
        return
    
    # Create a list of (model_name, dataset_name) pairs
    model_dataset_pairs = []
    for model_name, model_preds in raw_predictions.items():
        for dataset_name in model_preds.keys():
            model_dataset_pairs.append((model_name, dataset_name))
    
    # Create a list to store complementarity data
    complementarity_data = []
    
    # For each pair of models on the same dataset
    for i, (model1, dataset1) in enumerate(model_dataset_pairs):
        for j, (model2, dataset2) in enumerate(model_dataset_pairs[i+1:], i+1):
            # Skip if datasets are different
            if dataset1 != dataset2:
                continue
            
            # Get predictions
            data1 = raw_predictions[model1][dataset1]
            data2 = raw_predictions[model2][dataset2]
            
            # Check if we have the necessary data
            if ('labels' in data1 and 'predictions' in data1 and
                'labels' in data2 and 'predictions' in data2):
                
                labels = data1['labels']
                preds1 = data1['predictions']
                preds2 = data2['predictions']
                
                # Calculate complementarity metrics
                both_correct = np.sum((preds1 == labels) & (preds2 == labels))
                model1_only = np.sum((preds1 == labels) & (preds2 != labels))
                model2_only = np.sum((preds1 != labels) & (preds2 == labels))
                both_wrong = np.sum((preds1 != labels) & (preds2 != labels))
                
                # Calculate total samples
                total = len(labels)
                
                # Calculate individual accuracies
                acc1 = np.sum(preds1 == labels) / total
                acc2 = np.sum(preds2 == labels) / total
                
                # Calculate ensemble accuracy (majority voting)
                ensemble_preds = (preds1 + preds2 > 0.5).astype(int)
                ensemble_acc = np.sum(ensemble_preds == labels) / total
                
                # Calculate oracle accuracy (best possible combination)
                oracle_preds = ((preds1 == labels) | (preds2 == labels)).astype(int)
                oracle_acc = np.sum(oracle_preds == labels) / total
                
                # Store results
                complementarity_data.append({
                    'model1': model1,
                    'model2': model2,
                    'dataset': dataset1,
                    'total_samples': total,
                    'both_correct': both_correct,
                    'model1_only': model1_only,
                    'model2_only': model2_only,
                    'both_wrong': both_wrong,
                    'model1_acc': acc1,
                    'model2_acc': acc2,
                    'ensemble_acc': ensemble_acc,
                    'oracle_acc': oracle_acc,
                    'complementarity': (model1_only + model2_only) / total
                })
    
    # Create DataFrame
    if complementarity_data:
        comp_df = pd.DataFrame(complementarity_data)
        return comp_df
    
    return None

def visualize_model_complementarity(comp_df):
    \"\"\"Visualize model complementarity results\"\"\"
    if comp_df is None or comp_df.empty:
        print("No complementarity data to visualize.")
        return
    
    # Set up the figure
    fig = plt.figure(figsize=(20, 15))
    
    # 1. Confusion matrix-like visualization of model agreements
    ax1 = fig.add_subplot(2, 2, 1)
    
    # For each model pair
    for i, row in comp_df.iterrows():
        # Prepare data for confusion matrix
        cm_data = np.array([
            [row['both_correct'], row['model1_only']],
            [row['model2_only'], row['both_wrong']]
        ])
        
        # Normalize by total samples
        cm_data = cm_data / row['total_samples']
        
        # Create a separate axis for each pair
        plt.figure(figsize=(8, 6))
        
        sns.heatmap(cm_data, annot=True, fmt='.2%', cmap='Blues',
                  xticklabels=['Correct', 'Wrong'], yticklabels=['Correct', 'Wrong'])
        
        plt.title(f"{row['model1']} vs {row['model2']} - {row['dataset']}")
        plt.xlabel(f"{row['model1']} Predictions")
        plt.ylabel(f"{row['model2']} Predictions")
        
        plt.tight_layout()
        plt.show()
    
    # 2. Complementarity vs individual accuracy
    ax2 = fig.add_subplot(2, 2, 2)
    
    # Calculate average accuracy for each model pair
    comp_df['avg_acc'] = (comp_df['model1_acc'] + comp_df['model2_acc']) / 2
    
    # Plot
    sns.scatterplot(x='avg_acc', y='complementarity', hue='dataset', 
                   size='total_samples', sizes=(50, 200), data=comp_df, ax=ax2)
    
    ax2.set_title('Complementarity vs Average Accuracy')
    ax2.set_xlabel('Average Accuracy')
    ax2.set_ylabel('Complementarity')
    ax2.grid(True, linestyle='--', alpha=0.7)
    
    # 3. Accuracy improvement from ensemble
    ax3 = fig.add_subplot(2, 2, 3)
    
    # Calculate accuracy improvements
    comp_df['ensemble_improvement'] = comp_df['ensemble_acc'] - comp_df['avg_acc']
    comp_df['oracle_improvement'] = comp_df['oracle_acc'] - comp_df['avg_acc']
    
    # Sort by ensemble improvement
    plot_df = comp_df.sort_values('ensemble_improvement', ascending=False)
    
    # Create labels for model pairs
    plot_df['model_pair'] = plot_df.apply(lambda row: f"{row['model1']}\n+\n{row['model2']}", axis=1)
    
    # Plot
    sns.barplot(x='model_pair', y='ensemble_improvement', data=plot_df, ax=ax3)
    
    ax3.set_title('Accuracy Improvement from Ensemble')
    ax3.set_xlabel('Model Pair')
    ax3.set_ylabel('Accuracy Improvement')
    ax3.axhline(y=0, color='r', linestyle='--')
    
    # 4. Oracle vs ensemble performance
    ax4 = fig.add_subplot(2, 2, 4)
    
    # Melt the DataFrame
    plot_df = pd.melt(comp_df, 
                      id_vars=['model1', 'model2', 'dataset'], 
                      value_vars=['model1_acc', 'model2_acc', 'ensemble_acc', 'oracle_acc'],
                      var_name='Method', value_name='Accuracy')
    
    # Plot
    sns.boxplot(x='Method', y='Accuracy', data=plot_df, ax=ax4)
    
    ax4.set_title('Performance Comparison')
    ax4.set_xlabel('Method')
    ax4.set_ylabel('Accuracy')
    
    # Add overall title
    plt.suptitle('Model Complementarity Analysis', fontsize=16)
    
    # Adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
    
    # Return model pairs ranked by complementarity
    return comp_df.sort_values('complementarity', ascending=False)

# Analyze model complementarity
if 'raw_predictions' in locals() and raw_predictions and len(raw_predictions) >= 2:
    comp_df = analyze_model_complementarity(raw_predictions)
    
    if comp_df is not None:
        print("\nModel Complementarity Analysis:")
        display(comp_df)
        
        # Visualize complementarity
        ranked_pairs = visualize_model_complementarity(comp_df)
        if ranked_pairs is not None:
            print("\nModel Pairs Ranked by Complementarity (Higher is Better):")
            display(ranked_pairs[['model1', 'model2', 'dataset', 'complementarity', 'ensemble_acc', 'oracle_acc']])
    else:
        print("Could not perform complementarity analysis.")
else:
    print("Need at least 2 models with predictions for complementarity analysis.")## 6. Confidence Analysis

def analyze_prediction_confidence(raw_predictions):
    \"\"\"Analyze prediction confidence and its relationship with accuracy\"\"\"
    if raw_predictions is None or not raw_predictions:
        print("No predictions to analyze.")
        return
    
    # Create a list to store confidence data
    confidence_data = []
    
    for model_name, model_preds in raw_predictions.items():
        for dataset_name, data in model_preds.items():
            # Check if we have the necessary data
            if ('labels' in data and 'predictions' in data and 
                'probabilities' in data and data['probabilities'] is not None):
                
                labels = data['labels']
                predictions = data['predictions']
                probabilities = data['probabilities']
                
                # Ensure probabilities are for the positive class
                if probabilities.ndim > 1 and probabilities.shape[1] > 1:
                    probabilities = probabilities[:, 1]
                
                # Convert to confidence scores (distance from decision boundary)
                confidence = np.abs(probabilities - 0.5) * 2  # Scale to [0, 1]
                
                # Bin confidence scores
                n_bins = 10
                bin_edges = np.linspace(0, 1, n_bins + 1)
                bin_indices = np.digitize(confidence, bin_edges) - 1
                
                # Calculate accuracy for each bin
                for bin_idx in range(n_bins):
                    bin_mask = bin_indices == bin_idx
                    if np.sum(bin_mask) > 0:
                        bin_acc = np.mean(predictions[bin_mask] == labels[bin_mask])
                        bin_conf = (bin_edges[bin_idx] + bin_edges[bin_idx + 1]) / 2
                        
                        confidence_data.append({
                            'model': model_name,
                            'dataset': dataset_name,
                            'bin_confidence': bin_conf,
                            'bin_accuracy': bin_acc,
                            'bin_samples': np.sum(bin_mask)
                        })
    
    # Create DataFrame
    if confidence_data:
        conf_df = pd.DataFrame(confidence_data)
        return conf_df
    
    return None

def visualize_confidence_analysis(conf_df):
    \"\"\"Visualize confidence analysis results\"\"\"
    if conf_df is None or conf_df.empty:
        print("No confidence data to visualize.")
        return
    
    # Set up the figure
    fig = plt.figure(figsize=(20, 10))
    
    # 1. Confidence vs. Accuracy
    ax1 = fig.add_subplot(1, 2, 1)
    
    # Plot confidence vs. accuracy for each model
    for model_name in conf_df['model'].unique():
        model_data = conf_df[conf_df['model'] == model_name]
        
        # Calculate weighted average for each bin
        weighted_data = model_data.groupby('bin_confidence').apply(
            lambda x: pd.Series({
                'accuracy': np.average(x['bin_accuracy'], weights=x['bin_samples']),
                'samples': x['bin_samples'].sum()
            })
        ).reset_index()
        
        # Plot with point size proportional to number of samples
        ax1.scatter(weighted_data['bin_confidence'], weighted_data['accuracy'], 
                   s=weighted_data['samples'] / 10, alpha=0.7, label=model_name)
        
        # Add trend line
        ax1.plot(weighted_data['bin_confidence'], weighted_data['accuracy'], alpha=0.5)
    
    # Plot perfect calibration line
    ax1.plot([0, 1], [0.5, 1], 'k--', alpha=0.5, label='Perfect Calibration')
    
    ax1.set_title('Confidence vs. Accuracy')
    ax1.set_xlabel('Confidence')
    ax1.set_ylabel('Accuracy')
    ax1.set_xlim(0, 1)
    ax1.set_ylim(0.5, 1.05)
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax1.legend()
    
    # 2. Model Calibration Comparison
    ax2 = fig.add_subplot(1, 2, 2)
    
    # Calculate calibration error for each model
    calibration_errors = []
    
    for model_name in conf_df['model'].unique():
        model_data = conf_df[conf_df['model'] == model_name]
        
        # Calculate weighted average for each bin
        weighted_data = model_data.groupby('bin_confidence').apply(
            lambda x: pd.Series({
                'accuracy': np.average(x['bin_accuracy'], weights=x['bin_samples']),
                'samples': x['bin_samples'].sum()
            })
        ).reset_index()
        
        # Calculate expected accuracy for each bin (assuming perfect calibration)
        weighted_data['expected_accuracy'] = 0.5 + weighted_data['bin_confidence'] / 2
        
        # Calculate calibration error
        error = np.average(
            np.abs(weighted_data['accuracy'] - weighted_data['expected_accuracy']),
            weights=weighted_data['samples']
        )
        
        calibration_errors.append({
            'model': model_name,
            'calibration_error': error
        })
    
    # Create DataFrame and plot
    ce_df = pd.DataFrame(calibration_errors)
    ce_df = ce_df.sort_values('calibration_error')
    
    sns.barplot(x='model', y='calibration_error', data=ce_df, ax=ax2)
    ax2.set_title('Model Calibration Error (Lower is Better)')
    ax2.set_xlabel('Model')
    ax2.set_ylabel('Calibration Error')
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha='right')
    
    # Add overall title
    plt.suptitle('Prediction Confidence Analysis', fontsize=16)
    
    # Adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
    
    return ce_df

# Analyze prediction confidence
if 'raw_predictions' in locals() and raw_predictions:
    conf_df = analyze_prediction_confidence(raw_predictions)
    
    if conf_df is not None:
        print("\nConfidence Analysis:")
        display(conf_df)
        
        # Visualize confidence analysis
        ce_df = visualize_confidence_analysis(conf_df)
        if ce_df is not None:
            print("\nModels Ranked by Calibration Error (Lower is Better):")
            display(ce_df)
    else:
        print("Could not perform confidence analysis.")
else:
    print("No raw predictions available for confidence analysis.")## 5. Cross-Dataset Generalization Analysis

def analyze_cross_dataset_generalization(error_df):
    \"\"\"Analyze how well models generalize across datasets\"\"\"
    if error_df is None or error_df.empty:
        print("No error data to analyze.")
        return
    
    # Check if we have multiple datasets
    if len(error_df['dataset'].unique()) < 2:
        print("Need at least 2 datasets for cross-dataset analysis.")
        return
    
    # Create a list to store generalization data
    generalization_data = []
    
    # For each model
    for model in error_df['model'].unique():
        model_data = error_df[error_df['model'] == model]
        
        # For each pair of datasets
        for train_dataset in model_data['dataset'].unique():
            train_perf = model_data[model_data['dataset'] == train_dataset]
            
            for test_dataset in model_data['dataset'].unique():
                if train_dataset != test_dataset:
                    test_perf = model_data[model_data['dataset'] == test_dataset]
                    
                    # Skip if missing data
                    if train_perf.empty or test_perf.empty:
                        continue
                    
                    # Calculate performance drops
                    accuracy_drop = train_perf['accuracy'].values[0] - test_perf['accuracy'].values[0]
                    real_acc_drop = train_perf['real_accuracy'].values[0] - test_perf['real_accuracy'].values[0]
                    fake_acc_drop = train_perf['fake_accuracy'].values[0] - test_perf['fake_accuracy'].values[0]
                    
                    # Store results
                    generalization_data.append({
                        'model': model,
                        'train_dataset': train_dataset,
                        'test_dataset': test_dataset,
                        'train_accuracy': train_perf['accuracy'].values[0],
                        'test_accuracy': test_perf['accuracy'].values[0],
                        'accuracy_drop': accuracy_drop,
                        'real_acc_drop': real_acc_drop,
                        'fake_acc_drop': fake_acc_drop
                    })
    
    # Create DataFrame
    if generalization_data:
        gen_df = pd.DataFrame(generalization_data)
        return gen_df
    
    return None

def visualize_cross_dataset_generalization(gen_df):
    \"\"\"Visualize cross-dataset generalization results\"\"\"
    if gen_df is None or gen_df.empty:
        print("No generalization data to visualize.")
        return
    
    # Set up the figure
    fig = plt.figure(figsize=(20, 15))
    
    # 1. Accuracy drop across datasets
    ax1 = fig.add_subplot(2, 2, 1)
    
    # Plot
    sns.barplot(x='model', y='accuracy_drop', hue='test_dataset', data=gen_df, ax=ax1)
    ax1.set_title('Accuracy Drop when Testing on Different Datasets')
    ax1.set_xlabel('Model')
    ax1.set_ylabel('Accuracy Drop')
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, ha='right')
    
    # Add horizontal line at 0
    ax1.axhline(y=0, color='r', linestyle='--')
    
    # 2. Per-class accuracy drop
    ax2 = fig.add_subplot(2, 2, 2)
    
    # Melt the DataFrame
    plot_df = pd.melt(gen_df, 
                      id_vars=['model', 'test_dataset'], 
                      value_vars=['real_acc_drop', 'fake_acc_drop'],
                      var_name='Class', value_name='Accuracy Drop')
    
    # Plot
    sns.barplot(x='model', y='Accuracy Drop', hue='Class', data=plot_df, ax=ax2)
    ax2.set_title('Per-class Accuracy Drop')
    ax2.set_xlabel('Model')
    ax2.set_ylabel('Accuracy Drop')
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha='right')
    
    # Add horizontal line at 0
    ax2.axhline(y=0, color='r', linestyle='--')
    
    # 3. Train vs test accuracy scatter plot
    ax3 = fig.add_subplot(2, 2, 3)
    
    # Plot
    sns.scatterplot(x='train_accuracy', y='test_accuracy', hue='model', 
                   style='test_dataset', s=100, data=gen_df, ax=ax3)
    
    # Plot diagonal line (x=y)
    min_val = min(gen_df['train_accuracy'].min(), gen_df['test_accuracy'].min())
    max_val = max(gen_df['train_accuracy'].max(), gen_df['test_accuracy'].max())
    ax3.plot([min_val, max_val], [min_val, max_val], 'k--')
    
    ax3.set_title('Train vs Test Accuracy')
    ax3.set_xlabel('Train Accuracy')
    ax3.set_ylabel('Test Accuracy')
    ax3.grid(True, linestyle='--', alpha=0.7)
    
    # 4. Generalization ranking
    ax4 = fig.add_subplot(2, 2, 4)
    
    # Calculate average accuracy drop for each model
    model_drops = gen_df.groupby('model')['accuracy_drop'].mean().reset_index()
    model_drops = model_drops.sort_values('accuracy_drop')
    
    # Plot
    sns.barplot(x='model', y='accuracy_drop', data=model_drops, ax=ax4)
    ax4.set_title('Average Accuracy Drop (Lower is Better)')
    ax4.set_xlabel('Model')
    ax4.set_ylabel('Average Accuracy Drop')
    ax4.set_xticklabels(ax4.get_xticklabels(), rotation=45, ha='right')
    
    # Add horizontal line at 0
    ax4.axhline(y=0, color='r', linestyle='--')
    
    # Add overall title
    plt.suptitle('Cross-Dataset Generalization Analysis', fontsize=16)
    
    # Adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()
    
    return model_drops

# Analyze cross-dataset generalization
if 'error_df' in locals() and error_df is not None and not error_df.empty:
    gen_df = analyze_cross_dataset_generalization(error_df)
    
    if gen_df is not None:
        print("\nCross-Dataset Generalization Analysis:")
        display(gen_df)
        
        # Visualize generalization results
        model_ranks = visualize_cross_dataset_generalization(gen_df)
        if model_ranks is not None:
            print("\nModels Ranked by Generalization Performance (Lower Drop is Better):")
            display(model_ranks)
    else:
        print("Could not perform cross-dataset generalization analysis.")
else:
    print("No error data available for cross-dataset generalization analysis.")## 4. Error Analysis

def analyze_errors(raw_predictions):
    \"\"\"Analyze error patterns in model predictions\"\"\"
    if raw_predictions is None or not raw_predictions:
        print("No predictions to analyze.")
        return
    
    # Create a DataFrame to store error analysis results
    error_analysis = []
    
    for model_name, model_preds in raw_predictions.items():
        for dataset_name, data in model_preds.items():
            # Check if we have the necessary data
            if 'labels' in data and 'predictions' in data:
                labels = data['labels']
                predictions = data['predictions']
                
                # Calculate error statistics
                total = len(labels)
                correct = (predictions == labels).sum()
                incorrect = total - correct
                accuracy = correct / total
                
                # Per-class statistics
                real_samples = (labels == 0).sum()
                fake_samples = (labels == 1).sum()
                
                real_correct = ((predictions == 0) & (labels == 0)).sum()
                fake_correct = ((predictions == 1) & (labels == 1)).sum()
                
                real_to_fake = ((predictions == 1) & (labels == 0)).sum()
                fake_to_real = ((predictions == 0) & (labels == 1)).sum()
                
                real_accuracy = real_correct / real_samples if real_samples > 0 else 0
                fake_accuracy = fake_correct / fake_samples if fake_samples > 0 else 0
                
                # Store results
                error_analysis.append({
                    'model': model_name,
                    'dataset': dataset_name,
                    'total_samples': total,
                    'accuracy': accuracy,
                    'correct': correct,
                    'incorrect': incorrect,
                    'real_samples': real_samples,
                    'fake_samples': fake_samples,
                    'real_accuracy': real_accuracy,
                    'fake_accuracy': fake_accuracy,
                    'real_to_fake': real_to_fake,
                    'fake_to_real': fake_to_real
                })
    
    # Create DataFrame
    if error_analysis:
        error_df = pd.DataFrame(error_analysis)
        return error_df
    
    return None

def visualize_error_patterns(error_df):
    \"\"\"Visualize error patterns from error analysis\"\"\"
    if error_df is None or error_df.empty:
        print("No error data to visualize.")
        return
    
    # Set up the figure
    fig = plt.figure(figsize=(20, 15))
    
    # 1. Overall accuracy vs per-class accuracy
    ax1 = fig.add_subplot(2, 2, 1)
    
    # Melt the DataFrame to get it in the right format for seaborn
    plot_df = pd.melt(error_df, 
                      id_vars=['model', 'dataset'], 
                      value_vars=['accuracy', 'real_accuracy', 'fake_accuracy'],
                      var_name='Metric', value_name='Value')
    
    # Plot
    sns.barplot(x='model', y='Value', hue='Metric', data=plot_df, ax=ax1)
    ax1.set_title('Overall vs Per-Class Accuracy')
    ax1.set_xlabel('Model')
    ax1.set_ylabel('Accuracy')
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, ha='right')
    
    # 2. Error distribution
    ax2 = fig.add_subplot(2, 2, 2)
    
    # Melt the DataFrame
    plot_df = pd.melt(error_df, 
                      id_vars=['model', 'dataset'], 
                      value_vars=['real_to_fake', 'fake_to_real'],
                      var_name='Error Type', value_name='Count')
    
    # Plot
    sns.barplot(x='model', y='Count', hue='Error Type', data=plot_df, ax=ax2)
    ax2.set_title('Error Distribution')
    ax2.set_xlabel('Model')
    ax2.set_ylabel('Error Count')
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha='right')
    
    # 3. Class balance
    ax3 = fig.add_subplot(2, 2, 3)
    
    # Calculate class balance
    error_df['real_percent'] = error_df['real_samples'] / error_df['total_samples'] * 100
    error_df['fake_percent'] = 100 - error_df['real_percent']
    
    # Melt the DataFrame
    plot_df = pd.melt(error_df, 
                      id_vars=['dataset'], 
                      value_vars=['real_percent', 'fake_percent'],
                      var_name='Class', value_name='Percentage')
    
    # Plot
    sns.barplot(x='dataset', y='Percentage', hue='Class', data=plot_df, ax=ax3)
    ax3.set_title('Class Balance')
    ax3.set_xlabel('Dataset')
    ax3.set_ylabel('Percentage')
    
    # 4. Error rate vs class balance
    ax4 = fig.add_subplot(2, 2, 4)
    
    # Calculate error rates
    error_df['real_error_rate'] = error_df['real_to_fake'] / error_df['real_samples']
    error_df['fake_error_rate'] = error_df['fake_to_real'] / error_df['fake_samples']
    
    # Melt the DataFrame
    plot_df = pd.melt(error_df, 
                      id_vars=['model', 'dataset', 'real_percent', 'fake_percent'], 
                      value_vars=['real_error_rate', 'fake_error_rate'],
                      var_name='Error Type', value_name='Error Rate')
    
    # Add class percentage information
    plot_df['class_percent'] = plot_df.apply(
        lambda row: row['real_percent'] if row['Error Type'] == 'real_error_rate' else row['fake_percent'], 
        axis=1
    )
    
    # Plot
    for model in plot_df['model'].unique():
        model_df = plot_df[plot_df['model'] == model]
        ax4.scatter(model_df['class_percent'], model_df['Error Rate'], label=model)
    
    ax4.set_title('Error Rate vs Class Percentage')
    ax4.set_xlabel('Class Percentage')
    ax4.set_ylabel('Error Rate')
    ax4.legend()
    ax4.grid(True, linestyle='--', alpha=0.7)
    
    # Add overall title
    plt.suptitle('Error Analysis', fontsize=16)
    
    # Adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

# Analyze errors
if 'raw_predictions' in locals() and raw_predictions:
    error_df = analyze_errors(raw_predictions)
    
    if error_df is not None:
        print("\nError Analysis:")
        display(error_df)
        
        # Visualize error patterns
        visualize_error_patterns(error_df)
    else:
        print("Could not perform error analysis.")
else:
    print("No raw predictions available for error analysis.")## 3. Visualize ROC Curves and PR Curves

def plot_roc_curves(raw_predictions, dataset_name=None):
    \"\"\"Plot ROC curves for all models on a specific dataset\"\"\"
    if raw_predictions is None or not raw_predictions:
        print("No predictions to plot.")
        return
    
    # Set up the figure
    plt.figure(figsize=(12, 8))
    
    # Keep track of datasets plotted
    datasets_plotted = set()
    
    # Plot ROC curve for each model
    for model_name, model_preds in raw_predictions.items():
        # If dataset_name is specified, only plot for that dataset
        if dataset_name:
            if dataset_name in model_preds:
                data = model_preds[dataset_name]
                
                # Check if we have the necessary data
                if 'labels' in data and 'probabilities' in data and data['probabilities'] is not None:
                    # Compute ROC curve
                    fpr, tpr, _ = roc_curve(data['labels'], data['probabilities'])
                    roc_auc = auc(fpr, tpr)
                    
                    # Plot ROC curve
                    plt.plot(fpr, tpr, lw=2, label=f'{model_name} (AUC = {roc_auc:.3f})')
                    datasets_plotted.add(dataset_name)
        else:
            # Plot for all datasets
            for ds_name, data in model_preds.items():
                # Check if we have the necessary data
                if 'labels' in data and 'probabilities' in data and data['probabilities'] is not None:
                    # Compute ROC curve
                    fpr, tpr, _ = roc_curve(data['labels'], data['probabilities'])
                    roc_auc = auc(fpr, tpr)
                    
                    # Plot ROC curve
                    plt.plot(fpr, tpr, lw=2, label=f'{model_name} - {ds_name} (AUC = {roc_auc:.3f})')
                    datasets_plotted.add(ds_name)
    
    # Plot random classifier
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    
    # Set title and labels
    if dataset_name:
        plt.title(f'ROC Curves - {dataset_name}' if dataset_name in datasets_plotted else 'ROC Curves')
    else:
        plt.title('ROC Curves')
    
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(loc="lower right")
    
    plt.show()

def plot_pr_curves(raw_predictions, dataset_name=None):
    \"\"\"Plot Precision-Recall curves for all models on a specific dataset\"\"\"
    if raw_predictions is None or not raw_predictions:
        print("No predictions to plot.")
        return
    
    # Set up the figure
    plt.figure(figsize=(12, 8))
    
    # Keep track of datasets plotted
    datasets_plotted = set()
    
    # Plot PR curve for each model
    for model_name, model_preds in raw_predictions.items():
        # If dataset_name is specified, only plot for that dataset
        if dataset_name:
            if dataset_name in model_preds:
                data = model_preds[dataset_name]
                
                # Check if we have the necessary data
                if 'labels' in data and 'probabilities' in data and data['probabilities'] is not None:
                    # Compute PR curve
                    precision, recall, _ = precision_recall_curve(data['labels'], data['probabilities'])
                    pr_auc = auc(recall, precision)
                    
                    # Plot PR curve
                    plt.plot(recall, precision, lw=2, label=f'{model_name} (AUC = {pr_auc:.3f})')
                    datasets_plotted.add(dataset_name)
        else:
            # Plot for all datasets
            for ds_name, data in model_preds.items():
                # Check if we have the necessary data
                if 'labels' in data and 'probabilities' in data and data['probabilities'] is not None:
                    # Compute PR curve
                    precision, recall, _ = precision_recall_curve(data['labels'], data['probabilities'])
                    pr_auc = auc(recall, precision)
                    
                    # Plot PR curve
                    plt.plot(recall, precision, lw=2, label=f'{model_name} - {ds_name} (AUC = {pr_auc:.3f})')
                    datasets_plotted.add(ds_name)
    
    # Set title and labels
    if dataset_name:
        plt.title(f'Precision-Recall Curves - {dataset_name}' if dataset_name in datasets_plotted else 'Precision-Recall Curves')
    else:
        plt.title('Precision-Recall Curves')
    
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(loc="lower left")
    
    plt.show()

# Plot ROC and PR curves
if 'raw_predictions' in locals() and raw_predictions:
    # Get list of available datasets
    available_datasets = set()
    for model_preds in raw_predictions.values():
        available_datasets.update(model_preds.keys())
    
    print(f"Available datasets: {', '.join(available_datasets)}")
    
    # Plot curves for each dataset
    for dataset in available_datasets:
        print(f"\nROC and PR curves for {dataset} dataset:")
        
        # Plot ROC curve
        plot_roc_curves(raw_predictions, dataset)
        
        # Plot PR curve
        plot_pr_curves(raw_predictions, dataset)
else:
    print("No raw predictions available for plotting curves.")def rank_models(performance_df, metrics=None):
    \"\"\"Rank models based on performance metrics\"\"\"
    if performance_df is None or performance_df.empty:
        print("No performance data to rank.")
        return
    
    if metrics is None:
        metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'auc']
        if 'eer' in performance_df.columns:
            metrics.append('eer')
    
    # Ensure all metrics are in DataFrame
    metrics = [m for m in metrics if m in performance_df.columns]
    
    if not metrics:
        print("No valid metrics to rank.")
        return
    
    # Create a copy of DataFrame
    df = performance_df.copy()
    
    # Create rank columns for each metric
    for metric in metrics:
        # For EER, lower is better
        if metric == 'eer':
            df[f'{metric}_rank'] = df.groupby('dataset')[metric].rank(ascending=True)
        else:
            df[f'{metric}_rank'] = df.groupby('dataset')[metric].rank(ascending=False)
    
    # Calculate average rank across metrics
    rank_cols = [f'{metric}_rank' for metric in metrics]
    df['avg_rank'] = df[rank_cols].mean(axis=1)
    
    # Sort by average rank
    df = df.sort_values(['dataset', 'avg_rank'])
    
    # Select columns to display
    display_cols = ['model', 'dataset'] + metrics + ['avg_rank']
    
    # Reset index for better display
    df_display = df[display_cols].reset_index(drop=True)
    
    return df_display

# Rank models
if 'performance_df' in locals() and performance_df is not None and not performance_df.empty:
    ranked_df = rank_models(performance_df)
    
    if ranked_df is not None:
        print("\nModel Rankings:")
        display(ranked_df)
    else:
        print("Could not rank models.")
else:
    print("No performance data to rank.")def visualize_performance_metrics(performance_df):
    \"\"\"Visualize performance metrics across models and datasets\"\"\"
    if performance_df is None or performance_df.empty:
        print("No performance data to visualize.")
        return
    
    # Set up the figure
    fig = plt.figure(figsize=(20, 15))
    
    # Define metrics to plot
    metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'auc']
    if 'eer' in performance_df.columns:
        metrics.append('eer')
    
    # Number of rows and columns
    n_metrics = len(metrics)
    n_rows = (n_metrics + 1) // 2
    n_cols = 2
    
    # Plot each metric
    for i, metric in enumerate(metrics):
        if metric in performance_df.columns:
            ax = fig.add_subplot(n_rows, n_cols, i+1)
            
            # Create bar plot
            sns.barplot(x='model', y=metric, hue='dataset', data=performance_df, ax=ax)
            
            # Set title and labels
            ax.set_title(f'{metric.replace("_", " ").title()}')
            ax.set_xlabel('Model')
            ax.set_ylabel(metric)
            
            # Rotate x-labels for better readability
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
            
            # Set y-axis limits
            if metric != 'eer':  # For EER, lower is better
                ax.set_ylim(0.5, 1.05)
            else:
                ax.set_ylim(0, 0.5)
            
            # Add legend
            ax.legend(title='Dataset')
    
    # Add overall title
    plt.suptitle('Model Performance Comparison', fontsize=16)
    
    # Adjust layout
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

# Visualize performance metrics
if 'performance_df' in locals() and performance_df is not None and not performance_df.empty:
    visualize_performance_metrics(performance_df)
else:
    print("No performance data to visualize.")## 2. Analyze Model Performance Metrics

def create_performance_summary(evaluation_results):
    \"\"\"Create a summary table of model performance metrics\"\"\"
    if not evaluation_results:
        return None
    
    # Create a list to store performance data
    performance_data = []
    
    for model_name, model_results in evaluation_results.items():
        for dataset_name, dataset_results in model_results.items():
            # Extract metrics
            metrics = {
                'model': model_name,
                'dataset': dataset_name,
                'accuracy': dataset_results.get('accuracy', None),
                'precision': dataset_results.get('precision', None),
                'recall': dataset_results.get('recall', None),
                'f1_score': dataset_results.get('f1_score', None),
                'auc': dataset_results.get('auc', None),
                'eer': dataset_results.get('eer', None)
            }
            
            performance_data.append(metrics)
    
    # Create DataFrame
    if performance_data:
        df = pd.DataFrame(performance_data)
        
        # Ensure all metric columns are float
        for col in ['accuracy', 'precision', 'recall', 'f1_score', 'auc', 'eer']:
            if col in df:
                df[col] = df[col].astype(float)
        
        return df
    
    return None

# Create performance summary
if evaluation_results:
    performance_df = create_performance_summary(evaluation_results)
    
    if performance_df is not None:
        print("Performance Summary:")
        display(performance_df)
    else:
        print("Could not create performance summary.")
else:
    print("No evaluation results to analyze."){
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Results Analysis\n",
    "\n",
    "This notebook provides tools for analyzing the results of deepfake detection models. It includes:\n",
    "\n",
    "1. Loading and visualizing evaluation metrics\n",
    "2. Comparative analysis of different models\n",
    "3. Error analysis and identification of challenging cases\n",
    "4. Cross-dataset generalization analysis\n",
    "5. ROC curves and precision-recall analysis\n",
    "6. Ensembling and fusion performance assessment\n",
    "\n",
    "These analyses help understand the strengths and weaknesses of different approaches and guide future improvements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import sys\n",
    "import json\n",
    "import yaml\n",
    "import glob\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from PIL import Image\n",
    "import cv2\n",
    "from pathlib import Path\n",
    "from sklearn.metrics import roc_curve, precision_recall_curve, auc\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# Add parent directory to path to enable imports from project\n",
    "sys.path.append(os.path.abspath('..'))\n",
    "\n",
    "# Set plot style\n",
    "plt.style.use('fivethirtyeight')\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Evaluation Results\n",
    "\n",
    "Load evaluation results from the evaluation output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Set the evaluation results directory - update to your local path\n",
    "EVAL_DIR = \"../evaluation_results\"\n",
    "\n",
    "# Check if directory exists\n",
    "if not os.path.exists(EVAL_DIR):\n",
    "    print(f\"Evaluation directory {EVAL_DIR} does not exist. Please update the path.\")\n",
    "else:\n",
    "    print(f\"Found evaluation directory at {EVAL_DIR}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_evaluation_results(eval_dir):\n",
    "    \"\"\"Load evaluation results from directory\"\"\"\n",
    "    if not os.path.exists(eval_dir):\n",
    "        return None\n",
    "    \n",
    "    results = {}\n",
    "    \n",
    "    # Find all model directories\n",
    "    model_dirs = [d for d in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, d))]\n",
    "    \n",
    "    for model_name in model_dirs:\n",
    "        model_path = os.path.join(eval_dir, model_name)\n",
    "        results_file = os.path.join(model_path, \"all_results.json\")\n",
    "        \n",
    "        if os.path.exists(results_file):\n",
    "            with open(results_file, 'r') as f:\n",
    "                model_results = json.load(f)\n",
    "            \n",
    "            # Store results\n",
    "            results[model_name] = model_results\n",
    "    \n",
    "    return results\n",
    "\n",
    "# Load results\n",
    "evaluation_results = load_evaluation_results(EVAL_DIR)\n",
    "\n",
    "if evaluation_results is None:\n",
    "    print(\"No evaluation results found.\")\n",
    "elif not evaluation_results:\n",
    "    print(\"No models found in evaluation results.\")\n",
    "else:\n",
    "    print(f\"Loaded evaluation results for {len(evaluation_results)} models:\")\n",
    "    for model_name in evaluation_results.keys():\n",
    "        print(f\"  - {model_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_raw_predictions(eval_dir):\n",
    "    \"\"\"Load raw predictions for detailed analysis\"\"\"\n",
    "    if not os.path.exists(eval_dir):\n",
    "        return None\n",
    "    \n",
    "    predictions = {}\n",
    "    \n",
    "    # Find all model directories\n",
    "    model_dirs = [d for d in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, d))]\n",
    "    \n",
    "    for model_name in model_dirs:\n",
    "        model_path = os.path.join(eval_dir, model_name)\n",
    "        model_predictions = {}\n",
    "        \n",
    "        # Find dataset directories\n",
    "        dataset_dirs = [d for d in os.listdir(model_path) \n",
    "                       if os.path.isdir(os.path.join(model_path, d))]\n",
    "        \n",
    "        for dataset_name in dataset_dirs:\n",
    "            dataset_path = os.path.join(model_path, dataset_name)\n",
    "            pred_file = os