In [None]:
import tensorflow as tf
import keras
from keras import backend as K
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
print(tf.__version__)
print(keras.__version__)
import cv2  # for image processing
from sklearn.model_selection import train_test_split
import scipy.io
import os
import h5py
from arts_preprocess_utils import load_dataset
from keras.preprocessing.image import ImageDataGenerator
from IPython import display

In [None]:
# !!! remember to clear session/graph if you rebuild your graph to avoid out-of-memory errors !!!
def reset_tf_session():
    K.clear_session()
    tf.reset_default_graph()
    s = K.get_session()
    return s

In [None]:
reset_tf_session()

## Load Dataset

In [None]:
train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig = load_dataset('./wikiart_mini_portrait.h5')

In [None]:
img_Height = train_set_x_orig.shape[1]
img_Width = train_set_x_orig.shape[2]
N_CLASSES = len(np.unique(test_set_y_orig))

In [None]:
X_train = train_set_x_orig
y_train = train_set_y_orig

X_dev = test_set_x_orig
y_dev = test_set_y_orig

**Get impressionist images**

In [None]:
train_imp_index = np.argwhere(y_train == 1).reshape((-1,))
test_imp_index = np.argwhere(y_dev == 1 ).reshape((-1,))

In [None]:
X_train_imp = X_train[train_imp_index, :, :, :]
X_dev_imp = X_dev[test_imp_index,:, :, :]

In [None]:
X_imp = np.concatenate((X_train, X_dev), axis=0)
X_imp.shape

**Plot image**

In [None]:
plt.imshow(np.clip(X_imp[0], 0, 255).astype('uint8')[...,::-1])

### Discriminator

**Inception**

In [None]:
from models import Inception_model

inception_model = Inception_model()
model = inception_model.get_model((img_Height, img_Width, 3), N_CLASSES, False)

In [None]:
model.load_weights('./inception01.h5')

**ResNet50**

In [None]:
from models import ResNet_model

resnet_model = ResNet_model()
model = resNet_model.get_model((img_Height, img_Width, 3), N_CLASSES, False)

**Simple Discriminator**

In [None]:
from models import Discriminator_model

#based on art-DCGAN (robbiebarrat)
#change LeakyReLU parameter
discriminator_model = Discriminator_model(filters=40, code_shape=100)
model = discriminator_model.get_model((img_Height, img_Width, 3), N_CLASSES, False)

In [None]:
discriminator = model

In [None]:
discriminator.summary()

In [None]:
# set all layers trainable by default and prepare batch norm for fine-tuning
for layer in discriminator.layers:
    layer.trainable = True
    if isinstance(layer, keras.layers.BatchNormalization):
        # we do aggressive exponential smoothing of batch norm 
        # parameters to faster adjust to our new dataset
        layer.momentum = 0.8
    
# fix deep layers (fine-tuning only last n)
for layer in discriminator.layers[:-5]:
    layer.trainable = False

In [None]:
discriminator.trainable = False

In [None]:
discriminator.load_weights('./discriminator01.h5')

**Test discriminator**

In [None]:
X = X_imp * (1./255)

In [None]:
valid = np.ones((X.shape[0]))

In [None]:
valid.shape

In [None]:
discriminator.predict(x=X[:10])

In [None]:
discriminator.train_on_batch(x=X[:10], y=valid[:10])

In [None]:
pred = discriminator.predict(X).round().reshape((-1,))

In [None]:
unique, counts = np.unique(pred, return_counts=True)

In [None]:
unique

In [None]:
counts

In [None]:
img = plt.imread('berni_retrato.jpg')

In [None]:
plt.imshow(img_norm[0])

In [None]:
img = cv2.resize(img, (img_Height, img_Width), interpolation=cv2.INTER_CUBIC)
img_norm = img *(1./255)
img_norm = np.expand_dims(img_norm, axis=0)

In [None]:
discriminator.predict(img_norm).round()

### Generator

In [None]:
NOISE = 100

**Simple Generator**

In [None]:
from models import Generator_model

generator_model = Generator_model()
generator = generator_model.get_model(NOISE)

**Complex generator**

In [None]:
from models import Generator_model_complex

#based on art-DCGAN (robbiebarrat)
generator_model = Generator_model_complex(filters=80, code_shape= (1,1,NOISE))
generator = generator_model.get_model((img_Height, img_Width, 3))

In [None]:
#load weights from a pretrained autoencoder
generator.load_weights('./decoder01.h5')

In [None]:
generator.load_weights('./generator01.h5')

In [None]:
generator.summary()

In [None]:
generator.compile(loss='binary_crossentropy',
                 optimizer=keras.optimizers.adamax(lr=1e-2))

In [None]:
noise = np.random.uniform(0, 1, size=[100, 1, 1, 300])

In [None]:
fakes = generator.predict(noise)

In [None]:
plt.imshow(fakes[90][...,::-1])

In [None]:
plt.imshow(X[1])

In [None]:
type(X[0])

In [None]:
discriminator.predict(fakes).round()

### Content restriction

In [None]:
#TODO

#### Auxiliar function to save images

In [None]:
def sample_images(epoch, gen_size):
    r, c = 2, 2
    size = (r*c,) + gen_size
    noise = np.random.normal(0, 1, size= size)
    gen_imgs = generator.predict(noise)
    
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%d.png" % epoch)
    
    plt.close()

**Auxiliar function to see probas distribution D(x) and D(G(z))**

In [None]:
def sample_probas(batch_size, gen_size):
    plt.title('Generated vs real data')
    
    # Select a random batch of images
    idx = np.random.randint(0, X.shape[0], batch_size)
    imgs = X[idx]
    plt.hist(discriminator.predict(imgs)[:,0],
             label='D(x)', alpha=0.5,range=[0,1])
    
    #Generate random input
    noise = np.random.normal(0, 1, size=gen_size)
    plt.hist(discriminator.predict(generator.predict(noise))[:,0],
             label='D(G(z))',alpha=0.5,range=[0,1])
    plt.legend(loc='best')
    plt.show()

## Training

In [None]:
discriminator.compile(
    loss='binary_crossentropy',  
    optimizer=keras.optimizers.adamax(lr=1e-4),
    metrics=['accuracy']  # report accuracy during training
)

In [None]:
import keras.layers as L
from keras.models import Model

z = L.Input(shape=(1,1, NOISE))
img = generator(z)

discriminator.trainable = False
real = discriminator(img)
combined_model = Model(z, real)

In [None]:
combined_model.compile(
    loss='binary_crossentropy',  
    optimizer=keras.optimizers.adamax(lr=1e-4)
)

In [None]:
combined_model.summary()

In [None]:
batch_size = 64

# Rescale 0 to 1
X = X_imp * (1./255)

# Adversarial ground truths
valid = np.ones((batch_size,))
fake = np.zeros((batch_size,))

In [None]:
def train_gan(gen_size, epochs = 10000, sample_interval = 2000):
    
    g_loss_hist = []
    d_loss_hist = []
    size = (batch_size,) + gen_size
    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Select a random batch of images
        idx = np.random.randint(0, X.shape[0], batch_size)
        imgs = X[idx]

        noise = np.random.normal(0, 1, size=size)

        # Generate a batch of new images
        gen_imgs = generator.predict(noise)

        # Train the discriminator
        d_loss_real = discriminator.train_on_batch(imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------

        noise = np.random.normal(0, 1,  size=size)

        # Train the generator (to have the discriminator label samples as valid)
        g_loss = combined_model.train_on_batch(noise, valid)

        # Plot the progress each 100 epoch
        if (epoch % 100) == 0:
            display.clear_output(wait=True)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            g_loss_hist.append(g_loss)
            d_loss_hist.append(d_loss[0])
            sample_size = (1000,) + gen_size
            sample_probas(1000, sample_size)

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch, gen_size)


In [None]:
train_gan(gen_size=(1,1,NOISE))

In [None]:
generator.save_weights(filepath='generator01.h5')

In [None]:
discriminator.save_weights(filepath='discriminator01.h5')