# WGAN Training

## imports

In [None]:
%matplotlib inline

import os
import numpy as np
import matplotlib.pyplot as plt

from models.WassersteinGenerativeAdversarialNetwork import WassersteinGenerativeAdversarialNetwork as WGAN
from utils.loaders import load_cifar



In [None]:
# run params
SECTION = 'wgan'
RUN_ID = '0002'
DATA_NAME = 'horses'
RUN_FOLDER = f'run/{RUN_ID}/'
RUN_FOLDER += f'{RUN_ID}_{DATA_NAME}'

for p in ['run', f'run/{RUN_ID}']:
    if not os.path.exists(p):
        os.mkdir(p)

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #


## data

In [None]:
if DATA_NAME == 'cars':
    label = 1
elif DATA_NAME == 'horses':
    label = 7
(x_train, y_train) = load_cifar(label, 10)


In [None]:
plt.imshow((x_train[150,:,:,:]+1)/2)

## architecture

In [None]:
image_dim = (32,32,3)
latent_dim = 100

generator_initial_dim = (4, 4, 128)

generator_activation = 'leaky_relu'
critic_activation = 'leaky_relu'

critic_learning_rate = 0.00005,
generator_learning_rate = 0.00005,
generator_batch_norm_momentum = 0.8
critic_batch_norm_momentum = None
critic_dense_dim = 0
generator_dropout_rate = None
critic_dropout_rate = None

generator_convolutional_params = [
    {'strides': (1, 1), 'filters': 128, 'kernel_size': (5, 5), 'upsample': 2, },
    {'strides': (1, 1), 'filters': 64, 'kernel_size': (5, 5), 'upsample': 2, },
    {'strides': (1, 1), 'filters': 32, 'kernel_size': (5, 5), 'upsample': 2, },
    {'strides': (1, 1), 'filters': 3, 'kernel_size': (5, 5), 'upsample': 1, },
    ]

critic_convolutional_params = [
    {'strides': (2, 2), 'filters': 32, 'kernel_size': (5, 5),},
    {'strides': (2, 2), 'filters': 64, 'kernel_size': (5, 5),},
    {'strides': (2, 2), 'filters': 128, 'kernel_size': (5, 5),},
    {'strides': (1, 1), 'filters': 128, 'kernel_size': (5, 5),},
    ]

In [None]:
if mode == 'build':
    gan = WGAN(
        image_dim=image_dim,
        latent_dim=latent_dim,
        generator_initial_dim=generator_initial_dim,
        critic_dense_dim=critic_dense_dim,
        generator_activation=generator_activation,
        critic_activation=critic_activation,
        generator_convolutional_params=generator_convolutional_params,
        critic_learning_rate = critic_learning_rate,
        generator_learning_rate = generator_learning_rate,
        critic_convolutional_params=critic_convolutional_params,
        generator_batch_norm_momentum=generator_batch_norm_momentum,
        critic_batch_norm_momentum=critic_batch_norm_momentum,
        generator_dropout_rate=generator_dropout_rate,
        critic_dropout_rate=critic_dropout_rate,
        )
    gan.save(RUN_FOLDER)

else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))








In [None]:
gan.critic_model.summary()

In [None]:
gan.generator_model.summary()

## training

In [None]:
BATCH_SIZE = 128
EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5
N_CRITIC = 5
CLIP_THRESHOLD = 0.01

In [None]:
gan.train(     
    x_train
    , batch_size=BATCH_SIZE
    , epochs=EPOCHS
    , run_folder=RUN_FOLDER
    , print_every_n_batches=PRINT_EVERY_N_BATCHES
    , critic_training_steps=N_CRITIC
    , clip_threshold=CLIP_THRESHOLD
)

In [None]:
gan.sample_images(RUN_FOLDER)

In [None]:
fig = plt.figure()
plt.plot([x[0] for x in gan.critic_valid_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.critic_valid_losses], color='green', linewidth=0.25)
plt.plot([x[0] for x in gan.critic_generated_losses], color='red', linewidth=0.25)
plt.plot(gan.generator_losses, color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

# plt.xlim(0, 2000)
# plt.ylim(0, 2)

plt.show()

In [None]:
def compare_images(img1, img2):
    return np.mean(np.abs(img1 - img2))

In [None]:

r, c = 5, 5

idx = np.random.randint(0, x_train.shape[0], BATCH_SIZE)
true_imgs = (x_train[idx] + 1) *0.5

fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        axs[i,j].imshow(true_imgs[cnt], cmap = 'gray_r')
        axs[i,j].axis('off')
        cnt += 1
fig.savefig(os.path.join(RUN_FOLDER, "images/real.png"))
plt.close()

In [None]:
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, gan.latent_dim))
gen_imgs = gan.generator_model.predict(noise)

#Rescale images 0 - 1

gen_imgs = 0.5 * (gen_imgs + 1)
# gen_imgs = np.clip(gen_imgs, 0, 1)

fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray_r')
        axs[i,j].axis('off')
        cnt += 1
fig.savefig(os.path.join(RUN_FOLDER, "images/sample.png"))
plt.close()


fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        c_diff = 99999
        c_img = None
        for k_idx, k in enumerate((x_train + 1) * 0.5):
            
            diff = compare_images(gen_imgs[cnt, :,:,:], k)
            if diff < c_diff:
                c_img = np.copy(k)
                c_diff = diff
        axs[i,j].imshow(c_img, cmap = 'gray_r')
        axs[i,j].axis('off')
        cnt += 1

fig.savefig(os.path.join(RUN_FOLDER, "images/sample_closest.png"))
plt.close()