In [None]:
# notebooks/enhanced/Validation_Results.ipynb

"""
# WGAN-GP Validation Results Analysis
Comprehensive validation analysis of the WGAN-GP model with k-fold cross-validation.

Contents:
1. Cross-Validation Results Analysis
2. Generation Quality Assessment
3. Feature-wise Analysis
4. Statistical Tests
5. Visualization of Results
"""

# %% Setup and Imports
import sys
sys.path.append('../..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
import torch

from src.utils.enhanced.evaluation import (
    compute_statistics,
    evaluate_generation,
    compute_mmd
)

# %% Data Loading Functions
def load_validation_data():
    """
    Load validation results and generated samples.
    
    Returns:
        Tuple containing original data, generated samples, and validation metrics
    """
    original_data = pd.read_csv('../../data/original_controls.csv')
    generated_data = pd.read_csv('../../results/enhanced/generated_samples.csv')
    validation_metrics = pd.read_csv('../../results/enhanced/metrics/validation_metrics.csv')
    
    return original_data, generated_data, validation_metrics

# %% Validation Analysis Functions
def analyze_cross_validation_performance(validation_metrics):
    """
    Analyze cross-validation performance metrics.
    
    Args:
        validation_metrics: DataFrame containing validation metrics for each fold
    """
    # Create subplot grid
    fig = plt.figure(figsize=(15, 10))
    gs = fig.add_gridspec(2, 3)
    
    # Plot KS statistics distribution
    ax1 = fig.add_subplot(gs[0, 0])
    sns.boxplot(data=validation_metrics, y='ks_statistic')
    ax1.set_title('KS Statistics Across Folds')
    
    # Plot Wasserstein distances
    ax2 = fig.add_subplot(gs[0, 1])
    sns.boxplot(data=validation_metrics, y='wasserstein_distance')
    ax2.set_title('Wasserstein Distances')
    
    # Plot MMD scores
    ax3 = fig.add_subplot(gs[0, 2])
    sns.boxplot(data=validation_metrics, y='mmd_score')
    ax3.set_title('MMD Scores')
    
    # Plot correlation preservation
    ax4 = fig.add_subplot(gs[1, :])
    sns.lineplot(data=validation_metrics, x='fold', y='correlation_preservation')
    ax4.set_title('Feature Correlation Preservation')
    
    plt.tight_layout()
    plt.show()

def analyze_feature_stability(original_data, generated_samples_list):
    """
    Analyze feature stability across different folds.
    
    Args:
        original_data: Original control samples
        generated_samples_list: List of generated samples from each fold
    """
    feature_stats = []
    
    for i, samples in enumerate(generated_samples_list):
        stats_dict = {
            'fold': i,
            'mean_diff': np.mean(np.abs(samples.mean() - original_data.mean())),
            'std_diff': np.mean(np.abs(samples.std() - original_data.std())),
            'correlation': np.corrcoef(samples.mean(), original_data.mean())[0,1]
        }
        feature_stats.append(stats_dict)
    
    stats_df = pd.DataFrame(feature_stats)
    
    # Visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    sns.barplot(data=stats_df, x='fold', y='mean_diff', ax=axes[0])
    axes[0].set_title('Mean Difference by Fold')
    
    sns.barplot(data=stats_df, x='fold', y='std_diff', ax=axes[1])
    axes[1].set_title('STD Difference by Fold')
    
    sns.barplot(data=stats_df, x='fold', y='correlation', ax=axes[2])
    axes[2].set_title('Feature Correlation by Fold')
    
    plt.tight_layout()
    plt.show()

def visualize_sample_quality(original_data, generated_data):
    """
    Create comprehensive visualizations of sample quality.
    
    Args:
        original_data: Original control samples
        generated_data: Generated synthetic samples
    """
    fig = plt.figure(figsize=(20, 10))
    gs = fig.add_gridspec(2, 4)
    
    # t-SNE visualization
    ax1 = fig.add_subplot(gs[0, :2])
    tsne = TSNE(n_components=2, random_state=42)
    combined_data = pd.concat([original_data, generated_data])
    tsne_results = tsne.fit_transform(combined_data)
    
    ax1.scatter(tsne_results[:len(original_data), 0],
                tsne_results[:len(original_data), 1],
                label='Original', alpha=0.7)
    ax1.scatter(tsne_results[len(original_data):, 0],
                tsne_results[len(original_data):, 1],
                label='Generated', alpha=0.7)
    ax1.set_title('t-SNE Visualization')
    ax1.legend()
    
    # PCA visualization
    ax2 = fig.add_subplot(gs[0, 2:])
    pca = PCA(n_components=2)
    pca_results = pca.fit_transform(combined_data)
    
    ax2.scatter(pca_results[:len(original_data), 0],
                pca_results[:len(original_data), 1],
                label='Original', alpha=0.7)
    ax2.scatter(pca_results[len(original_data):, 0],
                pca_results[len(original_data):, 1],
                label='Generated', alpha=0.7)
    ax2.set_title('PCA Visualization')
    ax2.legend()
    
    # Feature distributions
    for i, feature in enumerate(np.random.choice(original_data.columns, 4)):
        ax = fig.add_subplot(gs[1, i])
        sns.kdeplot(data=original_data[feature], ax=ax, label='Original')
        sns.kdeplot(data=generated_data[feature], ax=ax, label='Generated')
        ax.set_title(f'Feature: {feature}')
    
    plt.tight_layout()
    plt.show()

# %% Statistical Analysis Functions
def perform_statistical_tests(original_data, generated_data):
    """
    Perform comprehensive statistical tests on the data.
    
    Args:
        original_data: Original control samples
        generated_data: Generated synthetic samples
    """
    results = {
        'feature': [],
        'ks_statistic': [],
        'ks_pvalue': [],
        'mean_diff': [],
        'std_diff': [],
        'correlation': []
    }
    
    for col in original_data.columns:
        # KS test
        ks_stat, p_val = stats.ks_2samp(original_data[col], generated_data[col])
        
        # Basic statistics
        mean_diff = abs(original_data[col].mean() - generated_data[col].mean())
        std_diff = abs(original_data[col].std() - generated_data[col].std())
        correlation = np.corrcoef(original_data[col], generated_data[col])[0,1]
        
        results['feature'].append(col)
        results['ks_statistic'].append(ks_stat)
        results['ks_pvalue'].append(p_val)
        results['mean_diff'].append(mean_diff)
        results['std_diff'].append(std_diff)
        results['correlation'].append(correlation)
    
    return pd.DataFrame(results)

# %% Run Analysis
if __name__ == "__main__":
    # Load data
    original_data, generated_data, validation_metrics = load_validation_data()
    
    # Analyze cross-validation performance
    print("Analyzing cross-validation performance...")
    analyze_cross_validation_performance(validation_metrics)
    
    # Analyze feature stability
    print("\nAnalyzing feature stability...")
    analyze_feature_stability(original_data, generated_data)
    
    # Visualize sample quality
    print("\nVisualizing sample quality...")
    visualize_sample_quality(original_data, generated_data)
    
    # Perform statistical tests
    print("\nPerforming statistical tests...")
    statistical_results = perform_statistical_tests(original_data, generated_data)
    print("\nStatistical Test Summary:")
    print(statistical_results.describe())