In [1]:
import tensorflow as tf
import torch
import numpy as np
import pandas as pd
import tensorflow_probability as tfp
import matplotlib.pyplot as plt

In [2]:
from ctgan.transformer import DataTransformer
from ctgan.sampler import Sampler
from ctgan.conditional import ConditionalGenerator

In [3]:
DEMO_URL = 'http://ctgan-data.s3.amazonaws.com/census.csv.gz'
train_data = pd.read_csv(DEMO_URL, compression='gzip')
discrete_columns = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
    'income'
]

In [4]:
transformer = DataTransformer()
transformer.fit(train_data, discrete_columns)
train_data = transformer.transform(train_data)



In [5]:
data_sampler = Sampler(train_data, transformer.output_info)
data_dim = transformer.output_dimensions
cond_generator = ConditionalGenerator(
    train_data, transformer.output_info, True)

In [6]:
128 + cond_generator.n_opt

232

In [7]:
l2scale = 1e-6

In [8]:
class GenActLayer(tf.keras.layers.Layer):
    def __init__(self, input_dim, num_outputs, transformer_info, tau):
        super(GenActLayer, self).__init__()
        self.num_outputs = num_outputs
        self.transformer_info = transformer_info
        self.tau = tau
        self.fc = tf.keras.layers.Dense(
            num_outputs, input_dim=(input_dim,),
            kernel_initializer=partial(init_bounded, dim=input_dim),
            bias_initializer=partial(init_bounded, dim=input_dim))

    def call(self, inputs, **kwargs):
        outputs = self.fc(inputs, **kwargs)
        shape = outputs.shape
        #print(shape)
        tf.print(shape)
        data_t = tf.zeros(shape)
        #print(self.transformer_info[0])
        x = tf.constant([[2, 3]], dtype=tf.int32)
        #print(x)
        for idx in self.transformer_info[0]:
            #r = tf.range(x[0], idx[1])
            #print(r)
            #print(tf.gather(data_t, idx, axis=1))
            #print("Gather:", tf.gather(data_t, idx, axis=1))
            #print("Idx:", idx)
            #print("Indexing:", data_t[:, :idx[0]])
            tf.print(data_t)
            act = tf.where(idx[5] == 0, 
                           tf.math.tanh(outputs[:,idx[0]:idx[1]]), 
                           self._gumbel_softmax(outputs[:, idx[0]:idx[1]], tau=self.tau))
            #print("act:", act)
            data_t = tf.concat([data_t[:,:idx[0]],
                                act,
                                data_t[:,idx[1]:]],
                               axis=1)
            tf.print(data_t)
            
        #print(data_t)
        return outputs, data_t
    
    @tf.function
    def _activation(self, data_info, data):
        return tf.where(data_info[5] == 0,
                  tf.math.tanh(data[:, data_info[0]:data_info[1]]),
                  self._gumbel_softmax(data[:, data_info[0]:data_info[1]], tau=self.tau))

    @tf.function
    def _gumbel_softmax(self, logits, tau=1.0, hard=False, dim=-1):
        r"""
        Samples from the Gumbel-Softmax distribution (`Link 1`_  `Link 2`_) and optionally discretizes.

        Args:
          logits: `[..., num_features]` unnormalized log probabilities
          tau: non-negative scalar temperature
          hard: if ``True``, the returned samples will be discretized as one-hot vectors,
                but will be differentiated as if it is the soft sample in autograd
          dim (int): A dimension along which softmax will be computed. Default: -1.

        Returns:
          Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
          If ``hard=True``, the returned samples will be one-hot, otherwise they will
          be probability distributions that sum to 1 across `dim`.

        .. note::
          The main trick for `hard` is to do  `y_hard - y_soft.detach() + y_soft`

          It achieves two things:
          - makes the output value exactly one-hot
          (since we add then subtract y_soft value)
          - makes the gradient equal to y_soft gradient
          (since we strip all other gradients)

        .. _Link 1:
            https://arxiv.org/abs/1611.00712
        .. _Link 2:
            https://arxiv.org/abs/1611.01144
        """

        gumbel_dist = tfp.distributions.Gumbel(loc=0, scale=1)
        gumbels = gumbel_dist.sample(tf.shape(logits))
        gumbels = (logits + gumbels) / tau
        y = tf.nn.softmax(gumbels, dim)

        if hard:
            # Straight through.
            index = tf.math.reduce_max(y, 1, keep_dims=True)
            y_hard = tf.cast(tf.equal(y, index), y.dtype)
            y = tf.stop_gradient(y_hard - y) + y
        return y

In [11]:
class Generator(tf.keras.Model):
    def __init__(self, input_dim, gen_dims, data_dim, transformer_info, tau):
        super(Generator, self).__init__()

        self.input_dim = input_dim
        self.model = list()
        dim = input_dim
        for layer_dim in list(gen_dims):
            self.model += [ResidualLayer(dim, layer_dim)]
            dim += layer_dim

        #self.model += [tf.keras.layers.Dense(
        #    data_dim, input_dim=(dim,),
        #    kernel_initializer=partial(init_bounded, dim=dim),
        #    bias_initializer=partial(init_bounded, dim=dim))]
        self.model += [GenActLayer(dim, data_dim, transformer_info, tau)]

    def call(self, x, **kwargs):
        out = x
        for layer in self.model:
            out = layer(out, **kwargs)
        return out

In [8]:
batch_size = 500
embedding_dim = z_dim = 128
mean = torch.zeros(batch_size, embedding_dim)
std = mean + 1

In [9]:
from ctgan.layers import ResidualLayer, init_bounded
from ctgan.models import Critic, Generator
from functools import partial
generator = Generator(
    128 + cond_generator.n_opt, (256,256), data_dim, transformer.output_info_tensor(), 0.2)

critic = Critic(data_dim + cond_generator.n_opt, (256,256), 10)

In [10]:
generator.build((batch_size, generator.input_dim))
critic.build((batch_size, critic.input_dim))



In [56]:
generator.layers

[<ctgan.layers.ResidualLayer at 0x7fc866c3dd90>,
 <ctgan.layers.ResidualLayer at 0x7fc866dfe6d0>,
 <__main__.GenActLayer at 0x7fc866c3de90>]

In [29]:
generator.layers[0]._layers[1].trainable_variables

[<tf.Variable 'residual_layer_2/batch_normalization_2/gamma:0' shape=(256,) dtype=float32, numpy=
 array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.

In [57]:
generator.trainable_variables

[<tf.Variable 'residual_layer_16/dense_48/kernel:0' shape=(232, 256) dtype=float32, numpy=
 array([[-0.05475991, -0.01095014,  0.05805656, ...,  0.02841506,
         -0.03025322, -0.00633234],
        [ 0.01443698,  0.01880741, -0.00380489, ..., -0.03929146,
         -0.00306889,  0.0562725 ],
        [-0.01052445, -0.05357137, -0.06477886, ...,  0.01439786,
          0.04665996, -0.06051522],
        ...,
        [-0.01289388, -0.04914027,  0.05129219, ..., -0.05793067,
          0.03870263, -0.00022382],
        [-0.04000986, -0.03895237, -0.04620291, ..., -0.00065832,
          0.03954718,  0.00330844],
        [ 0.01802916, -0.05134067,  0.0075512 , ...,  0.045027  ,
         -0.04505192,  0.025952  ]], dtype=float32)>,
 <tf.Variable 'residual_layer_16/dense_48/bias:0' shape=(256,) dtype=float32, numpy=
 array([-4.7047235e-02,  4.9967319e-04, -3.6795586e-02, -2.8126828e-02,
         2.6086263e-02,  4.8789091e-02, -1.2748964e-02,  3.8183346e-02,
        -3.9298117e-02, -7.3188022e-0

In [30]:
[np.mean(g) for g in generator.trainable_variables]

AttributeError: 'DType' object has no attribute 'type'

In [11]:
batch_size = 500
embedding_dim = z_dim = 128
mean = torch.zeros(batch_size, embedding_dim)
std = mean + 1

# Copy weights to TF model

In [21]:
critic.layers[1].set_weights([t_critic.seq[0].weight.detach().numpy().T, t_critic.seq[0].bias.detach().numpy()])
critic.layers[3].set_weights([t_critic.seq[2].weight.detach().numpy().T, t_critic.seq[2].bias.detach().numpy()])
critic.layers[5].set_weights([t_critic.seq[4].weight.detach().numpy().T, t_critic.seq[4].bias.detach().numpy()])

In [22]:
generator.layers[1]._layers[0].set_weights([t_gen.seq[0].fc.weight.detach().numpy().T, t_gen.seq[0].fc.bias.detach().numpy()])
generator.layers[2]._layers[0].set_weights([t_gen.seq[1].fc.weight.detach().numpy().T, t_gen.seq[1].fc.bias.detach().numpy()])
generator.layers[3].set_weights([t_gen.seq[2].weight.detach().numpy().T, t_gen.seq[2].bias.detach().numpy()])

# Train models

In [12]:
t_fakez = torch.normal(mean=mean, std=std)
tf_fakez = tf.random.normal([batch_size, z_dim])

fk = np.random.normal(size=(batch_size, z_dim)).astype(np.float32)
t_fakez = torch.from_numpy(fk)
tf_fakez = tf.convert_to_tensor(fk)

In [13]:
condvec = cond_generator.sample(batch_size)
if condvec is None:
    c1, m1, col, opt = None, None, None, None
    real = data_sampler.sample(batch_size, col, opt)
else:
    c1, m1, col, opt = condvec
    c1_tf = tf.convert_to_tensor(c1)
    m1_tf = tf.convert_to_tensor(m1)
    tf_fakez = tf.concat([tf_fakez, c1_tf], axis=1)
    
    c1_t = torch.from_numpy(c1)
    m1_t = torch.from_numpy(m1)
    t_fakez = torch.cat([t_fakez, c1_t], dim=1)

    perm = np.arange(batch_size)
    np.random.shuffle(perm)
    real = data_sampler.sample(batch_size, col[perm], opt[perm])
    tf_c2 = tf.gather(c1_tf, perm)
    t_c2 = c1_t[perm]

In [14]:
tf_fake = generator(tf_fakez, training=True)



In [26]:
t_fake = t_gen(t_fakez)

In [27]:
from ctgan.layers import _apply_activate
tf_fakeact = _apply_activate(tf_fake, transformer.output_info)

In [28]:
from torch.nn import functional
def torch_aa(data, transformer_info):
    data_t = []
    st = 0
    for item in transformer_info:
        if item[1] == 'tanh':
            ed = st + item[0]
            data_t.append(torch.tanh(data[:, st:ed]))
            st = ed
        elif item[1] == 'softmax':
            ed = st + item[0]
            data_t.append(functional.gumbel_softmax(data[:, st:ed], tau=0.2))
            st = ed
        else:
            assert 0

    return torch.cat(data_t, dim=1)
t_fakeact = torch_aa(t_fake, transformer.output_info)

In [29]:
# Copy fakeact tensor
tf_fakeact = tf.convert_to_tensor(t_fakeact.detach().numpy())
tf_real = tf.convert_to_tensor(real.astype('float32'))

In [30]:
tf_real = tf.convert_to_tensor(real.astype('float32'))
t_real = torch.from_numpy(real.astype('float32'))

if c1 is not None:
    tf_fake_cat = tf.concat([tf_fakeact, c1_tf], axis=1)
    tf_real_cat = tf.concat([tf_real, tf_c2], axis=1)
    
    t_fake_cat = torch.cat([t_fakeact, c1_t], dim=1)
    t_real_cat = torch.cat([t_real, t_c2], dim=1)
else:
    tf_real_cat = tf_real
    tf_fake_cat = tf_fake
    
    t_real_cat = t_real
    t_fake_cat = t_fake

In [19]:
@tf.function
def cond_loss(transformer_info, data, c, m):
    loss = []

    loss = tf.zeros(tf.shape(m))
    s = tf.shape(m)
    print("init loss:", loss)
    #i = tf.constant(0, dtype=tf.int32)

    for item in transformer_info:
        print(item)
        if item[4] == 0:
        #st, ed, st_c, ed_c, is_continuous, is_softmax = item
        #if is_continuous == 0 and is_softmax == 1:
            data_logsoftmax = data[:, item[0]:item[1]]
            c_argmax = tf.math.argmax(c[:, item[2]:item[3]], axis=1)
            l = tf.reshape(tf.nn.sparse_softmax_cross_entropy_with_logits(c_argmax, data_logsoftmax), [-1, 1])
            print("l:", l)
            print("loss slice:", loss[:,:item[-1]])
            loss = tf.concat([loss[:, :item[-1]], l, loss[:, item[-1]+1:]], axis=1)
            print("loss:", loss)
            print()
            #i = i + 1

    #loss = tf.stack(loss, axis=1)
    print(loss)
    return tf.reduce_sum(loss * m) / tf.cast(tf.shape(data)[0], dtype=tf.float32)

In [20]:
transformer.cond_info_tensor()

[<tf.Tensor: shape=(5,), dtype=int32, numpy=array([0, 9, 0, 9, 0], dtype=int32)>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 9, 25,  9, 25,  1], dtype=int32)>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([25, 32, 25, 32,  2], dtype=int32)>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([32, 47, 32, 47,  3], dtype=int32)>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([47, 53, 47, 53,  4], dtype=int32)>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([53, 58, 53, 58,  5], dtype=int32)>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([58, 60, 58, 60,  6], dtype=int32)>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 60, 102,  60, 102,   7], dtype=int32)>,
 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([102, 104, 102, 104,   8], dtype=int32)>]

In [21]:
#ti = ti, tf.constant(9)
cond_loss = cond_loss(transformer.cond_info_tensor(), tf_fake[1], c1_tf, m1_tf)
cond_loss

init loss: Tensor("zeros:0", shape=(500, 9), dtype=float32)
Tensor("transformer_info:0", shape=(5,), dtype=int32)
l: Tensor("Reshape:0", shape=(500, 1), dtype=float32)
loss slice: Tensor("strided_slice_7:0", shape=(500, None), dtype=float32)
loss: Tensor("concat:0", shape=(500, None), dtype=float32)

Tensor("transformer_info_1:0", shape=(5,), dtype=int32)
l: Tensor("Reshape:0", shape=(500, 1), dtype=float32)
loss slice: Tensor("strided_slice_7:0", shape=(500, None), dtype=float32)
loss: Tensor("concat:0", shape=(500, None), dtype=float32)

Tensor("transformer_info_2:0", shape=(5,), dtype=int32)
l: Tensor("Reshape:0", shape=(500, 1), dtype=float32)
loss slice: Tensor("strided_slice_7:0", shape=(500, None), dtype=float32)
loss: Tensor("concat:0", shape=(500, None), dtype=float32)

Tensor("transformer_info_3:0", shape=(5,), dtype=int32)
l: Tensor("Reshape:0", shape=(500, 1), dtype=float32)
loss slice: Tensor("strided_slice_7:0", shape=(500, None), dtype=float32)
loss: Tensor("concat:0", s

<tf.Tensor: shape=(), dtype=float32, numpy=0.25683954>

In [108]:
r = tf.range(9)

In [109]:
r

<tf.Tensor: shape=(9,), dtype=int32, numpy=array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=int32)>

In [110]:
r[0]

<tf.Tensor: shape=(), dtype=int32, numpy=0>

In [111]:
i = tf.constant(0, dtype=tf.int32)

In [112]:
i

<tf.Tensor: shape=(), dtype=int32, numpy=0>

In [22]:
@tf.function
def test_random():
    return tf.random.normal([10, 5])

In [23]:
x = test_random()

In [24]:
x

<tf.Tensor: shape=(10, 5), dtype=float32, numpy=
array([[ 1.7181406 ,  1.257453  , -1.2608223 ,  0.66196257, -0.63675344],
       [-0.40817076,  1.289678  , -0.34326625,  1.0015084 , -0.7475674 ],
       [ 0.03847428,  0.5340957 , -0.49519396,  0.24867155, -0.30822927],
       [ 0.26964048,  2.3972166 ,  0.2736266 ,  0.29006767,  1.8663241 ],
       [-0.7362258 ,  1.5243033 , -0.97702163,  1.0311329 , -0.90981704],
       [-1.6452307 , -2.365368  ,  0.6122781 ,  0.22597076, -1.1475091 ],
       [ 0.50468564,  0.569523  ,  0.83959645,  0.33996007,  0.8772535 ],
       [ 1.0880069 , -1.1181055 ,  0.12085038,  1.9761903 ,  0.6264744 ],
       [ 0.45854285,  0.8642469 , -0.997979  ,  0.4568015 , -0.54131883],
       [ 1.0633543 , -1.7459028 ,  0.93327826, -0.14066479, -1.8550262 ]],
      dtype=float32)>

In [25]:
x = test_random()

In [26]:
x

<tf.Tensor: shape=(10, 5), dtype=float32, numpy=
array([[-0.6385254 , -0.1601175 ,  0.09135835,  2.0197587 ,  0.5813612 ],
       [-0.45433167,  0.6446218 , -1.3282415 , -1.028059  , -1.2592136 ],
       [-0.83688265, -0.2657525 ,  1.1453569 , -1.4795412 ,  1.0295745 ],
       [ 0.10933934, -0.32809815,  0.95528835, -0.17743436, -0.62262774],
       [-1.5864108 ,  1.7068386 ,  0.21764839,  0.71861815, -0.9992302 ],
       [-0.6292568 ,  0.38587192,  0.2831245 , -1.8745811 , -0.4392017 ],
       [-0.16635115,  0.58761823,  0.24187803, -0.13324393,  1.8122779 ],
       [ 0.5585834 , -0.8186332 , -1.3387676 ,  2.1509855 ,  1.050257  ],
       [ 1.1381578 ,  0.48276943,  2.020263  , -0.28788325, -0.36081553],
       [ 1.2688764 ,  0.75428027, -0.46783352, -0.45494577, -2.1779795 ]],
      dtype=float32)>