In [1]:
# -*- coding: utf-8 -*-
import sugartensor as tf
import numpy as np
import matplotlib.pyplot as plt

__author__ = 'buriburisuri@gmail.com'
# only use gpu 1
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="3"

# limit gpu mem to 4GB
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)

sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

# set log level to debug
tf.sg_verbosity(10)

#
# hyper parameters
#

batch_size = 100   # batch size
cat_dim = 5   # total categorical factor
con_dim = 2    # total continuous factor
rand_dim = 38  # total random latent dimension


#
# create generator & discriminator function
#

# generator network
def generator(tensor):

    # reuse flag
    reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0

    with tf.sg_context(name='generator', size=4, stride=2, act='relu', bn=True, reuse=reuse):
        res = (tensor
               .sg_dense(dim=1024, name='fc1')
               .sg_dense(dim=9*9*128, name='fc2')
               .sg_reshape(shape=(-1, 9, 9, 128))
               .sg_upconv(dim=64, name='conv1')
               .sg_upconv(dim=1, act='sigmoid', bn=False, name='conv2'))
    return res

#
# inputs
#

# target_number
target_num = tf.placeholder(dtype=tf.sg_intx, shape=batch_size)
# target continuous variable # 1
target_cval_1 = tf.placeholder(dtype=tf.sg_floatx, shape=batch_size)
# target continuous variable # 2
target_cval_2 = tf.placeholder(dtype=tf.sg_floatx, shape=batch_size)

# category variables
z = (tf.ones(batch_size, dtype=tf.sg_intx) * target_num).sg_one_hot(depth=cat_dim)

# continuous variables
z = z.sg_concat(target=[target_cval_1.sg_expand_dims(), target_cval_2.sg_expand_dims()])

# random seed = categorical variable + continuous variable + random normal
z = z.sg_concat(target=tf.random_normal((batch_size, rand_dim)))

In [2]:
z

<tf.Tensor 'ConcatV2_1:0' shape=(100, 45) dtype=float32>

In [3]:
# generator
gen = generator(z).sg_squeeze()

#
# run generator
#


def run_generator(sess, num, x1, x2, fig_name='sample.png'):
        tf.sg_init(sess)
        # restore parameters
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint('asset/train/acgan-barcode36'))

        # run generator
        imgs = sess.run(gen, {target_num: num,
                              target_cval_1: x1,
                              target_cval_2: x2})

        # plot result
        _, ax = plt.subplots(10, 10, sharex=True, sharey=True)
        for i in range(10):
            for j in range(10):
                ax[i][j].imshow(imgs[i * 10 + j], 'gray')
                ax[i][j].set_axis_off()
        plt.savefig('asset/train/acgan-barcode36/' + fig_name, dpi=600)
        tf.sg_info('Sample image saved to "asset/train/acgan-barcode%s"' % fig_name)
        plt.close()


#
# draw sample by categorical division
#

# fake image
run_generator(sess, np.random.randint(0, cat_dim, batch_size),
              np.random.uniform(0, 1, batch_size), np.random.uniform(0, 1, batch_size),
              fig_name='fake.png')

# classified image
run_generator(sess, np.arange(10).repeat(10), np.ones(batch_size) * 0.5, np.ones(batch_size) * 0.5)

#
# draw sample by continuous division
#

for i in range(10):
    run_generator(sess, np.ones(batch_size) * i,
                  np.linspace(0, 1, 10).repeat(10),
                  np.expand_dims(np.linspace(0, 1, 10), axis=1).repeat(10, axis=1).T.flatten(),
                  fig_name='sample%d.png' % i)

INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:15:46.974:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodefake.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:15:54.958:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:16:02.762:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample0.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:16:10.224:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample1.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:16:18.117:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample2.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:16:25.777:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample3.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:16:33.484:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample4.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:16:41.536:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample5.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:16:49.326:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample6.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:16:57.358:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample7.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:17:05.567:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample8.png"


INFO:tensorflow:Restoring parameters from asset/train/acgan-barcode36/model.ckpt-15500


I 0423:00:17:13.723:<ipython-input-3-6ccfa98286f6>:27] Sample image saved to "asset/train/acgan-barcodesample9.png"
