In [32]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F

import pandas as pd
import numpy as np

In [44]:
from pytorch_tabnet.tab_network import TabNetEncoder, TabNetDecoder, EmbeddingGenerator
from torch.distributions.normal import Normal

class Generator(nn.Module):
	"""
	Defines generator architecture

	Parameters
	----------
	input_dim : int
		Number of features
	output_dim : int or list of int for multi task classification
		Dimension of network output
		examples : one for regression, 2 for binary classification etc...
	"""
	def __init__(self, input_dim, output_dim, cat_dims, cat_idxs, cat_emb_dim):
		super(Generator, self).__init__()
		self.embed = EmbeddingGenerator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
		self.encode = TabNetEncoder(input_dim, output_dim)
		self.decode = TabNetDecoder(input_dim)

	def forward(self, data):
		"""
		Function for completing a forward pass of the Decoder: Given a noise vector,
		returns a generated data.

		Parameters
		----------
			data: Tensor
				tabular data tensor with dimensions (batch_size, im_chan, im_height, im_width)
		Returns
        -------
		decoding:
			the autoencoded tabular data
		q_dist:
			the z-distribution of the encoding
		"""
		embedded_x = self.embed(data)
		q_mean, q_stddev = self.encode(embedded_x)
		q_dist = Normal(q_mean, q_stddev)
		z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation
		decoding = self.decode(z_sample)
		return decoding, q_dist

In [45]:
from torch.distributions.kl import kl_divergence
def kl_divergence_loss(q_dist):
    return kl_divergence(
        q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev))
    ).sum(-1)