# WGAN-GP Training

## imports

In [None]:
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
# config.gpu_options.per_process_gpu_memory_fraction = 0.4
config.gpu_options.allow_growth=True
set_session(tf.Session(config=config))

In [None]:
%matplotlib inline

import os
import matplotlib.pyplot as plt

from models.WGANGP import WGANGP
from utils.loaders import load_celeb

import pickle


In [None]:
# run params
SECTION = 'wgan'
RUN_ID = '0003'
DATA_NAME = 'celeb'
RUN_FOLDER = f'run/{SECTION}/'
RUN_FOLDER += f'{RUN_ID}_{DATA_NAME}'

for p in ['run', f'run/{SECTION}']:
    if not os.path.exists(p):
        os.mkdir(p)

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #

## data

In [None]:
BATCH_SIZE = 64
IMAGE_SIZE = 64

In [None]:
x_train = load_celeb(DATA_NAME, IMAGE_SIZE, BATCH_SIZE)

In [None]:
x_train[0][0][0]

In [None]:
plt.imshow((x_train[0][0][0]+1)/2)

## architecture

In [None]:
image_dim = (IMAGE_SIZE, IMAGE_SIZE, 3)
latent_dim = 100

generator_initial_dim = (4, 4, 512)

generator_activation = 'leaky_relu'
critic_activation = 'leaky_relu'

critic_learning_rate = 0.0002,
generator_learning_rate = 0.0002,
generator_batch_norm_momentum = 0.9
critic_batch_norm_momentum = None
critic_dense_dim = 0
generator_dropout_rate = None
critic_dropout_rate = None

gradient_penalty_weight = 10.

generator_convolutional_params = [
    {'strides': (2, 2), 'filters': 256, 'kernel_size': (5, 5), 'upsample': 1, 'transpose': True,},
    {'strides': (2, 2), 'filters': 128, 'kernel_size': (5, 5), 'upsample': 1, 'transpose': True,},
    {'strides': (2, 2), 'filters': 64, 'kernel_size': (5, 5), 'upsample': 1, 'transpose': True,},
    {'strides': (2, 2), 'filters': 3, 'kernel_size': (5, 5), 'upsample': 1, 'transpose': True,},
    ]

critic_convolutional_params = [
    {'strides': (2, 2), 'filters': 64, 'kernel_size': (5, 5),},
    {'strides': (2, 2), 'filters': 128, 'kernel_size': (5, 5),},
    {'strides': (2, 2), 'filters': 256, 'kernel_size': (5, 5),},
    {'strides': (2, 2), 'filters': 512, 'kernel_size': (5, 5),},
    ]

In [None]:
gan = WGANGP(
    image_dim=image_dim,
    latent_dim=latent_dim,
    generator_initial_dim=generator_initial_dim,
    critic_dense_dim=critic_dense_dim,
    generator_activation=generator_activation,
    critic_activation=critic_activation,
    generator_convolutional_params=generator_convolutional_params,
    critic_learning_rate = critic_learning_rate,
    generator_learning_rate = generator_learning_rate,
    batch_size=BATCH_SIZE,
    critic_convolutional_params=critic_convolutional_params,
    generator_batch_norm_momentum=generator_batch_norm_momentum,
    critic_batch_norm_momentum=critic_batch_norm_momentum,
    generator_dropout_rate=generator_dropout_rate,
    critic_dropout_rate=critic_dropout_rate,
    gradient_penalty_weight=gradient_penalty_weight,
    )

if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

In [None]:
gan.critic_model.summary()

In [None]:
gan.critic_gp_model.summary()

In [None]:
gan.generator_model.summary()

In [None]:
gan.generator_model.summary()

## training

In [None]:
EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5
N_CRITIC = 5
BATCH_SIZE = 64

In [None]:
gan.train(     
    x_train
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
    , critic_training_steps=N_CRITIC
    , using_generator = True
)

In [None]:
fig = plt.figure()
plt.plot([x[0] for x in gan.critic_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.critic_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.critic_losses], color='red', linewidth=0.25)
plt.plot([x[0] for x in gan.generator_losses], color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, 2000)
# plt.ylim(0, 2)

plt.show()
