In [1]:
from nets import GenConv, DisConv, QConv
from config import InfoGANConfig, SAVE_DIR
from ops import mnist_for_gan, optimizer, clip
import tensorflow as tf
import numpy as np
import logging
logging.basicConfig(format = "[%(asctime)s] %(message)s", datefmt="%m%d %H:%M:%S")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [2]:
batch_size = 100

In [1]:
def sample_c(c_size, add_cv, index = -1):
    '''
    Args:
        c_size - int
            number of samples
        add_v - int
            number of additional continuous random variables [-1, 1] uniform
        index - int
            default to be -1
    Return:
        [c_size, 10 + add_cv]
            10 for classification
            add_cv for independent continuos
    '''
    
    classify = np.zeros([c_size, 10])
    conti = np.random.uniform(low = -1.0, high = 1.0, size = [c_size, add_cv])
    if index < 0:
        index = np.random.randint(10)
    classify[:,index] = 1
    return np.concatenate((classify, conti), axis = 1)

def sample_z(z_size, z_dim):
    return np.random.uniform(low=-1, high=1, size= [z_size, z_dim])

In [None]:
class InfoGAN(infoGANConfig):
    def __init__(self):
        infoGANConfig.__init__(self)
        self.generator = GenConv(name ='g_conv', batch_size=batch_size)
        self.discriminator = DisConv(name='d_conv')
        self.classifier = QConv(name='q_conv')
        self.dataset = mnist_for_gan()
        
        self.X = tf.placeholder(tf.float32, shape = [self.batch_size, self.x_size, self.x_size, self.x_size])
        self.Z = tf.placeholder(tf.float32, shape = [self.batch_size, self.z_dim])
        self.C = tf.placeholder(tf.float32, shape = [self.batch_size, self.c_dim])
        
        self.G_sample = self.generator(tf.concat([self.Z, self.C], axis=1))
        self.D_real = self.discriminator(self.X)
        self.D_fake = self.discriminator(self.G_sample, reuse = True)
        self.Q_rct = self.classifier(self.G_sample, self.c_dim)
        
        self.Q_rct_classify, self.Q_rct_conti = tf.split(self.Q_rct, [10, self.c_dim-10],axis = 1)
        
        self.D_loss = -tf.reduce_mean(self.D_real)+tf.reduce_mean(self.D_fake)
        self.Q_loss = tf.reduce_mena(self.)
        self.G_loss = tf.reduce_mean(self.D_fake)        
        
        self.D_optimizer = optimizer(self.D_loss, self.discriminator.vars)
        
        with tf.control_dependencies([self.D_optimizer]):
            self.D_optimizer_wrapped = [tf.assign(var, tf.clip_by_value(var, clip(var, -self.clip_b, self.clip_b))) for var in self.discriminator.vars]
        
        self.G_optimizer = optimizer(self.G_loss, self.generator.vars)
        self.Q_optimizer = optimizer(self.Q_loss, self.generator.vars + self.classifier.vars)

        self.sess = tf.Session()
        
    def initialize(self):
        """Initialize all variables in graph"""
        self.sess.run(tf.global_variables_initializer())
        
    def restore(self):
        """Restore all variables in graph"""
        logger.info("Restoring model starts...")
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(SAVE_DIR))
        logger.info("Restoring model done.")     
        
    def train(self, train_epochs):
        for epoch in range(train_epochs):
            d_iter = 100 if epoch < 25 else 5
            
            for _ in range(d_iter):
                X_sample = self.dataset(self.batch_size)
                z_sample = sample_z(self.batch_size, self.z_dim)
                c_sample = sample_c(self.batch_size, self.c_dim)
                self.sess.run(self.D_optimizer_wrapped, feed_dict = {self.X : X_sample, self.Z : z_sample, self.C : c_sample})
            
            for _ in range(1):
                self.sess.run(self.G_optimizer, feed_dict = {self.Z : z_sample, self.C : c_sample})
            
            for _ in range(1):
                self.sess.run(self.Q_optimizer, feed_dict = {self.Z : z_sample, self.C : c_sample})
                
            if epoch % self.log_every == self.log_every+1:
                
                saver.save(sess, os.path.join(SAVE_DIR, 'model'), global_step = epoch+1)
                logger.info("Model save in %s"%SAVE_DIR)