In [None]:
# notebooks/enhanced/Metrics_Analysis.ipynb

# %% [markdown]
"""
# WGAN-GP Metrics Analysis
Comprehensive analysis of training metrics and generation quality.

## Contents:
1. Training Metrics Analysis
2. Distribution Analysis
3. Feature-wise Analysis
4. Quality Metrics
5. Cross-validation Results
"""

# %% Setup
import sys
sys.path.append('../..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.metrics import pairwise_distances
import torch

from src.utils.enhanced.evaluation import compute_statistics
from src.validation.metrics_tracking import MetricsTracker

# %% Training Metrics Analysis
def analyze_training_metrics(metrics_tracker):
    """
    Analyze training metrics with detailed visualizations.
    
    Args:
        metrics_tracker: MetricsTracker instance containing training history
    """
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Loss trajectories
    axes[0,0].plot(metrics_tracker.metrics['generator_loss'], 
                   label='Generator')
    axes[0,0].plot(metrics_tracker.metrics['critic_loss'], 
                   label='Critic')
    axes[0,0].set_title('Loss Trajectories')
    axes[0,0].legend()
    
    # Loss distributions
    sns.histplot(metrics_tracker.metrics['generator_loss'], 
                 ax=axes[0,1], label='Generator', alpha=0.5)
    sns.histplot(metrics_tracker.metrics['critic_loss'], 
                 ax=axes[0,1], label='Critic', alpha=0.5)
    axes[0,1].set_title('Loss Distributions')
    
    # Gradient penalty
    axes[1,0].plot(metrics_tracker.metrics['gradient_penalty'])
    axes[1,0].set_title('Gradient Penalty')
    
    # Wasserstein distance
    axes[1,1].plot(metrics_tracker.metrics['wasserstein_distance'])
    axes[1,1].set_title('Wasserstein Distance')
    
    plt.tight_layout()
    plt.show()

# %% Distribution Analysis
def analyze_distributions(original_data, generated_data):
    """
    Analyze statistical distributions of original and generated data.
    
    Args:
        original_data: Original control samples
        generated_data: Generated synthetic samples
    """
    # Statistical tests
    ks_stats = []
    p_values = []
    
    for col in original_data.columns:
        stat, p = stats.ks_2samp(original_data[col], 
                                generated_data[col])
        ks_stats.append(stat)
        p_values.append(p)
    
    # Visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # KS statistics
    sns.histplot(ks_stats, ax=axes[0,0])
    axes[0,0].set_title('KS Statistics Distribution')
    
    # P-values
    sns.histplot(p_values, ax=axes[0,1])
    axes[0,1].set_title('P-values Distribution')
    
    # Feature means comparison
    sns.scatterplot(x=original_data.mean(), 
                    y=generated_data.mean(),
                    ax=axes[1,0])
    axes[1,0].set_title('Feature Means Comparison')
    
    # Feature STDs comparison
    sns.scatterplot(x=original_data.std(), 
                    y=generated_data.std(),
                    ax=axes[1,1])
    axes[1,1].set_title('Feature STDs Comparison')
    
    plt.tight_layout()
    plt.show()

# %% Cross-validation Analysis
def analyze_cv_results(cv_results):
    """
    Analyze cross-validation results.
    
    Args:
        cv_results: Dictionary containing CV metrics
    """
    # Plot metrics across folds
    metrics = ['generator_loss', 'critic_loss', 'ks_statistic']
    
    fig, axes = plt.subplots(len(metrics), 1, figsize=(12, 4*len(metrics)))
    
    for i, metric in enumerate(metrics):
        fold_values = [fold[metric] for fold in cv_results.values()]
        axes[i].boxplot(fold_values)
        axes[i].set_title(f'{metric} Across Folds')
    
    plt.tight_layout()
    plt.show()

# %% Run Analysis
# Load results
metrics_tracker = MetricsTracker.load('path_to_metrics')
original_data = pd.read_csv('path_to_original_data.csv')
generated_data = pd.read_csv('path_to_generated_data.csv')
cv_results = pd.read_csv('path_to_cv_results.csv')

# Perform analysis
analyze_training_metrics(metrics_tracker)
analyze_distributions(original_data, generated_data)
analyze_cv_results(cv_results)