In [1]:
import pandas as pd
import numpy as np
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader
from torch.optim import SGD, Adam

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt

In [2]:
import tableGAN
from tableGAN.utils import create_GAN_data, TabularDataset, Preprocessor
from tableGAN.tableGAN import make_noise, CriticNet, WGAN, GeneratorNet

In [3]:
training = create_GAN_data(20000,class_ratio=0.05,random_state=123)
test = create_GAN_data(20000,class_ratio=0.05,random_state=123)
minority = training.groupby("group").get_group(1).reset_index(drop=True)

In [4]:
input_dim=training.shape[1]

In [5]:
emb_dims = [(1,5),(1,5),(3,10)]
generator = GeneratorNet(noise_dim=10, lin_layer_sizes=[10,10,10], lin_layer_dropouts=None,
                        no_of_cont=7, emb_dims=emb_dims)
#generator = GeneratorNet(noise_dim=100,n_output_continuous=7,n_output_binary=2,n_output_categorical=[3])
critic = CriticNet(no_of_cont=7, emb_dims=emb_dims, lin_layer_sizes=[10,10,10], 
                   emb_dropout = 0, lin_layer_dropouts = [0,0,0])
wgan = WGAN(generator, critic)

In [6]:
batch_size = 128
#learning_rate = 1e-5
critic_rounds = 5
gradient_penalty_coefficient = 10

In [7]:
minority_tab = TabularDataset(minority, cat_cols=['group',8,9,10,11])
data_loader = DataLoader(minority_tab, batch_size=batch_size, shuffle=True)

In [8]:
critic_optimizer = Adam(critic.parameters(), lr=1e-3)
generator_optimizer = Adam(generator.parameters(), lr=1e-3)

In [9]:
num_epochs = 20

In [10]:
critic_performance, generator_performance = wgan.train_WGAN(
    data_loader=data_loader, critic_optimizer=critic_optimizer, generator_optimizer=generator_optimizer,
    num_epochs =num_epochs, gradient_penalty_coefficient= gradient_penalty_coefficient,
    critic_rounds=critic_rounds,
    val_data=None) #torch.from_numpy(test.values).float()

1.218724 | 
3.978997 | 
5.971148 | 
8.126633 | 
10.938429 | 
10.490806 | 
13.776539 | 
19.359386 | 
35.364006 | 
46.441959 | 
95.632416 | 
118.584999 | 
229.817413 | 
495.755554 | 
957.134644 | 
1626.579468 | 
2495.883545 | 


KeyboardInterrupt: 

In [None]:
fake = pd.DataFrame(wgan.generator.sample(make_noise(20000, dim=wgan.generator.noise_dim)).detach().numpy())

In [None]:
pd.DataFrame(np.vstack([np.round(np.mean(minority, axis=0),4), 
                        np.round(np.mean(fake.values, axis=0),4)]))