In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches

# Set consistent style for all plots
plt.style.use('default')
sns.set_palette("husl")

# Global plot settings for thesis quality
PLOT_CONFIG = {
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'font.size': 16,
    'axes.titlesize': 20,
    'axes.labelsize': 18,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'legend.fontsize': 16,
    'figure.figsize': (12, 8),
    'lines.linewidth': 3,
    'lines.markersize': 10,
    'grid.alpha': 0.3
}

# Apply settings
plt.rcParams.update(PLOT_CONFIG)

# Consistent color scheme
COLORS = {
    'BNScale': '#2E86AB',      # Blue
    'MagnitudeL2': '#A23B72',  # Purple/Pink
    'Random': '#F18F01',       # Orange
    'MobileNetV2': '#2E86AB',
    'ResNet-18': '#A23B72',
    'MLP': '#F18F01',
    'LSTM': '#C73E1D'
}

def load_all_results():
    """Load all experimental results from JSON files"""
    try:
        with open('mobileNetV2_results.json', 'r') as f:
            mobilenet_results = json.load(f)
        with open('resnet18_results.json', 'r') as f:
            resnet_results = json.load(f)
        with open('mlp_results.json', 'r') as f:
            mlp_results = json.load(f)
        with open('lstm_results.json', 'r') as f:
            lstm_results = json.load(f)
        return mobilenet_results, resnet_results, mlp_results, lstm_results
    except FileNotFoundError as e:
        print(f"Error loading files: {e}")
        return None, None, None, None

def results_to_dataframe(results, model_name, task_type='classification'):
    """Convert results dictionary to DataFrame"""
    data = []
    for strategy, strategy_results in results.items():
        for pruning_ratio_str, metrics in strategy_results.items():
            pruning_ratio = float(pruning_ratio_str)
            row = {
                'model': model_name,
                'strategy': strategy,
                'pruning_ratio': pruning_ratio,
                'pruning_ratio_percent': pruning_ratio * 100,
                'macs': float(metrics['macs']),
                'macs_millions': float(metrics['macs']) / 1e6,
                'params': int(metrics['params']),
                'params_millions': int(metrics['params']) / 1e6,
                'loss': float(metrics['loss'])
            }
            
            if task_type == 'classification':
                row['accuracy'] = float(metrics['accuracy'])
            else:  # regression
                row['mse'] = float(metrics.get('mse', metrics.get('loss', 0)))
                row['mae'] = float(metrics.get('mae', 0))
            
            data.append(row)
    
    return pd.DataFrame(data)

def save_plot(filename, output_dir='thesis_plots'):
    """Save plot with consistent settings"""
    os.makedirs(output_dir, exist_ok=True)
    filepath = os.path.join(output_dir, filename)
    plt.tight_layout()
    plt.savefig(filepath, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    print(f"✅ Saved: {filepath}")

# 1. Baseline Performance Comparison
def plot_baseline_comparison(mobilenet_df, resnet_df, mlp_df, lstm_df):
    """Plot baseline performance across all models"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # Get baseline data (pruning_ratio = 0.0)
    baselines = []
    for df, model in [(mobilenet_df, 'MobileNetV2'), (resnet_df, 'ResNet-18'), 
                      (mlp_df, 'MLP'), (lstm_df, 'LSTM')]:
        baseline = df[df['pruning_ratio'] == 0.0].iloc[0]
        baselines.append({
            'model': model,
            'accuracy': baseline.get('accuracy', None),
            'mse': baseline.get('mse', None),
            'macs_millions': baseline['macs_millions'],
            'params_millions': baseline['params_millions']
        })
    
    baseline_df = pd.DataFrame(baselines)
    
    # Accuracy comparison (CNN models)
    cnn_models = baseline_df[baseline_df['accuracy'].notna()]
    bars1 = ax1.bar(cnn_models['model'], cnn_models['accuracy'], 
                    color=[COLORS['MobileNetV2'], COLORS['ResNet-18']], alpha=0.8)
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_title('Baseline Accuracy Comparison')
    ax1.grid(True, alpha=0.3)
    # Add value labels
    for bar in bars1:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}%', ha='center', va='bottom', fontsize=14)
    
    # MSE comparison (Time-series models)
    ts_models = baseline_df[baseline_df['mse'].notna()]
    bars2 = ax2.bar(ts_models['model'], ts_models['mse'], 
                    color=[COLORS['MLP'], COLORS['LSTM']], alpha=0.8)
    ax2.set_ylabel('MSE')
    ax2.set_title('Baseline MSE Comparison')
    ax2.grid(True, alpha=0.3)
    # Add value labels
    for bar in bars2:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}', ha='center', va='bottom', fontsize=14)
    
    # MACs comparison
    bars3 = ax3.bar(baseline_df['model'], baseline_df['macs_millions'], 
                    color=[COLORS[model] for model in baseline_df['model']], alpha=0.8)
    ax3.set_ylabel('MACs (Millions)')
    ax3.set_title('Baseline Computational Cost')
    ax3.set_yscale('log')
    ax3.grid(True, alpha=0.3)
    # Add value labels
    for bar in bars3:
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}M', ha='center', va='bottom', fontsize=12)
    
    # Model parameters comparison
    bars4 = ax4.bar(baseline_df['model'], baseline_df['params_millions'], 
                    color=[COLORS[model] for model in baseline_df['model']], alpha=0.8)
    ax4.set_ylabel('Parameters (Millions)')
    ax4.set_title('Baseline Model Parameters')
    ax4.set_yscale('log')
    ax4.grid(True, alpha=0.3)
    # Add value labels
    for bar in bars4:
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}M', ha='center', va='bottom', fontsize=12)
    
    plt.suptitle('Baseline Performance Comparison Across All Models', fontsize=24, fontweight='bold')
    save_plot('01_baseline_comparison.png')

# 2. MobileNetV2 Accuracy vs Pruning Ratio
def plot_mobilenetv2_accuracy_pruning_ratio(mobilenet_df):
    """Plot MobileNetV2 accuracy vs pruning ratio"""
    plt.figure(figsize=(12, 8))
    
    for strategy in mobilenet_df['strategy'].unique():
        data = mobilenet_df[mobilenet_df['strategy'] == strategy].sort_values('pruning_ratio')
        plt.plot(data['pruning_ratio_percent'], data['accuracy'], 'o-',
                linewidth=3, markersize=10, label=strategy, color=COLORS[strategy])
    
    plt.xlabel('Pruning Ratio (%)')
    plt.ylabel('Accuracy (%)')
    plt.title('MobileNetV2: Accuracy vs Pruning Ratio', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Calculate and set dynamic y-axis limits
    min_acc = mobilenet_df['accuracy'].min()
    max_acc = mobilenet_df['accuracy'].max()
    plt.ylim(min_acc - 1, max_acc + 1)
    
    save_plot('02_mobilenetv2_accuracy_vs_pruning_ratio.png')

# 3. ResNet-18 Accuracy vs Pruning Ratio
def plot_resnet18_accuracy_pruning_ratio(resnet_df):
    """Plot ResNet-18 accuracy vs pruning ratio"""
    plt.figure(figsize=(12, 8))
    
    for strategy in resnet_df['strategy'].unique():
        data = resnet_df[resnet_df['strategy'] == strategy].sort_values('pruning_ratio')
        plt.plot(data['pruning_ratio_percent'], data['accuracy'], 'o-',
                linewidth=3, markersize=10, label=strategy, color=COLORS[strategy])
    
    plt.xlabel('Pruning Ratio (%)')
    plt.ylabel('Accuracy (%)')
    plt.title('ResNet-18: Accuracy vs Pruning Ratio', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Calculate and set dynamic y-axis limits
    min_acc = resnet_df['accuracy'].min()
    max_acc = resnet_df['accuracy'].max()
    plt.ylim(min_acc - 1, max_acc + 1)
    
    save_plot('03_resnet18_accuracy_vs_pruning_ratio.png')

# 4. MLP Performance vs Pruning Ratio
def plot_mlp_performance_pruning_ratio(mlp_df):
    """Plot MLP MSE and MAE vs pruning ratio"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # MSE plot
    for strategy in mlp_df['strategy'].unique():
        data = mlp_df[mlp_df['strategy'] == strategy].sort_values('pruning_ratio')
        ax1.plot(data['pruning_ratio_percent'], data['mse'], 'o-',
                linewidth=3, markersize=10, label=strategy, color=COLORS[strategy])
    
    ax1.set_xlabel('Pruning Ratio (%)')
    ax1.set_ylabel('MSE')
    ax1.set_title('MLP: MSE vs Pruning Ratio', fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # MAE plot
    for strategy in mlp_df['strategy'].unique():
        data = mlp_df[mlp_df['strategy'] == strategy].sort_values('pruning_ratio')
        ax2.plot(data['pruning_ratio_percent'], data['mae'], 'o-',
                linewidth=3, markersize=10, label=strategy, color=COLORS[strategy])
    
    ax2.set_xlabel('Pruning Ratio (%)')
    ax2.set_ylabel('MAE')
    ax2.set_title('MLP: MAE vs Pruning Ratio', fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    save_plot('04_mlp_performance_vs_pruning_ratio.png')

# 5. LSTM Performance vs Pruning Ratio
def plot_lstm_performance_pruning_ratio(lstm_df):
    """Plot LSTM MSE and MAE vs pruning ratio"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # MSE plot
    for strategy in lstm_df['strategy'].unique():
        data = lstm_df[lstm_df['strategy'] == strategy].sort_values('pruning_ratio')
        ax1.plot(data['pruning_ratio_percent'], data['mse'], 'o-',
                linewidth=3, markersize=10, label=strategy, color=COLORS[strategy])
    
    ax1.set_xlabel('Pruning Ratio (%)')
    ax1.set_ylabel('MSE')
    ax1.set_title('LSTM: MSE vs Pruning Ratio', fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # MAE plot
    for strategy in lstm_df['strategy'].unique():
        data = lstm_df[lstm_df['strategy'] == strategy].sort_values('pruning_ratio')
        ax2.plot(data['pruning_ratio_percent'], data['mae'], 'o-',
                linewidth=3, markersize=10, label=strategy, color=COLORS[strategy])
    
    ax2.set_xlabel('Pruning Ratio (%)')
    ax2.set_ylabel('MAE')
    ax2.set_title('LSTM: MAE vs Pruning Ratio', fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    save_plot('05_lstm_performance_vs_pruning_ratio.png')

# 6. MobileNetV2 Efficiency Frontier (Accuracy vs MACs)
def plot_mobilenetv2_efficiency_frontier(mobilenet_df):
    """Plot MobileNetV2 Pareto frontier (Accuracy vs MACs)"""
    plt.figure(figsize=(12, 8))
    
    for strategy in mobilenet_df['strategy'].unique():
        data = mobilenet_df[mobilenet_df['strategy'] == strategy].sort_values('macs_millions')
        plt.scatter(data['macs_millions'], data['accuracy'], 
                   s=150, label=strategy, color=COLORS[strategy], alpha=0.8, 
                   edgecolors='black', linewidth=1)
        plt.plot(data['macs_millions'], data['accuracy'], 
                linestyle='--', alpha=0.7, linewidth=2, color=COLORS[strategy])
    
    plt.xlabel('MACs (Millions)')
    plt.ylabel('Accuracy (%)')
    plt.title('MobileNetV2: Accuracy vs Computational Cost (Pareto Frontier)', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    save_plot('06_mobilenetv2_efficiency_frontier.png')

# 7. ResNet-18 Efficiency Frontier (Accuracy vs MACs)
def plot_resnet18_efficiency_frontier(resnet_df):
    """Plot ResNet-18 Pareto frontier (Accuracy vs MACs)"""
    plt.figure(figsize=(12, 8))
    
    for strategy in resnet_df['strategy'].unique():
        data = resnet_df[resnet_df['strategy'] == strategy].sort_values('macs_millions')
        plt.scatter(data['macs_millions'], data['accuracy'], 
                   s=150, label=strategy, color=COLORS[strategy], alpha=0.8,
                   edgecolors='black', linewidth=1)
        plt.plot(data['macs_millions'], data['accuracy'], 
                linestyle='--', alpha=0.7, linewidth=2, color=COLORS[strategy])
    
    plt.xlabel('MACs (Millions)')
    plt.ylabel('Accuracy (%)')
    plt.title('ResNet-18: Accuracy vs Computational Cost (Pareto Frontier)', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    save_plot('07_resnet18_efficiency_frontier.png')

# 8. Time-Series Models Efficiency Frontier (NEW)
def plot_timeseries_efficiency_frontier(mlp_df, lstm_df):
    """Plot Time-Series models Pareto frontier (MSE vs MACs)"""
    plt.figure(figsize=(12, 8))
    
    # Plot MLP
    for strategy in mlp_df['strategy'].unique():
        data = mlp_df[mlp_df['strategy'] == strategy].sort_values('macs_millions')
        plt.scatter(data['macs_millions'], data['mse'], 
                   s=150, label=f'MLP-{strategy}', color=COLORS[strategy], 
                   alpha=0.8, edgecolors='black', linewidth=1, marker='o')
        plt.plot(data['macs_millions'], data['mse'], 
                linestyle='--', alpha=0.7, linewidth=2, color=COLORS[strategy])
    
    # Plot LSTM
    for strategy in lstm_df['strategy'].unique():
        data = lstm_df[lstm_df['strategy'] == strategy].sort_values('macs_millions')
        plt.scatter(data['macs_millions'], data['mse'], 
                   s=150, label=f'LSTM-{strategy}', color=COLORS[strategy], 
                   alpha=0.8, edgecolors='black', linewidth=1, marker='s')
        plt.plot(data['macs_millions'], data['mse'], 
                linestyle=':', alpha=0.7, linewidth=2, color=COLORS[strategy])
    
    plt.xlabel('MACs (Millions)')
    plt.ylabel('MSE')
    plt.title('Time-Series Models: MSE vs Computational Cost (Pareto Frontier)', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    save_plot('08_timeseries_efficiency_frontier.png')

# 9. Combined Efficiency Frontier (All Models) (NEW)
def plot_combined_efficiency_frontier(mobilenet_df, resnet_df, mlp_df, lstm_df):
    """Plot combined Pareto frontier for all models"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # Classification models (Accuracy vs MACs)
    for df, model_name in [(mobilenet_df, 'MobileNetV2'), (resnet_df, 'ResNet-18')]:
        strategy = 'MagnitudeL2' if 'MagnitudeL2' in df['strategy'].unique() else df['strategy'].unique()[0]
        data = df[df['strategy'] == strategy].sort_values('macs_millions')
        ax1.scatter(data['macs_millions'], data['accuracy'], 
                   s=150, label=model_name, color=COLORS[model_name], 
                   alpha=0.8, edgecolors='black', linewidth=1)
        ax1.plot(data['macs_millions'], data['accuracy'], 
                linestyle='--', alpha=0.7, linewidth=2, color=COLORS[model_name])
    
    ax1.set_xlabel('MACs (Millions)')
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_title('Classification Models: Combined Efficiency Frontier', fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_xscale('log')
    
    # Regression models (MSE vs MACs)
    for df, model_name in [(mlp_df, 'MLP'), (lstm_df, 'LSTM')]:
        strategy = 'MagnitudeL2' if 'MagnitudeL2' in df['strategy'].unique() else df['strategy'].unique()[0]
        data = df[df['strategy'] == strategy].sort_values('macs_millions')
        ax2.scatter(data['macs_millions'], data['mse'], 
                   s=150, label=model_name, color=COLORS[model_name], 
                   alpha=0.8, edgecolors='black', linewidth=1)
        ax2.plot(data['macs_millions'], data['mse'], 
                linestyle='--', alpha=0.7, linewidth=2, color=COLORS[model_name])
    
    ax2.set_xlabel('MACs (Millions)')
    ax2.set_ylabel('MSE')
    ax2.set_title('Regression Models: Combined Efficiency Frontier', fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_xscale('log')
    
    save_plot('09_combined_efficiency_frontier.png')

# 10. CNN Models Comparison
def plot_cnn_comparison(mobilenet_df, resnet_df):
    """Compare CNN models performance"""
    plt.figure(figsize=(12, 8))
    
    # Plot MobileNetV2 (MagnitudeL2 strategy)
    mob_data = mobilenet_df[mobilenet_df['strategy'] == 'MagnitudeL2'].sort_values('pruning_ratio')
    plt.plot(mob_data['pruning_ratio_percent'], mob_data['accuracy'], 'o-',
            linewidth=3, markersize=10, label='MobileNetV2', color=COLORS['MobileNetV2'])
    
    # Plot ResNet-18 (MagnitudeL2 strategy)
    res_data = resnet_df[resnet_df['strategy'] == 'MagnitudeL2'].sort_values('pruning_ratio')
    plt.plot(res_data['pruning_ratio_percent'], res_data['accuracy'], 's-',
            linewidth=3, markersize=10, label='ResNet-18', color=COLORS['ResNet-18'])
    
    plt.xlabel('Pruning Ratio (%)')
    plt.ylabel('Accuracy (%)')
    plt.title('CNN Architectures Comparison: Accuracy vs Pruning Ratio (MagnitudeL2)', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    save_plot('10_cnn_comparison.png')

# 11. Time-Series Models Comparison
def plot_timeseries_comparison(mlp_df, lstm_df):
    """Compare time-series models performance"""
    plt.figure(figsize=(12, 8))
    
    # Plot MLP (MagnitudeL2 strategy)
    mlp_data = mlp_df[mlp_df['strategy'] == 'MagnitudeL2'].sort_values('pruning_ratio')
    plt.plot(mlp_data['pruning_ratio_percent'], mlp_data['mse'], 'o-',
            linewidth=3, markersize=10, label='MLP', color=COLORS['MLP'])
    
    # Plot LSTM (MagnitudeL2 strategy)
    lstm_data = lstm_df[lstm_df['strategy'] == 'MagnitudeL2'].sort_values('pruning_ratio')
    plt.plot(lstm_data['pruning_ratio_percent'], lstm_data['mse'], 's-',
            linewidth=3, markersize=10, label='LSTM', color=COLORS['LSTM'])
    
    plt.xlabel('Pruning Ratio (%)')
    plt.ylabel('MSE')
    plt.title('Time-Series Models Comparison: MSE vs Pruning Ratio (MagnitudeL2)', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    save_plot('11_timeseries_comparison.png')

# 12. Model Parameters Reduction Analysis
def plot_model_params_reduction(mobilenet_df, resnet_df, mlp_df, lstm_df):
    """Plot model parameters reduction across all models"""
    plt.figure(figsize=(14, 8))
    
    models = ['MobileNetV2', 'ResNet-18', 'MLP', 'LSTM']
    pruning_ratio_levels = [0.0, 0.2, 0.5, 0.7]
    
    x = np.arange(len(models))
    width = 0.2
    
    for i, pruning_ratio in enumerate(pruning_ratio_levels):
        params = []
        for df in [mobilenet_df, resnet_df, mlp_df, lstm_df]:
            # Use MagnitudeL2 strategy for comparison
            if 'MagnitudeL2' in df['strategy'].unique():
                param_count = df[(df['strategy'] == 'MagnitudeL2') & 
                         (df['pruning_ratio'] == pruning_ratio)]['params_millions'].values[0]
            else:  # For models without MagnitudeL2, use first available strategy
                param_count = df[df['pruning_ratio'] == pruning_ratio]['params_millions'].values[0]
            params.append(param_count)
        
        bars = plt.bar(x + i*width, params, width, 
                      label=f'{int(pruning_ratio*100)}% Pruning Ratio', alpha=0.8)
        
        # Add value labels on bars
        for bar, param_count in zip(bars, params):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{param_count:.1f}M', ha='center', va='bottom', fontsize=12)
    
    plt.xlabel('Model Architecture')
    plt.ylabel('Parameters (Millions)')
    plt.title('Parameter Reduction Across Architectures', fontweight='bold')
    plt.xticks(x + width * 1.5, models)
    plt.legend()
    plt.grid(True, alpha=0.3, axis='y')
    plt.yscale('log')
    
    save_plot('12_model_params_reduction.png')

# 13. MACs Reduction Analysis
def plot_macs_reduction(mobilenet_df, resnet_df, mlp_df, lstm_df):
    """Plot MACs reduction across all models"""
    plt.figure(figsize=(12, 8))
    
    models_data = [
        (mobilenet_df, 'MobileNetV2', COLORS['MobileNetV2']),
        (resnet_df, 'ResNet-18', COLORS['ResNet-18']),
        (mlp_df, 'MLP', COLORS['MLP']),
        (lstm_df, 'LSTM', COLORS['LSTM'])
    ]
    
    for df, model_name, color in models_data:
        # Use MagnitudeL2 strategy if available
        if 'MagnitudeL2' in df['strategy'].unique():
            data = df[df['strategy'] == 'MagnitudeL2'].sort_values('pruning_ratio')
        else:
            data = df[df['strategy'] == df['strategy'].unique()[0]].sort_values('pruning_ratio')
        
        baseline_macs = data[data['pruning_ratio'] == 0.0]['macs_millions'].values[0]
        macs_retained = (data['macs_millions'] / baseline_macs) * 100
        
        plt.plot(data['pruning_ratio_percent'], macs_retained, 'o-',
                linewidth=3, markersize=10, label=model_name, color=color)
    
    plt.xlabel('Pruning Ratio (%)')
    plt.ylabel('MACs Retained (%)')
    plt.title('Computational Cost Reduction Across Models', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 105)
    
    # Add reference lines
    plt.axhline(y=50, color='gray', linestyle=':', alpha=0.7, linewidth=2)
    plt.text(35, 52, '50% MACs Retained', fontsize=14, color='gray')
    
    save_plot('13_macs_reduction.png')

# 14. Performance Retention Analysis
def plot_performance_retention(mobilenet_df, resnet_df, mlp_df, lstm_df):
    """Plot performance retention across models"""
    plt.figure(figsize=(14, 8))
    
    pruning_ratio_levels = [0.2, 0.5, 0.7]
    models = ['MobileNetV2', 'ResNet-18', 'MLP', 'LSTM']
    
    x = np.arange(len(pruning_ratio_levels))
    width = 0.2
    
    for i, (df, model_name, metric) in enumerate([
        (mobilenet_df, 'MobileNetV2', 'accuracy'),
        (resnet_df, 'ResNet-18', 'accuracy'),
        (mlp_df, 'MLP', 'mse'),
        (lstm_df, 'LSTM', 'mse')
    ]):
        retention_rates = []
        
        # Use MagnitudeL2 if available
        strategy = 'MagnitudeL2' if 'MagnitudeL2' in df['strategy'].unique() else df['strategy'].unique()[0]
        baseline = df[(df['strategy'] == strategy) & (df['pruning_ratio'] == 0.0)].iloc[0]
        
        for pruning_ratio in pruning_ratio_levels:
            pruned = df[(df['strategy'] == strategy) & (df['pruning_ratio'] == pruning_ratio)].iloc[0]
            
            if metric == 'accuracy':
                retention = (pruned['accuracy'] / baseline['accuracy']) * 100
            else:  # For MSE, inverse retention (lower is better)
                retention = (baseline['mse'] / pruned['mse']) * 100
            
            retention_rates.append(retention)
        
        bars = plt.bar(x + i*width, retention_rates, width, 
               label=model_name, alpha=0.8, color=COLORS[model_name])
        
        # Add value labels
        for bar, rate in zip(bars, retention_rates):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{rate:.1f}%', ha='center', va='bottom', fontsize=12)
    
    plt.xlabel('Pruning Ratio (%)')
    plt.ylabel('Performance Retention (%)')
    plt.title('Performance Retention Across Models and Pruning Ratios', fontweight='bold')
    plt.xticks(x + width * 1.5, [f'{int(s*100)}%' for s in pruning_ratio_levels])
    plt.legend()
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add reference line at 100%
    plt.axhline(y=100, color='black', linestyle='--', alpha=0.5, linewidth=2)
    
    save_plot('14_performance_retention.png')

# 15. Strategy Effectiveness Heatmap
def plot_strategy_effectiveness_heatmap(mobilenet_df, resnet_df):
    """Plot strategy effectiveness heatmap for CNN models"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # MobileNetV2 heatmap
    pivot_mob = mobilenet_df.pivot(index='strategy', columns='pruning_ratio_percent', values='accuracy')
    baseline_mob = pivot_mob[0.0]
    normalized_mob = (pivot_mob.T / baseline_mob * 100).T
    
    sns.heatmap(normalized_mob, annot=True, fmt='.1f', cmap='RdYlGn',
                center=100, ax=ax1, cbar_kws={'label': 'Relative Accuracy (%)'})
    ax1.set_title('MobileNetV2: Strategy Effectiveness', fontweight='bold')
    ax1.set_xlabel('Pruning Ratio (%)')
    ax1.set_ylabel('Pruning Strategy')
    
    # ResNet-18 heatmap
    pivot_res = resnet_df.pivot(index='strategy', columns='pruning_ratio_percent', values='accuracy')
    baseline_res = pivot_res[0.0]
    normalized_res = (pivot_res.T / baseline_res * 100).T
    
    sns.heatmap(normalized_res, annot=True, fmt='.1f', cmap='RdYlGn',
                center=100, ax=ax2, cbar_kws={'label': 'Relative Accuracy (%)'})
    ax2.set_title('ResNet-18: Strategy Effectiveness', fontweight='bold')
    ax2.set_xlabel('Pruning Ratio (%)')
    ax2.set_ylabel('Pruning Strategy')
    
    save_plot('15_strategy_effectiveness_heatmap.png')

# 16. Comprehensive Strategy Effectiveness (NEW)
def plot_comprehensive_strategy_effectiveness(mobilenet_df, resnet_df, mlp_df, lstm_df):
    """Plot comprehensive strategy effectiveness across all models"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 16))
    
    models_data = [
        (mobilenet_df, 'MobileNetV2', 'accuracy', ax1),
        (resnet_df, 'ResNet-18', 'accuracy', ax2),
        (mlp_df, 'MLP', 'mse', ax3),
        (lstm_df, 'LSTM', 'mse', ax4)
    ]
    
    for df, model_name, metric, ax in models_data:
        strategies = df['strategy'].unique()
        pruning_ratios = [0.2, 0.5, 0.7]
        
        # Calculate strategy effectiveness
        effectiveness_data = []
        for strategy in strategies:
            baseline = df[(df['strategy'] == strategy) & (df['pruning_ratio'] == 0.0)].iloc[0]
            strategy_scores = []
            
            for ratio in pruning_ratios:
                pruned = df[(df['strategy'] == strategy) & (df['pruning_ratio'] == ratio)].iloc[0]
                
                if metric == 'accuracy':
                    score = (pruned['accuracy'] / baseline['accuracy']) * 100
                else:  # MSE - inverse score
                    score = (baseline['mse'] / pruned['mse']) * 100
                
                strategy_scores.append(score)
            
            effectiveness_data.append(strategy_scores)
        
        # Create heatmap
        effectiveness_array = np.array(effectiveness_data)
        im = ax.imshow(effectiveness_array, cmap='RdYlGn', aspect='auto', vmin=90, vmax=105)
        
        # Add text annotations
        for i in range(len(strategies)):
            for j in range(len(pruning_ratios)):
                text = ax.text(j, i, f'{effectiveness_array[i, j]:.1f}%',
                             ha="center", va="center", color="black", fontsize=14)
        
        ax.set_xticks(range(len(pruning_ratios)))
        ax.set_xticklabels([f'{int(r*100)}%' for r in pruning_ratios])
        ax.set_yticks(range(len(strategies)))
        ax.set_yticklabels(strategies)
        ax.set_title(f'{model_name}: Strategy Effectiveness', fontweight='bold')
        ax.set_xlabel('Pruning Ratio')
        ax.set_ylabel('Strategy')
    
    plt.suptitle('Comprehensive Strategy Effectiveness Analysis', fontsize=24, fontweight='bold')
    save_plot('16_comprehensive_strategy_effectiveness.png')

# 17. LSTM Special Analysis
def plot_lstm_special_analysis(lstm_df):
    """Special analysis for LSTM showing minimal MACs reduction"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # MACs vs Pruning Ratio
    for strategy in lstm_df['strategy'].unique():
        data = lstm_df[lstm_df['strategy'] == strategy].sort_values('pruning_ratio')
        ax1.plot(data['pruning_ratio_percent'], data['macs_millions'], 'o-',
                linewidth=3, markersize=10, label=strategy, color=COLORS[strategy])
    
    ax1.set_xlabel('Pruning Ratio (%)')
    ax1.set_ylabel('MACs (Millions)')
    ax1.set_title('LSTM: MACs vs Pruning Ratio (Minimal Reduction)', fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Parameters vs MSE
    for strategy in lstm_df['strategy'].unique():
        data = lstm_df[lstm_df['strategy'] == strategy].sort_values('params')
        ax2.scatter(data['params'] / 1000, data['mse'], s=150, 
                   label=strategy, alpha=0.8, color=COLORS[strategy])
        ax2.plot(data['params'] / 1000, data['mse'], '--', 
                alpha=0.7, linewidth=2, color=COLORS[strategy])
    
    ax2.set_xlabel('Parameters (Thousands)')
    ax2.set_ylabel('MSE')
    ax2.set_title('LSTM: Parameter Reduction vs Performance', fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    save_plot('17_lstm_special_analysis.png')

# 18. Performance Degradation Rate Analysis (NEW)
def plot_performance_degradation_rate(mobilenet_df, resnet_df, mlp_df, lstm_df):
    """Plot performance degradation rate analysis"""
    plt.figure(figsize=(14, 8))
    
    models_data = [
        (mobilenet_df, 'MobileNetV2', 'accuracy', COLORS['MobileNetV2']),
        (resnet_df, 'ResNet-18', 'accuracy', COLORS['ResNet-18']),
        (mlp_df, 'MLP', 'mse', COLORS['MLP']),
        (lstm_df, 'LSTM', 'mse', COLORS['LSTM'])
    ]
    
    degradation_rates = []
    model_names = []
    
    for df, model_name, metric, color in models_data:
        strategy = 'MagnitudeL2' if 'MagnitudeL2' in df['strategy'].unique() else df['strategy'].unique()[0]
        data = df[df['strategy'] == strategy].sort_values('pruning_ratio')
        
        # Calculate degradation rate (slope)
        baseline_perf = data[data['pruning_ratio'] == 0.0][metric].values[0]
        max_pruned_perf = data[data['pruning_ratio'] == 0.7][metric].values[0]
        
        if metric == 'accuracy':
            degradation_rate = (baseline_perf - max_pruned_perf) / 70  # per 1% pruning
        else:  # MSE
            degradation_rate = (max_pruned_perf - baseline_perf) / 70  # per 1% pruning
        
        degradation_rates.append(abs(degradation_rate))
        model_names.append(model_name)
    
    bars = plt.bar(model_names, degradation_rates, 
                   color=[COLORS[name] for name in model_names], alpha=0.8)
    
    plt.ylabel('Performance Degradation Rate (per 1% Pruning)')
    plt.title('Model Sensitivity to Pruning: Performance Degradation Rate', fontweight='bold')
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar, rate in zip(bars, degradation_rates):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{rate:.3f}', ha='center', va='bottom', fontsize=14)
    
    save_plot('18_performance_degradation_rate.png')

# 19. Edge Deployment Feasibility Analysis
def plot_edge_deployment_analysis(mobilenet_df, resnet_df, mlp_df, lstm_df):
    """Plot edge deployment feasibility based on resource constraints"""
    plt.figure(figsize=(12, 8))
    
    # Define edge device constraints (example thresholds)
    EDGE_CONSTRAINTS = {
        'max_macs': 10,  # 10M MACs
        'max_params': 2,   # 2M Parameters
    }
    
    models_data = [
        (mobilenet_df, 'MobileNetV2'),
        (resnet_df, 'ResNet-18'),
        (mlp_df, 'MLP'),
        (lstm_df, 'LSTM')
    ]
    
    for df, model_name in models_data:
        # Use MagnitudeL2 if available
        strategy = 'MagnitudeL2' if 'MagnitudeL2' in df['strategy'].unique() else df['strategy'].unique()[0]
        data = df[df['strategy'] == strategy]
        
        for _, row in data.iterrows():
            feasible = (row['macs_millions'] <= EDGE_CONSTRAINTS['max_macs'] and 
                       row['params_millions'] <= EDGE_CONSTRAINTS['max_params'])
            
            color = COLORS[model_name] if feasible else 'lightgray'
            alpha = 1.0 if feasible else 0.3
            marker = 'o' if feasible else 'x'
            markersize = 12 if feasible else 8
            
            plt.scatter(row['macs_millions'], row['params_millions'], 
                       s=150, color=color, alpha=alpha, marker=marker,
                       label=model_name if row['pruning_ratio'] == 0.0 else "")
    
    # Add constraint lines
    plt.axvline(x=EDGE_CONSTRAINTS['max_macs'], color='red', linestyle='--', 
                linewidth=3, alpha=0.7, label='MACs Constraint (10M)')
    plt.axhline(y=EDGE_CONSTRAINTS['max_params'], color='red', linestyle='--', 
                linewidth=3, alpha=0.7, label='Parameters Constraint (2M)')
    
    # Add feasible region
    plt.fill_between([0, EDGE_CONSTRAINTS['max_macs']], 0, EDGE_CONSTRAINTS['max_params'], 
                     alpha=0.1, color='green', label='Feasible Region')
    
    plt.xlabel('MACs (Millions)')
    plt.ylabel('Parameters (Millions)')
    plt.title('Edge Deployment Feasibility Analysis', fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.yscale('log')
    plt.xscale('log')
    
    save_plot('19_edge_deployment_analysis.png')

# 20. Optimal Deployment Points Analysis (NEW)
def plot_optimal_deployment_points(mobilenet_df, resnet_df, mlp_df, lstm_df):
    """Plot optimal deployment points for different scenarios"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 16))
    
    # Define different edge scenarios
    scenarios = {
        'Ultra-Low Power': {'max_macs': 1, 'max_params': 0.5},
        'Mobile Device': {'max_macs': 10, 'max_params': 2},
        'Edge Server': {'max_macs': 50, 'max_params': 10}
    }
    
    models_data = [
        (mobilenet_df, 'MobileNetV2', 'accuracy', ax1),
        (resnet_df, 'ResNet-18', 'accuracy', ax2),
        (mlp_df, 'MLP', 'mse', ax3),
        (lstm_df, 'LSTM', 'mse', ax4)
    ]
    
    for df, model_name, metric, ax in models_data:
        strategy = 'MagnitudeL2' if 'MagnitudeL2' in df['strategy'].unique() else df['strategy'].unique()[0]
        data = df[df['strategy'] == strategy].sort_values('pruning_ratio')
        
        # Plot the performance curve
        if metric == 'accuracy':
            ax.plot(data['pruning_ratio_percent'], data[metric], 'o-',
                   linewidth=3, markersize=10, color=COLORS[model_name], label=f'{model_name} Performance')
        else:
            ax.plot(data['pruning_ratio_percent'], data[metric], 'o-',
                   linewidth=3, markersize=10, color=COLORS[model_name], label=f'{model_name} MSE')
        
        # Find optimal points for each scenario
        for scenario_name, constraints in scenarios.items():
            feasible_points = data[
                (data['macs_millions'] <= constraints['max_macs']) & 
                (data['params_millions'] <= constraints['max_params'])
            ]
            
            if not feasible_points.empty:
                if metric == 'accuracy':
                    optimal_point = feasible_points.loc[feasible_points[metric].idxmax()]
                else:  # MSE
                    optimal_point = feasible_points.loc[feasible_points[metric].idxmin()]
                
                ax.scatter(optimal_point['pruning_ratio_percent'], optimal_point[metric], 
                          s=200, marker='*', edgecolors='black', linewidth=2,
                          label=f'{scenario_name} Optimal', alpha=0.8)
        
        ax.set_xlabel('Pruning Ratio (%)')
        ax.set_ylabel(metric.upper() if metric != 'accuracy' else 'Accuracy (%)')
        ax.set_title(f'{model_name}: Optimal Deployment Points', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('Optimal Deployment Points for Different Edge Scenarios', fontsize=24, fontweight='bold')
    save_plot('20_optimal_deployment_points.png')

def main():
    """Generate all plots for Chapter 5"""
    print("🚀 Generating comprehensive thesis plots for Chapter 5...")
    
    # Load data
    mobilenet_results, resnet_results, mlp_results, lstm_results = load_all_results()
    
    if any(x is None for x in [mobilenet_results, resnet_results, mlp_results, lstm_results]):
        print("❌ Error loading data files. Please check file names and paths.")
        return
    
    # Convert to DataFrames
    mobilenet_df = results_to_dataframe(mobilenet_results, 'MobileNetV2', 'classification')
    resnet_df = results_to_dataframe(resnet_results, 'ResNet-18', 'classification')
    mlp_df = results_to_dataframe(mlp_results, 'MLP', 'regression')
    lstm_df = results_to_dataframe(lstm_results, 'LSTM', 'regression')
    
    print(f"📊 Data loaded successfully:")
    print(f"   • MobileNetV2: {len(mobilenet_df)} data points")
    print(f"   • ResNet-18: {len(resnet_df)} data points")
    print(f"   • MLP: {len(mlp_df)} data points")
    print(f"   • LSTM: {len(lstm_df)} data points")
    
    # Generate all plots
    plot_functions = [
        (plot_baseline_comparison, "Baseline comparison"),
        (plot_mobilenetv2_accuracy_pruning_ratio, "MobileNetV2 accuracy vs pruning ratio"),
        (plot_resnet18_accuracy_pruning_ratio, "ResNet-18 accuracy vs pruning ratio"),
        (plot_mlp_performance_pruning_ratio, "MLP performance vs pruning ratio"),
        (plot_lstm_performance_pruning_ratio, "LSTM performance vs pruning ratio"),
        (plot_mobilenetv2_efficiency_frontier, "MobileNetV2 efficiency frontier"),
        (plot_resnet18_efficiency_frontier, "ResNet-18 efficiency frontier"),
        (plot_timeseries_efficiency_frontier, "Time-series efficiency frontier"),
        (plot_combined_efficiency_frontier, "Combined efficiency frontier"),
        (plot_cnn_comparison, "CNN models comparison"),
        (plot_timeseries_comparison, "Time-series models comparison"),
        (plot_model_params_reduction, "Model parameters reduction analysis"),
        (plot_macs_reduction, "MACs reduction analysis"),
        (plot_performance_retention, "Performance retention analysis"),
        (plot_strategy_effectiveness_heatmap, "Strategy effectiveness heatmap"),
        (plot_comprehensive_strategy_effectiveness, "Comprehensive strategy effectiveness"),
        (plot_lstm_special_analysis, "LSTM special analysis"),
        (plot_performance_degradation_rate, "Performance degradation rate analysis"),
        (plot_edge_deployment_analysis, "Edge deployment feasibility"),
        (plot_optimal_deployment_points, "Optimal deployment points")
    ]
    
    for i, (plot_func, description) in enumerate(plot_functions, 1):
        print(f"📈 Generating plot {i:2d}/20: {description}")
        try:
            if plot_func.__name__ in ['plot_baseline_comparison', 'plot_model_params_reduction', 
                                    'plot_macs_reduction', 'plot_performance_retention',
                                    'plot_edge_deployment_analysis', 'plot_combined_efficiency_frontier',
                                    'plot_comprehensive_strategy_effectiveness', 'plot_performance_degradation_rate',
                                    'plot_optimal_deployment_points']:
                plot_func(mobilenet_df, resnet_df, mlp_df, lstm_df)
            elif plot_func.__name__ == 'plot_cnn_comparison':
                plot_func(mobilenet_df, resnet_df)
            elif plot_func.__name__ in ['plot_timeseries_comparison', 'plot_timeseries_efficiency_frontier']:
                plot_func(mlp_df, lstm_df)
            elif plot_func.__name__ == 'plot_strategy_effectiveness_heatmap':
                plot_func(mobilenet_df, resnet_df)
            elif 'mobilenetv2' in plot_func.__name__:
                plot_func(mobilenet_df)
            elif 'resnet18' in plot_func.__name__:
                plot_func(resnet_df)
            elif 'mlp' in plot_func.__name__:
                plot_func(mlp_df)
            elif 'lstm' in plot_func.__name__:
                plot_func(lstm_df)
        except Exception as e:
            print(f"❌ Error generating {description}: {e}")
    
    print(f"\n🎉 All plots generated successfully!")
    print(f"📁 Plots saved in: thesis_plots/")
    print(f"📋 Total plots: 20")
    print(f"🔧 Plot specifications:")
    print(f"   • Resolution: 300 DPI")
    print(f"   • Format: PNG")
    print(f"   • Font sizes: Title(20), Labels(18), Ticks(16)")
    print(f"   • Consistent color scheme applied")
    print(f"   • Added value labels where appropriate")
    print(f"   • Removed size-related plots (parameters cover storage)")

if __name__ == "__main__":
    main()