# Part 5: Detailed Performance Analysis

## Noise Type & SNR Level Analysis

This notebook performs in-depth analysis by:
- Noise type (Gaussian, Salt & Pepper, Burst)
- SNR level (-30, -25, -20, -15, -10 dB)
- Model comparison across different conditions

## 0. Import and Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
from tqdm import tqdm
import tensorflow as tf
from tensorflow import keras

# Set matplotlib font to support English
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['axes.unicode_minus'] = False

# Set seaborn style
sns.set_style('whitegrid')
sns.set_palette('husl')

print("Environment ready!")

## 1. Load Data and Metadata

In [None]:
# Load test data
print("Loading test data...")
x_test_augmented = np.load('data/x_test_augmented.npy')
y_test_augmented = np.load('data/y_test_augmented.npy')
x_test_clean = np.load('data/x_test_clean.npy')
test_noise_info = pd.read_csv('data/test_noise_info.csv')

print(f"Test data shape: {x_test_augmented.shape}")
print(f"Test noise info shape: {test_noise_info.shape}")

# Prepare clean data for comparison (repeat 3 times for 3 noise types)
x_test_clean_repeated = np.repeat(x_test_clean, 3, axis=0)

# For BAM models (flattened)
x_test_flat = x_test_augmented.reshape(x_test_augmented.shape[0], -1)
x_test_clean_flat = x_test_clean_repeated.reshape(x_test_clean_repeated.shape[0], -1)

print("\nData loaded successfully!")

## 2. Load All Models

In [None]:
print("Loading trained models...")

# Model configurations
model_configs = [
    {
        'name': 'Sequential BAM',
        'type': 'sequential',
        'arch': 'BAM',
        'restore_path': 'weights/sequential_bam_denoise.keras',
        'cls_path': 'weights/sequential_bam_classification.keras',
        'input_type': 'flat'
    },
    {
        'name': 'Sequential CAE',
        'type': 'sequential',
        'arch': 'CAE',
        'restore_path': 'weights/sequential_cae_restore.keras',
        'cls_path': 'weights/sequential_cae_classification.keras',
        'input_type': 'image'
    },
    {
        'name': 'Sequential U-Net',
        'type': 'sequential',
        'arch': 'U-Net',
        'restore_path': 'weights/sequential_unet_restore.keras',
        'cls_path': 'weights/sequential_unet_classification.keras',
        'input_type': 'image'
    },
    {
        'name': 'MTL BAM',
        'type': 'mtl',
        'arch': 'BAM',
        'model_path': 'weights/mtl_bam.keras',
        'input_type': 'flat'
    },
    {
        'name': 'MTL CAE',
        'type': 'mtl',
        'arch': 'CAE',
        'model_path': 'weights/mtl_cae.keras',
        'input_type': 'image'
    },
    {
        'name': 'MTL U-Net',
        'type': 'mtl',
        'arch': 'U-Net',
        'model_path': 'weights/mtl_unet.keras',
        'input_type': 'image'
    }
]

# Load models
loaded_models = {}
for config in model_configs:
    try:
        if config['type'] == 'sequential':
            restore_model = keras.models.load_model(config['restore_path'])
            cls_model = keras.models.load_model(config['cls_path'])
            loaded_models[config['name']] = {
                'restore': restore_model,
                'classify': cls_model,
                'type': 'sequential',
                'arch': config['arch'],
                'input_type': config['input_type']
            }
            print(f"✓ {config['name']} loaded (Sequential)")
        else:  # MTL
            model = keras.models.load_model(config['model_path'])
            loaded_models[config['name']] = {
                'model': model,
                'type': 'mtl',
                'arch': config['arch'],
                'input_type': config['input_type']
            }
            print(f"✓ {config['name']} loaded (MTL)")
    except Exception as e:
        print(f"✗ {config['name']} failed to load: {e}")

print(f"\nTotal models loaded: {len(loaded_models)}")

## 3. Evaluation Functions

In [None]:
def calculate_psnr(img1, img2):
    """Calculate PSNR"""
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    max_pixel = 1.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr


def evaluate_by_condition(model_dict, x_test, x_clean, y_test, condition_mask, batch_size=128):
    """
    Evaluate model on specific condition (noise type or SNR level)
    
    Args:
        model_dict: Model dictionary
        x_test: Test noisy images
        x_clean: Clean reference images
        y_test: Test labels
        condition_mask: Boolean mask for specific condition
        batch_size: Batch size for prediction
    
    Returns:
        Dictionary with evaluation results
    """
    # Filter data by condition
    x_cond = x_test[condition_mask]
    x_clean_cond = x_clean[condition_mask]
    y_cond = y_test[condition_mask]
    
    if len(x_cond) == 0:
        return None
    
    results = {}
    
    if model_dict['type'] == 'sequential':
        # Stage 1: Restoration
        restored = model_dict['restore'].predict(x_cond, batch_size=batch_size, verbose=0)
        
        # Restoration metrics
        mse = np.mean((restored - x_clean_cond) ** 2)
        mae = np.mean(np.abs(restored - x_clean_cond))
        psnr = calculate_psnr(restored, x_clean_cond)
        
        results['restoration'] = {
            'mse': float(mse),
            'mae': float(mae),
            'psnr': float(psnr)
        }
        
        # Stage 2: Classification
        predictions = model_dict['classify'].predict(x_cond, batch_size=batch_size, verbose=0)
        pred_classes = np.argmax(predictions, axis=1)
        true_classes = np.argmax(y_cond, axis=1)
        accuracy = np.mean(pred_classes == true_classes)
        
        results['classification'] = {
            'accuracy': float(accuracy)
        }
        
    else:  # MTL
        # Get both outputs
        predictions = model_dict['model'].predict(x_cond, batch_size=batch_size, verbose=0)
        
        if isinstance(predictions, dict):
            restored = predictions['restoration_output']
            cls_output = predictions['classification_output']
        else:
            restored, cls_output = predictions
        
        # Restoration metrics
        mse = np.mean((restored - x_clean_cond) ** 2)
        mae = np.mean(np.abs(restored - x_clean_cond))
        psnr = calculate_psnr(restored, x_clean_cond)
        
        results['restoration'] = {
            'mse': float(mse),
            'mae': float(mae),
            'psnr': float(psnr)
        }
        
        # Classification metrics
        pred_classes = np.argmax(cls_output, axis=1)
        true_classes = np.argmax(y_cond, axis=1)
        accuracy = np.mean(pred_classes == true_classes)
        
        results['classification'] = {
            'accuracy': float(accuracy)
        }
    
    results['num_samples'] = len(x_cond)
    return results


print("Evaluation functions defined!")

## 4. Evaluate by Noise Type

In [None]:
print("="*80)
print("Evaluating by Noise Type")
print("="*80)

noise_types = ['gaussian', 'sp', 'burst']
noise_results = {}

for model_name, model_dict in loaded_models.items():
    print(f"\nEvaluating {model_name}...")
    noise_results[model_name] = {}
    
    # Select appropriate input
    if model_dict['input_type'] == 'flat':
        x_input = x_test_flat
        x_clean_input = x_test_clean_flat
    else:
        x_input = x_test_augmented
        x_clean_input = x_test_clean_repeated
    
    for noise_type in noise_types:
        # Create mask for this noise type
        mask = test_noise_info['noise_type'] == noise_type
        
        # Evaluate
        results = evaluate_by_condition(
            model_dict, x_input, x_clean_input, y_test_augmented, mask
        )
        
        if results:
            noise_results[model_name][noise_type] = results
            print(f"  {noise_type:10s}: PSNR={results['restoration']['psnr']:.2f}dB, "
                  f"Acc={results['classification']['accuracy']:.4f}")

print("\n✓ Noise type evaluation complete!")

# Save results
with open('results/noise_type_results.json', 'w') as f:
    json.dump(noise_results, f, indent=2)
print("✓ Results saved: results/noise_type_results.json")

## 5. Evaluate by SNR Level

In [None]:
print("="*80)
print("Evaluating by SNR Level")
print("="*80)

snr_levels = [-30, -25, -20, -15, -10]
snr_results = {}

for model_name, model_dict in loaded_models.items():
    print(f"\nEvaluating {model_name}...")
    snr_results[model_name] = {}
    
    # Select appropriate input
    if model_dict['input_type'] == 'flat':
        x_input = x_test_flat
        x_clean_input = x_test_clean_flat
    else:
        x_input = x_test_augmented
        x_clean_input = x_test_clean_repeated
    
    for snr in snr_levels:
        # Create mask for this SNR level
        mask = test_noise_info['snr_db'] == snr
        
        # Evaluate
        results = evaluate_by_condition(
            model_dict, x_input, x_clean_input, y_test_augmented, mask
        )
        
        if results:
            snr_results[model_name][str(snr)] = results
            print(f"  {snr:3d}dB: PSNR={results['restoration']['psnr']:.2f}dB, "
                  f"Acc={results['classification']['accuracy']:.4f}")

print("\n✓ SNR level evaluation complete!")

# Save results
with open('results/snr_level_results.json', 'w') as f:
    json.dump(snr_results, f, indent=2)
print("✓ Results saved: results/snr_level_results.json")

## 6. Create Summary Tables

In [None]:
# Noise Type Summary Table
print("\n" + "="*80)
print("NOISE TYPE PERFORMANCE SUMMARY")
print("="*80)

noise_summary_data = []
for model_name in loaded_models.keys():
    for noise_type in noise_types:
        if noise_type in noise_results[model_name]:
            result = noise_results[model_name][noise_type]
            noise_summary_data.append({
                'Model': model_name,
                'Noise Type': noise_type.upper(),
                'PSNR (dB)': result['restoration']['psnr'],
                'MSE': result['restoration']['mse'],
                'MAE': result['restoration']['mae'],
                'Accuracy': result['classification']['accuracy'],
                'Samples': result['num_samples']
            })

noise_summary_df = pd.DataFrame(noise_summary_data)
print(noise_summary_df.to_string(index=False))
noise_summary_df.to_csv('results/noise_type_summary.csv', index=False)
print("\n✓ Saved: results/noise_type_summary.csv")

# SNR Level Summary Table
print("\n" + "="*80)
print("SNR LEVEL PERFORMANCE SUMMARY")
print("="*80)

snr_summary_data = []
for model_name in loaded_models.keys():
    for snr in snr_levels:
        snr_key = str(snr)
        if snr_key in snr_results[model_name]:
            result = snr_results[model_name][snr_key]
            snr_summary_data.append({
                'Model': model_name,
                'SNR (dB)': snr,
                'PSNR (dB)': result['restoration']['psnr'],
                'MSE': result['restoration']['mse'],
                'MAE': result['restoration']['mae'],
                'Accuracy': result['classification']['accuracy'],
                'Samples': result['num_samples']
            })

snr_summary_df = pd.DataFrame(snr_summary_data)
print(snr_summary_df.to_string(index=False))
snr_summary_df.to_csv('results/snr_level_summary.csv', index=False)
print("\n✓ Saved: results/snr_level_summary.csv")

## 7. Visualization: Noise Type Comparison

In [None]:
# Create figure
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Performance by Noise Type', fontsize=16, fontweight='bold')

# Color scheme
seq_color = '#2E86AB'
mtl_color = '#A23B72'
colors_dict = {
    'Sequential BAM': seq_color,
    'Sequential CAE': seq_color,
    'Sequential U-Net': seq_color,
    'MTL BAM': mtl_color,
    'MTL CAE': mtl_color,
    'MTL U-Net': mtl_color
}

# Plot 1: PSNR by Noise Type
ax = axes[0, 0]
for model_name in loaded_models.keys():
    psnr_values = [noise_results[model_name][nt]['restoration']['psnr'] 
                   for nt in noise_types if nt in noise_results[model_name]]
    ax.plot(noise_types, psnr_values, marker='o', linewidth=2, 
            label=model_name, color=colors_dict[model_name], markersize=8)
ax.set_xlabel('Noise Type', fontsize=12)
ax.set_ylabel('PSNR (dB)', fontsize=12)
ax.set_title('Restoration PSNR', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Plot 2: MSE by Noise Type
ax = axes[0, 1]
for model_name in loaded_models.keys():
    mse_values = [noise_results[model_name][nt]['restoration']['mse'] 
                  for nt in noise_types if nt in noise_results[model_name]]
    ax.plot(noise_types, mse_values, marker='o', linewidth=2, 
            label=model_name, color=colors_dict[model_name], markersize=8)
ax.set_xlabel('Noise Type', fontsize=12)
ax.set_ylabel('MSE', fontsize=12)
ax.set_title('Restoration MSE', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Plot 3: Classification Accuracy by Noise Type
ax = axes[0, 2]
for model_name in loaded_models.keys():
    acc_values = [noise_results[model_name][nt]['classification']['accuracy'] * 100 
                  for nt in noise_types if nt in noise_results[model_name]]
    ax.plot(noise_types, acc_values, marker='o', linewidth=2, 
            label=model_name, color=colors_dict[model_name], markersize=8)
ax.set_xlabel('Noise Type', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Classification Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Plot 4: Bar chart - PSNR by Noise Type
ax = axes[1, 0]
x_pos = np.arange(len(noise_types))
width = 0.15
for i, model_name in enumerate(loaded_models.keys()):
    psnr_values = [noise_results[model_name][nt]['restoration']['psnr'] 
                   for nt in noise_types if nt in noise_results[model_name]]
    ax.bar(x_pos + i * width, psnr_values, width, 
           label=model_name, color=colors_dict[model_name], alpha=0.8)
ax.set_xlabel('Noise Type', fontsize=12)
ax.set_ylabel('PSNR (dB)', fontsize=12)
ax.set_title('PSNR Comparison (Bar)', fontsize=14, fontweight='bold')
ax.set_xticks(x_pos + width * 2.5)
ax.set_xticklabels(noise_types)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3, axis='y')

# Plot 5: Bar chart - Accuracy by Noise Type
ax = axes[1, 1]
for i, model_name in enumerate(loaded_models.keys()):
    acc_values = [noise_results[model_name][nt]['classification']['accuracy'] * 100 
                  for nt in noise_types if nt in noise_results[model_name]]
    ax.bar(x_pos + i * width, acc_values, width, 
           label=model_name, color=colors_dict[model_name], alpha=0.8)
ax.set_xlabel('Noise Type', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Accuracy Comparison (Bar)', fontsize=14, fontweight='bold')
ax.set_xticks(x_pos + width * 2.5)
ax.set_xticklabels(noise_types)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3, axis='y')

# Plot 6: Heatmap - PSNR by Model and Noise Type
ax = axes[1, 2]
heatmap_data = []
for model_name in loaded_models.keys():
    psnr_values = [noise_results[model_name][nt]['restoration']['psnr'] 
                   for nt in noise_types if nt in noise_results[model_name]]
    heatmap_data.append(psnr_values)
heatmap_df = pd.DataFrame(heatmap_data, 
                          index=list(loaded_models.keys()),
                          columns=noise_types)
sns.heatmap(heatmap_df, annot=True, fmt='.2f', cmap='YlOrRd', ax=ax, cbar_kws={'label': 'PSNR (dB)'})
ax.set_title('PSNR Heatmap', fontsize=14, fontweight='bold')
ax.set_xlabel('Noise Type', fontsize=12)
ax.set_ylabel('Model', fontsize=12)

plt.tight_layout()
plt.savefig('results/noise_type_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Visualization saved: results/noise_type_analysis.png")

## 8. Visualization: SNR Level Comparison

In [None]:
# Create figure
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Performance by SNR Level', fontsize=16, fontweight='bold')

# Plot 1: PSNR vs SNR
ax = axes[0, 0]
for model_name in loaded_models.keys():
    psnr_values = [snr_results[model_name][str(snr)]['restoration']['psnr'] 
                   for snr in snr_levels if str(snr) in snr_results[model_name]]
    ax.plot(snr_levels, psnr_values, marker='o', linewidth=2, 
            label=model_name, color=colors_dict[model_name], markersize=8)
ax.set_xlabel('Input SNR (dB)', fontsize=12)
ax.set_ylabel('Output PSNR (dB)', fontsize=12)
ax.set_title('Restoration PSNR vs SNR', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Plot 2: MSE vs SNR
ax = axes[0, 1]
for model_name in loaded_models.keys():
    mse_values = [snr_results[model_name][str(snr)]['restoration']['mse'] 
                  for snr in snr_levels if str(snr) in snr_results[model_name]]
    ax.plot(snr_levels, mse_values, marker='o', linewidth=2, 
            label=model_name, color=colors_dict[model_name], markersize=8)
ax.set_xlabel('Input SNR (dB)', fontsize=12)
ax.set_ylabel('MSE', fontsize=12)
ax.set_title('Restoration MSE vs SNR', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Plot 3: Classification Accuracy vs SNR
ax = axes[0, 2]
for model_name in loaded_models.keys():
    acc_values = [snr_results[model_name][str(snr)]['classification']['accuracy'] * 100 
                  for snr in snr_levels if str(snr) in snr_results[model_name]]
    ax.plot(snr_levels, acc_values, marker='o', linewidth=2, 
            label=model_name, color=colors_dict[model_name], markersize=8)
ax.set_xlabel('Input SNR (dB)', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Classification Accuracy vs SNR', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Plot 4: PSNR Improvement over SNR
ax = axes[1, 0]
for model_name in loaded_models.keys():
    psnr_values = [snr_results[model_name][str(snr)]['restoration']['psnr'] 
                   for snr in snr_levels if str(snr) in snr_results[model_name]]
    # Calculate improvement (Output PSNR - Input SNR)
    improvements = [psnr - snr for psnr, snr in zip(psnr_values, snr_levels)]
    ax.plot(snr_levels, improvements, marker='o', linewidth=2, 
            label=model_name, color=colors_dict[model_name], markersize=8)
ax.set_xlabel('Input SNR (dB)', fontsize=12)
ax.set_ylabel('PSNR Improvement (dB)', fontsize=12)
ax.set_title('PSNR Improvement', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Plot 5: Grouped Bar Chart
ax = axes[1, 1]
x_pos = np.arange(len(snr_levels))
width = 0.15
for i, model_name in enumerate(loaded_models.keys()):
    acc_values = [snr_results[model_name][str(snr)]['classification']['accuracy'] * 100 
                  for snr in snr_levels if str(snr) in snr_results[model_name]]
    ax.bar(x_pos + i * width, acc_values, width, 
           label=model_name, color=colors_dict[model_name], alpha=0.8)
ax.set_xlabel('Input SNR (dB)', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Accuracy by SNR (Bar)', fontsize=14, fontweight='bold')
ax.set_xticks(x_pos + width * 2.5)
ax.set_xticklabels(snr_levels)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3, axis='y')

# Plot 6: Heatmap - PSNR by Model and SNR
ax = axes[1, 2]
heatmap_data = []
for model_name in loaded_models.keys():
    psnr_values = [snr_results[model_name][str(snr)]['restoration']['psnr'] 
                   for snr in snr_levels if str(snr) in snr_results[model_name]]
    heatmap_data.append(psnr_values)
heatmap_df = pd.DataFrame(heatmap_data, 
                          index=list(loaded_models.keys()),
                          columns=[f'{snr}dB' for snr in snr_levels])
sns.heatmap(heatmap_df, annot=True, fmt='.2f', cmap='RdYlGn', ax=ax, cbar_kws={'label': 'PSNR (dB)'})
ax.set_title('PSNR Heatmap', fontsize=14, fontweight='bold')
ax.set_xlabel('Input SNR', fontsize=12)
ax.set_ylabel('Model', fontsize=12)

plt.tight_layout()
plt.savefig('results/snr_level_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Visualization saved: results/snr_level_analysis.png")

## 9. Statistical Analysis

In [None]:
print("\n" + "="*80)
print("STATISTICAL ANALYSIS")
print("="*80)

# Analyze Sequential vs MTL
print("\n[Sequential vs MTL Comparison]")
print("-" * 40)

seq_models = [k for k in loaded_models.keys() if 'Sequential' in k]
mtl_models = [k for k in loaded_models.keys() if 'MTL' in k]

# By Noise Type
print("\nBy Noise Type:")
for noise_type in noise_types:
    seq_psnr = [noise_results[m][noise_type]['restoration']['psnr'] 
                for m in seq_models if noise_type in noise_results[m]]
    mtl_psnr = [noise_results[m][noise_type]['restoration']['psnr'] 
                for m in mtl_models if noise_type in noise_results[m]]
    
    seq_acc = [noise_results[m][noise_type]['classification']['accuracy'] 
               for m in seq_models if noise_type in noise_results[m]]
    mtl_acc = [noise_results[m][noise_type]['classification']['accuracy'] 
               for m in mtl_models if noise_type in noise_results[m]]
    
    print(f"\n  {noise_type.upper()}:")
    print(f"    Sequential - PSNR: {np.mean(seq_psnr):.2f}±{np.std(seq_psnr):.2f}dB, "
          f"Acc: {np.mean(seq_acc):.4f}±{np.std(seq_acc):.4f}")
    print(f"    MTL        - PSNR: {np.mean(mtl_psnr):.2f}±{np.std(mtl_psnr):.2f}dB, "
          f"Acc: {np.mean(mtl_acc):.4f}±{np.std(mtl_acc):.4f}")
    print(f"    Difference - PSNR: {np.mean(seq_psnr) - np.mean(mtl_psnr):+.2f}dB, "
          f"Acc: {np.mean(seq_acc) - np.mean(mtl_acc):+.4f}")

# By SNR Level
print("\n\nBy SNR Level:")
for snr in snr_levels:
    snr_key = str(snr)
    seq_psnr = [snr_results[m][snr_key]['restoration']['psnr'] 
                for m in seq_models if snr_key in snr_results[m]]
    mtl_psnr = [snr_results[m][snr_key]['restoration']['psnr'] 
                for m in mtl_models if snr_key in snr_results[m]]
    
    seq_acc = [snr_results[m][snr_key]['classification']['accuracy'] 
               for m in seq_models if snr_key in snr_results[m]]
    mtl_acc = [snr_results[m][snr_key]['classification']['accuracy'] 
               for m in mtl_models if snr_key in snr_results[m]]
    
    print(f"\n  {snr}dB:")
    print(f"    Sequential - PSNR: {np.mean(seq_psnr):.2f}±{np.std(seq_psnr):.2f}dB, "
          f"Acc: {np.mean(seq_acc):.4f}±{np.std(seq_acc):.4f}")
    print(f"    MTL        - PSNR: {np.mean(mtl_psnr):.2f}±{np.std(mtl_psnr):.2f}dB, "
          f"Acc: {np.mean(mtl_acc):.4f}±{np.std(mtl_acc):.4f}")
    print(f"    Difference - PSNR: {np.mean(seq_psnr) - np.mean(mtl_psnr):+.2f}dB, "
          f"Acc: {np.mean(seq_acc) - np.mean(mtl_acc):+.4f}")

print("\n" + "="*80)

## 10. Key Findings Summary

In [None]:
print("\n" + "="*80)
print("KEY FINDINGS")
print("="*80)

findings = []

# Finding 1: Best model overall
all_psnr = [(m, np.mean([noise_results[m][nt]['restoration']['psnr'] 
                          for nt in noise_types if nt in noise_results[m]]))
            for m in loaded_models.keys()]
best_psnr_model = max(all_psnr, key=lambda x: x[1])
findings.append(f"1. Best Restoration Model: {best_psnr_model[0]} (PSNR: {best_psnr_model[1]:.2f}dB)")

all_acc = [(m, np.mean([noise_results[m][nt]['classification']['accuracy'] 
                        for nt in noise_types if nt in noise_results[m]]))
           for m in loaded_models.keys()]
best_acc_model = max(all_acc, key=lambda x: x[1])
findings.append(f"2. Best Classification Model: {best_acc_model[0]} (Accuracy: {best_acc_model[1]:.4f})")

# Finding 2: Hardest noise type
noise_difficulty = {}
for noise_type in noise_types:
    avg_psnr = np.mean([noise_results[m][noise_type]['restoration']['psnr'] 
                        for m in loaded_models.keys() if noise_type in noise_results[m]])
    noise_difficulty[noise_type] = avg_psnr
hardest_noise = min(noise_difficulty.items(), key=lambda x: x[1])
findings.append(f"3. Most Challenging Noise: {hardest_noise[0].upper()} (Avg PSNR: {hardest_noise[1]:.2f}dB)")

# Finding 3: SNR sensitivity
snr_improvements = {}
for model_name in loaded_models.keys():
    psnr_at_30 = snr_results[model_name]['-30']['restoration']['psnr']
    psnr_at_10 = snr_results[model_name]['-10']['restoration']['psnr']
    improvement = psnr_at_10 - psnr_at_30
    snr_improvements[model_name] = improvement
most_sensitive = max(snr_improvements.items(), key=lambda x: x[1])
findings.append(f"4. Most SNR-Sensitive Model: {most_sensitive[0]} (+{most_sensitive[1]:.2f}dB improvement from -30dB to -10dB)")

# Finding 4: Sequential vs MTL
seq_avg_psnr = np.mean([all_psnr[i][1] for i, m in enumerate(loaded_models.keys()) if 'Sequential' in m])
mtl_avg_psnr = np.mean([all_psnr[i][1] for i, m in enumerate(loaded_models.keys()) if 'MTL' in m])
seq_avg_acc = np.mean([all_acc[i][1] for i, m in enumerate(loaded_models.keys()) if 'Sequential' in m])
mtl_avg_acc = np.mean([all_acc[i][1] for i, m in enumerate(loaded_models.keys()) if 'MTL' in m])

findings.append(f"5. Sequential Models: Avg PSNR={seq_avg_psnr:.2f}dB, Avg Acc={seq_avg_acc:.4f}")
findings.append(f"   MTL Models: Avg PSNR={mtl_avg_psnr:.2f}dB, Avg Acc={mtl_avg_acc:.4f}")
findings.append(f"   Trade-off: Sequential +{seq_avg_psnr - mtl_avg_psnr:.2f}dB PSNR, MTL +{mtl_avg_acc - seq_avg_acc:.4f} Accuracy")

# Print findings
for finding in findings:
    print(finding)

# Save findings
with open('results/key_findings.txt', 'w') as f:
    f.write("KEY FINDINGS\n")
    f.write("="*80 + "\n\n")
    for finding in findings:
        f.write(finding + "\n")

print("\n✓ Key findings saved: results/key_findings.txt")
print("\n" + "="*80)
print("DETAILED ANALYSIS COMPLETE!")
print("="*80)