In [1]:
import tensorflow as tf
import torch
import numpy as np
import pandas as pd
import tensorflow_probability as tfp
import matplotlib.pyplot as plt

In [2]:
from ctgan.transformer import DataTransformer
from ctgan.sampler import Sampler
from ctgan.conditional import ConditionalGenerator

In [3]:
from ctgan_torch.synthesizer import CTGANSynthesizer

In [4]:
DEMO_URL = 'http://ctgan-data.s3.amazonaws.com/census.csv.gz'
train_data = pd.read_csv(DEMO_URL, compression='gzip')
discrete_columns = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
    'income'
]

In [5]:
transformer = DataTransformer()
transformer.fit(train_data, discrete_columns)
train_data = transformer.transform(train_data)



In [6]:
data_sampler = Sampler(train_data, transformer.output_info)
data_dim = transformer.output_dimensions
cond_generator = ConditionalGenerator(
    train_data, transformer.output_info, True)

In [9]:
import math
def w_bounded_initializer(fan_in):
    bound = math.sqrt(1 / fan_in)
    return tf.random_uniform_initializer(minval=-bound, maxval=bound)

In [10]:
class NewResidualLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs):
        super(NewResidualLayer, self).__init__()
        self.num_outputs = num_outputs
        self.fc = None
        self.bn = None
        self.relu = None

    def build(self, input_shape):
        self.fc = tf.keras.layers.Dense(
            self.num_outputs, input_shape=input_shape,
            kernel_initializer=w_bounded_initializer(input_shape[1]), bias_initializer=w_bounded_initializer(input_shape[1]))
        self.bn = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9)
        self.relu = tf.keras.layers.ReLU()

    def call(self, inputs, **kwargs):
        outputs = self.fc(inputs)
        outputs = self.bn(outputs)
        outputs = self.relu(outputs)
        return tf.concat([outputs, inputs], axis=1)

In [11]:
from ctgan.layers import *
def build_generator(gen_dims, embedding_dim, data_dim):
    dim = embedding_dim
    model = inputs = tf.keras.Input(shape=(dim,))

    for layer_dim in list(gen_dims):
        model = NewResidualLayer(layer_dim)(model)
        dim += layer_dim

    outputs = tf.keras.layers.Dense(
        data_dim, kernel_initializer=w_bounded_initializer(dim),
        bias_initializer=w_bounded_initializer(dim))(model)
    #outputs = GenActLayer(
    #    data_dim, self.transformer.output_info_tensor(), self.tau)(model)
    return tf.keras.Model(inputs, outputs, name='Generator')

In [12]:
def build_critic(dis_dims, input_dim):
    pac_dim = input_dim * 10
    #dim = pac_dim
    dim = input_dim
    
    model = inputs = tf.keras.Input(shape=(input_dim,))
    #model = tf.reshape(model, [-1, pac_dim])
    for layer_dim in list(dis_dims):
        model = tf.keras.layers.Dense(
            layer_dim, input_shape=(dim,), kernel_initializer=w_bounded_initializer(dim),
            bias_initializer=w_bounded_initializer(dim))(model)
        model = tf.keras.layers.LeakyReLU(0.2)(model)
        #model = tf.keras.layers.Dropout(0.5)(model)
        dim = layer_dim

    outputs = tf.keras.layers.Dense(
        1, input_shape=(dim,), kernel_initializer=w_bounded_initializer(dim),
        bias_initializer=w_bounded_initializer(dim))(model)
    return tf.keras.Model(inputs, outputs, name='Critic')

In [13]:
128 + cond_generator.n_opt

232

In [14]:
l2scale = 1e-6

In [15]:
generator = build_generator(
    (256,256), 128 + cond_generator.n_opt, data_dim)
critic = build_critic((256,256), data_dim + cond_generator.n_opt)

In [16]:
g_opt = tf.keras.optimizers.Adam(
    learning_rate=2e-4, beta_1=0.5, beta_2=0.9, epsilon=1e-08, decay=l2scale)
c_opt = tf.keras.optimizers.Adam(
    learning_rate=2e-4, beta_1=0.5, beta_2=0.9, epsilon=1e-08)

In [17]:
class Discriminator(torch.nn.Module):

    def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10):

        alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
        alpha = alpha.repeat(1, pac, real_data.size(1))
        alpha = alpha.view(-1, real_data.size(1))

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)

        disc_interpolates = self(interpolates)

        gradients = torch.autograd.grad(
            outputs=disc_interpolates, inputs=interpolates,
            grad_outputs=torch.ones(disc_interpolates.size(), device=device),
            create_graph=True, retain_graph=True, only_inputs=True
        )[0]

        gradient_penalty = ((
            gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
        ) ** 2).mean() * lambda_

        return gradient_penalty

    def __init__(self, input_dim, dis_dims, pack=10):
        super(Discriminator, self).__init__()
        #dim = input_dim * pack
        dim = input_dim
        self.pack = pack
        self.packdim = dim
        seq = []
        for item in list(dis_dims):
            seq += [torch.nn.Linear(dim, item), torch.nn.LeakyReLU(0.2)]
            #seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
            dim = item

        seq += [torch.nn.Linear(dim, 1)]
        self.seq = torch.nn.Sequential(*seq)

    def forward(self, input):
        assert input.size()[0] % self.pack == 0
        #return self.seq(input.view(-1, self.packdim))
        return self.seq(input)

In [18]:
from ctgan_torch.models import Generator
t_gen = Generator(
    128 + cond_generator.n_opt,
    (256,256),
    data_dim
)

t_critic = Discriminator(
    data_dim + cond_generator.n_opt,
    (256,256)
)

In [19]:
optimizerG = torch.optim.Adam(
    t_gen.parameters(), lr=2e-4, betas=(0.5, 0.9),
    weight_decay=l2scale
)
optimizerD = torch.optim.Adam(t_critic.parameters(), lr=2e-4, betas=(0.5, 0.9))

In [20]:
batch_size = 500
embedding_dim = z_dim = 128
mean = torch.zeros(batch_size, embedding_dim)
std = mean + 1

# Copy weights to TF model

In [21]:
critic.layers[1].set_weights([t_critic.seq[0].weight.detach().numpy().T, t_critic.seq[0].bias.detach().numpy()])
critic.layers[3].set_weights([t_critic.seq[2].weight.detach().numpy().T, t_critic.seq[2].bias.detach().numpy()])
critic.layers[5].set_weights([t_critic.seq[4].weight.detach().numpy().T, t_critic.seq[4].bias.detach().numpy()])

In [22]:
generator.layers[1]._layers[0].set_weights([t_gen.seq[0].fc.weight.detach().numpy().T, t_gen.seq[0].fc.bias.detach().numpy()])
generator.layers[2]._layers[0].set_weights([t_gen.seq[1].fc.weight.detach().numpy().T, t_gen.seq[1].fc.bias.detach().numpy()])
generator.layers[3].set_weights([t_gen.seq[2].weight.detach().numpy().T, t_gen.seq[2].bias.detach().numpy()])

# Train models

In [23]:
t_fakez = torch.normal(mean=mean, std=std)
tf_fakez = tf.random.normal([batch_size, z_dim])

fk = np.random.normal(size=(batch_size, z_dim)).astype(np.float32)
t_fakez = torch.from_numpy(fk)
tf_fakez = tf.convert_to_tensor(fk)

In [24]:
condvec = cond_generator.sample(batch_size)
if condvec is None:
    c1, m1, col, opt = None, None, None, None
    real = data_sampler.sample(batch_size, col, opt)
else:
    c1, m1, col, opt = condvec
    c1_tf = tf.convert_to_tensor(c1)
    m1_tf = tf.convert_to_tensor(m1)
    tf_fakez = tf.concat([tf_fakez, c1_tf], axis=1)
    
    c1_t = torch.from_numpy(c1)
    m1_t = torch.from_numpy(m1)
    t_fakez = torch.cat([t_fakez, c1_t], dim=1)

    perm = np.arange(batch_size)
    np.random.shuffle(perm)
    real = data_sampler.sample(batch_size, col[perm], opt[perm])
    tf_c2 = tf.gather(c1_tf, perm)
    t_c2 = c1_t[perm]

In [25]:
tf_fake = generator(tf_fakez, training=True)

In [26]:
t_fake = t_gen(t_fakez)

In [27]:
from ctgan.layers import _apply_activate
tf_fakeact = _apply_activate(tf_fake, transformer.output_info)

In [28]:
from torch.nn import functional
def torch_aa(data, transformer_info):
    data_t = []
    st = 0
    for item in transformer_info:
        if item[1] == 'tanh':
            ed = st + item[0]
            data_t.append(torch.tanh(data[:, st:ed]))
            st = ed
        elif item[1] == 'softmax':
            ed = st + item[0]
            data_t.append(functional.gumbel_softmax(data[:, st:ed], tau=0.2))
            st = ed
        else:
            assert 0

    return torch.cat(data_t, dim=1)
t_fakeact = torch_aa(t_fake, transformer.output_info)

In [29]:
# Copy fakeact tensor
tf_fakeact = tf.convert_to_tensor(t_fakeact.detach().numpy())
tf_real = tf.convert_to_tensor(real.astype('float32'))

In [30]:
tf_real = tf.convert_to_tensor(real.astype('float32'))
t_real = torch.from_numpy(real.astype('float32'))

if c1 is not None:
    tf_fake_cat = tf.concat([tf_fakeact, c1_tf], axis=1)
    tf_real_cat = tf.concat([tf_real, tf_c2], axis=1)
    
    t_fake_cat = torch.cat([t_fakeact, c1_t], dim=1)
    t_real_cat = torch.cat([t_real, t_c2], dim=1)
else:
    tf_real_cat = tf_real
    tf_fake_cat = tf_fake
    
    t_real_cat = t_real
    t_fake_cat = t_fake

## Torch Critic gradient

In [31]:
t_y_fake = t_critic(t_fake_cat)
t_y_real = t_critic(t_real_cat)

In [32]:
t_critic.seq[0].weight

Parameter containing:
tensor([[ 0.0394,  0.0553,  0.0593,  ...,  0.0539, -0.0371,  0.0267],
        [ 0.0364, -0.0284,  0.0046,  ..., -0.0178, -0.0449,  0.0512],
        [ 0.0422,  0.0044, -0.0385,  ...,  0.0456, -0.0134, -0.0304],
        ...,
        [-0.0528, -0.0321,  0.0250,  ..., -0.0136, -0.0175, -0.0086],
        [ 0.0126, -0.0208, -0.0324,  ...,  0.0411,  0.0220, -0.0596],
        [-0.0434, -0.0229,  0.0378,  ...,  0.0427, -0.0114, -0.0400]],
       requires_grad=True)

In [33]:
t_critic.seq[0].weight.grad

In [34]:
optimizerD.zero_grad()

In [35]:
t_gp = t_critic.calc_gradient_penalty(t_real_cat, t_fake_cat)
t_gp

tensor(4.4482, grad_fn=<MulBackward0>)

In [36]:
t_gp.backward(retain_graph=True)
t_critic.seq[0].weight.grad

tensor([[-9.1329e-04, -1.3599e-04, -2.0159e-02,  ..., -1.5964e-02,
          2.1048e-02, -1.8032e-02],
        [ 6.3296e-05,  1.1307e-03,  1.3827e-02,  ...,  1.2684e-02,
         -1.4056e-02,  1.2407e-02],
        [-4.6566e-04, -2.3619e-03, -3.4046e-02,  ..., -3.1243e-02,
          3.8162e-02, -2.8887e-02],
        ...,
        [-6.8629e-05,  7.1035e-05,  6.4512e-03,  ...,  5.0679e-03,
         -7.1159e-03,  5.0784e-03],
        [ 5.9869e-04,  8.3954e-04,  2.0180e-02,  ...,  1.3155e-02,
         -2.0582e-02,  1.8380e-02],
        [ 3.4123e-05,  1.4451e-03,  1.3386e-02,  ...,  1.0244e-02,
         -1.4730e-02,  1.2419e-02]])

In [39]:
t_critic.seq[0].bias.grad

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [197]:
t.gradient(cost, critic.trainable_variables)[0].numpy().T

array([[-0.00844331,  0.004779  , -0.02588256, ...,  0.0125904 ,
        -0.01018348,  0.00513233],
       [-0.00792216,  0.00628714, -0.02754391, ...,  0.01410859,
        -0.00976788,  0.00463192],
       [ 0.00805554, -0.00322066,  0.02322641, ..., -0.01283849,
         0.00801958, -0.00228744],
       ...,
       [-0.00577756,  0.00759055, -0.03305208, ...,  0.02059012,
        -0.01217141,  0.00259927],
       [ 0.00254121, -0.00154477,  0.01237051, ..., -0.00832565,
         0.00477842, -0.00101274],
       [-0.00316215,  0.00415935, -0.01578547, ...,  0.00967369,
        -0.0061935 ,  0.00161804]], dtype=float32)

In [202]:
# BIAS GP gradient is zero

In [182]:
t_loss_d = -(torch.mean(t_y_real) - torch.mean(t_y_fake))
t_loss_d

tensor(-0.0212, grad_fn=<NegBackward>)

In [183]:
t_loss_d.backward(retain_graph=True)
t_critic.seq[0].weight.grad

tensor([[-3.2967e-05, -4.1965e-04, -7.9922e-04,  ..., -9.3629e-06,
         -3.4491e-04, -3.1650e-04],
        [-1.2957e-04, -2.2784e-04, -1.4991e-03,  ..., -1.8192e-05,
         -2.2587e-04, -1.1298e-04],
        [ 3.2450e-04,  3.2676e-04,  1.2400e-03,  ...,  5.2052e-05,
          2.2332e-04,  9.1875e-05],
        ...,
        [-1.5260e-04,  3.6859e-04, -1.0196e-03,  ..., -3.6431e-06,
         -2.7294e-05,  3.8206e-04],
        [ 1.8673e-04,  1.6291e-04,  5.7845e-04,  ..., -1.9890e-05,
          2.5632e-04, -9.9192e-05],
        [-2.1609e-04,  2.3110e-04, -3.7279e-05,  ..., -1.3771e-05,
          3.7411e-04, -7.1187e-06]])

In [191]:
t.gradient(tf_loss_d, critic.trainable_variables)[0].numpy().T

array([[-3.29672548e-05, -4.19651158e-04, -7.99217494e-04, ...,
        -9.36291690e-06, -3.44913104e-04, -3.16496531e-04],
       [-1.29574779e-04, -2.27835990e-04, -1.49905507e-03, ...,
        -1.81916221e-05, -2.25866417e-04, -1.12976995e-04],
       [ 3.24495923e-04,  3.26758425e-04,  1.24002690e-03, ...,
         5.20520698e-05,  2.23320327e-04,  9.18747101e-05],
       ...,
       [-1.52602792e-04,  3.68591165e-04, -1.01956772e-03, ...,
        -3.64308426e-06, -2.72941252e-05,  3.82056838e-04],
       [ 1.86732563e-04,  1.62908720e-04,  5.78450388e-04, ...,
        -1.98896741e-05,  2.56315863e-04, -9.91916168e-05],
       [-2.16088883e-04,  2.31102313e-04, -3.72785435e-05, ...,
        -1.37712623e-05,  3.74114403e-04, -7.11862231e-06]], dtype=float32)

In [198]:
t_critic.seq[0].bias.grad

tensor([-3.7141e-03, -3.4286e-03,  1.5972e-03, -2.9054e-03,  6.3167e-04,
        -3.1064e-03, -1.2677e-05,  3.7158e-04, -7.2336e-04, -2.4003e-04,
        -9.8249e-04,  4.1842e-03, -4.7912e-03, -1.3064e-03, -2.6162e-04,
        -9.6417e-04,  4.9891e-04, -2.9432e-03,  6.3967e-04,  1.2029e-03,
        -1.4168e-03, -4.0689e-03,  7.0814e-03,  6.9827e-04, -1.9378e-04,
         1.5502e-03,  7.5785e-04, -2.2987e-03, -8.4559e-04,  1.3174e-03,
        -7.6710e-04,  2.7725e-03, -4.1151e-03,  3.0265e-05,  1.2375e-03,
        -4.8434e-03, -5.3207e-03,  3.7058e-03, -2.7931e-03,  5.1953e-04,
         2.1123e-03, -1.0427e-03, -4.3118e-03, -4.0709e-03,  1.8504e-03,
         2.1879e-03, -4.2011e-04,  3.3535e-03,  2.6802e-03,  1.4859e-03,
        -8.6288e-03, -2.0054e-03,  9.3788e-04, -2.3347e-03, -6.0889e-04,
         1.2244e-03, -8.2609e-05,  2.0363e-03,  1.3671e-04, -1.5748e-03,
        -1.1991e-04, -1.1027e-03,  3.1112e-04,  2.0957e-03, -4.2280e-03,
        -4.6176e-03,  1.9194e-03, -1.5437e-03,  7.3

In [199]:
t.gradient(tf_loss_d, critic.trainable_variables)[1].numpy()

array([-3.71405412e-03, -3.42860166e-03,  1.59716327e-03, -2.90540047e-03,
        6.31667324e-04, -3.10636754e-03, -1.26765808e-05,  3.71579954e-04,
       -7.23355450e-04, -2.40026740e-04, -9.82492696e-04,  4.18418413e-03,
       -4.79119830e-03, -1.30637095e-03, -2.61623412e-04, -9.64172767e-04,
        4.98906709e-04, -2.94324197e-03,  6.39672158e-04,  1.20291091e-03,
       -1.41678413e-03, -4.06886544e-03,  7.08143972e-03,  6.98265620e-04,
       -1.93772372e-04,  1.55020715e-03,  7.57850707e-04, -2.29874533e-03,
       -8.45589209e-04,  1.31738745e-03, -7.67095014e-04,  2.77247746e-03,
       -4.11512703e-03,  3.02651897e-05,  1.23747997e-03, -4.84339288e-03,
       -5.32073528e-03,  3.70577816e-03, -2.79306434e-03,  5.19529276e-04,
        2.11230293e-03, -1.04267639e-03, -4.31181863e-03, -4.07088920e-03,
        1.85036054e-03,  2.18787510e-03, -4.20107041e-04,  3.35351587e-03,
        2.68023903e-03,  1.48588140e-03, -8.62875022e-03, -2.00544763e-03,
        9.37876757e-04, -

In [203]:
optimizerD.step()

## TF Critic gradient

In [184]:
from functools import partial
def gradient_penalty(f, real, fake):
    pac = 10
    grad_penalty_lambda = 10
    """Calculates the gradient penalty loss for a batch of "averaged" samples.
    In Improved WGANs, the 1-Lipschitz constraint is enforced by adding a term to the
    loss function that penalizes the network if the gradient norm moves away from 1.
    However, it is impossible to evaluate this function at all points in the input
    space. The compromise used in the paper is to choose random points on the lines
    between real and generated samples, and check the gradients at these points. Note
    that it is the gradient w.r.t. the input averaged samples, not the weights of the
    discriminator, that we're penalizing!
    In order to evaluate the gradients, we must first run samples through the generator
    and evaluate the loss. Then we get the gradients of the discriminator w.r.t. the
    input averaged samples. The l2 norm and penalty can then be calculated for this
    gradient.
    """
    alpha = tf.random.uniform([real.shape[0] // pac, 1, 1], 0., 1.)
    alpha = tf.tile(alpha, tf.constant([1, pac, real.shape[1]], tf.int32))
    alpha = tf.reshape(alpha, [-1, real.shape[1]])

    interpolates = alpha * real + ((1 - alpha) * fake)
    with tf.GradientTape() as t:
        t.watch(interpolates)
        pred = f(interpolates)
    grad = t.gradient(pred, [interpolates])[0]
    grad = tf.reshape(grad, tf.constant([-1, pac * real.shape[1]], tf.int32))

    slopes = tf.math.reduce_euclidean_norm(grad, axis=1)
    gp = tf.reduce_mean((slopes - 1.) ** 2) * grad_penalty_lambda
    return gp

In [185]:
with tf.GradientTape(persistent=True) as t:
    tf_y_fake = critic(tf_fake_cat, training=True)
    tf_y_real = critic(tf_real_cat, training=True)
    tf_loss_d = -(tf.reduce_mean(tf_y_real) - tf.reduce_mean(tf_y_fake))
    tf_gp = gradient_penalty(partial(critic, training=True), tf_real_cat, tf_fake_cat)
    cost = tf_loss_d + tf_gp

In [186]:
tf_loss_d

<tf.Tensor: shape=(), dtype=float32, numpy=-0.021227747>

In [187]:
t.gradient(tf_loss_d, critic.trainable_variables)

[<tf.Tensor: shape=(261, 256), dtype=float32, numpy=
 array([[-3.29672548e-05, -1.29574779e-04,  3.24495923e-04, ...,
         -1.52602792e-04,  1.86732563e-04, -2.16088883e-04],
        [-4.19651158e-04, -2.27835990e-04,  3.26758425e-04, ...,
          3.68591165e-04,  1.62908720e-04,  2.31102313e-04],
        [-7.99217494e-04, -1.49905507e-03,  1.24002690e-03, ...,
         -1.01956772e-03,  5.78450388e-04, -3.72785435e-05],
        ...,
        [-9.36291690e-06, -1.81916221e-05,  5.20520698e-05, ...,
         -3.64308426e-06, -1.98896741e-05, -1.37712623e-05],
        [-3.44913104e-04, -2.25866417e-04,  2.23320327e-04, ...,
         -2.72941252e-05,  2.56315863e-04,  3.74114403e-04],
        [-3.16496531e-04, -1.12976995e-04,  9.18747101e-05, ...,
          3.82056838e-04, -9.91916168e-05, -7.11862231e-06]], dtype=float32)>,
 <tf.Tensor: shape=(256,), dtype=float32, numpy=
 array([-3.71405412e-03, -3.42860166e-03,  1.59716327e-03, -2.90540047e-03,
         6.31667324e-04, -3.1063675

In [188]:
tf_gp

<tf.Tensor: shape=(), dtype=float32, numpy=4.42831>

In [194]:
t.gradient(tf_gp, critic.trainable_variables)

[<tf.Tensor: shape=(261, 256), dtype=float32, numpy=
 array([[-0.00841035, -0.00779258,  0.00773105, ..., -0.00562495,
          0.00235448, -0.00294606],
        [ 0.00519865,  0.00651497, -0.00354741, ...,  0.00722196,
         -0.00170768,  0.00392825],
        [-0.02508334, -0.02604486,  0.02198639, ..., -0.03203251,
          0.01179206, -0.01574819],
        ...,
        [ 0.01259976,  0.01412678, -0.01289054, ...,  0.02059377,
         -0.00830576,  0.00968746],
        [-0.00983857, -0.00954202,  0.00779626, ..., -0.01214412,
          0.0045221 , -0.00656762],
        [ 0.00544882,  0.00474489, -0.00237931, ...,  0.00221722,
         -0.00091355,  0.00162516]], dtype=float32)>,
 <tf.Tensor: shape=(256,), dtype=float32, numpy=
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0.,

In [195]:
t.gradient(cost, critic.trainable_variables)

[<tf.Tensor: shape=(261, 256), dtype=float32, numpy=
 array([[-0.00844331, -0.00792216,  0.00805554, ..., -0.00577756,
          0.00254121, -0.00316215],
        [ 0.004779  ,  0.00628714, -0.00322066, ...,  0.00759055,
         -0.00154477,  0.00415935],
        [-0.02588256, -0.02754391,  0.02322641, ..., -0.03305208,
          0.01237051, -0.01578547],
        ...,
        [ 0.0125904 ,  0.01410859, -0.01283849, ...,  0.02059012,
         -0.00832565,  0.00967369],
        [-0.01018348, -0.00976788,  0.00801958, ..., -0.01217141,
          0.00477842, -0.0061935 ],
        [ 0.00513233,  0.00463192, -0.00228744, ...,  0.00259927,
         -0.00101274,  0.00161804]], dtype=float32)>,
 <tf.Tensor: shape=(256,), dtype=float32, numpy=
 array([-3.71405412e-03, -3.42860166e-03,  1.59716327e-03, -2.90540047e-03,
         6.31667324e-04, -3.10636754e-03, -1.26765808e-05,  3.71579954e-04,
        -7.23355450e-04, -2.40026740e-04, -9.82492696e-04,  4.18418413e-03,
        -4.79119830e-03, -1

In [205]:
c_opt.apply_gradients(zip(t.gradient(cost, critic.trainable_variables), critic.trainable_variables))

<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=1>

# Compare Critic layers weights

In [207]:
critic.layers[1].get_weights()

[array([[ 0.0265521 ,  0.05120495, -0.02169383, ..., -0.04435714,
         -0.03071266, -0.05264142],
        [-0.03998381, -0.01836094,  0.00707057, ..., -0.02528278,
         -0.05359489, -0.03072562],
        [-0.04267072,  0.06065835,  0.0498224 , ...,  0.05786606,
          0.03843578,  0.02996124],
        ...,
        [ 0.05995943, -0.01746897, -0.00472673, ..., -0.05517949,
         -0.03008169, -0.02729477],
        [ 0.05577808, -0.03644891,  0.02254241, ..., -0.03830523,
         -0.03216379,  0.02538869],
        [ 0.01388281, -0.04009137, -0.03755167, ..., -0.00310175,
         -0.03693753,  0.06073611]], dtype=float32),
 array([ 0.0214271 , -0.00978368, -0.0285322 , -0.01877919,  0.04006867,
         0.03189117,  0.05600672,  0.00776461, -0.0206714 ,  0.03618724,
         0.00723118, -0.01371707,  0.04144338,  0.06086565, -0.00680179,
        -0.04592123, -0.0556839 ,  0.06013308,  0.03362063,  0.06118033,
         0.06089877,  0.0034836 , -0.03124236, -0.05797988,  0.033

In [208]:
t_critic.seq[0].weight

Parameter containing:
tensor([[ 0.0266, -0.0400, -0.0427,  ...,  0.0600,  0.0558,  0.0139],
        [ 0.0512, -0.0184,  0.0607,  ..., -0.0175, -0.0364, -0.0401],
        [-0.0217,  0.0071,  0.0498,  ..., -0.0047,  0.0225, -0.0376],
        ...,
        [-0.0444, -0.0253,  0.0579,  ..., -0.0552, -0.0383, -0.0031],
        [-0.0307, -0.0536,  0.0384,  ..., -0.0301, -0.0322, -0.0369],
        [-0.0526, -0.0307,  0.0300,  ..., -0.0273,  0.0254,  0.0607]],
       requires_grad=True)

# Train Generator TF

In [212]:
t_fakez = torch.normal(mean=mean, std=std)
tf_fakez = tf.random.normal([batch_size, z_dim])

fk = np.random.normal(size=(batch_size, z_dim)).astype(np.float32)
t_fakez = torch.from_numpy(fk)
tf_fakez = tf.convert_to_tensor(fk)

In [213]:
condvec = cond_generator.sample(batch_size)
if condvec is None:
    c1, m1, col, opt = None, None, None, None
else:
    c1, m1, col, opt = condvec
    c1_tf = tf.convert_to_tensor(c1)
    m1_tf = tf.convert_to_tensor(m1)
    tf_fakez = tf.concat([tf_fakez, c1_tf], axis=1)
    
    c1_t = torch.from_numpy(c1)
    m1_t = torch.from_numpy(m1)
    t_fakez = torch.cat([t_fakez, c1_t], dim=1)

In [244]:
from ctgan.losses import _cond_loss
with tf.GradientTape(persistent=True) as tape:
    tf_fake = generator(tf_fakez, training=True)
    tf_fakeact = _apply_activate(tf_fake, transformer.output_info)
    if c1 is not None:
        tf_y_fake = critic(tf.concat([tf_fakeact, c1_tf], axis=1), training=False)
    else:
        tf_y_fake = critic(tf_fakeact, training=False)
        
    if condvec is None:
        tf_cross_entropy = 0
    else:
        tf_cross_entropy = _cond_loss(transformer.output_info_tensor(), tf_fake, c1_tf, m1_tf)
        
    tf_loss_g = -tf.reduce_mean(tf_y_fake) + tf_cross_entropy

In [246]:
tf_cross_entropy

<tf.Tensor: shape=(), dtype=float32, numpy=2.1578503>

In [245]:
tf_loss_g

<tf.Tensor: shape=(), dtype=float32, numpy=2.1228223>

In [247]:
tf_grad_g = tape.gradient(tf_loss_g, generator.trainable_variables)

In [248]:
tf_grad_g

[<tf.Tensor: shape=(232, 256), dtype=float32, numpy=
 array([[ 2.7526710e-03,  3.0532721e-04, -2.5625275e-03, ...,
          7.7317818e-04,  1.0328727e-03,  4.2899646e-04],
        [ 2.9443530e-03, -2.0082088e-03, -3.3158513e-03, ...,
          2.4447355e-03,  1.4973036e-04, -3.2807905e-03],
        [-2.2296100e-03, -1.2234291e-03, -2.3494317e-04, ...,
         -1.4441580e-03,  2.8443704e-03,  2.2263867e-03],
        ...,
        [-1.8350031e-04,  2.0419480e-05, -6.8647489e-05, ...,
         -4.6593086e-06,  4.2493644e-05,  1.3123547e-06],
        [ 1.2246779e-03, -1.6767414e-04,  1.4468404e-03, ...,
          2.9827594e-03, -3.7376164e-04,  1.5908631e-03],
        [-1.8611145e-03,  4.2059371e-05, -2.0337000e-03, ...,
         -2.4531656e-03,  1.8522781e-04, -6.5414433e-04]], dtype=float32)>,
 <tf.Tensor: shape=(256,), dtype=float32, numpy=
 array([-2.2118911e-09, -2.3283064e-10, -1.9790605e-09, -1.1641532e-10,
         8.1490725e-10, -3.2014214e-10, -6.4028427e-10,  2.3283064e-10,
   

In [249]:
tf_grad_g[0].numpy().T

array([[ 2.7526710e-03,  2.9443530e-03, -2.2296100e-03, ...,
        -1.8350031e-04,  1.2246779e-03, -1.8611145e-03],
       [ 3.0532721e-04, -2.0082088e-03, -1.2234291e-03, ...,
         2.0419480e-05, -1.6767414e-04,  4.2059371e-05],
       [-2.5625275e-03, -3.3158513e-03, -2.3494317e-04, ...,
        -6.8647489e-05,  1.4468404e-03, -2.0337000e-03],
       ...,
       [ 7.7317818e-04,  2.4447355e-03, -1.4441580e-03, ...,
        -4.6593086e-06,  2.9827594e-03, -2.4531656e-03],
       [ 1.0328727e-03,  1.4973036e-04,  2.8443704e-03, ...,
         4.2493644e-05, -3.7376164e-04,  1.8522781e-04],
       [ 4.2899646e-04, -3.2807905e-03,  2.2263867e-03, ...,
         1.3123547e-06,  1.5908631e-03, -6.5414433e-04]], dtype=float32)

In [242]:
t_gen.seq[0].fc.weight.grad

tensor([[ 2.7463e-03,  2.9144e-03, -2.2822e-03,  ..., -1.9383e-04,
          1.2420e-03, -1.8658e-03],
        [ 4.2485e-04, -2.1075e-03, -1.3197e-03,  ...,  2.0667e-05,
         -1.6423e-04,  4.9232e-05],
        [-2.6086e-03, -3.3572e-03,  1.5500e-05,  ..., -6.9507e-05,
          1.4405e-03, -2.0404e-03],
        ...,
        [ 8.3747e-04,  2.3871e-03, -1.5011e-03,  ..., -4.7911e-06,
          2.9972e-03, -2.4614e-03],
        [ 1.0989e-03,  1.2560e-04,  2.8820e-03,  ...,  4.8489e-05,
         -3.8367e-04,  2.0448e-04],
        [ 4.5667e-04, -3.3713e-03,  2.1251e-03,  ...,  1.2318e-06,
          1.5836e-03, -6.3234e-04]])

In [256]:
tf.reduce_mean(tf_grad_g[-1].numpy())

<tf.Tensor: shape=(), dtype=float32, numpy=-5.2918866e-05>

In [257]:
torch.mean(t_gen.seq[2].bias.grad)

tensor(-5.1303e-05)

In [262]:
t_gen.seq[1].bn.running_mean

tensor([ 0.0572, -0.0195,  0.0021, -0.0207,  0.0047,  0.0383, -0.0073,  0.0091,
        -0.0534, -0.0418, -0.0529,  0.0559,  0.0307,  0.0389,  0.0084,  0.0185,
         0.0076,  0.0004,  0.0176, -0.0264, -0.0187,  0.0013,  0.0128, -0.0228,
        -0.0268, -0.0414,  0.0320,  0.0073, -0.0389,  0.0020,  0.0042,  0.0085,
        -0.0025, -0.0163, -0.0196,  0.0188, -0.0171, -0.0374,  0.0420,  0.0213,
        -0.0185,  0.0631, -0.0323,  0.0017,  0.0182, -0.0362, -0.0069,  0.0216,
        -0.0117, -0.0068, -0.0112, -0.0246,  0.0299,  0.0328, -0.0216,  0.0450,
         0.0162,  0.0464, -0.0400, -0.0054,  0.0180,  0.0008,  0.0793,  0.0223,
        -0.0065, -0.0437, -0.0607,  0.0368, -0.0533,  0.0031, -0.0015, -0.0403,
         0.0055, -0.0122, -0.0079, -0.0009, -0.0542, -0.0165,  0.0060, -0.0072,
        -0.0415,  0.0765, -0.0266, -0.0152, -0.0230, -0.0143,  0.0374, -0.0509,
        -0.0105, -0.0247,  0.0037,  0.0547, -0.0265,  0.0300, -0.0491,  0.0452,
        -0.0181, -0.0162, -0.0062,  0.02

In [255]:
tf_grad_g

[<tf.Tensor: shape=(232, 256), dtype=float32, numpy=
 array([[ 2.7526710e-03,  3.0532721e-04, -2.5625275e-03, ...,
          7.7317818e-04,  1.0328727e-03,  4.2899646e-04],
        [ 2.9443530e-03, -2.0082088e-03, -3.3158513e-03, ...,
          2.4447355e-03,  1.4973036e-04, -3.2807905e-03],
        [-2.2296100e-03, -1.2234291e-03, -2.3494317e-04, ...,
         -1.4441580e-03,  2.8443704e-03,  2.2263867e-03],
        ...,
        [-1.8350031e-04,  2.0419480e-05, -6.8647489e-05, ...,
         -4.6593086e-06,  4.2493644e-05,  1.3123547e-06],
        [ 1.2246779e-03, -1.6767414e-04,  1.4468404e-03, ...,
          2.9827594e-03, -3.7376164e-04,  1.5908631e-03],
        [-1.8611145e-03,  4.2059371e-05, -2.0337000e-03, ...,
         -2.4531656e-03,  1.8522781e-04, -6.5414433e-04]], dtype=float32)>,
 <tf.Tensor: shape=(256,), dtype=float32, numpy=
 array([-2.2118911e-09, -2.3283064e-10, -1.9790605e-09, -1.1641532e-10,
         8.1490725e-10, -3.2014214e-10, -6.4028427e-10,  2.3283064e-10,
   

In [265]:
g_opt.apply_gradients(zip(tf_grad_g, generator.trainable_variables))

<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=1>

# Train Generator Torch

In [214]:
t_fake = t_gen(t_fakez)

In [215]:
from torch.nn import functional
def torch_aa(data, transformer_info):
    data_t = []
    st = 0
    for item in transformer_info:
        if item[1] == 'tanh':
            ed = st + item[0]
            data_t.append(torch.tanh(data[:, st:ed]))
            st = ed
        elif item[1] == 'softmax':
            ed = st + item[0]
            data_t.append(functional.gumbel_softmax(data[:, st:ed], tau=0.2))
            st = ed
        else:
            assert 0

    return torch.cat(data_t, dim=1)
t_fakeact = torch_aa(t_fake, transformer.output_info)

In [229]:
if c1 is not None:
    t_y_fake = t_critic(torch.cat([t_fakeact, c1_t], dim=1))
else:
    t_y_fake = t_critic(t_fakeact)

In [235]:
def t_cond_loss(transformer_info, data, c, m):
    loss = []
    st = 0
    st_c = 0
    skip = False
    for item in transformer_info:
        if item[1] == 'tanh':
            st += item[0]
            skip = True

        elif item[1] == 'softmax':
            if skip:
                skip = False
                st += item[0]
                continue

            ed = st + item[0]
            ed_c = st_c + item[0]
            tmp = functional.cross_entropy(
                data[:, st:ed],
                torch.argmax(c[:, st_c:ed_c], dim=1),
                reduction='none'
            )
            loss.append(tmp)
            st = ed
            st_c = ed_c

        else:
            assert 0

    loss = torch.stack(loss, dim=1)
    return (loss * m).sum() / data.size()[0]

In [236]:
if condvec is None:
    t_cross_entropy = 0
else:
    t_cross_entropy = t_cond_loss(transformer.output_info, t_fake, c1_t, m1_t)
t_cross_entropy

tensor(2.1579, grad_fn=<DivBackward0>)

In [238]:
t_loss_g = -torch.mean(t_y_fake) + t_cross_entropy
t_loss_g

tensor(2.1228, grad_fn=<AddBackward0>)

In [239]:
optimizerG.zero_grad()

In [240]:
t_gen.seq[0].fc.weight

Parameter containing:
tensor([[ 0.0408, -0.0033, -0.0356,  ..., -0.0503, -0.0280,  0.0357],
        [ 0.0062,  0.0638,  0.0546,  ...,  0.0050,  0.0655, -0.0550],
        [-0.0574,  0.0120,  0.0102,  ...,  0.0218, -0.0389,  0.0038],
        ...,
        [ 0.0208, -0.0526, -0.0477,  ...,  0.0300,  0.0052, -0.0034],
        [ 0.0180,  0.0642,  0.0654,  ...,  0.0621,  0.0548,  0.0408],
        [-0.0390, -0.0436, -0.0042,  ...,  0.0448,  0.0636, -0.0119]],
       requires_grad=True)

In [241]:
t_loss_g.backward()

In [242]:
t_gen.seq[0].fc.weight.grad

tensor([[ 2.7463e-03,  2.9144e-03, -2.2822e-03,  ..., -1.9383e-04,
          1.2420e-03, -1.8658e-03],
        [ 4.2485e-04, -2.1075e-03, -1.3197e-03,  ...,  2.0667e-05,
         -1.6423e-04,  4.9232e-05],
        [-2.6086e-03, -3.3572e-03,  1.5500e-05,  ..., -6.9507e-05,
          1.4405e-03, -2.0404e-03],
        ...,
        [ 8.3747e-04,  2.3871e-03, -1.5011e-03,  ..., -4.7911e-06,
          2.9972e-03, -2.4614e-03],
        [ 1.0989e-03,  1.2560e-04,  2.8820e-03,  ...,  4.8489e-05,
         -3.8367e-04,  2.0448e-04],
        [ 4.5667e-04, -3.3713e-03,  2.1251e-03,  ...,  1.2318e-06,
          1.5836e-03, -6.3234e-04]])

In [263]:
optimizerG.step()