In [None]:
# notebooks/wgan_training.ipynb
"""
WGAN-GP Training with K-Fold Cross-Validation
"""

import sys
from pathlib import Path
import yaml
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import logging
from datetime import datetime

# Add project root to path
project_root = Path.cwd().parent
sys.path.append(str(project_root))

# Updated imports
from src.models.wgan.trainer import WGANGPTrainer
from src.data.preprocessing import DataProcessor
from src.evaluation.metrics import EvaluationMetrics
from src.evaluation.visualization import VisualizationTools

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Load configuration
config_path = 'config/wgan_config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Initialize components
processor = DataProcessor()
trainer = WGANGPTrainer(config)
evaluator = EvaluationMetrics()
visualizer = VisualizationTools()

# Load and preprocess data
data_path = 'data/data_combined_controls.csv'
processed_data = processor.load_and_process_data(data_path)
control_data = processed_data[2]  # Get DataFrame from tuple

# Train model
logger.info("Starting WGAN-GP training...")
results = trainer.train_with_kfold(control_data)

# Generate samples and evaluate
generated_samples = trainer.generate_samples(n_samples=77)  # Match case count

# Comprehensive evaluation
metrics = evaluator.evaluate_all(control_data.values, generated_samples)

# Visualizations
logger.info("Generating visualizations...")

# 1. KS Statistics
plt.figure(figsize=(10, 6))
sns.histplot(metrics['basic_metrics']['ks_stats']['ks_statistics'], kde=True)
plt.title('Distribution of KS Statistics')
plt.xlabel('KS Statistic')
plt.ylabel('Frequency')
plt.show()

# 2. Distribution Comparisons
visualizer.plot_distribution_comparison(
    pd.DataFrame(control_data), 
    pd.DataFrame(generated_samples)
)

# 3. Quality Metrics
visualizer.plot_quality_metrics(metrics['quality'])

# Save results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = Path(f'results/models/wgan_model_{timestamp}.pt')
torch.save({
    'model_state': trainer.state_dict(),
    'results': results,
    'config': config,
    'metrics': metrics
}, save_path)

logger.info(f"Model and results saved to {save_path}")

# Print summary statistics
print("\n=== Training Summary ===")
print(f"Original samples: {len(control_data)}")
print(f"Generated samples: {len(generated_samples)}")
print("\nMetrics Summary:")
for metric_name, value in metrics['statistical'].items():
    print(f"{metric_name}: {value:.4f}")