In [32]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F

import pandas as pd
import numpy as np

In [44]:
from pytorch_tabnet.tab_network import TabNetEncoder, TabNetDecoder, EmbeddingGenerator
from torch.distributions.normal import Normal

class Generator(nn.Module):
	"""
	Defines generator architecture

	Parameters
	----------
	input_dim : int
		Number of features
	output_dim : int or list of int for multi task classification
		Dimension of network output
		examples : one for regression, 2 for binary classification etc...
	"""
	def __init__(self, input_dim, output_dim, cat_dims, cat_idxs, cat_emb_dim):
		super(Generator, self).__init__()
		self.embed = EmbeddingGenerator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
		self.encode = TabNetEncoder(input_dim, output_dim)
		self.decode = TabNetDecoder(input_dim)

	def forward(self, data):
		"""
		Function for completing a forward pass of the Decoder: Given a noise vector,
		returns a generated data.

		Parameters
		----------
			data: Tensor
				tabular data tensor with dimensions (batch_size, im_chan, im_height, im_width)
		Returns
        -------
		decoding:
			the autoencoded tabular data
		q_dist:
			the z-distribution of the encoding
		"""
		embedded_x = self.embed(data)
		q_mean, q_stddev = self.encode(embedded_x)
		q_dist = Normal(q_mean, q_stddev)
		z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation
		decoding = self.decode(z_sample)
		return decoding, q_dist


In [46]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 300),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(300, 100),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(100, 1)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = self.model(x)
        return out.view(x.size(0))


In [None]:
from torch.distributions.kl import kl_divergence

def kl_divergence_loss(q_dist):
    return kl_divergence(
        q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev))
    ).sum(-1)

reconstruction_loss = nn.BCELoss(reduction='sum')

In [None]:
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms

train_dataloader = None
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
num_epochs = 5
lr = 0.001
discriminator = Discriminator()
generator = Generator().to(device)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)

In [None]:
for epoch in range(num_epochs):
	for data, _ in train_dataloader:

		# Train the Discriminator

		noise = torch.randn((data.size(0), 100))
		fake_data = generator(noise)

		inputs = torch.cat([data, fake_data])
		labels = torch.cat([torch.zeros(data.size(0)), # real
							torch.ones(fake_data.size(0))]) # fake

		d_outputs = discriminator(inputs)
		d_loss = reconstruction_loss(d_outputs, labels)
		d_loss.backward()
		d_optimizer.step()
		d_optimizer.zero_grad()

		# Train the Generator
		data = data.to(device)
		g_optimizer.zero_grad()
		fake_data, encoding = generator(data)
		outputs = discriminator(fake_data)

		g_loss = reconstruction_loss(outputs, torch.zeros(data.size(0)))
		g_loss.backward()
		g_optimizer.step()
		g_optimizer.zero_grad()

	scores = torch.sigmoid(d_outputs)
	real_score = scores[:data.size(0)].data.mean()
	fake_score = scores[:data.size(0)].data.mean()

	print(f'Epoch {epoch+1}/{num_epochs}, d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, \
	 D(x): {real_score:.2f}, D(G(z)): {fake_score:.2f}' )



	# plot data
	num_test_samples = 16
	test_noise = torch.randn(num_test_samples, 100)
	generator.eval()
	discriminator.eval()
	test_images = generator(test_noise)