In [None]:
import tensorflow as tf
import keras
from keras import backend as K
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
print(tf.__version__)
print(keras.__version__)
import cv2  # for image processing
from sklearn.model_selection import train_test_split
import scipy.io
import os
import h5py
from arts_preprocess_utils import load_dataset
from keras.preprocessing.image import ImageDataGenerator
from IPython import display

In [None]:
# !!! remember to clear session/graph if you rebuild your graph to avoid out-of-memory errors !!!
def reset_tf_session():
    K.clear_session()
    tf.reset_default_graph()
    s = K.get_session()
    return s

In [None]:
reset_tf_session()

## Load Dataset

In [None]:
train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig = load_dataset('./wikiart_mini_portrait.h5')

In [None]:
img_Height = train_set_x_orig.shape[1]
img_Width = train_set_x_orig.shape[2]
N_CLASSES = len(np.unique(test_set_y_orig))

In [None]:
X_train = train_set_x_orig
y_train = train_set_y_orig

X_dev = test_set_x_orig
y_dev = test_set_y_orig

**Get impressionist images**

In [None]:
train_imp_index = np.argwhere(y_train == 1).reshape((-1,))
test_imp_index = np.argwhere(y_dev == 1 ).reshape((-1,))

In [None]:
X_train_imp = X_train[train_imp_index, :, :, :]
X_dev_imp = X_dev[test_imp_index,:, :, :]

In [None]:
X_imp = np.concatenate((X_train, X_dev), axis=0)
X_imp.shape

In [None]:
#normalize images
X_imp = X_imp * (1./255)

**Plot image**

In [None]:
plt.imshow(X_imp[0][...,::-1])

### Discriminator

In [None]:
from models import Discriminator_model

#based on art-DCGAN (robbiebarrat)
discriminator_model = Discriminator_model(filters=40, code_shape=100)
discriminator = discriminator_model.get_model((img_Height, img_Width, 3), N_CLASSES, False)

In [None]:
discriminator.summary()

In [None]:
discriminator.load_weights('./discriminator01.h5')

### Generator

In [None]:
NOISE = 100

**Complex generator**

In [None]:
from models import Generator_model_complex

#based on art-DCGAN (robbiebarrat)
generator_model = Generator_model_complex(filters=80, code_shape= (1,1,NOISE), leaky_alpha= 0.001)
generator = generator_model.get_model((img_Height, img_Width, 3))

In [None]:
generator.summary()

In [None]:
#load weights from a pretrained autoencoder
generator.load_weights('./decoder01.h5')
#load pre-trained weights
#generator.load_weights('./generator01.h5')

### Content restriction

In [None]:
#TODO

## Generate models

**Discriminator model**

In [None]:
import keras.layers as L
from keras.models import Model

for layer in discriminator.layers:
    layer.trainable = True

for layer in generator.layers:
    layer.trainable = False

discriminator.trainable = True
generator.trainable = False

real_samples = L.Input(shape=X_imp.shape[1:], name='real_samples')
noisy_input = L.Input(shape=(1,1, NOISE))

generated_samples = generator(noisy_input)
generated_samples_prediction = discriminator(generated_samples)
real_samples_prediction = discriminator(real_samples)

discriminator_model = Model(inputs=[real_samples,noisy_input], 
                            outputs=[real_samples_prediction, generated_samples_prediction])


In [None]:
discriminator_model.compile(
    loss='binary_crossentropy',  
    loss_weights=[0.5, 0.5],
    optimizer=keras.optimizers.adamax(lr=1e-3),
    metrics=['accuracy']  # report accuracy during training
)

In [None]:
discriminator_model.summary()

**Generator model**

In [None]:

for layer in discriminator.layers:
    layer.trainable = False

for layer in generator.layers:
    layer.trainable = True

discriminator.trainable = False
generator.trainable = True

z = L.Input(shape=(1,1, NOISE))
img = generator(z)

real = discriminator(img)
generator_model = Model(z, real)

In [None]:
generator_model.compile(
    loss='binary_crossentropy',  
    optimizer=keras.optimizers.adamax(lr=1e-4)
)

In [None]:
generator_model.summary()

## Training

In [None]:
batch_size = 64

# Adversarial ground truths
valid = np.ones((batch_size,))
fake = np.zeros((batch_size,))

In [None]:
from gan_utils import noisy_images, sample_images, sample_probas

def train_gan(X, gen_size, epochs = 200000, sample_interval = 5000):
    
    #TODO add noise to real images after 2000 epochs
    
    g_loss_hist = []
    d_loss_hist = []
    size = (batch_size,) + gen_size
    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Select a random batch of images
        idx = np.random.randint(0, X.shape[0], batch_size)
        imgs = X[idx]
        if (epoch % 10) == 0:
            n = int(batch_size / 2)
            noisy_imgs = noisy_images(imgs[:n])
            imgs = np.concatenate([imgs[n:,], noisy_imgs])

        #Generate noise for generator
        noise = np.random.normal(0, 1, size=size)
        
        #train output = [general loss, loss D(x), loss D(G(z)), acc D(x), acc D(G(z)) ]
        d_loss = discriminator_model.train_on_batch([imgs, noise], [valid,fake])

        # ---------------------
        #  Train Generator
        # ---------------------

        #Generate noise for generator
        noise = np.random.normal(0, 1,  size=size)

        # Train the generator (to have the discriminator label samples as valid)
        g_loss = generator_model.train_on_batch(noise, valid)

        # Plot the progress each 100 epoch
        if (epoch % 100) == 0:
            display.clear_output(wait=True)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f] loss: %f" % (epoch, d_loss[1], 100*d_loss[3], d_loss[2], d_loss[0]))
            g_loss_hist.append(g_loss)
            d_loss_hist.append(d_loss[0])
            sample_size = (1000,) + gen_size
            sample_probas(X, 1000, sample_size, discriminator=discriminator, generator=generator)

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch, gen_size, generator)


In [None]:
train_gan(X_imp, gen_size=(1,1,NOISE))

In [None]:
generator.save_weights(filepath='generator01.h5')

In [None]:
discriminator.save_weights(filepath='discriminator01.h5')

**Test discriminator**

In [None]:
X = X_imp * (1./255)
valid = np.ones((X.shape[0]))

In [None]:
pred = discriminator.predict(X).round().reshape((-1,))

**Test on a new image**

In [None]:
img = plt.imread('berni_retrato.jpg')
img = cv2.resize(img, (img_Height, img_Width), interpolation=cv2.INTER_CUBIC)
img_norm = img *(1./255)
img_norm = np.expand_dims(img_norm, axis=0)

In [None]:
plt.imshow(img_norm[0])

In [None]:
discriminator.predict(img_norm).round()

**Test generator**

In [None]:
noise = np.random.uniform(0, 1, size=[100, 1, 1, 100])
fakes = generator.predict(noise)

In [None]:
plt.imshow(fakes[0])

In [None]:
fakes[0]