In [1]:
import os
import tensorflow as tf
import time
import datetime
import numpy as np
from mlp.data_providers import MNISTDataProvider, AugmentedCIFAR10DataProvider, AugmentedCIFAR100DataProvider, CIFAR100DataProvider, CIFAR10DataProvider
from mlp.tf_layers import FCLayer, ConvLayer, max_pool_2x2
from mlp.image_transforms import random_flip, random_crop, center_crop, random_flip_small
from mlp.GAN_models import GAN, WasserstienGAN, CWGAN
import mlp.tf_utils as utils
import matplotlib.pyplot as plt
import seaborn as sns
import cPickle
%matplotlib inline

# Seed a random number generator
seed = 24102016 
rng = np.random.RandomState(seed)

In [2]:
model_name = 'CWGAN_with_TTUR'
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

In [3]:
batch_size = 64  # batch size for training
z_dim = 100  
gen_learning_rate = 2e-4
disc_learning_rate = 2e-4
optimizer_param = 0.9  # beta1 for Adam optimizer / decay for RMSProp
iterations = 1e5  # No. of iterations to train model
image_size = 32  # Size of actual images, Size of images to be generated at.
model = 1  # Model to train. 0 - GAN, 1 - WassersteinGAN
optimizer = "Adam"  # Optimizer to use for training
gen_dim = 16  # dimension of first layer in generator
mode = "train"  # train / visualize model
logs_dir = '/home/ben/Dissertation/Multitask-Learning-With-GANs/4_cwgan_3/tf-log'

In [None]:
train_data = CIFAR100DataProvider(which_set='train', batch_size=batch_size)
#train_data = CIFAR100DataProvider(which_set='train', batch_size=batch_size)
#train_data.inputs = center_crop(train_data.inputs, rng)

In [None]:
generator_dims = [128 * gen_dim, 64 * gen_dim // 2, 64 * gen_dim // 4, 3]
discriminator_dims = [3, 64, 64 * 2, 64 * 4, 64 * 8, 1]

if model == 0:
    raise NotImplementedError
    #model = GAN(z_dim, batch_size, train_data, image_size)
elif model == 1:
    model = CWGAN(z_dim, batch_size, train_data, image_size, 100, clip_values=(-0.01, 0.01), 
                  critic_iterations=5,
                  penalise_gradient_norm=True, lmbda=10)
else:
    raise ValueError("Unknown model identifier - model=%d" % model)

model.create_network(generator_dims, discriminator_dims, model_name, optimizer, 
                     gen_learning_rate, disc_learning_rate, optimizer_param)
model.initialize_network(logs_dir)

if mode == "train":
    model.train_model(int(1 + iterations))
elif mode == "visualize":
    model.visualize_model()

Setting up model...
Initializing network...
Training CWGAN model...
Time: 0.516834/itr, Step: 100, generator loss: 0.286539, discriminator_loss: -0.226323
Time: 0.301517/itr, Step: 200, generator loss: 0.186883, discriminator_loss: -0.11647
Time: 0.301485/itr, Step: 300, generator loss: 0.0069702, discriminator_loss: -0.113229
Time: 0.301232/itr, Step: 400, generator loss: 0.0154631, discriminator_loss: -0.0691061
Time: 0.312119/itr, Step: 500, generator loss: -0.122473, discriminator_loss: -0.0773889
Saving images... Step: 500
Time: 0.303511/itr, Step: 600, generator loss: 0.428218, discriminator_loss: -0.135569
Time: 0.302392/itr, Step: 700, generator loss: 0.357727, discriminator_loss: -0.0327141
Time: 0.302097/itr, Step: 800, generator loss: 0.326918, discriminator_loss: -0.070046
Time: 0.302356/itr, Step: 900, generator loss: 0.237024, discriminator_loss: -0.0555693
Time: 0.310601/itr, Step: 1000, generator loss: 0.167078, discriminator_loss: -0.0282843
Saving images... Step: 1000

In [None]:
def class_to_vec(class_name):
    map_list = list(map.copy())
    item_to_visualise = class_name
    indx = map_list.index(item_to_visualise)
    class_v = np.zeros([len(map)])
    class_v[indx] = 1
    class_v = [class_v for i in range(batch_size)]
    return np.array(class_v)

In [None]:
map = train_data.label_map
a = np.array([0,0,1])
i = np.where(a==1)[0][0]
map[i]

In [None]:
print("Sampling images from model...")
class_v = class_to_vec("sunflower")
z = np.random.uniform(-1.0, 1.0, size=[batch_size, 200 - class_v.shape[1]]).astype(np.float32)
z = np.concatenate([z, class_v.astype(np.float32)], 1)

feed_dict = {model.z_vec: z, model.train_phase: True}

images = model.sess.run(model.gen_images, feed_dict=feed_dict)
shape = [1, 1]
utils.save_imshow_grid(images, model.logs_dir, "generated_sample_" + str(time.time()) + ".png", shape=shape)

In [None]:
model.generate_sample(np.array(class_v))

In [None]:
model.gen_images[0]