In [1]:
import os

# Define the folders to be mounted
folders_to_mount = ['src', 'data', 'notebooks', 'plots']

# Define the target directory in /content/
content_dir = '/content/'

for folder in folders_to_mount:
    source_path = os.path.join('..', folder) # Assuming folders are in the parent directory
    target_path = os.path.join(content_dir, folder)

    if not os.path.exists(target_path):
        try:
            os.symlink(source_path, target_path)
            print(f'Symlinked {source_path} to {target_path}')
        except Exception as e:
            print(f'Error creating symlink for {folder}: {e}')
    else:
        print(f'{target_path} already exists, skipping symlink creation.')

Symlinked ../src to /content/src
Symlinked ../data to /content/data
Symlinked ../notebooks to /content/notebooks
Symlinked ../plots to /content/plots


# Improved CTGAN Training: Enhanced Synthetic Data Generation

## Key Improvements

1. **Proper handling of mixed data types**
   - Continuous features: Mode-specific normalization
   - Binary features: Preserved as 0/1
   - One-hot encoded: Gumbel-softmax for differentiability

2. **Better architecture**
   - Deeper networks (512x512x512)
   - Separate output heads for different data types
   - Improved dropout and batch normalization

3. **Enhanced training**
   - More epochs (300 vs 100)
   - Early stopping based on Wasserstein distance
   - Better monitoring and validation

4. **Comprehensive validation**
   - Type-specific quality metrics
   - One-hot validity checking
   - Distribution alignment metrics

## Setup

In [2]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import pickle
import os
import warnings
warnings.filterwarnings('ignore')

# Import improved modules
from src.improved_ctgan import build_improved_ctgan, DataTransformer
from src.improved_training import train_improved_ctgan, validate_synthetic_quality, print_validation_report

# Set style
sns.set_style('whitegrid')
np.random.seed(42)
tf.random.set_seed(42)

print(f'TensorFlow: {tf.__version__}')
print('Improved CTGAN modules loaded successfully!')

ModuleNotFoundError: No module named 'src'

## 1. Load Data

In [None]:
# Load unscaled training data
train_data = pd.read_csv('../data/processed/train_data_unscaled.csv')

print(f'Training data shape: {train_data.shape}')
print(f'\nColumn names:')
print(list(train_data.columns))
print(f'\nFirst few rows:')
print(train_data.head())
print(f'\nData types:')
print(train_data.dtypes)

## 2. Build Improved CTGAN

In [None]:
# Convert to numpy array
train_array = train_data.values.astype(np.float32)
column_names = list(train_data.columns)

# Build improved CTGAN with data transformer
improved_ctgan, data_transformer = build_improved_ctgan(
    data=train_array,
    column_names=column_names,
    noise_dim=128,
    generator_lr=2e-4,
    discriminator_lr=2e-4
)

print('\n✓ Improved CTGAN initialized!')

## 3. Transform Data

In [None]:
# Transform data for training (already done during build)
transformed_data = data_transformer.transform(train_array)

print(f'Transformed data shape: {transformed_data.shape}')
print(f'Transformed data range: [{transformed_data.min():.4f}, {transformed_data.max():.4f}]')
print(f'\nTransformation complete!')

## 4. Train Improved CTGAN

In [None]:
# Train with improved training loop
print('Training Improved CTGAN...')
print('This will take longer but produce much better results!')
print('Expected training time: 5-10 minutes\n')

history = train_improved_ctgan(
    ctgan=improved_ctgan,
    real_data=transformed_data,
    epochs=300,
    batch_size=500,
    n_critic=5,
    verbose=True,
    early_stopping_patience=50,
    validation_interval=10
)

print('\n✓ Training completed!')
print(f'Best epoch: {history["best_epoch"]}')
print(f'Final G loss: {history["g_loss"][-1]:.4f}')
print(f'Final D loss: {history["d_loss"][-1]:.4f}')
print(f'Final W distance: {history["w_distance"][-1]:.4f}')

## 5. Visualize Training Progress

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Generator loss
axes[0, 0].plot(history['g_loss'], linewidth=2, color='blue', alpha=0.8)
axes[0, 0].axvline(history['best_epoch'], color='red', linestyle='--', label='Best Epoch')
axes[0, 0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
axes[0, 0].set_ylabel('Generator Loss', fontsize=12, fontweight='bold')
axes[0, 0].set_title('Generator Loss Over Time', fontsize=13, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Discriminator loss
axes[0, 1].plot(history['d_loss'], linewidth=2, color='red', alpha=0.8)
axes[0, 1].axvline(history['best_epoch'], color='red', linestyle='--', label='Best Epoch')
axes[0, 1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
axes[0, 1].set_ylabel('Discriminator Loss', fontsize=12, fontweight='bold')
axes[0, 1].set_title('Discriminator Loss Over Time', fontsize=13, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Wasserstein distance
axes[1, 0].plot(history['w_distance'], linewidth=2, color='green', alpha=0.8)
axes[1, 0].axvline(history['best_epoch'], color='red', linestyle='--', label='Best Epoch')
axes[1, 0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel('Wasserstein Distance', fontsize=12, fontweight='bold')
axes[1, 0].set_title('Wasserstein Distance Over Time', fontsize=13, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Gradient penalty
axes[1, 1].plot(history['gp'], linewidth=2, color='purple', alpha=0.8)
axes[1, 1].axvline(history['best_epoch'], color='red', linestyle='--', label='Best Epoch')
axes[1, 1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
axes[1, 1].set_ylabel('Gradient Penalty', fontsize=12, fontweight='bold')
axes[1, 1].set_title('Gradient Penalty Over Time', fontsize=13, fontweight='bold')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../plots/improved_ctgan_training.png', dpi=300, bbox_inches='tight')
plt.show()

print('Training curves saved!')

## 6. Generate Synthetic Data

In [None]:
# Generate 5x synthetic data
num_real_samples = len(train_data)
num_synthetic_samples = 5 * num_real_samples

print(f'Generating {num_synthetic_samples:,} synthetic samples (5x augmentation)...')

# Generate in transformed space
synthetic_transformed = improved_ctgan.generate_samples(num_synthetic_samples)

# Inverse transform to original scale
synthetic_data = data_transformer.inverse_transform(synthetic_transformed)

# Convert to DataFrame
synthetic_df = pd.DataFrame(synthetic_data, columns=column_names)

print(f'\n✓ Synthetic data generated!')
print(f'Shape: {synthetic_df.shape}')
print(f'\nFirst few rows:')
print(synthetic_df.head())

## 7. Comprehensive Quality Validation

In [None]:
# Validate synthetic data quality
print('Validating synthetic data quality...')

validation_results = validate_synthetic_quality(
    real_data=train_array,
    synthetic_data=synthetic_data,
    column_names=column_names,
    continuous_cols=data_transformer.continuous_columns,
    binary_cols=data_transformer.binary_columns,
    onehot_groups=data_transformer.onehot_groups
)

# Print detailed report
print_validation_report(validation_results)

## 8. Visualize Distribution Comparisons

In [None]:
# Plot distributions for continuous features
continuous_cols = data_transformer.continuous_columns
n_plots = min(6, len(continuous_cols))

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for i in range(n_plots):
    col_idx = continuous_cols[i]
    col_name = column_names[col_idx]

    ax = axes[i]

    # Plot histograms
    ax.hist(train_data.iloc[:, col_idx], bins=30, alpha=0.6, label='Real',
            color='blue', density=True, edgecolor='black')
    ax.hist(synthetic_df.iloc[:, col_idx], bins=30, alpha=0.6, label='Synthetic',
            color='red', density=True, edgecolor='black')

    # Get validation result
    result = next((r for r in validation_results['continuous'] if r['column'] == col_name), None)
    if result:
        status = '✓ PASS' if result['passed'] else '✗ FAIL'
        ax.set_title(f"{col_name}\n{status} (p={result['ks_p_value']:.4f})",
                    fontsize=11, fontweight='bold')
    else:
        ax.set_title(col_name, fontsize=11, fontweight='bold')

    ax.set_xlabel('Value', fontsize=10)
    ax.set_ylabel('Density', fontsize=10)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../plots/improved_ctgan_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

print('Distribution comparison saved!')

## 9. Validate One-Hot Encoded Features

In [None]:
# Check one-hot validity
print('Checking one-hot encoded features...')

for group_idx, group in enumerate(data_transformer.onehot_groups):
    group_cols = [column_names[i] for i in group]

    # Real data
    real_group = train_data[group_cols].values
    real_sums = np.sum(real_group, axis=1)
    real_valid = np.mean(np.isclose(real_sums, 1.0))

    # Synthetic data
    syn_group = synthetic_df[group_cols].values
    syn_sums = np.sum(syn_group, axis=1)
    syn_valid = np.mean(np.isclose(syn_sums, 1.0))

    print(f'\nGroup {group_idx}: {group_cols[0].split("_")[0]}')
    print(f'  Real validity: {real_valid:.2%}')
    print(f'  Synthetic validity: {syn_valid:.2%}')
    print(f'  Status: {"✓ PASS" if syn_valid > 0.95 else "✗ FAIL"}')

    # Distribution
    print(f'  Category distribution:')
    for col in group_cols:
        real_freq = train_data[col].mean()
        syn_freq = synthetic_df[col].mean()
        print(f'    {col}: Real={real_freq:.2%}, Synthetic={syn_freq:.2%}')

## 10. Save Models and Data

In [None]:
import pickle
import os

os.makedirs('../models/improved', exist_ok=True)

# Save using pickle (more reliable)
with open('../models/improved/improved_ctgan.pkl', 'wb') as f:
    pickle.dump(improved_ctgan, f)

with open('../models/improved/data_transformer.pkl', 'wb') as f:
    pickle.dump(data_transformer, f)

synthetic_df.to_csv('../models/improved/synthetic_data.csv', index=False)

with open('../models/improved/training_history.pkl', 'wb') as f:
    pickle.dump(history, f)

with open('../models/improved/validation_results.pkl', 'wb') as f:
    pickle.dump(validation_results, f)

print('✓ All models and data saved successfully!')

## 11. Summary & Comparison

In [None]:
# Load original CTGAN results for comparison
try:
    original_ks_results = pd.read_csv('../models/ks_test_results.csv')
    original_pass_rate = (original_ks_results['P_Value'] > 0.05).sum() / len(original_ks_results) * 100

    print('\n' + '='*80)
    print('COMPARISON: ORIGINAL CTGAN vs IMPROVED CTGAN')
    print('='*80)
    print(f'\nOriginal CTGAN:')
    print(f'  Overall pass rate: {original_pass_rate:.1f}%')
    print(f'  Training epochs: 100')
    print(f'  Architecture: Simple (256x256)')
    print(f'\nImproved CTGAN:')
    print(f'  Overall pass rate: {validation_results["summary"]["overall_pass_rate"]:.1f}%')
    print(f'  Training epochs: {len(history["g_loss"])}')
    print(f'  Architecture: Advanced (512x512x512)')
    print(f'\nImprovement:')
    improvement = validation_results['summary']['overall_pass_rate'] - original_pass_rate
    print(f'  Pass rate improvement: {improvement:+.1f}%')
    print(f'  Relative improvement: {improvement / original_pass_rate * 100:.1f}%')
    print('='*80)
except:
    print('\nCould not load original results for comparison')

# Summary
print('\n' + '='*80)
print('IMPROVED CTGAN TRAINING SUMMARY')
print('='*80)
print(f'\nDataset:')
print(f'  Real samples: {num_real_samples:,}')
print(f'  Synthetic samples: {num_synthetic_samples:,}')
print(f'  Augmentation factor: 5.0x')
print(f'\nQuality Metrics:')
print(f'  Continuous features: {validation_results["summary"]["continuous_pass_rate"]:.1f}%')
print(f'  Binary features: {validation_results["summary"]["binary_pass_rate"]:.1f}%')
print(f'  One-hot groups: {validation_results["summary"]["onehot_pass_rate"]:.1f}%')
print(f'  Overall quality: {validation_results["summary"]["overall_pass_rate"]:.1f}%')
print(f'\nTraining:')
print(f'  Epochs completed: {len(history["g_loss"])}')
print(f'  Best epoch: {history["best_epoch"]}')
print(f'  Final Wasserstein distance: {history["w_distance"][-1]:.4f}')
print('='*80)
print('\n✓ Improved CTGAN training complete!')
print('Next: Run augmented model evaluation with improved synthetic data')