In [None]:
# notebooks/04_validation.ipynb
"""
Validation Analysis Notebook
Purpose: Validate generated samples and assess model performance
"""

import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from scipy import stats

# Add project root to path
project_root = Path.cwd().parent
sys.path.append(str(project_root))

from src.evaluation.metrics import EvaluationMetrics
from src.data.preprocessing import DataProcessor

# Load results
results_path = sorted(Path('results/models').glob('wgan_model_*.pt'))[-1]
results = torch.load(results_path)

# Load original data
processor = DataProcessor()
original_data = pd.read_csv('data/data_combined_controls.csv')

# Initialize evaluator
evaluator = EvaluationMetrics()

# 1. Statistical Validation
print("=== Statistical Validation ===")
stats_comparison = evaluator.compare_statistics(
    original_data,
    results['generated_samples']
)
print("\nStatistical Comparisons:")
print(stats_comparison)

# 2. Distribution Analysis
evaluator.visualize_distributions(
    original_data,
    results['generated_samples']
)

# 3. Feature-wise Analysis
print("\n=== Feature-wise Analysis ===")
for feature in original_data.columns[:5]:  # First 5 features
    print(f"\nFeature: {feature}")
    print("Original mean:", original_data[feature].mean())
    print("Generated mean:", results['generated_samples'][feature].mean())
    print("Original std:", original_data[feature].std())
    print("Generated std:", results['generated_samples'][feature].std())

# 4. Cross-validation Results
print("\n=== Cross-validation Results ===")
cv_results = results['cv_results']
for fold, metrics in cv_results.items():
    print(f"\nFold {fold}:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")

# 5. Quality Metrics
print("\n=== Quality Metrics ===")
mmd_score = evaluator.compute_mmd(
    original_data.values,
    results['generated_samples'].values
)
print(f"MMD Score: {mmd_score:.4f}")

# 6. Visualization
plt.figure(figsize=(12, 4))
plt.subplot(121)
evaluator.generate_tsne(original_data, results['generated_samples'])
plt.title('t-SNE Visualization')

plt.subplot(122)
sns.boxplot(data=[
    original_data.mean(),
    results['generated_samples'].mean()
])
plt.title('Feature Means Comparison')
plt.xticks([0, 1], ['Original', 'Generated'])

plt.tight_layout()
plt.show()

# Save validation results
validation_results = {
    'statistical_comparison': stats_comparison,
    'mmd_score': mmd_score,
    'cv_results': cv_results
}

with open('results/validation_results.pkl', 'wb') as f:
    pickle.dump(validation_results, f)