In [16]:
from __future__ import absolute_import, division, print_function

import math
import os

import numpy as np
import prettytensor as pt
import scipy.misc
import tensorflow as tf

import scipy

from time import time

import matplotlib.pyplot as plt
%matplotlib inline

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement=True
sess = tf.Session(config=config)

In [17]:
flags = tf.flags
logging = tf.logging

flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_integer("updates_per_epoch", 1000, "number of updates per epoch")
flags.DEFINE_integer("max_epoch", 100, "max epoch")
flags.DEFINE_float("learning_rate", 1e-2, "learning rate")
flags.DEFINE_string("working_directory", "", "")
flags.DEFINE_integer("hidden_size", 128, "size of the hidden VAE unit")
# flags.DEFINE_integer("hidden_size", 2, "size of the hidden VAE unit")

flags.DEFINE_integer("dim", 1, "dimensionality of the target distribution")
flags.DEFINE_integer("gener_dim", 20, "dimensionality of the generator distribution")


FLAGS = flags.FLAGS

In [18]:
def discriminator(input_tensor):
    '''Create discriminator network.

    Args:
        input_tensor: a batch of flattened images [batch_size, 28*28]

    Returns:
        A tensor that expresses the logit of being a true sample
    '''

    return (pt.wrap(input_tensor).
            fully_connected(128).
            fully_connected(128).
            fully_connected(128).
            dropout(0.9).
            fully_connected(1, activation_fn=None)).tensor

def generator(Z=None):
    '''Create a generator network
    
    '''
    if Z==None:
        Z = tf.random_uniform([FLAGS.batch_size,FLAGS.gener_dim])
    
    return (pt.wrap(Z).
            fully_connected(128).
            fully_connected(128).
            fully_connected(128).
            fully_connected(FLAGS.dim, activation_fn=None)).tensor

def chi2_loss(Xn,Yn):
    C = tf.matmul(Xn,Yn,transpose_a=True)/FLAGS.batch_size # bias correction needed???
    chi2 = tf.reduce_sum(tf.square(C))
    return chi2

# def get_gan_loss(input_tensor,generated_tensor):
    
#     with tf.variable_scope("model-discriminator", reuse=True) as scope:
#         D_input = discriminator(input_tensor)
#     with tf.variable_scope("model-discriminator", reuse=True) as scope:
#         D_generated = discriminator(generated_tensor)
        
#     return tf.reduce_mean(D_input-tf.nn.softplus(D_input)-tf.nn.softplus(D_generated))

In [19]:
mtn_weights = [.3,.3,.3]
mu = np.array([-1,1.5,4])
sig = np.array([.5,.5,.5])


def mg_sampler(n,mtn_weights,mu,sig):
    mtn_samples = np.random.multinomial(1,mtn_weights,(n))
    sampl_idx = np.argmax(mtn_samples,1)
    mg_samples = np.reshape(mu[sampl_idx],[n,1])+np.reshape(sig[sampl_idx],[n,1])*np.random.randn(n,1)
    return mg_samples

# aa = mg_sampler(100000, mtn_weights,mu,sig)
# plt.hist(aa,bins=1000)

In [20]:
input_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.dim])
learning_rate = tf.placeholder(tf.float32)
# input_code = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.hidden_size])

# ema = tf.train.ExponentialMovingAverage(decay=0.99)
ema = tf.train.ExponentialMovingAverage(decay=0.8)

label = tf.constant(np.concatenate((np.ones([FLAGS.batch_size,1]),\
                    -np.ones([FLAGS.batch_size,1]))),dtype=tf.float32)

with tf.device('/cpu:0'):

    with pt.defaults_scope(activation_fn=tf.nn.elu,
                           batch_normalize=True,
                           learned_moments_update_rate=0.0003,
                           variance_epsilon=0.001,
                           scale_after_normalization=True):
        with pt.defaults_scope(phase=pt.Phase.train):
            with tf.variable_scope("discriminator") as scope:
                D_input = discriminator(input_tensor)
            with tf.variable_scope("generator") as scope:
                generated_tensor = generator()
            with tf.variable_scope("discriminator", reuse=True) as scope:
                D_generated = discriminator(generated_tensor)

        with pt.defaults_scope(phase=pt.Phase.test):
            with tf.variable_scope("discriminator", reuse=True) as scope:
                D_input_test = discriminator(input_tensor)
            with tf.variable_scope("generator", reuse=True) as scope:
                generated_tensor_test = generator()
            with tf.variable_scope("discriminator", reuse=True) as scope:
                D_generated_test = discriminator(generated_tensor)
                
    D = tf.sigmoid(tf.concat((D_input,D_generated),axis=0))

    D_mean,D_var = tf.nn.moments(D,axes=[0])
    maintain_averages_op = ema.apply([D_mean, D_var])

    ema_D_mean = ema.average(D_mean)
    ema_D_std = tf.sqrt(ema.average(D_var))

    Dn = (D-ema_D_mean)/ema_D_std

    data_loss = tf.reduce_mean(D_input-tf.nn.softplus(D_input))
    sampl_loss = -tf.reduce_mean(tf.nn.softplus(D_generated))
    
    discr_loss = -(data_loss+sampl_loss)
    gener_loss = chi2_loss(Dn,label)
    
    discr_vars = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name.startswith('discriminator')]
    gener_vars = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name.startswith('generator')]

    optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=1.0)
    
    train_gan_discr = pt.apply_optimizer(optimizer, losses=[discr_loss], var_list=discr_vars)
    train_gan_gener = pt.apply_optimizer(optimizer, losses=[gener_loss], var_list=gener_vars)

In [21]:
# Initialization

init = tf.global_variables_initializer()
sess.run(init)

In [22]:
lr = 1e-3

max_epoch = 100
updates_per_epoch = 1000
# max_epoch = 10
# updates_per_epoch = 100
# max_epoch = 100
# updates_per_epoch = 100

epoch_record = np.zeros([max_epoch,])
gener_record = list()

for epoch_id in range(max_epoch):
    
    loss_record = np.zeros([updates_per_epoch,])
    
    t0 = time()
    
    for step in range(updates_per_epoch):

#         x = np.random.randn(FLAGS.batch_size,1)
        x = mg_sampler(FLAGS.batch_size,mtn_weights,mu,sig)
        
#         _,loss_val = sess.run([train_gan_discr,discr_loss],{input_tensor: x,
#                     learning_rate: lr})
        
#         _,_ = sess.run([train_gan_gener,gener_loss],{input_tensor: x,
#                     learning_rate: lr})
        
        _,loss_val,_ = sess.run([train_gan_discr,discr_loss,maintain_averages_op],
                    {input_tensor: x, learning_rate: 100.*lr})
        
        _,loss_val,_ = sess.run([train_gan_gener,gener_loss,maintain_averages_op], 
                    {input_tensor: x, learning_rate: lr})
        
        loss_record[step] = loss_val
        
    if epoch_id>0 and np.mod(epoch_id+1,20)==0:
        gener_samples = np.zeros([MLOOP*FLAGS.batch_size,])
        for mloop in range(MLOOP):
            y = sess.run(generated_tensor)
            gener_samples[mloop*FLAGS.batch_size:(mloop+1)*FLAGS.batch_size] = np.reshape(y,[FLAGS.batch_size,])
        plt.subplot(2,3,epoch_id/20+2)
        _,_,_ = plt.hist((gener_samples,real_samples),bins=20)
        _ = plt.legend({'Epoch %d' % (epoch_id+1)})

        gener_record.append(gener_samples)
        
        
    t1 = time()
        
    print([epoch_id+1,np.mean(loss_record),t1-t0])
    epoch_record[epoch_id] = np.mean(loss_record)

plt.subplot(2,3,1)
_ = plt.plot(epoch_record)
_ = plt.plot(range(19,max_epoch,20),epoch_record[19::20],'sg',alpha=.3)
# _ = plt.title('Loss')
# _ = plt.xlabel('epoch')
_ = plt.legend({'Loss'})


[1, 1.475317530155182, 5.082407474517822]
[2, 0.76314093846803366, 4.86935830116272]
[3, 0.17332598644457289, 5.690716743469238]
[4, 0.31082712606526913, 4.892812967300415]
[5, 0.43151562875509264, 5.5575385093688965]
[6, 0.54786040797829627, 4.824517011642456]
[7, 0.57868378311395641, 4.692577123641968]
[8, 0.58623844015598292, 4.533369779586792]
[9, 0.59274677070975301, 4.517522573471069]
[10, 0.58055584787577386, 4.52625846862793]
[11, 0.56197652968019252, 4.517590761184692]
[12, 0.54638612679392096, 4.955818176269531]
[13, 0.52807139160484073, 4.885118007659912]
[14, 0.48705342867970469, 4.495056629180908]
[15, 0.42892089822143314, 4.50938868522644]
[16, 0.37026572960242626, 4.492819547653198]
[17, 0.31926256710290907, 4.5141284465789795]
[18, 0.30337377323769033, 4.498210191726685]
[19, 0.29702285630628467, 4.500646591186523]


NameError: name 'MLOOP' is not defined

In [None]:
gener_samples = np.zeros([MLOOP*FLAGS.batch_size,])
for mloop in range(MLOOP):
    y = sess.run(generated_tensor)
    gener_samples[mloop*FLAGS.batch_size:(mloop+1)*FLAGS.batch_size] = np.reshape(y,[FLAGS.batch_size,])
gener_record.append(gener_samples)

In [None]:
width = 12
height = 6
plt.figure(figsize=(width, height))

for epoch_id in range(5):
    plt.subplot(2,3,epoch_id+2)
    gener_samples = gener_record[epoch_id]
    _,_,_ = plt.hist((gener_samples,real_samples),bins=20,normed=True)
    _ = plt.legend({'Epoch %d' % ((epoch_id+1)*20)},loc='upperleft')
    _ = plt.xlim([-4,4])
plt.subplot(2,3,1)
_ = plt.plot(epoch_record)
_ = plt.plot(range(19,max_epoch,20),epoch_record[19::20],'sg',alpha=.3)
# _ = plt.title('Loss')
# _ = plt.xlabel('epoch')
_ = plt.legend({'Loss'})