In [1]:
from nets import GenConv, DisConv, QConv
from config import InfoGANConfig, SAVE_DIR
from ops import mnist_for_gan, optimizer, clip, get_shape
from utils import show_gray_image_3d
import tensorflow as tf
import numpy as np
import logging
logging.basicConfig(format = "[%(asctime)s] %(message)s", datefmt="%m%d %H:%M:%S")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [2]:
def sample_c(c_size, add_cv, index = -1):
    '''
    Args:
        c_size - int
            number of samples
        add_v - int
            number of additional continuous random variables [-1, 1] uniform
        index - int
            default to be -1
    Return:
        [c_size, 10 + add_cv]
            10 for classification
            add_cv for independent continuos
    '''
    
    classify = np.zeros([c_size, 10])
    conti = np.random.uniform(low = -1.0, high = 1.0, size = [c_size, add_cv])
    if index < 0:
        index = np.random.randint(10)
    classify[:,index] = 1
    return np.concatenate((classify, conti), axis = 1)

def sample_z(z_size, z_dim):
    return np.random.uniform(low=-1, high=1, size= [z_size, z_dim])

In [3]:
class InfoGAN(InfoGANConfig):
    def __init__(self):
        InfoGANConfig.__init__(self)
        logger.info("Building model starts...")
        tf.reset_default_graph()
        self.generator = GenConv(name ='g_conv', batch_size=self.batch_size)
        self.discriminator = DisConv(name='d_conv')
        self.classifier = QConv(name='q_conv', c_dim=self.c_dim)
        self.dataset = mnist_for_gan()
        
        self.X = tf.placeholder(tf.float32, shape = [self.batch_size, self.x_size, self.x_size, self.x_channel])
        self.Z = tf.placeholder(tf.float32, shape = [self.batch_size, self.z_dim])
        self.C = tf.placeholder(tf.float32, shape = [self.batch_size, self.c_dim])
        
        self.G_sample = self.generator(tf.concat([self.Z, self.C], axis=1))
        print(self.G_sample)
        self.D_real = self.discriminator(self.X)
        print(self.D_real)
        self.D_fake = self.discriminator(self.G_sample, reuse = True)
        print(self.D_fake)
        self.Q_rct = self.classifier(self.G_sample)
        print(self.Q_rct)
        
        self.Q_rct_classify, self.Q_rct_conti = tf.split(self.Q_rct, [10, self.c_dim-10],axis = 1)
        self.C_classify, self.C_conti = tf.split(self.C, [10, self.c_dim-10], axis = 1)
        
        self.D_loss = -tf.reduce_mean(self.D_real)+tf.reduce_mean(self.D_fake)
        self.Q_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.C_classify, logits=self.Q_rct_classify))+tf.reduce_mean(tf.square(self.C_conti-self.Q_rct_conti))
        self.G_loss = tf.reduce_mean(self.D_fake)        

        print("Generator_variables")
        self.generator.print_vars()
        print("Discriminator_variables")
        self.discriminator.print_vars()
        print("Classifier_variables")
        self.classifier.print_vars()

        self.D_optimizer = optimizer(self.D_loss, self.discriminator.vars)
        
        with tf.control_dependencies([self.D_optimizer]):
            self.D_optimizer_wrapped = [tf.assign(var, clip(var, -self.clip_b, self.clip_b)) for var in self.discriminator.vars]
        
        self.Q_optimizer = optimizer(self.Q_loss, self.generator.vars + self.classifier.vars)
        self.G_optimizer = optimizer(self.G_loss, self.generator.vars)

        logger.info("Building model done.")
        self.sess = tf.Session()
        
    def initialize(self):
        """Initialize all variables in graph"""
        self.sess.run(tf.global_variables_initializer())
        
    def restore(self):
        """Restore all variables in graph"""
        logger.info("Restoring model starts...")
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(SAVE_DIR))
        logger.info("Restoring model done.")     
        
    def train(self, train_epochs):
        for epoch in range(train_epochs):
            d_iter = 100 if epoch < 25 else 5
            
            for _ in range(d_iter):
                X_sample = self.dataset(self.batch_size)
                z_sample = sample_z(self.batch_size, self.z_dim)
                c_sample = sample_c(self.batch_size, self.c_dim)
                self.sess.run(self.D_optimizer_wrapped, feed_dict = {self.X : X_sample, self.Z : z_sample, self.C : c_sample})
            
            for _ in range(1):
                self.sess.run(self.G_optimizer, feed_dict = {self.Z : z_sample, self.C : c_sample})
            
            for _ in range(1):
                self.sess.run(self.Q_optimizer, feed_dict = {self.Z : z_sample, self.C : c_sample})
                
            if epoch % self.log_every == self.log_every+1:
                X_sample = self.dataset(self.batch_size)
                z_sample = sample_z(self.batch_size, self.z_dim)
                c_sample = sample_c(self.batch_size, self.c_dim)
                
                D_loss = self.sess.run(self.D_loss, feed_dict = {self.X : X_sample, self.Z : z_sample, self.C : c_sample})
                G_loss = self.sess.run(self.G_loss, feed_dict = {self.Z : z_sample, self.C : c_sample})
                Q_loss = self.sess.run(self.Q_loss, feed_dict = {self.Z : z_sample, self.C : c_sample})
                
                gray_3d = sess.run(G_sample, feed_dict = {self.Z : z_sample, self.C : c_sample}) # self.batch_size x 28 x 28 x 1
                gray_3d = np.squeeze(gray_3d)#self.batch_size x 28 x 28
                show_gray_image_3d(gray_3d, col=5, fig_size = (10, 40), dataformat = 'CHW')
                
                logger.info("Epoch({}/{}) D_loss : {}, G_loss : {}, Q_loss : {}".format(epoch+1, train_epochs, D_loss, G_loss, Q_loss))
                saver.save(sess, os.path.join(SAVE_DIR, 'model'), global_step = epoch+1)
                logger.info("Model save in %s"%SAVE_DIR)

In [4]:
infogan = InfoGAN()

[0722 11:48:47] Building model starts...


Extracting ../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz
Tensor("g_conv/deconv2/Sigmoid:0", shape=(100, 28, 28, 1), dtype=float32)
Tensor("d_conv/d_fc2/Sigmoid:0", shape=(100, 1), dtype=float32)
Tensor("d_conv_1/d_fc2/Sigmoid:0", shape=(100, 1), dtype=float32)
Tensor("q_conv/d_fc2/xw_plus_b:0", shape=(100, 12), dtype=float32)
Generator_variables
g_conv/fc1/w:0:[112, 6272]
g_conv/fc1/b:0:[6272]
g_conv/deconv1/w:0:[5, 5, 64, 128]
g_conv/deconv1/b:0:[64]
g_conv/deconv2/w:0:[5, 5, 1, 64]
g_conv/deconv2/b:0:[1]
Discriminator_variables
d_conv/conv1/w:0:[4, 4, 1, 32]
d_conv/conv1/b:0:[32]
d_conv/conv2/w:0:[4, 4, 32, 64]
d_conv/conv2/b:0:[64]
d_conv/d_fc1/w:0:[3136, 128]
d_conv/d_fc1/b:0:[128]
d_conv/d_fc2/w:0:[128, 1]
d_conv/d_fc2/b:0:[1]
Classifier_variables
q_conv/conv1/w:0:[4, 4, 1, 32]
q_conv/conv1/b:0:[32]
q_conv/conv2/w:0:[4, 4, 32, 64]
q_c

[0722 11:48:49] Building model done.
