In [None]:
"""
Explore.py: Comprehensive WGAN-GP Training and Evaluation for COVID-19 Data Augmentation

This script sets up the environment, processes data, trains a WGAN-GP model, evaluates synthetic data,
and generates informative visualizations to assess the quality of synthetic samples compared to the original dataset.

Key Functionalities:
- Set random seeds for reproducibility.
- Load, preprocess, and split data for training.
- Train WGAN-GP model to generate high-quality synthetic data.
- Evaluate synthetic data using statistical tests and visualization (KS tests, t-SNE).
- Save results and plots to designated directories.

Modules Used:
- config for global settings
- preprocessing for data handling
- GAN for model training
- evaluation for data quality assessment

Authors: CM Kiekhaefer
Version: 4.0
"""

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
from scipy.stats import ks_2samp
from datetime import datetime

# Local module imports
import config
from src.models.data_augmentation.GAN import Generator, Critic, train_wgan_gp
from src.utils.preprocessing import preprocess_data
from src.utils.evaluation import compare_distributions, recenter_data

# Ensure consistent behavior across runs
torch.manual_seed(config.RANDOM_STATE)
np.random.seed(config.RANDOM_STATE)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(config.RANDOM_STATE)

def main():
    start_time = datetime.now()
    print(f"Starting Explore.py at {start_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    
    # Data loading and preprocessing
    original_data = pd.read_csv(config.DATA_PATH)
    processed_data, scaler = preprocess_data(original_data)

    # Initialize models
    generator = Generator(input_dim=processed_data.shape[1], output_dim=processed_data.shape[1]).to(config.DEVICE)
    critic = Critic(input_dim=processed_data.shape[1]).to(config.DEVICE)

    # Train the model
    train_loader = DataLoader(processed_data, batch_size=config.BATCH_SIZE, shuffle=True)
    train_wgan_gp(train_loader, generator, critic, config.DEVICE, config.EPOCHS)

    # Generate synthetic data
    synthetic_data = generator(torch.randn((len(processed_data), processed_data.shape[1]), device=config.DEVICE)).detach().cpu().numpy()
    synthetic_data = recenter_data(synthetic_data, scaler)

    # Evaluate synthetic data
    results = compare_distributions(original_data.to_numpy(), synthetic_data)
    ks_statistic, p_value = ks_2samp(original_data.ravel(), synthetic_data.ravel())
    print(f"KS Statistic: {ks_statistic}, P-Value: {p_value}")

    # Visualize results using t-SNE
    data_combined = np.vstack((original_data, synthetic_data))
    tsne_results = TSNE(n_components=2, random_state=config.RANDOM_STATE).fit_transform(data_combined)

    plt.figure(figsize=(10, 6))
    plt.scatter(tsne_results[:len(original_data), 0], tsne_results[:len(original_data), 1], c='blue', label='Original')
    plt.scatter(tsne_results[len(original_data):, 0], tsne_results[len(original_data):, 1], c='red', label='Synthetic')
    plt.title('t-SNE visualization of Original and Synthetic Data')
    plt.legend()
    plt.show()

    # Save results
    pd.DataFrame(synthetic_data, columns=original_data.columns).to_csv(config.RESULT_DIR / 'synthetic_data.csv', index=False)
    print(f"Results saved in {config.RESULT_DIR}")

    # Completion
    print(f"Process completed in {datetime.now() - start_time}")

if __name__ == "__main__":
    main()