# Assignment 4: Generative Models with GAN/cGAN (CTGAN Fixes)

**Key Changes from Original:**
- PAC = 10 (packing for discriminator)
- HIDDEN_DIM = 256 (larger networks)
- BATCH_SIZE = 500 (divisible by PAC)
- N_CRITIC = 1 (not 5 like WGAN-GP)
- Same learning rate for G and D
- Residual blocks in Generator
- Log transform for long-tail features

## 1. Imports and Setup

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import arff
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import warnings
import os
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

Using device: cuda


## 2. Data Loading

In [3]:
# Load the Adult dataset from ARFF file
data, meta = arff.loadarff('adult.arff')
df = pd.DataFrame(data)

# Decode byte strings
for col in df.columns:
    if df[col].dtype == object:
        df[col] = df[col].str.decode('utf-8')

print(f"Dataset shape: {df.shape}")
df.head()

Dataset shape: (32561, 15)


Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,39.0,State-gov,77516.0,Bachelors,13.0,Never-married,Adm-clerical,Not-in-family,White,Male,2174.0,0.0,40.0,United-States,<=50K
1,50.0,Self-emp-not-inc,83311.0,Bachelors,13.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,0.0,0.0,13.0,United-States,<=50K
2,38.0,Private,215646.0,HS-grad,9.0,Divorced,Handlers-cleaners,Not-in-family,White,Male,0.0,0.0,40.0,United-States,<=50K
3,53.0,Private,234721.0,11th,7.0,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0.0,0.0,40.0,United-States,<=50K
4,28.0,Private,338409.0,Bachelors,13.0,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0.0,0.0,40.0,Cuba,<=50K


In [4]:
# Check for missing values
print("Missing values ('?') per column:")
for col in df.columns:
    if df[col].dtype == object:
        missing_count = (df[col] == '?').sum()
        if missing_count > 0:
            print(f"  {col}: {missing_count} ({missing_count/len(df)*100:.2f}%)")

Missing values ('?') per column:
  workclass: 1836 (5.64%)
  occupation: 1843 (5.66%)
  native-country: 583 (1.79%)


## 3. Feature Definitions

In [5]:
# Define feature types
# NOTE: education-num moved to categorical because it's discrete
CONTINUOUS_COLS = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
CATEGORICAL_COLS = ['workclass', 'education', 'marital-status', 'occupation',
                    'relationship', 'race', 'sex', 'native-country', 'education-num']
TARGET_COL = 'income'

# Features that need log transform (long-tail distributions)
LOG_TRANSFORM_COLS = ['capital-gain', 'capital-loss']

print(f"Continuous features ({len(CONTINUOUS_COLS)}): {CONTINUOUS_COLS}")
print(f"Categorical features ({len(CATEGORICAL_COLS)}): {CATEGORICAL_COLS}")
print(f"Log transform features: {LOG_TRANSFORM_COLS}")
print(f"Target: {TARGET_COL}")

Continuous features (5): ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
Categorical features (9): ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'education-num']
Log transform features: ['capital-gain', 'capital-loss']
Target: income


## 4. VGM Transformer with Log Transform

In [6]:
from sklearn.mixture import GaussianMixture

class VGMTransformer:
    """VGM transformation with log transform for long-tail and zero/peak inflation."""

    def __init__(self, n_modes=5, zero_inflated_cols=None, peak_inflated_cols=None, log_transform_cols=None):
        self.n_modes = n_modes
        self.zero_inflated_cols = zero_inflated_cols or []
        self.peak_inflated_cols = peak_inflated_cols or {}
        self.log_transform_cols = log_transform_cols or []
        self.gmms = {}
        self.fitted = False

    def _apply_log(self, values, col):
        """Apply log transform if needed."""
        if col in self.log_transform_cols:
            return np.log1p(np.abs(values))  # log(1 + |x|) to handle zeros
        return values

    def fit(self, df, continuous_cols):
        """Fit GMM to each continuous column."""
        for col in continuous_cols:
            values = df[col].values.copy()
            values = self._apply_log(values, col)

            if col in self.zero_inflated_cols:
                nonzero_mask = df[col].values != 0
                if nonzero_mask.sum() > self.n_modes:
                    fit_values = values[nonzero_mask].reshape(-1, 1)
                else:
                    fit_values = values.reshape(-1, 1)
            elif col in self.peak_inflated_cols:
                peak_val = self.peak_inflated_cols[col]
                nonpeak_mask = df[col].values != peak_val
                if nonpeak_mask.sum() > self.n_modes:
                    fit_values = values[nonpeak_mask].reshape(-1, 1)
                else:
                    fit_values = values.reshape(-1, 1)
            else:
                fit_values = values.reshape(-1, 1)

            gmm = GaussianMixture(n_components=self.n_modes, random_state=42, covariance_type='full')
            gmm.fit(fit_values)
            self.gmms[col] = gmm

        self.continuous_cols = continuous_cols
        self.fitted = True
        return self

    def transform(self, df, continuous_cols):
        """Transform continuous features."""
        if not self.fitted:
            raise ValueError("VGMTransformer must be fitted first")

        results = []
        for col in continuous_cols:
            gmm = self.gmms[col]
            original_values = df[col].values
            values = self._apply_log(original_values.copy(), col)

            if col in self.zero_inflated_cols:
                is_special = (original_values == 0).astype(np.float32).reshape(-1, 1)
                values_for_gmm = np.where(original_values == 0, 1e-6, values).reshape(-1, 1)
            elif col in self.peak_inflated_cols:
                peak_val = self.peak_inflated_cols[col]
                is_special = (original_values == peak_val).astype(np.float32).reshape(-1, 1)
                values_for_gmm = np.where(original_values == peak_val, values.mean(), values).reshape(-1, 1)
            else:
                is_special = None
                values_for_gmm = values.reshape(-1, 1)

            # Get mode probabilities
            mode_probs = gmm.predict_proba(values_for_gmm)
            modes = gmm.predict(values_for_gmm)

            # Normalize values per mode
            means = gmm.means_.flatten()
            stds = np.sqrt(gmm.covariances_.flatten())
            normalized = (values_for_gmm.flatten() - means[modes]) / (4 * stds[modes] + 1e-8)
            normalized = np.clip(normalized, -0.99, 0.99).reshape(-1, 1)

            # Create mode one-hot
            mode_onehot = np.zeros((len(values), self.n_modes))
            mode_onehot[np.arange(len(values)), modes] = 1

            if is_special is not None:
                results.append(np.hstack([is_special, normalized, mode_onehot]))
            else:
                results.append(np.hstack([normalized, mode_onehot]))

        return np.hstack(results) if results else np.array([]).reshape(len(df), 0)

## 5. Data Preprocessor

In [7]:
class DataPreprocessor:
    """Preprocessor with zero-inflation, peak-inflation, and log transform."""

    def __init__(self, continuous_cols, categorical_cols, target_col,
                 n_modes=5, rare_threshold=0.01):
        self.continuous_cols = continuous_cols
        self.categorical_cols = categorical_cols
        self.target_col = target_col
        self.n_modes = n_modes
        self.rare_threshold = rare_threshold

        self.zero_inflated_cols = ['capital-gain', 'capital-loss']
        self.peak_inflated_cols = {'hours-per-week': 40.0}
        self.log_transform_cols = ['capital-gain', 'capital-loss']

        self.vgm = VGMTransformer(
            n_modes=n_modes,
            zero_inflated_cols=self.zero_inflated_cols,
            peak_inflated_cols=self.peak_inflated_cols,
            log_transform_cols=self.log_transform_cols
        )
        self.label_encoder = LabelEncoder()
        self.category_mappings = {}
        self.category_dims = {}
        self.mode_values = {}
        self.rare_categories = {}

    def _group_rare_categories(self, df):
        df = df.copy()
        for col in self.categorical_cols:
            if col == 'education-num':
                df[col] = df[col].astype(str)
            freq = df[col].value_counts(normalize=True)
            rare_cats = freq[freq < self.rare_threshold].index.tolist()
            self.rare_categories[col] = rare_cats
            if rare_cats:
                df[col] = df[col].replace(rare_cats, 'Other')
        return df

    def _apply_rare_grouping(self, df):
        df = df.copy()
        for col in self.categorical_cols:
            if col == 'education-num':
                df[col] = df[col].astype(str)
            if col in self.rare_categories:
                df[col] = df[col].replace(self.rare_categories[col], 'Other')
        return df

    def fit(self, df):
        df = df.copy()
        if 'education-num' in self.categorical_cols:
            df['education-num'] = df['education-num'].astype(str)

        for col in self.categorical_cols:
            if df[col].dtype == object or col == 'education-num':
                valid_values = df[df[col] != '?'][col]
                if len(valid_values) > 0:
                    self.mode_values[col] = valid_values.mode()[0]

        for col in self.categorical_cols:
            if col in self.mode_values:
                df[col] = df[col].replace('?', self.mode_values[col])

        df = self._group_rare_categories(df)

        for col in self.categorical_cols:
            unique_vals = sorted(df[col].unique())
            self.category_mappings[col] = {v: i for i, v in enumerate(unique_vals)}
            self.category_dims[col] = len(unique_vals)

        self.vgm.fit(df, self.continuous_cols)
        self.label_encoder.fit(df[self.target_col])

        return self

    def transform(self, df):
        df = df.copy()
        if 'education-num' in self.categorical_cols:
            df['education-num'] = df['education-num'].astype(str)

        for col in self.categorical_cols:
            if col in self.mode_values:
                df[col] = df[col].replace('?', self.mode_values[col])

        df = self._apply_rare_grouping(df)

        continuous_transformed = self.vgm.transform(df, self.continuous_cols)

        categorical_arrays = []
        for col in self.categorical_cols:
            n_categories = self.category_dims[col]
            onehot = np.zeros((len(df), n_categories))
            for i, val in enumerate(df[col]):
                if val in self.category_mappings[col]:
                    onehot[i, self.category_mappings[col][val]] = 1
                else:
                    if 'Other' in self.category_mappings[col]:
                        onehot[i, self.category_mappings[col]['Other']] = 1
            categorical_arrays.append(onehot)

        categorical_transformed = np.hstack(categorical_arrays) if categorical_arrays else np.array([]).reshape(len(df), 0)

        X = np.hstack([continuous_transformed, categorical_transformed])
        y = self.label_encoder.transform(df[self.target_col])

        return X.astype(np.float32), y

    def get_output_dim(self):
        return self.get_continuous_dim() + sum(self.category_dims.values())

    def get_continuous_dim(self):
        dim = 0
        for col in self.continuous_cols:
            if col in self.zero_inflated_cols or col in self.peak_inflated_cols:
                dim += 2 + self.n_modes
            else:
                dim += 1 + self.n_modes
        return dim

    def get_categorical_dims(self):
        return [self.category_dims[col] for col in self.categorical_cols]

    def get_n_modes(self):
        return self.n_modes

    def get_feature_structure(self):
        structure = []
        for col in self.continuous_cols:
            if col in self.zero_inflated_cols:
                structure.append({'name': col, 'type': 'zero_inflated', 'dim': 2 + self.n_modes, 'special_value': 0})
            elif col in self.peak_inflated_cols:
                structure.append({'name': col, 'type': 'peak_inflated', 'dim': 2 + self.n_modes, 'special_value': self.peak_inflated_cols[col]})
            else:
                structure.append({'name': col, 'type': 'regular', 'dim': 1 + self.n_modes})
        return structure

    def get_raw_continuous_dim(self):
        return len(self.continuous_cols)

In [8]:
# Initialize preprocessor
preprocessor = DataPreprocessor(
    CONTINUOUS_COLS, 
    CATEGORICAL_COLS, 
    TARGET_COL,
    n_modes=5,
    rare_threshold=0.01
)
preprocessor.fit(df)

print(f"Total feature dimension: {preprocessor.get_output_dim()}")
print(f"VGM continuous dimension: {preprocessor.get_continuous_dim()}")
print(f"Categorical dimensions: {preprocessor.get_categorical_dims()}")

Total feature dimension: 105
VGM continuous dimension: 33
Categorical dimensions: [7, 15, 7, 13, 6, 4, 2, 3, 15]


## 6. Train-Test Split

In [9]:
def prepare_data(df, seed=42, test_size=0.2):
    """Prepare train and test data with stratified split."""
    train_df, test_df = train_test_split(
        df, test_size=test_size, random_state=seed, stratify=df[TARGET_COL]
    )
    
    preprocessor_local = DataPreprocessor(CONTINUOUS_COLS, CATEGORICAL_COLS, TARGET_COL)
    preprocessor_local.fit(train_df)
    
    X_train, y_train = preprocessor_local.transform(train_df)
    X_test, y_test = preprocessor_local.transform(test_df)
    
    print(f"Training set: {X_train.shape[0]} samples")
    print(f"Test set: {X_test.shape[0]} samples")
    print(f"Label distribution (train): {np.bincount(y_train) / len(y_train)}")
    
    return X_train, X_test, y_train, y_test, preprocessor_local

X_train, X_test, y_train, y_test, preprocessor = prepare_data(df, seed=42)

Training set: 26048 samples
Test set: 6513 samples
Label distribution (train): [0.75917537 0.24082463]


## 7. Model Definitions with CTGAN-style Architecture

**Key CTGAN features:**
- PAC (packing) = 10 samples for discriminator
- Residual blocks in Generator
- Larger hidden dimensions (256)

In [10]:
def gumbel_softmax(logits, temperature=0.2, hard=True):
    """Gumbel-Softmax with straight-through gradient."""
    gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
    y_soft = torch.softmax((logits + gumbels) / temperature, dim=-1)
    if hard:
        index = y_soft.max(dim=-1, keepdim=True)[1]
        y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0)
        return y_hard - y_soft.detach() + y_soft
    return y_soft


class ResidualBlock(nn.Module):
    """Residual block for Generator (CTGAN-style)."""
    def __init__(self, dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.bn1 = nn.BatchNorm1d(dim)
        self.fc2 = nn.Linear(dim, dim)
        self.bn2 = nn.BatchNorm1d(dim)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.bn2(self.fc2(x))
        return F.relu(x + residual)


class Generator(nn.Module):
    """CTGAN-style Generator with residual blocks and zero/peak inflation."""

    def __init__(self, latent_dim, feature_structure, categorical_dims,
                 hidden_dim=256, temperature=0.2, n_modes=5):
        super().__init__()
        self.feature_structure = feature_structure
        self.categorical_dims = categorical_dims
        self.temperature = temperature
        self.n_modes = n_modes

        self.continuous_output_dim = sum(f['dim'] for f in feature_structure)
        output_dim = self.continuous_output_dim + sum(categorical_dims)

        # CTGAN-style: fc -> residual blocks -> fc
        self.fc_input = nn.Linear(latent_dim, hidden_dim)
        self.bn_input = nn.BatchNorm1d(hidden_dim)
        
        self.res_blocks = nn.Sequential(
            ResidualBlock(hidden_dim),
            ResidualBlock(hidden_dim),
        )
        
        self.fc_output = nn.Linear(hidden_dim, output_dim)

    def forward(self, z, hard=True):
        x = F.relu(self.bn_input(self.fc_input(z)))
        x = self.res_blocks(x)
        x = self.fc_output(x)

        outputs = []
        pos = 0

        for feat in self.feature_structure:
            if feat['type'] in ['zero_inflated', 'peak_inflated']:
                is_special_logit = x[:, pos:pos+1]
                is_special_prob = torch.sigmoid(is_special_logit)
                pos += 1

                value = torch.tanh(x[:, pos:pos+1])
                pos += 1

                mode_logits = x[:, pos:pos+self.n_modes]
                mode_onehot = gumbel_softmax(mode_logits, self.temperature, hard)
                pos += self.n_modes

                if hard:
                    is_special_hard = (is_special_prob > 0.5).float()
                    is_special_out = is_special_hard - is_special_prob.detach() + is_special_prob
                else:
                    is_special_out = is_special_prob

                outputs.extend([is_special_out, value, mode_onehot])
            else:
                value = torch.tanh(x[:, pos:pos+1])
                pos += 1

                mode_logits = x[:, pos:pos+self.n_modes]
                mode_onehot = gumbel_softmax(mode_logits, self.temperature, hard)
                pos += self.n_modes

                outputs.extend([value, mode_onehot])

        for dim in self.categorical_dims:
            cat_logits = x[:, pos:pos+dim]
            cat_samples = gumbel_softmax(cat_logits, self.temperature, hard)
            outputs.append(cat_samples)
            pos += dim

        return torch.cat(outputs, dim=1)


class Discriminator(nn.Module):
    """CTGAN-style Discriminator with PAC (packing)."""

    def __init__(self, input_dim, hidden_dim=256, pac=10, dropout=0.5):
        super().__init__()
        self.pac = pac
        self.pacdim = input_dim * pac  # Concatenate pac samples

        self.main = nn.Sequential(
            nn.Linear(self.pacdim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # Reshape for PAC: (batch_size, input_dim) -> (batch_size/pac, input_dim*pac)
        batch_size = x.size(0)
        x = x.reshape(batch_size // self.pac, self.pacdim)  # Changed from view to reshape
        return self.main(x)


def gradient_penalty(discriminator, real_data, fake_data, device, pac=10):
    """Compute gradient penalty for WGAN-GP with PAC."""
    batch_size = real_data.size(0)
    alpha = torch.rand(batch_size // pac, 1, device=device)
    alpha = alpha.repeat(1, pac).reshape(-1, 1).expand_as(real_data)  # Changed view to reshape
    
    interpolated = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)
    d_out = discriminator(interpolated)

    grads = torch.autograd.grad(
        outputs=d_out, inputs=interpolated,
        grad_outputs=torch.ones_like(d_out),
        create_graph=True, retain_graph=True
    )[0]
    
    grads = grads.reshape(batch_size // pac, -1)  # Changed from view to reshape
    return ((grads.norm(2, dim=1) - 1) ** 2).mean()


def correlation_loss(real, fake, feature_structure, n_modes):
    """Correlation loss between real and fake continuous features."""
    real_values = []
    fake_values = []
    pos = 0

    for feat in feature_structure:
        if feat['type'] in ['zero_inflated', 'peak_inflated']:
            real_values.append(real[:, pos + 1:pos + 2])
            fake_values.append(fake[:, pos + 1:pos + 2])
            pos += 2 + n_modes
        else:
            real_values.append(real[:, pos:pos + 1])
            fake_values.append(fake[:, pos:pos + 1])
            pos += 1 + n_modes

    real_cont = torch.cat(real_values, dim=1)
    fake_cont = torch.cat(fake_values, dim=1)

    real_std = (real_cont - real_cont.mean(0)) / (real_cont.std(0) + 1e-8)
    fake_std = (fake_cont - fake_cont.mean(0)) / (fake_cont.std(0) + 1e-8)

    r_corr = torch.mm(real_std.t(), real_std) / real_std.size(0)
    f_corr = torch.mm(fake_std.t(), fake_std) / fake_std.size(0)

    return F.mse_loss(f_corr, r_corr)


def special_proportion_loss(real, fake, feature_structure, n_modes):
    """Loss to match proportion of special values."""
    loss = 0
    pos = 0
    count = 0

    for feat in feature_structure:
        if feat['type'] in ['zero_inflated', 'peak_inflated']:
            real_is_special = real[:, pos]
            fake_is_special = fake[:, pos]
            loss += (real_is_special.mean() - fake_is_special.mean()) ** 2
            count += 1
            pos += 2 + n_modes
        else:
            pos += 1 + n_modes

    return loss / max(count, 1)

## 8. Conditional GAN Models

In [11]:
class ConditionalGenerator(nn.Module):
    """CTGAN-style Conditional Generator."""

    def __init__(self, latent_dim, num_classes, feature_structure, categorical_dims,
                 hidden_dim=256, temperature=0.2, n_modes=5):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.feature_structure = feature_structure
        self.categorical_dims = categorical_dims
        self.temperature = temperature
        self.n_modes = n_modes

        input_dim = latent_dim + num_classes
        self.continuous_output_dim = sum(f['dim'] for f in feature_structure)
        output_dim = self.continuous_output_dim + sum(categorical_dims)

        self.fc_input = nn.Linear(input_dim, hidden_dim)
        self.bn_input = nn.BatchNorm1d(hidden_dim)
        
        self.res_blocks = nn.Sequential(
            ResidualBlock(hidden_dim),
            ResidualBlock(hidden_dim),
        )
        
        self.fc_output = nn.Linear(hidden_dim, output_dim)

    def forward(self, z, labels, hard=True):
        if labels.dim() == 1:
            labels_onehot = torch.zeros(labels.size(0), self.num_classes, device=z.device)
            labels_onehot.scatter_(1, labels.unsqueeze(1), 1)
        else:
            labels_onehot = labels

        x = torch.cat([z, labels_onehot], dim=1)
        x = F.relu(self.bn_input(self.fc_input(x)))
        x = self.res_blocks(x)
        x = self.fc_output(x)

        outputs = []
        pos = 0

        for feat in self.feature_structure:
            if feat['type'] in ['zero_inflated', 'peak_inflated']:
                is_special_logit = x[:, pos:pos+1]
                is_special_prob = torch.sigmoid(is_special_logit)
                pos += 1

                value = torch.tanh(x[:, pos:pos+1])
                pos += 1

                mode_logits = x[:, pos:pos+self.n_modes]
                mode_onehot = gumbel_softmax(mode_logits, self.temperature, hard)
                pos += self.n_modes

                if hard:
                    is_special_hard = (is_special_prob > 0.5).float()
                    is_special_out = is_special_hard - is_special_prob.detach() + is_special_prob
                else:
                    is_special_out = is_special_prob

                outputs.extend([is_special_out, value, mode_onehot])
            else:
                value = torch.tanh(x[:, pos:pos+1])
                pos += 1

                mode_logits = x[:, pos:pos+self.n_modes]
                mode_onehot = gumbel_softmax(mode_logits, self.temperature, hard)
                pos += self.n_modes

                outputs.extend([value, mode_onehot])

        for dim in self.categorical_dims:
            cat_logits = x[:, pos:pos+dim]
            cat_samples = gumbel_softmax(cat_logits, self.temperature, hard)
            outputs.append(cat_samples)
            pos += dim

        return torch.cat(outputs, dim=1)


class ConditionalDiscriminator(nn.Module):
    """CTGAN-style Conditional Discriminator with PAC."""

    def __init__(self, input_dim, num_classes, hidden_dim=256, pac=10, dropout=0.5):
        super().__init__()
        self.pac = pac
        self.num_classes = num_classes
        self.pacdim = (input_dim + num_classes) * pac

        self.main = nn.Sequential(
            nn.Linear(self.pacdim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, labels):
        if labels.dim() == 1:
            labels_onehot = torch.zeros(labels.size(0), self.num_classes, device=x.device)
            labels_onehot.scatter_(1, labels.unsqueeze(1), 1)
        else:
            labels_onehot = labels

        x = torch.cat([x, labels_onehot], dim=1)
        batch_size = x.size(0)
        x = x.reshape(batch_size // self.pac, self.pacdim)  # Changed from view to reshape
        return self.main(x)


def gradient_penalty_cgan(discriminator, real_data, fake_data, labels, device, pac=10):
    """Gradient penalty for cGAN with PAC."""
    batch_size = real_data.size(0)
    alpha = torch.rand(batch_size // pac, 1, device=device)
    alpha = alpha.repeat(1, pac).reshape(-1, 1).expand_as(real_data)  # Changed view to reshape
    
    interpolated = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)
    d_out = discriminator(interpolated, labels)

    grads = torch.autograd.grad(
        outputs=d_out, inputs=interpolated,
        grad_outputs=torch.ones_like(d_out),
        create_graph=True, retain_graph=True
    )[0]
    
    grads = grads.reshape(batch_size // pac, -1)  # Changed from view to reshape
    return ((grads.norm(2, dim=1) - 1) ** 2).mean()

## 9. Training Functions with CTGAN Hyperparameters

In [12]:
def train_gan(X_train, preprocessor, latent_dim=128, hidden_dim=256,
              batch_size=500, epochs=300, lr=0.0002,
              n_critic=1, lambda_gp=10, lambda_corr=0.1, lambda_special=0.5,
              temperature=0.2, dropout=0.5, pac=10, seed=42,
              save_dir='plots_v2/gan'):
    """
    Train WGAN-GP with CTGAN-style architecture.
    
    Key CTGAN differences:
    - PAC = 10 (packing)
    - n_critic = 1 (not 5)
    - Same lr for G and D
    - Larger hidden_dim = 256
    - batch_size = 500
    """
    set_seed(seed)
    os.makedirs(save_dir, exist_ok=True)

    feature_structure = preprocessor.get_feature_structure()
    n_modes = preprocessor.get_n_modes()
    categorical_dims = preprocessor.get_categorical_dims()
    data_dim = preprocessor.get_output_dim()

    generator = Generator(
        latent_dim=latent_dim,
        feature_structure=feature_structure,
        categorical_dims=categorical_dims,
        hidden_dim=hidden_dim,
        temperature=temperature,
        n_modes=n_modes
    ).to(device)

    discriminator = Discriminator(
        input_dim=data_dim,
        hidden_dim=hidden_dim,
        pac=pac,
        dropout=dropout
    ).to(device)

    optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))

    # Ensure batch_size is divisible by pac
    effective_batch_size = (batch_size // pac) * pac
    
    dataset = TensorDataset(torch.FloatTensor(X_train))
    dataloader = DataLoader(dataset, batch_size=effective_batch_size, shuffle=True, drop_last=True)

    g_losses = []
    d_losses = []

    for epoch in tqdm(range(epochs), desc="Training GAN (CTGAN-style)"):
        epoch_g_loss = 0
        epoch_d_loss = 0
        n_batches = 0

        for batch_idx, (real_data,) in enumerate(dataloader):
            batch_size_actual = real_data.size(0)
            real_data = real_data.to(device)

            # Train Discriminator
            for _ in range(n_critic):
                optimizer_d.zero_grad()
                
                z = torch.randn(batch_size_actual, latent_dim, device=device)
                fake_data = generator(z)

                d_real = discriminator(real_data)
                d_fake = discriminator(fake_data.detach())
                
                gp = gradient_penalty(discriminator, real_data, fake_data.detach(), device, pac)
                d_loss = d_fake.mean() - d_real.mean() + lambda_gp * gp
                
                d_loss.backward()
                optimizer_d.step()

            # Train Generator
            optimizer_g.zero_grad()
            
            z = torch.randn(batch_size_actual, latent_dim, device=device)
            fake_data = generator(z)
            d_fake = discriminator(fake_data)

            g_loss_wgan = -d_fake.mean()
            corr_loss = correlation_loss(real_data, fake_data, feature_structure, n_modes)
            special_loss = special_proportion_loss(real_data, fake_data, feature_structure, n_modes)

            g_loss = g_loss_wgan + lambda_corr * corr_loss + lambda_special * special_loss
            g_loss.backward()
            
            torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
            optimizer_g.step()

            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            n_batches += 1

        d_losses.append(epoch_d_loss / n_batches)
        g_losses.append(epoch_g_loss / n_batches)

        if (epoch + 1) % 50 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] D_loss: {d_losses[-1]:.4f} G_loss: {g_losses[-1]:.4f}")

    return generator, g_losses, d_losses


def train_cgan(X_train, y_train, preprocessor, latent_dim=128, hidden_dim=256,
               batch_size=500, epochs=300, lr=0.0002,
               n_critic=1, lambda_gp=10, lambda_corr=0.1, lambda_special=0.5,
               temperature=0.2, dropout=0.5, pac=10, seed=42,
               save_dir='plots_v2/cgan'):
    """
    Train Conditional WGAN-GP with CTGAN-style architecture.
    """
    set_seed(seed)
    os.makedirs(save_dir, exist_ok=True)

    feature_structure = preprocessor.get_feature_structure()
    n_modes = preprocessor.get_n_modes()
    categorical_dims = preprocessor.get_categorical_dims()
    data_dim = preprocessor.get_output_dim()
    num_classes = len(np.unique(y_train))

    generator = ConditionalGenerator(
        latent_dim=latent_dim,
        num_classes=num_classes,
        feature_structure=feature_structure,
        categorical_dims=categorical_dims,
        hidden_dim=hidden_dim,
        temperature=temperature,
        n_modes=n_modes
    ).to(device)

    discriminator = ConditionalDiscriminator(
        input_dim=data_dim,
        num_classes=num_classes,
        hidden_dim=hidden_dim,
        pac=pac,
        dropout=dropout
    ).to(device)

    optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))

    effective_batch_size = (batch_size // pac) * pac
    
    dataset = TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train))
    dataloader = DataLoader(dataset, batch_size=effective_batch_size, shuffle=True, drop_last=True)

    g_losses = []
    d_losses = []

    for epoch in tqdm(range(epochs), desc="Training cGAN (CTGAN-style)"):
        epoch_g_loss = 0
        epoch_d_loss = 0
        n_batches = 0

        for batch_idx, (real_data, labels) in enumerate(dataloader):
            batch_size_actual = real_data.size(0)
            real_data = real_data.to(device)
            labels = labels.to(device)

            # Train Discriminator
            for _ in range(n_critic):
                optimizer_d.zero_grad()
                
                z = torch.randn(batch_size_actual, latent_dim, device=device)
                fake_data = generator(z, labels)

                d_real = discriminator(real_data, labels)
                d_fake = discriminator(fake_data.detach(), labels)
                
                gp = gradient_penalty_cgan(discriminator, real_data, fake_data.detach(), labels, device, pac)
                d_loss = d_fake.mean() - d_real.mean() + lambda_gp * gp
                
                d_loss.backward()
                optimizer_d.step()

            # Train Generator
            optimizer_g.zero_grad()
            
            z = torch.randn(batch_size_actual, latent_dim, device=device)
            fake_data = generator(z, labels)
            d_fake = discriminator(fake_data, labels)

            g_loss_wgan = -d_fake.mean()
            corr_loss = correlation_loss(real_data, fake_data, feature_structure, n_modes)
            special_loss = special_proportion_loss(real_data, fake_data, feature_structure, n_modes)

            g_loss = g_loss_wgan + lambda_corr * corr_loss + lambda_special * special_loss
            g_loss.backward()
            
            torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
            optimizer_g.step()

            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            n_batches += 1

        d_losses.append(epoch_d_loss / n_batches)
        g_losses.append(epoch_g_loss / n_batches)

        if (epoch + 1) % 50 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] D_loss: {d_losses[-1]:.4f} G_loss: {g_losses[-1]:.4f}")

    return generator, g_losses, d_losses

## 10. Data Generation

In [13]:
def generate_synthetic_data(generator, n_samples, latent_dim, device):
    """Generate synthetic data using trained GAN."""
    generator.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, latent_dim, device=device)
        synthetic_data = generator(z, hard=True)
    return synthetic_data.cpu().numpy()


def generate_conditional_synthetic_data(generator, n_samples, latent_dim, label_ratios, device):
    """Generate synthetic data using trained cGAN with specified label ratios."""
    generator.eval()
    
    samples_per_class = (np.array(label_ratios) * n_samples).astype(int)
    samples_per_class[-1] = n_samples - samples_per_class[:-1].sum()
    
    all_data = []
    all_labels = []
    
    with torch.no_grad():
        for label, n in enumerate(samples_per_class):
            if n > 0:
                z = torch.randn(n, latent_dim, device=device)
                labels = torch.full((n,), label, dtype=torch.long, device=device)
                synthetic = generator(z, labels, hard=True)
                all_data.append(synthetic.cpu().numpy())
                all_labels.append(np.full(n, label))
    
    return np.vstack(all_data), np.concatenate(all_labels)

## 11. Evaluation Metrics

In [14]:
def compute_detection_metric(X_real, X_synthetic, n_folds=4, seed=42):
    """Compute detection metric using Random Forest."""
    y_real = np.zeros(len(X_real))
    y_synthetic = np.ones(len(X_synthetic))
    
    X_combined = np.vstack([X_real, X_synthetic])
    y_combined = np.concatenate([y_real, y_synthetic])
    
    kfold = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
    auc_scores = []
    
    for train_idx, test_idx in kfold.split(X_combined, y_combined):
        X_train_fold, X_test_fold = X_combined[train_idx], X_combined[test_idx]
        y_train_fold, y_test_fold = y_combined[train_idx], y_combined[test_idx]
        
        rf = RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=-1)
        rf.fit(X_train_fold, y_train_fold)
        
        y_pred_proba = rf.predict_proba(X_test_fold)[:, 1]
        auc = roc_auc_score(y_test_fold, y_pred_proba)
        auc_scores.append(auc)
    
    return np.mean(auc_scores), np.std(auc_scores)


def compute_efficacy_metric(X_train_real, y_train_real, X_synthetic, y_synthetic, 
                           X_test, y_test, seed=42):
    """Compute efficacy metric."""
    rf_real = RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=-1)
    rf_real.fit(X_train_real, y_train_real)
    y_pred_proba_real = rf_real.predict_proba(X_test)[:, 1]
    auc_real = roc_auc_score(y_test, y_pred_proba_real)
    
    rf_synthetic = RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=-1)
    rf_synthetic.fit(X_synthetic, y_synthetic)
    y_pred_proba_synthetic = rf_synthetic.predict_proba(X_test)[:, 1]
    auc_synthetic = roc_auc_score(y_test, y_pred_proba_synthetic)
    
    efficacy_ratio = auc_synthetic / auc_real
    
    return efficacy_ratio, auc_real, auc_synthetic

## 12. Plotting Functions

In [15]:
def plot_training_losses(g_losses, d_losses, title="GAN Training Losses", save_path=None):
    """Plot generator and discriminator losses."""
    plt.figure(figsize=(10, 5))
    plt.plot(g_losses, label='Generator Loss', alpha=0.8)
    plt.plot(d_losses, label='Discriminator Loss', alpha=0.8)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    plt.show()


def plot_feature_distributions(X_real, X_synthetic, preprocessor, n_features=5,
                               title="Feature Distributions", save_path=None):
    """Plot histograms comparing real vs synthetic."""
    feature_structure = preprocessor.get_feature_structure()
    n_modes = preprocessor.get_n_modes()
    n_plots = min(n_features, len(feature_structure))

    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    axes = axes.flatten()

    pos = 0
    for i, feat in enumerate(feature_structure[:n_plots]):
        ax = axes[i]

        if feat['type'] in ['zero_inflated', 'peak_inflated']:
            is_special_real = X_real[:, pos]
            is_special_synth = X_synthetic[:, pos]
            value_real = X_real[:, pos + 1]
            value_synth = X_synthetic[:, pos + 1]

            combined_real = np.where(is_special_real > 0.5, -1.5, value_real)
            combined_synth = np.where(is_special_synth > 0.5, -1.5, value_synth)

            ax.hist(combined_real, bins=40, alpha=0.5, label='Real', density=True, range=(-1.5, 1))
            ax.hist(combined_synth, bins=40, alpha=0.5, label='Synthetic', density=True, range=(-1.5, 1))
            ax.axvline(x=-1.25, color='red', linestyle='--', linewidth=0.5)

            special_prop_real = (is_special_real > 0.5).mean() * 100
            special_prop_synth = (is_special_synth > 0.5).mean() * 100
            special_name = "Zeros" if feat['type'] == 'zero_inflated' else f"Peak@{feat['special_value']}"
            ax.set_title(f'{feat["name"]}\n{special_name}: R={special_prop_real:.1f}%, S={special_prop_synth:.1f}%')

            pos += 2 + n_modes
        else:
            ax.hist(X_real[:, pos], bins=30, alpha=0.5, label='Real', density=True)
            ax.hist(X_synthetic[:, pos], bins=30, alpha=0.5, label='Synthetic', density=True)
            ax.set_title(f'{feat["name"]}')
            pos += 1 + n_modes

        ax.legend()
        ax.grid(True, alpha=0.3)

    if n_plots < 6:
        axes[5].axis('off')

    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    plt.show()


def plot_correlation_matrices(X_real, X_synthetic, preprocessor,
                             title="Correlation Comparison", save_path=None):
    """Plot correlation matrices."""
    feature_structure = preprocessor.get_feature_structure()
    n_modes = preprocessor.get_n_modes()

    real_values = []
    synth_values = []
    names = []
    pos = 0

    for feat in feature_structure:
        if feat['type'] in ['zero_inflated', 'peak_inflated']:
            real_values.append(X_real[:, pos + 1])
            synth_values.append(X_synthetic[:, pos + 1])
            pos += 2 + n_modes
        else:
            real_values.append(X_real[:, pos])
            synth_values.append(X_synthetic[:, pos])
            pos += 1 + n_modes
        names.append(feat['name'])

    real_cont = np.column_stack(real_values)
    synth_cont = np.column_stack(synth_values)

    corr_real = np.corrcoef(real_cont.T)
    corr_synth = np.corrcoef(synth_cont.T)

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    sns.heatmap(corr_real, ax=axes[0], cmap='coolwarm', center=0,
                xticklabels=names, yticklabels=names,
                annot=True, fmt='.2f', square=True)
    axes[0].set_title('Real Data Correlation')

    sns.heatmap(corr_synth, ax=axes[1], cmap='coolwarm', center=0,
                xticklabels=names, yticklabels=names,
                annot=True, fmt='.2f', square=True)
    axes[1].set_title('Synthetic Data Correlation')

    corr_diff = corr_real - corr_synth
    sns.heatmap(corr_diff, ax=axes[2], cmap='coolwarm', center=0,
                xticklabels=names, yticklabels=names,
                annot=True, fmt='.2f', square=True)
    axes[2].set_title('Difference (Real - Synthetic)')

    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
    plt.show()

## 13. Configuration (CTGAN Defaults)

In [16]:
# CTGAN-style Configuration
SEEDS = [42, 123, 456]
LATENT_DIM = 128
HIDDEN_DIM = 256       # CTGAN uses 256 (was 128)
BATCH_SIZE = 500       # CTGAN uses 500 (was 128)
EPOCHS = 300           # CTGAN uses 300 (was 500)
LR = 0.0002            # Same for G and D (CTGAN style)
N_CRITIC = 1           # CTGAN uses 1 (was 5)
LAMBDA_GP = 10
LAMBDA_CORR = 0.1
LAMBDA_SPECIAL = 0.5
TEMPERATURE = 0.2
DROPOUT = 0.5
PAC = 10               # CTGAN packing (NEW)
N_MODES = 5
RARE_THRESHOLD = 0.01

os.makedirs('plots_v2/gan', exist_ok=True)
os.makedirs('plots_v2/cgan', exist_ok=True)

print("="*60)
print("CTGAN-STYLE CONFIGURATION")
print("="*60)
print(f"Seeds: {SEEDS}")
print(f"Latent dimension: {LATENT_DIM}")
print(f"Hidden dimension: {HIDDEN_DIM} (CTGAN: 256)")
print(f"Batch size: {BATCH_SIZE} (CTGAN: 500)")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate (G & D): {LR}")
print(f"Critic iterations: {N_CRITIC} (CTGAN: 1)")
print(f"PAC (packing): {PAC} (CTGAN: 10)")
print(f"Gradient penalty lambda: {LAMBDA_GP}")
print(f"Correlation loss lambda: {LAMBDA_CORR}")
print(f"Special proportion loss lambda: {LAMBDA_SPECIAL}")
print(f"Gumbel-Softmax temperature: {TEMPERATURE}")
print(f"Discriminator dropout: {DROPOUT}")
print(f"VGM modes: {N_MODES}")
print("="*60)
print("\nKEY CTGAN CHANGES:")
print("- PAC = 10 (packing samples for discriminator)")
print("- N_CRITIC = 1 (not 5 like WGAN-GP)")
print("- Same learning rate for G and D")
print("- Larger network (256 hidden)")
print("- Residual blocks in Generator")
print("- Log transform for capital-gain/loss")
print("="*60)

CTGAN-STYLE CONFIGURATION
Seeds: [42, 123, 456]
Latent dimension: 128
Hidden dimension: 256 (CTGAN: 256)
Batch size: 500 (CTGAN: 500)
Epochs: 300
Learning rate (G & D): 0.0002
Critic iterations: 1 (CTGAN: 1)
PAC (packing): 10 (CTGAN: 10)
Gradient penalty lambda: 10
Correlation loss lambda: 0.1
Special proportion loss lambda: 0.5
Gumbel-Softmax temperature: 0.2
Discriminator dropout: 0.5
VGM modes: 5

KEY CTGAN CHANGES:
- PAC = 10 (packing samples for discriminator)
- N_CRITIC = 1 (not 5 like WGAN-GP)
- Same learning rate for G and D
- Larger network (256 hidden)
- Residual blocks in Generator
- Log transform for capital-gain/loss


## 14. Run Experiments

In [17]:
# Run GAN experiment with seed 42
seed = 42
print(f"\n{'='*60}")
print(f"Running GAN with seed {seed}")
print(f"{'='*60}")

X_train, X_test, y_train, y_test, preprocessor_seed = prepare_data(df, seed=seed)

generator, g_losses, d_losses = train_gan(
    X_train, preprocessor_seed,
    latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM,
    batch_size=BATCH_SIZE, epochs=EPOCHS, lr=LR,
    n_critic=N_CRITIC, lambda_gp=LAMBDA_GP,
    lambda_corr=LAMBDA_CORR, lambda_special=LAMBDA_SPECIAL,
    temperature=TEMPERATURE, dropout=DROPOUT, pac=PAC,
    seed=seed, save_dir=f'plots_v2/gan'
)

# Plot losses
plot_training_losses(g_losses, d_losses, 
                    title=f"GAN Training Losses (Seed {seed})",
                    save_path=f'plots_v2/gan/losses_seed{seed}.png')

# Generate synthetic data
X_synthetic = generate_synthetic_data(generator, len(X_train), LATENT_DIM, device)
y_synthetic = y_train.copy()  # For GAN, use same label distribution

# === DIAGNOSTIC: Check for issues ===
print("\n=== DIAGNOSTIC INFO ===")
print(f"X_train shape: {X_train.shape}, X_synthetic shape: {X_synthetic.shape}")
print(f"X_train range: [{X_train.min():.4f}, {X_train.max():.4f}]")
print(f"X_synthetic range: [{X_synthetic.min():.4f}, {X_synthetic.max():.4f}]")
print(f"X_train has NaN: {np.isnan(X_train).any()}, X_synthetic has NaN: {np.isnan(X_synthetic).any()}")

# Check categorical feature distributions (start at position 33 after continuous)
cont_dim = preprocessor_seed.get_continuous_dim()
cat_dims = preprocessor_seed.get_categorical_dims()
print(f"\nContinuous dim: {cont_dim}, Categorical dims: {cat_dims}")

# Check first categorical feature distribution
pos = cont_dim
for i, (col, dim) in enumerate(zip(CATEGORICAL_COLS[:3], cat_dims[:3])):
    real_cat = X_train[:, pos:pos+dim]
    synth_cat = X_synthetic[:, pos:pos+dim]
    
    # Check if one-hot (should sum to 1)
    real_sum = real_cat.sum(axis=1)
    synth_sum = synth_cat.sum(axis=1)
    
    print(f"\n{col} (dim={dim}):")
    print(f"  Real - sum range: [{real_sum.min():.4f}, {real_sum.max():.4f}], mean per cat: {real_cat.mean(axis=0)[:3]}")
    print(f"  Synth - sum range: [{synth_sum.min():.4f}, {synth_sum.max():.4f}], mean per cat: {synth_cat.mean(axis=0)[:3]}")
    pos += dim

# Check mode vectors in continuous features
print("\n=== MODE VECTOR CHECK ===")
feature_structure = preprocessor_seed.get_feature_structure()
n_modes = preprocessor_seed.get_n_modes()
pos = 0
for feat in feature_structure[:2]:
    if feat['type'] in ['zero_inflated', 'peak_inflated']:
        mode_start = pos + 2
        mode_end = pos + 2 + n_modes
        pos += 2 + n_modes
    else:
        mode_start = pos + 1
        mode_end = pos + 1 + n_modes
        pos += 1 + n_modes
    
    real_modes = X_train[:, mode_start:mode_end]
    synth_modes = X_synthetic[:, mode_start:mode_end]
    
    print(f"{feat['name']} modes:")
    print(f"  Real - sum range: [{real_modes.sum(axis=1).min():.4f}, {real_modes.sum(axis=1).max():.4f}]")
    print(f"  Synth - sum range: [{synth_modes.sum(axis=1).min():.4f}, {synth_modes.sum(axis=1).max():.4f}]")
    print(f"  Real mode distribution: {real_modes.mean(axis=0)}")
    print(f"  Synth mode distribution: {synth_modes.mean(axis=0)}")

print("=== END DIAGNOSTIC ===\n")

# Plot feature distributions
plot_feature_distributions(X_train, X_synthetic, preprocessor_seed,
                          title=f"GAN Feature Distributions (Seed {seed})",
                          save_path=f'plots_v2/gan/features_seed{seed}.png')

# Plot correlations
plot_correlation_matrices(X_train, X_synthetic, preprocessor_seed,
                         title=f"GAN Correlation Comparison (Seed {seed})",
                         save_path=f'plots_v2/gan/correlation_seed{seed}.png')

# Compute metrics
detection_auc, detection_std = compute_detection_metric(X_train, X_synthetic, seed=seed)
efficacy, auc_real, auc_synth = compute_efficacy_metric(
    X_train, y_train, X_synthetic, y_synthetic, X_test, y_test, seed=seed
)

print(f"\nGAN Results (Seed {seed}):")
print(f"  Detection AUC: {detection_auc:.4f} (+/- {detection_std:.4f})")
print(f"  Efficacy Ratio: {efficacy:.4f}")
print(f"  AUC (Real): {auc_real:.4f}, AUC (Synthetic): {auc_synth:.4f}")


Running GAN with seed 42
Training set: 26048 samples
Test set: 6513 samples
Label distribution (train): [0.75917537 0.24082463]


Training GAN (CTGAN-style):  17%|█▋        | 50/300 [01:42<08:26,  2.03s/it]

Epoch [50/300] D_loss: 0.0143 G_loss: 0.2034


Training GAN (CTGAN-style):  33%|███▎      | 100/300 [03:23<06:40,  2.00s/it]

Epoch [100/300] D_loss: -0.0357 G_loss: -0.0468


Training GAN (CTGAN-style):  50%|█████     | 150/300 [05:03<04:56,  1.98s/it]

Epoch [150/300] D_loss: -0.0367 G_loss: 0.3860


Training GAN (CTGAN-style):  67%|██████▋   | 200/300 [06:44<03:19,  1.99s/it]

Epoch [200/300] D_loss: -0.0029 G_loss: 0.0211


Training GAN (CTGAN-style):  83%|████████▎ | 250/300 [08:28<01:46,  2.13s/it]

Epoch [250/300] D_loss: 0.0269 G_loss: -0.0556


Training GAN (CTGAN-style):  84%|████████▍ | 253/300 [08:36<01:35,  2.04s/it]


KeyboardInterrupt: 

In [None]:
# Run cGAN experiment with seed 42
seed = 42
print(f"\n{'='*60}")
print(f"Running cGAN with seed {seed}")
print(f"{'='*60}")

X_train, X_test, y_train, y_test, preprocessor_seed = prepare_data(df, seed=seed)

cgan_generator, cgan_g_losses, cgan_d_losses = train_cgan(
    X_train, y_train, preprocessor_seed,
    latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM,
    batch_size=BATCH_SIZE, epochs=EPOCHS, lr=LR,
    n_critic=N_CRITIC, lambda_gp=LAMBDA_GP,
    lambda_corr=LAMBDA_CORR, lambda_special=LAMBDA_SPECIAL,
    temperature=TEMPERATURE, dropout=DROPOUT, pac=PAC,
    seed=seed, save_dir=f'plots_v2/cgan'
)

# Plot losses
plot_training_losses(cgan_g_losses, cgan_d_losses,
                    title=f"cGAN Training Losses (Seed {seed})",
                    save_path=f'plots_v2/cgan/losses_seed{seed}.png')

# Generate synthetic data with correct label ratios
label_ratios = np.bincount(y_train) / len(y_train)
X_synthetic_cgan, y_synthetic_cgan = generate_conditional_synthetic_data(
    cgan_generator, len(X_train), LATENT_DIM, label_ratios, device
)

# Plot feature distributions
plot_feature_distributions(X_train, X_synthetic_cgan, preprocessor_seed,
                          title=f"cGAN Feature Distributions (Seed {seed})",
                          save_path=f'plots_v2/cgan/features_seed{seed}.png')

# Plot correlations
plot_correlation_matrices(X_train, X_synthetic_cgan, preprocessor_seed,
                         title=f"cGAN Correlation Comparison (Seed {seed})",
                         save_path=f'plots_v2/cgan/correlation_seed{seed}.png')

# Compute metrics
cgan_detection_auc, cgan_detection_std = compute_detection_metric(X_train, X_synthetic_cgan, seed=seed)
cgan_efficacy, cgan_auc_real, cgan_auc_synth = compute_efficacy_metric(
    X_train, y_train, X_synthetic_cgan, y_synthetic_cgan, X_test, y_test, seed=seed
)

print(f"\ncGAN Results (Seed {seed}):")
print(f"  Detection AUC: {cgan_detection_auc:.4f} (+/- {cgan_detection_std:.4f})")
print(f"  Efficacy Ratio: {cgan_efficacy:.4f}")
print(f"  AUC (Real): {cgan_auc_real:.4f}, AUC (Synthetic): {cgan_auc_synth:.4f}")


Running cGAN with seed 42
Training set: 26048 samples
Test set: 6513 samples
Label distribution (train): [0.75917537 0.24082463]


Training cGAN (CTGAN-style):  17%|█▋        | 50/300 [01:47<08:45,  2.10s/it]

Epoch [50/300] D_loss: 0.0609 G_loss: 1.0297


Training cGAN (CTGAN-style):  26%|██▌       | 78/300 [02:48<08:03,  2.18s/it]

In [None]:
# Run remaining seeds (123, 456) for GAN
gan_results = []

for seed in [123, 456]:
    print(f"\n{'='*60}")
    print(f"Running GAN with seed {seed}")
    print(f"{'='*60}")
    
    X_train, X_test, y_train, y_test, preprocessor_seed = prepare_data(df, seed=seed)
    
    generator, g_losses, d_losses = train_gan(
        X_train, preprocessor_seed,
        latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM,
        batch_size=BATCH_SIZE, epochs=EPOCHS, lr=LR,
        n_critic=N_CRITIC, lambda_gp=LAMBDA_GP,
        lambda_corr=LAMBDA_CORR, lambda_special=LAMBDA_SPECIAL,
        temperature=TEMPERATURE, dropout=DROPOUT, pac=PAC,
        seed=seed, save_dir=f'plots_v2/gan'
    )
    
    plot_training_losses(g_losses, d_losses, 
                        title=f"GAN Training Losses (Seed {seed})",
                        save_path=f'plots_v2/gan/losses_seed{seed}.png')
    
    X_synthetic = generate_synthetic_data(generator, len(X_train), LATENT_DIM, device)
    y_synthetic = y_train.copy()
    
    plot_feature_distributions(X_train, X_synthetic, preprocessor_seed,
                              title=f"GAN Feature Distributions (Seed {seed})",
                              save_path=f'plots_v2/gan/features_seed{seed}.png')
    
    detection_auc, detection_std = compute_detection_metric(X_train, X_synthetic, seed=seed)
    efficacy, auc_real, auc_synth = compute_efficacy_metric(
        X_train, y_train, X_synthetic, y_synthetic, X_test, y_test, seed=seed
    )
    
    gan_results.append({
        'seed': seed,
        'detection_auc': detection_auc,
        'efficacy': efficacy
    })
    
    print(f"\nGAN Results (Seed {seed}):")
    print(f"  Detection AUC: {detection_auc:.4f}")
    print(f"  Efficacy Ratio: {efficacy:.4f}")

In [None]:
# Run remaining seeds (123, 456) for cGAN
cgan_results = []

for seed in [123, 456]:
    print(f"\n{'='*60}")
    print(f"Running cGAN with seed {seed}")
    print(f"{'='*60}")
    
    X_train, X_test, y_train, y_test, preprocessor_seed = prepare_data(df, seed=seed)
    
    cgan_generator, cgan_g_losses, cgan_d_losses = train_cgan(
        X_train, y_train, preprocessor_seed,
        latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM,
        batch_size=BATCH_SIZE, epochs=EPOCHS, lr=LR,
        n_critic=N_CRITIC, lambda_gp=LAMBDA_GP,
        lambda_corr=LAMBDA_CORR, lambda_special=LAMBDA_SPECIAL,
        temperature=TEMPERATURE, dropout=DROPOUT, pac=PAC,
        seed=seed, save_dir=f'plots_v2/cgan'
    )
    
    plot_training_losses(cgan_g_losses, cgan_d_losses,
                        title=f"cGAN Training Losses (Seed {seed})",
                        save_path=f'plots_v2/cgan/losses_seed{seed}.png')
    
    label_ratios = np.bincount(y_train) / len(y_train)
    X_synthetic_cgan, y_synthetic_cgan = generate_conditional_synthetic_data(
        cgan_generator, len(X_train), LATENT_DIM, label_ratios, device
    )
    
    plot_feature_distributions(X_train, X_synthetic_cgan, preprocessor_seed,
                              title=f"cGAN Feature Distributions (Seed {seed})",
                              save_path=f'plots_v2/cgan/features_seed{seed}.png')
    
    cgan_detection_auc, cgan_detection_std = compute_detection_metric(X_train, X_synthetic_cgan, seed=seed)
    cgan_efficacy, _, _ = compute_efficacy_metric(
        X_train, y_train, X_synthetic_cgan, y_synthetic_cgan, X_test, y_test, seed=seed
    )
    
    cgan_results.append({
        'seed': seed,
        'detection_auc': cgan_detection_auc,
        'efficacy': cgan_efficacy
    })
    
    print(f"\ncGAN Results (Seed {seed}):")
    print(f"  Detection AUC: {cgan_detection_auc:.4f}")
    print(f"  Efficacy Ratio: {cgan_efficacy:.4f}")

## 15. Final Results Summary

In [None]:
# Compile all results
print("\n" + "="*60)
print("FINAL RESULTS SUMMARY (CTGAN-Style)")
print("="*60)

# Note: seed 42 results need to be added from the first run
# This cell assumes you've run all experiments

print("\nConfiguration:")
print(f"  Hidden Dim: {HIDDEN_DIM}, Batch Size: {BATCH_SIZE}")
print(f"  PAC: {PAC}, N_CRITIC: {N_CRITIC}, LR: {LR}")
print(f"  Epochs: {EPOCHS}")

print("\n" + "-"*60)
print("Results per seed will be printed after running all experiments.")
print("Target: Detection AUC < 0.70, Efficacy > 0.85")
print("-"*60)