# TableGAN

For a summary of changes introduced by Wasserstein GANs, see the original paper or https://www.alexirpan.com/2017/02/22/wasserstein-gan.html:

- Discriminator gives continuous output
- Loss is the difference in mean output (Critic) and mean output on generated data for optimal critic (generator)
- Apply gradient regularization to ensure the assumption that the critic is a K-Lipschitz function

Thanks to the pytorch implementation in https://github.com/caogang/wgan-gp/blob/master/gan_toy.py

In [1]:
import torch
from torch import nn, optim
# Variable provides a wrapper around tensors to allow automatic differentiation, etc.
from torch.autograd.variable import Variable 
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
class SimulationDataset(Dataset):
    """Simulated dataset with continuous and categorical variables."""

    def __init__(self, csv_file):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
        """
        self.data_frame = pd.read_csv(csv_file)

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        #observation = self.data_frame.iloc[idx, 1:].values
        observation = self.data_frame.iloc[idx, :].values
        label = self.data_frame.loc[idx,"group"]
        return (observation, label)

In [3]:
def make_noise(size):
    """
    Generates a vector with length 100 of Gaussian noise with (batch_size, 100)
    """
    n = Variable(
        torch.randn(size, 100) # random values from standard normal
    )
    return n

In [None]:
class GeneratorNet(torch.nn.Module):
    """
    A three-layer generative neural network
    Assumes that the input data is sorted continuous, then categorical
    output_continuous: Number of continuous variables
    output_continuous: Number of binary variables
    output_categorical: List of number of columns each categorical variable
    """
    def __init__(self, n_output_continuous, n_output_binary, n_output_categorical, noise_dim=100):
        super().__init__()
        self.n_output_continuous = n_output_continuous
        self.n_output_binary = n_output_binary
        self.n_output_categorical = n_output_categorical
        
        self.hidden0 = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.LeakyReLU(0.2)
            # TODO: Why no dropout in generator?
        )
        
        self.hidden1 = nn.Sequential(
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2)
        )
        
        self.hidden2 = nn.Sequential(
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2)
        )
        
        self.binary = [nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid()
        ) 
                            for x in range(n_output_binary)]
        
        self.categorical = [nn.Sequential(
            nn.Linear(256, x),
            nn.Softmax(dim=0)
        ) 
                            for x in n_output_categorical]
        
        self.continuous = nn.Sequential(
            nn.Linear(256, n_output_continuous),
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        out_binary = [self.binary[var](x) for var in 
                   range(self.n_output_binary)]    
        out_categorical = [self.categorical[var](x)
                           for var in 
                           range(len(self.n_output_categorical))]    
        out_continuous = self.continuous(x)
        return torch.cat((out_continuous,*out_binary, *out_categorical), dim=1)
    
    def sample(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        out_binary = [(self.binary[var](x)>0.5).float() for var in 
                   range(self.n_output_binary)]    
        out_categorical = [torch.eye(self.n_output_categorical[var])[torch.multinomial(self.categorical[var](x),1).squeeze()]
                            #[torch.eye(self.n_output_categorical[var])[torch.argmax(self.categorical[var](x), dim=1)] 
                           for var in 
                           range(len(self.n_output_categorical))]
        
        out_continuous = self.continuous(x)
        return torch.cat((out_continuous,*out_binary, *out_categorical), dim=1)
        

In [None]:
class CriticNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self, input_dim):
        super().__init__() # get the __init__() from the parent module
        input_dim
        n_out = 1
        
        self.hidden0 = nn.Sequential(
            nn.Linear(input_dim, 256), # Linear transformation part input*W+b
            nn.LeakyReLU(0.2), # leaky relu is more robust for GANs than ReLU
            nn.Dropout(0.2)
            )
        
        self.hidden1 = nn.Sequential(
            nn.Linear(256, 256), # Linear transformation part input*W+b
            nn.LeakyReLU(0.2), # leaky relu is more robust for GANs than ReLU
            nn.Dropout(0.2)
            )
            
        self.hidden2 = nn.Sequential(
            nn.Linear(256, 128), # Linear transformation part input*W+b
            nn.LeakyReLU(0.2), # leaky relu is more robust for GANs than ReLU
            nn.Dropout(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(128, n_out), # Linear transformation part input*W+b
        )
    
    # Careful to make forward() a function of the net, not of __init__
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    

In [None]:
def interpolate_data(real_data, fake_data):
    eps = torch.rand(real_data.size(0), 1) # A random unif number for each obs in the batch
    eps = eps.expand(real_data.size()) # Can only multiply tensors with tensors, so expand to same dimensions
    interpolated_data = eps*real_data + (1-eps)*fake_data
    interpolated_data = Variable(interpolated_data, requires_grad=True) # Transform into Variable again
    return interpolated_data

def calc_gradient_penalty(critic, real_data, fake_data):
    interpolated_data = interpolate_data(real_data, fake_data)
    critic_output = critic(interpolated_data)
    gradients = autograd.grad(inputs=interpolated_data, outputs=critic_output,
                             grad_outputs=torch.ones(critic_output.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty=((gradients.norm(2, dim=1)-1) ** 2)
    return gradient_penalty

In [None]:
# def critic_loss_GP(critic, real_data, fake_data, penalty_coefficient):
#     """
#     Wasserstein distance to minimize as loss for the critic, regularized by
#     Lipschwitz 1 gradient penalty
#     -(E[D(x_real)] - E[D(x_fake)]) + lambda*E[(||D(x_imputed)'||_2 -1)**2]
#     """
#     # Original critic loss
#     # D(x_real)
#     output_real = critic.forward(real_data)
#     # D(x_fake)
#     output_fake = critic.forward(fake_data)
#     raw_loss = (output_fake - output_real).squeeze()
    
#     # Gradient penalty for Lipschwitz-1
#     gradient_penalty = calc_gradient_penalty(critic, real_data,fake_data)
    
#     # Total loss
#     loss = (raw_loss + penalty_coefficient * gradient_penalty).mean()
#     return loss

def critic_loss(output_real, output_fake):
    """
    Wasserstein distance to minimize as loss for the critic
    -(E[D(x_real)] - E[D(x_fake)]) 
    """
    return -( torch.mean(output_real) - torch.mean(output_fake) )

def generator_loss(output_fake):
    """
    Loss to minimize for the generator on the output of the optimal critic
    -E[D(G(noise))]
    """
    return -torch.mean(output_fake)

In [None]:
def train_critic(optimizer, real_data, fake_data, gradient_penalty_coefficient=10):
    N = real_data.size(0) # Get number of rows from torch tensor
    optimizer.zero_grad() # reset gradient

    # Note: Calling backward() multiple times will acumulate the gradients
    # until they are reset with zero_grad()
    # E[D(x_real)]
    output_real = critic.forward(real_data)

    # E[D(x_fake)]
    output_fake = critic.forward(fake_data)
    raw_loss = critic_loss(output_real, output_fake)

    # Gradient penalty
    gradient_penalty = calc_gradient_penalty(critic, real_data,fake_data)

    # Calculate overall loss
    # Minimize the raw loss pushed upwards by penalty (always positive)
    loss = raw_loss + gradient_penalty_coefficient*gradient_penalty
    loss = loss.mean() # Average over batch

    # Weight update
    loss.backward()
    optimizer.step()

    # Return error and predictions for monitoring
    return raw_loss.mean(), output_real, output_fake

def train_generator(optimizer, fake_data):
    N = fake_data.size(0) # Get number of rows from torch tensor
    optimizer.zero_grad() # reset gradient

    # Get discriminator prediction output
    critic_prediction = critic.forward(fake_data)

    # See explanation above. Intuitively, we create loss if the 
    # discriminator predicts our pseudo-ones as zeros.
    loss_generator = generator_loss(critic_prediction)
    loss_generator.backward()

    # Weight update
    optimizer.step()

    # Return error and predictions for monitoring
    return loss_generator

### Training

In [4]:
data = SimulationDataset("../simulation_data/simulation.csv")
val_data = torch.from_numpy(pd.read_csv("../simulation_data/simulation_val.csv").values).float()
validation_noise = make_noise(2000)

In [5]:
critic_performance = []
generator_performance = []

In [6]:
generator = GeneratorNet(n_output_continuous=7,n_output_binary=2,n_output_categorical=[3])
critic = CriticNet(input_dim=12)

NameError: name 'GeneratorNet' is not defined

In [7]:
batch_size = 256
learning_rate = 1e-4
critic_rounds = 5
gradient_penalty_coefficient = 5

In [8]:
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

In [12]:
critic_optimizer = optim.Adam(critic.parameters(), lr=learning_rate, )
generator_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)

In [10]:
num_epochs = 20

In [None]:
for epoch in range(num_epochs):
    # enumerate() outputs index, value for an indexable object
    # output the index of the batch and the output of data_loader
    # data_loader() outputs a batch of images and their label (which we don't need in this case)
    for n_batch, (real_batch,_) in enumerate(data_loader):
        N = real_batch.size(0) # Get the number of images from tensor

        ## Train discriminator
        # Collect real data
        real_data = Variable(real_batch.float())

        temp_performance = []
        for k in range(critic_rounds):
            # Create fake data
            fake_data = generator(make_noise(N)).detach() 
            # generator() creates a graph on the fly, which we drop after collecting the fake data
            disc_error, disc_pred_real, disc_pred_fake = train_critic(real_data = real_data, fake_data = fake_data, 
                                                                      optimizer = critic_optimizer,
                                                                      gradient_penalty_coefficient=gradient_penalty_coefficient)
            #temp_performance.append(disc_error.detach().cpu().numpy())
        #critic_performance.append(-np.mean(temp_performance))
        critic_performance.append(-disc_error.detach().cpu().numpy())

        ## Train generator
        fake_data = generator(make_noise(N))
        # This time we keep the graph, because we backprop on it in the training function
        gen_error = train_generator(optimizer = generator_optimizer, fake_data = fake_data)
        generator_performance.append(gen_error.detach().cpu().numpy())

        if n_batch % 50 ==0 and val_data is not None:
            print("{train:.6f} | {val:.6f}".format(
                train=(critic(real_data).mean() - critic(generator(make_noise(len(val_data)))).mean()).detach().cpu().numpy(),
                val  =(critic(val_data).mean()  - critic(generator(make_noise(len(val_data)))).mean()).detach().cpu().numpy() 
            ))
            #print(pd.DataFrame(generator(validation_noise).detach().numpy()).mean())

In [11]:
from tableGAN.simulation import create_GAN_data
from tableGAN.tableGAN import make_noise
from tableGAN import tableGAN

generator = tableGAN.GeneratorNet(n_output_continuous=7,n_output_binary=2,n_output_categorical=[3])
critic = tableGAN.CriticNet(input_dim=12)
wgan = tableGAN.WGAN(generator, critic)

In [13]:
critic_performance, generator_performance = wgan.train_WGAN(
    data_loader=data_loader, critic_optimizer=critic_optimizer, generator_optimizer=generator_optimizer,
    num_epochs =num_epochs, gradient_penalty_coefficient= gradient_penalty_coefficient,
    critic_rounds=critic_rounds,
    val_data=val_data)

0.319896 | 0.305091
17.684021 | 16.547075
5.628733 | 6.530212
4.506333 | 4.725708
5.565834 | 5.130300


KeyboardInterrupt: 

In [None]:
plt.plot(critic_performance[10:], color = "red")
plt.plot(generator_performance[10:])
plt.show()

In [None]:
real = pd.read_csv("../simulation_data/simulation.csv")

In [None]:
fake = pd.DataFrame(generator.sample(make_noise(20000)).detach().numpy())
fake.iloc[:,0] = fake.iloc[:,0].round()

In [None]:
pd.concat([np.round(np.mean(real, axis=0),4), np.round(np.mean(fake, axis=0),4)], axis=1)

In [None]:
for i in range(fake.shape[1]):
    plt.hist(real.iloc[:,i], alpha=0.5, bins=25, density=True)
    plt.hist(fake.iloc[:,i], alpha=0.5, bins=25, density=True, color="orange")
    plt.show()

In [None]:
print(real.iloc[:,[7,8]].groupby("group").apply(np.mean, axis=0))
print(fake.iloc[:,[7,8]].groupby(7).apply(np.mean, axis=0))

In [None]:
print(real.iloc[:,[7,9,10,11]].groupby("group").apply(np.mean, axis=0))
print(fake.iloc[:,[7,9,10,11]].groupby(7).apply(np.mean, axis=0))

In [None]:
np.round(np.cov(real.iloc[:,4:7].values, rowvar=False), 3)

In [None]:
np.cov(fake.iloc[:,4:7].values, rowvar=False)