In [1]:
# Import necessary libraries
# from src.models.data_augmentation.VAE import *
# from src.models.data_augmentation.WAE import *
from src.models.data_augmentation.GAN import *
from src.utils.evaluation import *

import seaborn as sns
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt


In [2]:
from src.utils.evaluation import *

In [3]:
dataset_path = "data/data_combined_controls.csv"

## Modifying Explore.ipynb to Bypass AE and WAE

In [4]:
# At the start of the notebook
run_models = {'VAE': False, 'WAE': False, 'GAN': True}

In [5]:
dataset, tensor_data, scaled_data, scaler, original_dim = process(dataset_path)
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [6]:
dataset, tensor_data, scaled_data, scaler, original_dim = process(dataset_path)

# vae = train_vae(dataset, original_dim)
# augmented_df = generate_vae(vae, scaled_data.columns, scaler)

In [7]:
# compare_stats_df = compare_statistics(scaled_data, augmented_df)
# compare_distributions_df = compare_distributions(scaled_data, augmented_df)

In [8]:
# plt.figure(figsize=(6, 6))
# sns.histplot(compare_distributions_df['KS Statistic'], kde=True)
# plt.title('VAE - Distribution of KS Statistics for Original vs. Synthetic Data')
# plt.xlabel('KS Statistic')
# plt.ylabel('Frequency')
# plt.show()

In [9]:
# generate_tsne(scaled_data, augmented_df)

# WAE

In [10]:
# dataset, tensor_data, scaled_data, scaler, original_dim = process(dataset_path)

# wae = train_wae(dataset, original_dim)
# augmented_df = generate_wae(wae, scaled_data.columns, scaler)

In [11]:
# compare_stats_df = compare_statistics(scaled_data, augmented_df)
# compare_distributions_df = compare_distributions(scaled_data, augmented_df)

In [12]:
# plt.figure(figsize=(6, 6))
# sns.histplot(compare_distributions_df['KS Statistic'], kde=True)
# plt.title('Distribution of KS Statistics for Original vs. Synthetic Data')
# plt.xlabel('KS Statistic')
# plt.ylabel('Frequency')
# plt.show()

In [13]:
# generate_tsne(scaled_data, augmented_df)

# WGAN

In [14]:
## WGAN-GP Model

print("Starting WGAN-GP training with K-fold validation...\n")

# Load and process data
dataset_path = "data/data_combined_controls.csv"
dataset, tensor_data, scaled_data, scaler, original_dim = process(dataset_path)

# Print initial data information
print("Dataset information:")
print(f"Original data shape: {scaled_data.shape}")
print(f"Number of features: {original_dim}")
print("\nColumn types:")
print(scaled_data.dtypes)

# Set training parameters
epochs = 20  # Reduced for debugging
batch_size = 32
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_splits = 5

print(f"\nTraining parameters:")
print(f"Epochs: {epochs}")
print(f"Batch size: {batch_size}")
print(f"Learning rate: {learning_rate}")
print(f"Device: {device}")
print(f"Number of folds: {n_splits}")

try:
    # Generate synthetic samples using WGAN-GP with k-fold validation
    print("\nTraining WGAN-GP and generating synthetic samples...")
    generated_samples = train_and_generate(
        filepath=dataset_path,
        batch_size=batch_size,
        epochs=epochs,
        device=device,
        n_splits=n_splits,
        learning_rate=learning_rate
    )
    
    # Create DataFrame with generated samples
    print("\nProcessing generated samples...")
    generated_df = pd.DataFrame(generated_samples, columns=scaled_data.columns)
    
    # Inverse transform the generated samples to original scale
    df_unscaled = pd.DataFrame(
        scaler.inverse_transform(generated_df.drop(['fold', 'type'], axis=1)),
        columns=[col for col in scaled_data.columns if col not in ['fold', 'type']]
    )
    
    # Calculate KS statistics
    print("\nCalculating KS statistics...")
    ks_stats = compare_distributions(scaled_data, df_unscaled)
    
    # Plot KS statistics distribution
    plt.figure(figsize=(8, 6))
    sns.histplot(ks_stats['KS Statistic'], kde=True)
    plt.title('Distribution of KS Statistics: Original vs. Synthetic Data')
    plt.xlabel('KS Statistic')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()
    
    # Calculate and plot augmented data with variance
    print("\nApplying recentering...")
    augmented_data_with_variance = recenter_data(df_unscaled, scaled_data)
    
    # Calculate KS statistics after recentering
    ks_stats_with_added_variance = compare_distributions(scaled_data, augmented_data_with_variance)
    
    # Plot KS statistics after recentering
    plt.figure(figsize=(8, 6))
    sns.histplot(ks_stats_with_added_variance['KS Statistic'], kde=True)
    plt.title('Distribution of KS Statistics After Recentering')
    plt.xlabel('KS Statistic')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()
    
    # Generate t-SNE visualization
    print("\nGenerating t-SNE visualization...")
    generate_tsne(scaled_data, augmented_data_with_variance)
    
    # Print summary statistics
    print("\nSummary Statistics:")
    print("\nBefore recentering:")
    print(ks_stats['KS Statistic'].describe())
    print("\nAfter recentering:")
    print(ks_stats_with_added_variance['KS Statistic'].describe())
    
    # Print generation statistics
    print(f"\nTotal synthetic samples generated: {len(generated_df)}")
    print(f"Samples per fold: {len(generated_df) // n_splits}")
    
    # Compare mean and std between original and synthetic data
    print("\nFeature Statistics Comparison:")
    for column in scaled_data.columns:
        print(f"\n{column}:")
        print(f"Original - Mean: {scaled_data[column].mean():.4f}, Std: {scaled_data[column].std():.4f}")
        print(f"Synthetic - Mean: {augmented_data_with_variance[column].mean():.4f}, Std: {augmented_data_with_variance[column].std():.4f}")

except Exception as e:
    print(f"\nError occurred during execution:")
    print(f"Error type: {type(e).__name__}")
    print(f"Error message: {str(e)}")
    print("\nFull traceback:")
    import traceback
    traceback.print_exc()

Starting WGAN-GP training with K-fold validation...



Dataset information:
Original data shape: (23, 8063)
Number of features: 8063

Column types:
IGKV2.28     float64
IGKV3D.20    float64
IGKV1.12     float64
IGLC7        float64
IGKV2.29     float64
              ...   
ZSCAN32      float64
ZSWIM8       float64
ZW10         float64
ZWILCH       float64
ZWINT        float64
Length: 8063, dtype: object

Training parameters:
Epochs: 20
Batch size: 32
Learning rate: 0.001
Device: cpu
Number of folds: 5

Training WGAN-GP and generating synthetic samples...

Data Overview:
Original samples: 23
Features: 8063

Feature types:
IGKV2.28     float64
IGKV3D.20    float64
IGKV1.12     float64
IGLC7        float64
IGKV2.29     float64
              ...   
ZSCAN32      float64
ZSWIM8       float64
ZW10         float64
ZWILCH       float64
ZWINT        float64
Length: 8063, dtype: object

Using device: cpu

Processing fold 1/5
Training samples: 18
Validation samples: 5
Train Loader Len: 18


Training WGAN-GP:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch [0/20] Batch [0] G_loss: 27.6189 C_loss: -118.3356


Training WGAN-GP:   5%|▌         | 1/20 [02:14<42:29, 134.18s/it]


Epoch 0 Summary:
Average G_loss: 16.9546
Average C_loss: -312.6441
Epoch [1/20] Batch [0] G_loss: 5.0891 C_loss: -51.8142


Training WGAN-GP:  10%|█         | 2/20 [04:29<40:27, 134.88s/it]


Epoch 1 Summary:
Average G_loss: 34.3543
Average C_loss: -390.6550
Epoch [2/20] Batch [0] G_loss: 8.1996 C_loss: -87.8171


Training WGAN-GP:  15%|█▌        | 3/20 [06:46<38:27, 135.76s/it]


Epoch 2 Summary:
Average G_loss: 20.3857
Average C_loss: -444.1220
Epoch [3/20] Batch [0] G_loss: 1.5772 C_loss: 0.9374


Training WGAN-GP:  20%|██        | 4/20 [09:01<36:11, 135.71s/it]


Epoch 3 Summary:
Average G_loss: 27.6542
Average C_loss: -453.4451
Epoch [4/20] Batch [0] G_loss: 37.6819 C_loss: -26.7995


Training WGAN-GP:  25%|██▌       | 5/20 [11:18<34:00, 136.03s/it]


Epoch 4 Summary:
Average G_loss: 18.5887
Average C_loss: -395.6305
Epoch [5/20] Batch [0] G_loss: -15.0549 C_loss: -42.2535


Training WGAN-GP:  30%|███       | 6/20 [13:34<31:45, 136.14s/it]


Epoch 5 Summary:
Average G_loss: 11.2747
Average C_loss: -379.8270
Epoch [6/20] Batch [0] G_loss: -16.4635 C_loss: 43.2421


## CODE BELOW GENERATING ERROR

In [None]:
# dataset, tensor_data, scaled_data, scaler, original_dim = process(dataset_path)

# # Parameter designation for Debugging
# epochs = 20  # Small number for initial debugging
# batch_size = 32  # A reasonable starting point
# learning_rate = 0.001  # Typical for many applications

# generated_samples = train_and_generate(dataset_path, batch_size=batch_size, epochs=epochs, device=device)

In [None]:
# generated_df = pd.DataFrame(generated_samples, columns=scaled_data.columns)
# df_unscaled = pd.DataFrame(scaler.inverse_transform(generated_df), columns=generated_df.columns)
# ks_stats = compare_distributions(scaled_data, df_unscaled)

In [None]:
# plt.figure(figsize=(6, 6))
# sns.histplot(ks_stats['KS Statistic'], kde=True)
# plt.title('Distribution of KS Statistics for Original vs. Synthetic Data')
# plt.xlabel('KS Statistic')
# plt.ylabel('Frequency')
# plt.show()

In [None]:
# augmented_data_with_variance = recenter_data(df_unscaled, scaled_data)
# ks_stats_with_added_variance = compare_distributions(scaled_data, augmented_data_with_variance)

In [None]:
# plt.figure(figsize=(6, 6))
# sns.histplot(ks_stats_with_added_variance['KS Statistic'], kde=True)
# plt.title('Distribution of KS Statistics for Original vs. Synthetic Data')
# plt.xlabel('KS Statistic')
# plt.ylabel('Frequency')
# plt.show()

In [None]:
# generate_tsne(scaled_data, augmented_data_with_variance)