In [1]:
import sugartensor as tf
import numpy as np


__author__ = 'namju.kim@kakaobrain.com'

# only use gpu 0
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# set log level to debug
tf.sg_verbosity(10)

#
# hyper parameters
#

batch_size = 32   # batch size
cat_dim = 10   # total categorical factor
con_dim = 2    # total continuous factor
rand_dim = 38  # total random latent dimension


#
# create generator & discriminator function
#

def generator(tensor):

    # reuse flag
    reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0

    with tf.sg_context(name='generator', size=4, stride=2, act='leaky_relu', bn=True, reuse=reuse):
        res = (tensor
               .sg_dense(dim=1024, name='fc1')
               .sg_dense(dim=7*7*128, name='fc2')
               .sg_reshape(shape=(-1, 7, 7, 128))
               .sg_upconv(dim=64, name='conv1')
               .sg_upconv(dim=1, act='sigmoid', bn=False, name='conv2'))
    return res


def discriminator(tensor):

    # reuse flag
    reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0

    with tf.sg_context(name='discriminator', size=4, stride=2, act='leaky_relu', bn=True, reuse=reuse):
        # shared part
        shared = (tensor
                  .sg_conv(dim=64, name='conv1')
                  .sg_conv(dim=128, name='conv2')
                  .sg_flatten()
                  .sg_dense(dim=1024, name='fc1'))

        # discriminator end
        disc = shared.sg_dense(dim=1, act='linear', bn=False, name='disc').sg_squeeze()

        # shared recognizer part
        recog_shared = shared.sg_dense(dim=128, name='recog')

        # categorical auxiliary classifier end
        cat = recog_shared.sg_dense(dim=cat_dim, act='linear', bn=False, name='cat')

        # continuous auxiliary classifier end
        con = recog_shared.sg_dense(dim=con_dim, act='sigmoid', bn=False, name='con')

        return disc, cat, con


#
# inputs
#

# MNIST input tensor ( with QueueRunner )
data = tf.sg_data.Mnist(batch_size=batch_size)

# input images and label
x = data.train.image

# labels for discriminator
y_real = tf.ones(batch_size)
y_fake = tf.zeros(batch_size)

# categorical latent variable
z_cat = tf.multinomial(tf.ones((batch_size, cat_dim), dtype=tf.sg_floatx) / cat_dim, 1).sg_squeeze().sg_int()
# continuous latent variable
z_con = tf.random_uniform((batch_size, con_dim))
# random latent variable dimension
z_rand = tf.random_uniform((batch_size, rand_dim))
# latent variable
z = tf.concat([z_cat.sg_one_hot(depth=cat_dim), z_con, z_rand], 1)


#
# Computational graph
#

# generator
gen = generator(z)

# add image summary
tf.sg_summary_image(x, name='real')
tf.sg_summary_image(gen, name='fake')

# discriminator
disc_real, _, _ = discriminator(x)
disc_fake, cat_fake, con_fake = discriminator(gen)


#
# loss
#

# discriminator loss
loss_d_r = disc_real.sg_bce(target=y_real, name='disc_real')
loss_d_f = disc_fake.sg_bce(target=y_fake, name='disc_fake')
loss_d = (loss_d_r + loss_d_f) / 2


# generator loss
loss_g = disc_fake.sg_bce(target=y_real, name='gen')

# categorical factor loss
loss_c = cat_fake.sg_ce(target=z_cat, name='cat')

# continuous factor loss
loss_con = con_fake.sg_mse(target=z_con, name='con').sg_mean(axis=1)


#
# train ops
#

# discriminator train ops
train_disc = tf.sg_optim(loss_d + loss_c + loss_con, lr=0.0001, category='discriminator')
# generator train ops
# maybe no need for loss_c and loss_con? why would you want that
train_gen = tf.sg_optim(loss_g + loss_c + loss_con, lr=0.001, category='generator')


#
# training
#

# def alternate training func
@tf.sg_train_func
def alt_train(sess, opt):
    l_disc = sess.run([loss_d, train_disc])[0]  # training discriminator
    l_gen = sess.run([loss_g, train_gen])[0]  # training generator
    return np.mean(l_disc) + np.mean(l_gen)

# do training
alt_train(log_interval=10, max_ep=30, ep_size=data.train.num_batch, early_stop=False,
          save_dir='asset/train/infogan')


Extracting ./asset/data/mnist/train-images-idx3-ubyte.gz
Extracting ./asset/data/mnist/train-labels-idx1-ubyte.gz
Extracting ./asset/data/mnist/t10k-images-idx3-ubyte.gz
Extracting ./asset/data/mnist/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:global_step/sec: 0


I 0421:04:53:36.727:sg_train.py:327] Training started from epoch[000]-step[0].
train:  23%|█████▌                  | 402/1718 [00:09<00:29, 45.06b/s]

INFO:tensorflow:global_step/sec: 81.3651


train:  51%|████████████▏           | 876/1718 [00:19<00:19, 44.22b/s]

INFO:tensorflow:global_step/sec: 94.6017


train:  79%|██████████████████▏    | 1358/1718 [00:29<00:07, 45.29b/s]

INFO:tensorflow:global_step/sec: 96.1996


I 0421:04:54:14.360:sg_train.py:301] 	Epoch[000:gs=3436] - loss = 1.428414
train:   6%|█▍                      | 101/1718 [00:02<00:33, 48.73b/s]

INFO:tensorflow:global_step/sec: 92.6001


train:  34%|████████                | 581/1718 [00:12<00:26, 42.52b/s]

INFO:tensorflow:global_step/sec: 95.6006


train:  61%|██████████████         | 1050/1718 [00:22<00:14, 45.09b/s]

INFO:tensorflow:global_step/sec: 94.0999


train:  89%|████████████████████▍  | 1523/1718 [00:32<00:03, 50.15b/s]

INFO:tensorflow:global_step/sec: 95.0001


I 0421:04:54:50.431:sg_train.py:301] 	Epoch[001:gs=6872] - loss = 1.433505
train:  17%|████                    | 288/1718 [00:06<00:28, 49.59b/s]

INFO:tensorflow:global_step/sec: 96.5


train:  44%|██████████▋             | 763/1718 [00:16<00:20, 46.06b/s]

INFO:tensorflow:global_step/sec: 94.4001


I 0421:04:55:26.797:sg_train.py:301] 	Epoch[002:gs=10308] - loss = 1.438771
I 0421:04:56:03.790:sg_train.py:301] 	Epoch[003:gs=13744] - loss = 1.430687
I 0421:04:56:40.675:sg_train.py:301] 	Epoch[004:gs=17180] - loss = 1.459116
I 0421:04:57:17.722:sg_train.py:301] 	Epoch[005:gs=20616] - loss = 1.454191
I 0421:04:57:54.223:sg_train.py:301] 	Epoch[006:gs=24052] - loss = 1.462838
I 0421:04:58:29.565:sg_train.py:301] 	Epoch[007:gs=27488] - loss = 1.487397
I 0421:04:59:03.272:sg_train.py:301] 	Epoch[008:gs=30924] - loss = 1.499143
I 0421:04:59:36.898:sg_train.py:301] 	Epoch[009:gs=34360] - loss = 1.494002
I 0421:05:00:10.604:sg_train.py:301] 	Epoch[010:gs=37796] - loss = 1.499426
I 0421:05:00:44.384:sg_train.py:301] 	Epoch[011:gs=41232] - loss = 1.494208
I 0421:05:01:17.978:sg_train.py:301] 	Epoch[012:gs=44668] - loss = 1.526422
I 0421:05:01:51.472:sg_train.py:301] 	Epoch[013:gs=48104] - loss = 1.577077
I 0421:05:02:24.942:sg_train.py:301] 	Epoch[014:gs=51540] - loss = 1.526411
I 0421:05:02