In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import os
import datetime

class Generator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims=None, z_dim=10):
        super().__init__()
        if h_dims is None:
            h_dims = [256, 256]
            
        self.z_dim = z_dim
        self.nb_categoric = len(vocab_sizes)
        
        # Embeddings for categorical variables
        self.embeddings = nn.ModuleList()
        total_emb_dim = 0
        for vs in vocab_sizes:
            emb_dim = int(vs ** 0.5) + 1
            self.embeddings.append(nn.Embedding(vs, emb_dim))
            total_emb_dim += emb_dim
        total_emb_dim += nb_numeric
        
        # Generator layers
        layers = []
        prev_dim = z_dim + total_emb_dim
        for dim in h_dims:
            layers.extend([
                nn.Linear(prev_dim, dim),
                nn.ReLU()
            ])
            prev_dim = dim
        layers.append(nn.Linear(prev_dim, x_dim))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, z, cat, num):
        # Process categorical variables
        emb = []
        for i, embedding in enumerate(self.embeddings):
            emb.append(embedding(cat[:, i]))
        if self.nb_categoric > 0:
            emb = torch.cat(emb, dim=1)
            h = torch.cat([z, num, emb], dim=1)
        else:
            h = torch.cat([z, num], dim=1)
        
        return self.layers(h)

class Discriminator(nn.Module):
    def __init__(self, x_dim, vocab_sizes, nb_numeric, h_dims=None):
        super().__init__()
        if h_dims is None:
            h_dims = [256, 256]
            
        self.nb_categoric = len(vocab_sizes)
        
        # Embeddings for categorical variables
        self.embeddings = nn.ModuleList()
        total_emb_dim = 0
        for vs in vocab_sizes:
            emb_dim = int(vs ** 0.5) + 1
            self.embeddings.append(nn.Embedding(vs, emb_dim))
            total_emb_dim += emb_dim
            
        # Discriminator layers
        layers = []
        prev_dim = x_dim + nb_numeric + total_emb_dim
        for dim in h_dims:
            layers.extend([
                nn.Linear(prev_dim, dim),
                nn.ReLU()
            ])
            prev_dim = dim
        layers.append(nn.Linear(prev_dim, 1))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x, cat, num):
        # Process categorical variables
        emb = []
        for i, embedding in enumerate(self.embeddings):
            emb.append(embedding(cat[:, i]))
        if self.nb_categoric > 0:
            emb = torch.cat(emb, dim=1)
            h = torch.cat([x, num, emb], dim=1)
        else:
            h = torch.cat([x, num], dim=1)
            
        return self.layers(h)

class Encoder(nn.Module):
    def __init__(self, x_dim, h_dims=None, z_dim=10):
        super().__init__()
        if h_dims is None:
            h_dims = [256, 256]
            
        layers = []
        prev_dim = x_dim
        for dim in h_dims:
            layers.extend([
                nn.Linear(prev_dim, dim),
                nn.ReLU()
            ])
            prev_dim = dim
        layers.append(nn.Linear(prev_dim, z_dim))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)

def compute_gradient_penalty(discriminator, real_samples, fake_samples, cat, num):
    alpha = torch.rand(real_samples.size(0), 1, device=real_samples.device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = discriminator(interpolates, cat, num)
    
    grad_outputs = torch.ones_like(d_interpolates)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

def train(dataset, cat_covs, num_covs, z_dim, epochs, batch_size,
          generator, discriminator, encoder, score_fn, save_fn,
          device='cuda', lr=5e-4, nb_critic=5, cycle_weight=10,
          gradient_penalty_weight=10, checkpoint_dir='./checkpoints',
          log_dir='./logs', patience=10):
    
    # Convert data to tensors
    dataset = torch.FloatTensor(dataset).to(device)
    cat_covs = torch.LongTensor(cat_covs).to(device)
    num_covs = torch.FloatTensor(num_covs).to(device)
    
    # Create data loader
    train_dataset = TensorDataset(dataset, cat_covs, num_covs)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Optimizers
    gen_optimizer = optim.RMSprop(generator.parameters(), lr=lr)
    disc_optimizer = optim.RMSprop(discriminator.parameters(), lr=lr)
    encoder_optimizer = optim.RMSprop(encoder.parameters(), lr=lr)
    
    # Logger
    writer = SummaryWriter(log_dir)
    
    best_score = float('-inf')
    patience_counter = patience
    
    for epoch in range(epochs):
        gen_losses = []
        disc_losses = []
        
        for batch_idx, (real_data, cat, num) in enumerate(train_loader):
            batch_size = real_data.size(0)
            
            # Train Discriminator
            for _ in range(nb_critic):
                disc_optimizer.zero_grad()
                
                # Generate fake data
                z = torch.randn(batch_size, z_dim, device=device)
                fake_data = generator(z, cat, num)
                
                # Cycle consistency
                z_rec = encoder(fake_data)
                cycled_data = generator(z_rec, cat, num)
                
                # Compute discriminator outputs
                real_validity = discriminator(real_data, cat, num)
                fake_validity = discriminator(fake_data.detach(), cat, num)
                
                # Compute losses
                gradient_penalty = compute_gradient_penalty(
                    discriminator, real_data, fake_data.detach(), cat, num)
                cycle_loss = torch.mean(torch.abs(real_data - cycled_data))
                
                disc_loss = (-torch.mean(real_validity) + torch.mean(fake_validity)
                           + gradient_penalty_weight * gradient_penalty
                           + cycle_weight * cycle_loss)
                
                disc_loss.backward()
                disc_optimizer.step()
                disc_losses.append(disc_loss.item())
            
            # Train Generator and Encoder
            gen_optimizer.zero_grad()
            encoder_optimizer.zero_grad()
            
            z = torch.randn(batch_size, z_dim, device=device)
            fake_data = generator(z, cat, num)
            fake_validity = discriminator(fake_data, cat, num)
            
            z_rec = encoder(fake_data)
            cycled_data = generator(z_rec, cat, num)
            cycle_loss = torch.mean(torch.abs(fake_data - cycled_data))
            
            gen_loss = -torch.mean(fake_validity) + cycle_weight * cycle_loss
            gen_loss.backward()
            
            gen_optimizer.step()
            encoder_optimizer.step()
            gen_losses.append(gen_loss.item())
        
        # Logging
        writer.add_scalar('Generator Loss', np.mean(gen_losses), epoch)
        writer.add_scalar('Discriminator Loss', np.mean(disc_losses), epoch)
        
        # Model evaluation and checkpointing
        if epoch % 5 == 0:
            score = score_fn(generator)
            if score > best_score:
                best_score = score
                save_fn()
                patience_counter = patience
            else:
                patience_counter -= 1
                
            if patience_counter == 0:
                print(f'Early stopping at epoch {epoch}')
                break
                
        print(f'Epoch {epoch}/{epochs} - '
              f'Gen Loss: {np.mean(gen_losses):.4f} - '
              f'Disc Loss: {np.mean(disc_losses):.4f}')
    
    writer.close()

In [None]:
generator = Generator(x_dim, vocab_sizes, nb_numeric).to(device)
discriminator = Discriminator(x_dim, vocab_sizes, nb_numeric).to(device)
encoder = Encoder(x_dim, z_dim=z_dim).to(device)

train(dataset, cat_covs, num_covs, z_dim, epochs, batch_size,
      generator, discriminator, encoder, score_fn, save_fn)