In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import wandb

In [None]:
class Generator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims, z_dim):
        """
        Generator network for conditional GAN
        Args:
            x_dim: Dimension of output data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
            z_dim: Dimension of latent noise vector
        """
        super(Generator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size)) 
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is latent dim + embedding dim + numeric covariates
        input_dim = z_dim + embedding_dim + nb_numeric
        
        # Build generator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.BatchNorm1d(h_dim),
                nn.ReLU()
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, x_dim))
        
        self.network = nn.Sequential(*layers)

    def forward(self, z, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        gen_input = torch.cat([z, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(gen_input)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims):
        """
        Discriminator network for conditional GAN
        Args:
            x_dim: Dimension of input data
            vocab_sizes: List of vocabulary sizes for each categorical variable
            nb_numeric: Number of numeric covariates
            h_dims: List of hidden dimensions
        """
        super(Discriminator, self).__init__()
        
        # Embedding layers for categorical variables
        self.embeddings = nn.ModuleList([
            nn.Embedding(vocab_size, min(50, vocab_size))
            for vocab_size in vocab_sizes
        ])
        
        # Calculate total embedding dimension
        embedding_dim = sum(min(50, vocab_size) for vocab_size in vocab_sizes)
        
        # Input dimension is data dim + embedding dim + numeric covariates
        input_dim = x_dim + embedding_dim + nb_numeric
        
        # Build discriminator network
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for h_dim in h_dims:
            layers.extend([
                nn.Linear(current_dim, h_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3)
            ])
            current_dim = h_dim
        
        # Output layer
        layers.append(nn.Linear(current_dim, 1))
        
        self.network = nn.Sequential(*layers)

    def forward(self, x, cat_covs, num_covs):
        # Process categorical covariates through embeddings
        embeddings = [emb(cat_covs[:, i]) for i, emb in enumerate(self.embeddings)]
        embedded = torch.cat(embeddings, dim=1)
        
        # Concatenate all inputs
        disc_input = torch.cat([x, embedded, num_covs], dim=1)
        
        # Generate output
        return self.network(disc_input)


In [None]:
def train_gan(generator, discriminator, dataloader, cat_covs, num_covs, 
              config, device, score_fn=None, save_fn=None):
    """
    Train the conditional GAN
    """
    # Optimizers
    g_optimizer = optim.RMSprop(generator.parameters(), lr=config['lr'])
    d_optimizer = optim.RMSprop(discriminator.parameters(), lr=config['lr'])
    
    # Convert covariates to tensors
    cat_covs = torch.tensor(cat_covs, dtype=torch.long).to(device)
    num_covs = torch.tensor(num_covs, dtype=torch.float32).to(device)
    
    for epoch in range(config['epochs']):
        d_losses = []
        g_losses = []
        
        for batch_idx, (real_data,) in enumerate(dataloader):
            batch_size = real_data.size(0)
            
            # Get random batch of categorical and numerical covariates
            batch_indices = torch.randint(0, cat_covs.size(0), (batch_size,))
            batch_cat_covs = cat_covs[batch_indices]
            batch_num_covs = num_covs[batch_indices]
            
            # Train Discriminator
            for _ in range(config['nb_critic']):
                d_optimizer.zero_grad()
                
                # Generate fake data
                z = torch.randn(batch_size, config['latent_dim']).to(device)
                fake_data = generator(z, batch_cat_covs, batch_num_covs)
                
                # Calculate discriminator loss
                real_validity = discriminator(real_data, batch_cat_covs, batch_num_covs)
                fake_validity = discriminator(fake_data.detach(), batch_cat_covs, batch_num_covs)
                
                d_loss = -(torch.mean(real_validity) - torch.mean(fake_validity))
                d_loss.backward()
                d_optimizer.step()
                
                # Clip discriminator weights (Wasserstein GAN)
                for p in discriminator.parameters():
                    p.data.clamp_(-0.01, 0.01)
                    
                d_losses.append(d_loss.item())
            
            # Train Generator
            g_optimizer.zero_grad()
            
            # Generate fake data
            z = torch.randn(batch_size, config['latent_dim']).to(device)
            fake_data = generator(z, batch_cat_covs, batch_num_covs)
            
            # Calculate generator loss
            fake_validity = discriminator(fake_data, batch_cat_covs, batch_num_covs)
            g_loss = -torch.mean(fake_validity)
            
            g_loss.backward()
            g_optimizer.step()
            
            g_losses.append(g_loss.item())
        
        # Log metrics
        if wandb.run is not None:
            wandb.log({
                'epoch': epoch,
                'd_loss': np.mean(d_losses),
                'g_loss': np.mean(g_losses)
            })
        
        # Evaluate and save model if needed
        if score_fn is not None and epoch % 10 == 0:
            score = score_fn(generator)
            print(f'Epoch {epoch}: Score = {score:.4f}')
            
        if save_fn is not None and epoch % 100 == 0:
            save_fn(generator, discriminator)

In [None]:
def main():
    # Configuration
    CONFIG = {
        'epochs': 2000,
        'latent_dim': 64,
        'batch_size': 32,
        'nb_layers': 2,
        'hdim': 256,
        'lr': 5e-4,
        'nb_critic': 5
    }
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load and preprocess data (similar to original code)
    # ... (data loading code remains the same until standardization)
    
    # Convert data to PyTorch tensors
    x_train = torch.tensor(x_train, dtype=torch.float32)
    x_test = torch.tensor(x_test, dtype=torch.float32)
    
    # Create data loader
    train_dataset = TensorDataset(x_train)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        drop_last=True
    )
    
    # Initialize models
    vocab_sizes = [len(c) for c in cat_dicts]
    nb_numeric = num_covs.shape[-1]
    x_dim = x_train.shape[-1]
    
    generator = Generator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers'],
        z_dim=CONFIG['latent_dim']
    ).to(device)
    
    discriminator = Discriminator(
        x_dim=x_dim,
        vocab_sizes=vocab_sizes,
        nb_numeric=nb_numeric,
        h_dims=[CONFIG['hdim']] * CONFIG['nb_layers']
    ).to(device)
    
    # Initialize wandb
    wandb.init(project='adversarial_gene_expr', config=CONFIG)
    
    # Train model
    train_gan(
        generator=generator,
        discriminator=discriminator,
        dataloader=train_loader,
        cat_covs=cat_covs_train,
        num_covs=num_covs_train,
        config=CONFIG,
        device=device
    )

if __name__ == '__main__':
    main()