In [None]:
"""
WGAN-GP Training and Evaluation Notebook

This notebook demonstrates the training and evaluation of WGAN-GP for synthetic data generation.
The implementation focuses on the WGAN-GP model, which showed superior performance over
VAE and WAE in preliminary testing.
"""
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Import local modules
from src.models.data_augmentation.GAN import train_and_generate
from src.utils.preprocessing import process
from src.utils.evaluation import SyntheticDataEvaluator
from src.utils.evaluation import (
    compare_statistics,
    compare_distributions,
    generate_tsne,
    recenter_data
)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configuration
CONFIG = {
    'data_path': "data/data_combined_controls.csv",
    'results_dir': Path("results"),
    'model_params': {
        'batch_size': 32,
        'epochs': 20,
        'learning_rate': 0.001,
        'n_splits': 3,  # Using 3-fold CV as recommended
        'device': torch.device("cuda" if torch.cuda.is_available() else "cpu")
    }
}

# Create results directory
CONFIG['results_dir'].mkdir(exist_ok=True)

print("Configuration:")
print(f"Data path: {CONFIG['data_path']}")
print(f"Device: {CONFIG['model_params']['device']}")
print(f"Epochs: {CONFIG['model_params']['epochs']}")
print(f"Batch size: {CONFIG['model_params']['batch_size']}")
print(f"Learning rate: {CONFIG['model_params']['learning_rate']}")
print(f"Number of folds: {CONFIG['model_params']['n_splits']}")

# Data Loading and Initial Processing
print("\nLoading and processing data...")
_, _, scaled_data, scaler, n_features = process(CONFIG['data_path'])

print("\nDataset Information:")
print(f"Original data shape: {scaled_data.shape}")
print(f"Number of features: {n_features}")
print("\nFeature types:")
print(scaled_data.dtypes)

# Generate Synthetic Data
print("\nTraining WGAN-GP and generating synthetic samples...")
synthetic_data, original_data = train_and_generate(
    filepath=CONFIG['data_path'],
    save_info=True,
    **CONFIG['model_params']
)

# Initialize evaluator and perform comprehensive evaluation
print("\nEvaluating synthetic data quality...")
evaluator = SyntheticDataEvaluator(output_dir=CONFIG['results_dir'])

# Evaluate with recentering
results = evaluator.evaluate_synthetic_data(
    original_data=original_data,
    synthetic_data=synthetic_data,
    recenter=True
)

# Access recentered data and status
recentered_synthetic = results['synthetic_data']
recentering_status = results['recentering_applied']

# 1. Statistical Comparison and Plotting
print("\nGenerating statistical comparisons...")
stats_comparison = results['statistical_comparison']
print("\nStatistical Comparison Summary:")
print(stats_comparison.describe())

# 2. Distribution Comparison (KS Test) and Plotting
ks_results = results['distribution_comparison']

plt.figure(figsize=(10, 6))
sns.histplot(data=ks_results, x='KS_Statistic', kde=True)
plt.title('Distribution of KS Statistics: Original vs Synthetic Data')
plt.xlabel('KS Statistic')
plt.ylabel('Frequency')
plt.grid(True)
plt.savefig(CONFIG['results_dir'] / "ks_statistics_distribution.png")
plt.show()

# 3. Feature-wise Visualization
print("\nGenerating feature-wise visualizations...")
for column in original_data.columns:
    if column not in ['data_type', 'fold', 'sample_id']:
        plt.figure(figsize=(10, 6))
        sns.kdeplot(data=original_data[column], label='Original', alpha=0.6)
        sns.kdeplot(data=recentered_synthetic[column], label='Synthetic (Recentered)', alpha=0.6)
        plt.title(f'Distribution Comparison: {column}')
        plt.xlabel('Value')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True)
        plt.savefig(CONFIG['results_dir'] / f"feature_distribution_{column}.png")
        plt.close()

# 4. t-SNE Visualization
print("\nGenerating t-SNE visualization...")
generate_tsne(original_data, recentered_synthetic)
plt.savefig(CONFIG['results_dir'] / "tsne_visualization.png")
plt.show()

# 5. Loss Plots (from training)
if 'training_metrics' in results:
    for fold, metrics in results['training_metrics'].items():
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.plot(metrics['g_loss'], label=f"Generator Loss - {fold}")
        plt.title("Generator Training Loss")
        plt.xlabel("Iteration")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid(True)

        plt.subplot(1, 2, 2)
        plt.plot(metrics['c_loss'], label=f"Critic Loss - {fold}")
        plt.title("Critic Training Loss")
        plt.xlabel("Iteration")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(CONFIG['results_dir'] / f"training_losses_{fold}.png")
        plt.show()

# 6. Summary Statistics
print("\nSummary Statistics:")
print("\nOriginal Data:")
print(original_data.describe())
print("\nSynthetic Data (Recentered):")
print(recentered_synthetic.describe())

# 7. Quality Metrics
print("\nQuality Metrics:")
print(f"Total features with KS statistic < 0.1: {(ks_results['KS_Statistic'] < 0.1).sum()}")
print(f"Percentage of well-matched features: {(ks_results['KS_Statistic'] < 0.1).mean()*100:.2f}%")

# Save comprehensive results
results_summary = {
    "n_original_samples": len(original_data),
    "n_synthetic_samples": len(synthetic_data),
    "n_features": n_features,
    "n_folds": CONFIG['model_params']['n_splits'],
    "recentering_applied": recentering_status,
    "mean_ks_statistic": ks_results['KS_Statistic'].mean(),
    "percent_good_features": (ks_results['KS_Statistic'] < 0.1).mean()*100,
    "training_summary": {
        "epochs": CONFIG['model_params']['epochs'],
        "batch_size": CONFIG['model_params']['batch_size'],
        "learning_rate": CONFIG['model_params']['learning_rate']
    }
}

# Save results summary
with open(CONFIG['results_dir'] / "results_summary.txt", "w", encoding='utf-8') as f:
    for key, value in results_summary.items():
        f.write(f"{key}: {value}\n")

print("\nAnalysis complete. Results saved in:", CONFIG['results_dir'])
```
