In [None]:
import tensorflow as tf
import keras
from keras import backend as K
from keras.layers.merge import _Merge
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
print(tf.__version__)
print(keras.__version__)
import cv2  # for image processing
from sklearn.model_selection import train_test_split
import scipy.io
import os
import h5py
from arts_preprocess_utils import load_dataset
from keras.preprocessing.image import ImageDataGenerator
from IPython import display
from wassertstein_utils import RandomWeightedAverage, gradient_penalty_loss, wasserstein_loss

In [None]:
BATCH_SIZE = 16
TRAINING_RATIO = 5  # The training ratio is the number of discriminator updates per generator update. The paper uses 5.
GRADIENT_PENALTY_WEIGHT = 10  # As per the paper

In [None]:
# !!! remember to clear session/graph if you rebuild your graph to avoid out-of-memory errors !!!
def reset_tf_session():
    K.clear_session()
    tf.reset_default_graph()
    s = K.get_session()
    return s

In [None]:
reset_tf_session()

## Load dataset

In [None]:
train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig = load_dataset('/root/work/datasets/wikiart_mini_portrait.h5')

In [None]:
img_Height = train_set_x_orig.shape[1]
img_Width = train_set_x_orig.shape[2]
N_CLASSES = len(np.unique(test_set_y_orig))

In [None]:
X_train = train_set_x_orig
y_train = train_set_y_orig

X_dev = test_set_x_orig
y_dev = test_set_y_orig

In [None]:
plt.imshow(X_train[0][...,::-1])

## Discriminator and generator base model

In [None]:
from models import Discriminator_model, Generator_model_complex

code_shape = 100

#based on art-DCGAN (robbiebarrat)
generator_model = Generator_model_complex(filters=80, code_shape= (1,1,code_shape))
generator = generator_model.get_model((img_Height, img_Width, 3))

#based on art-DCGAN (robbiebarrat)
discriminator_model = Discriminator_model(filters=40, code_shape=code_shape, include_top = False)
discriminator = discriminator_model.get_model((img_Height, img_Width, 3), N_CLASSES, False)

In [None]:
import keras.layers as L

#add top
discriminator.add(L.Flatten())
discriminator.add(L.Dense(1, kernel_initializer='he_normal'))

In [None]:
discriminator.summary()

In [None]:
generator.summary()

### Creater generator model

In [None]:
import keras.layers as L
from keras.models import Model
from keras.optimizers import Adam

#Define graph for generator

#discriminator.trainable = False
#generator.trainable = True

for layer in discriminator.layers:
    layer.trainable = False

for layer in generator.layers:
    layer.trainable = True

generator_input = L.Input(shape=(1,1,code_shape))
generator_layers = generator(generator_input)
discriminator_layers_for_generator = discriminator(generator_layers)
generator_model = Model(inputs=[generator_input], outputs=[discriminator_layers_for_generator])
# We use the Adam paramaters from Gulrajani et al.
generator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9), loss=wasserstein_loss)

In [None]:
generator_model.summary()

### Create discriminator model

In [None]:
from keras.models import Model

#Define graph for discriminator

#discriminator.trainable = True
#generator.trainable = False

for layer in discriminator.layers:
    layer.trainable = True

for layer in generator.layers:
    layer.trainable = False

real_samples = L.Input(shape=X_train.shape[1:])
generator_input_for_discriminator = L.Input(shape=(1,1,code_shape))
generated_samples_for_discriminator = generator(generator_input_for_discriminator)
discriminator_output_from_generator = discriminator(generated_samples_for_discriminator)
discriminator_output_from_real_samples = discriminator(real_samples)
averaged_samples = RandomWeightedAverage(BATCH_SIZE)([real_samples, generated_samples_for_discriminator])
averaged_samples_out = discriminator(averaged_samples)

discriminator_model = Model(inputs=[real_samples, generator_input_for_discriminator],
                            outputs=[discriminator_output_from_real_samples,
                                     discriminator_output_from_generator,
                                     averaged_samples_out])

**Define loss fucntions**

In [None]:
from functools import partial

# The gradient penalty loss function requires the input averaged samples to get gradients. However,
# Keras loss functions can only have two arguments, y_true and y_pred. We get around this by making a partial()
# of the function with the averaged samples here.
partial_gp_loss = partial(gradient_penalty_loss,
                          averaged_samples=averaged_samples,
                          gradient_penalty_weight=10)
partial_gp_loss.__name__ = 'gradient_penalty'  # Functions need names or Keras will throw an error

In [None]:
discriminator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9),
                            loss=[wasserstein_loss,
                                  wasserstein_loss,
                                  partial_gp_loss])

In [None]:
discriminator_model.summary()

In [None]:
discriminator_model.metrics_names

## Training

In [None]:
y_train_positive = np.ones_like(y_train)
y_train_positive.shape

In [None]:
positive_y = np.ones((BATCH_SIZE, 1), dtype=np.float32)
negative_y = -positive_y
dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)

In [None]:
from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator( 
    rescale = 1.0/255.,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

train_datagen.fit(X_train)

validation_datagen = ImageDataGenerator( rescale = 1.0/255. )

validation_datagen.fit(X_dev)

In [None]:
from gan_utils import noisy_images, sample_images, sample_probas

def train_gan(X, y_train, datagen, gen_size, epochs = 5, sample_interval = 1000):

    for epoch in range(epochs):

        d_loss_hist = []
        g_loss_hist = []
        size = (BATCH_SIZE,) + gen_size

        minibatches_size = BATCH_SIZE * TRAINING_RATIO
        batches = 0
        
        for x_batch, y_batch in datagen.flow(X, y_train, batch_size=minibatches_size):
            
            # ---------------------
            #  Train Discriminator
            # ---------------------
            print(x_batch.shape)
            for j in range(TRAINING_RATIO):
                image_batch = x_batch[j * BATCH_SIZE:(j + 1) * BATCH_SIZE]
                noise = np.random.normal(0, 1, size=size)
                d_loss =discriminator_model.train_on_batch([image_batch, noise], [positive_y, negative_y, dummy_y])


            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, size=size)

            g_loss = generator_model.train_on_batch( noise, positive_y)
            
            if batches >= len(X) / minibatches_size:
            # we need to break the loop by hand because
            # the generator loops indefinitely
                break

        # Plot the progress 
        display.clear_output(wait=True)
        print ("%d [D loss: %f] [D(G(z)) loss: %f] loss: %f" % (epoch, d_loss[1], d_loss[2], d_loss[0]))
        g_loss_hist.append(g_loss)
        d_loss_hist.append(d_loss[0])
        sample_size = (1000,) + gen_size
        #TODO:change because discriminator do not classify between 0-1
        sample_probas(X, 1000, sample_size, discriminator=discriminator, generator=generator)

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch, gen_size, generator)
            #checkpoint to save weights
            generator.save_weights(filepath='generator_wasserstein.h5')
            discriminator.save_weights(filepath='discriminator_wasserstein.h5')

In [None]:
train_gan(X=X_train, y_train=y_train_positive, datagen=train_datagen, gen_size=(1,1,100))