# Generative Adversarial Network

In [4]:
import ndjson
import numpy as np
import matplotlib.pyplot as plt

from os import mkdir, walk
from os.path import join, exists
from PIL import Image, ImageDraw

from src.model.GAN import GAN

## Data

### Load Training Data

In [8]:
def load_images_from_npy_file(directory, n_images):
    mypath = join("./data", directory)
    txt_name_list = []
    for (dirpath, dirnames, filenames) in walk(mypath):
        for f in filenames:
            if f != '.DS_Store':
                txt_name_list.append(f)
                break

    slice_train = int(n_images/len(txt_name_list))
    i = 0
    seed = np.random.randint(1, 10e6)

    for txt_name in txt_name_list:
        txt_path = join(mypath,txt_name)
        x = np.load(txt_path)
        x = (x.astype('float32') - 127.5) / 127.5
        
        x = x.reshape(x.shape[0], 28, 28, 1)
        
        y = [i] * len(x)  
        
        np.random.seed(seed)
        np.random.shuffle(x)
        
        np.random.seed(seed)
        np.random.shuffle(y)

        x = x[:slice_train]
        y = y[:slice_train]
        
        if i != 0: 
            x_total = np.concatenate((x,x_total), axis=0)
            y_total = np.concatenate((y,y_total), axis=0)
        else:
            x_total = x
            y_total = y
        i += 1
        
    return x_total, y_total

In [9]:
(x_train, y_train) = load_images_from_npy_file('camel', 80000)

## GAN Architecture

In [None]:
gan = GAN(
    input_dim = (28,28,1),
    discriminator_conv_filters = [64,64,128,128],
    discriminator_conv_kernel_size = [5,5,5,5],
    discriminator_conv_strides = [2,2,2,1],
    discriminator_batch_norm_momentum = None,
    discriminator_activation = 'relu',
    discriminator_dropout_rate = 0.4,
    discriminator_learning_rate = 0.0008,
    generator_initial_dense_layer_size = (7, 7, 64),
    generator_upsample = [2,2, 1, 1],
    generator_conv_filters = [128,64, 64,1],
    generator_conv_kernel_size = [5,5,5,5],
    generator_conv_strides = [1,1, 1, 1],
    generator_batch_norm_momentum = 0.9,
    generator_activation = 'relu',
    generator_dropout_rate = None,
    generator_learning_rate = 0.0004,
    optimizer = 'rmsprop',
    z_dim = 100,
)

In [None]:
gan.discriminator.summary()

In [None]:
gan.generator.summary()

## Train GAN

In [None]:
BATCH_SIZE = 64
EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5

In [None]:
gan.train(     
    x_train,
    batch_size = BATCH_SIZE,
    epochs = EPOCHS,
    print_every_n_batches = PRINT_EVERY_N_BATCHES
)

In [None]:
fig = plt.figure()
plt.plot([x['d_loss'] for x in gan.discriminator_losses], color='black', linewidth=0.25)

plt.plot([x['d_loss_real'] for x in gan.discriminator_losses], color='green', linewidth=0.25)
plt.plot([x['d_loss_fake'] for x in gan.discriminator_losses], color='red', linewidth=0.25)
plt.plot([x['d_loss'] for x in gan.discriminator_losses], color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, 2000)
plt.ylim(0, 2)

plt.show()

In [None]:
fig = plt.figure()
plt.plot([x['d_acc'] for x in gan.discriminator_losses], color='black', linewidth=0.25)
plt.plot([x['d_acc_real'] for x in gan.discriminator_losses], color='green', linewidth=0.25)
plt.plot([x['d_acc_fake'] for x in gan.discriminator_losses], color='red', linewidth=0.25)
plt.plot([x['d_loss_real'] for x in gan.discriminator_losses], color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('accuracy', fontsize=16)

plt.xlim(0, 2000)

plt.show()

In [None]:
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, gan.z_dim))
gen_imgs = gan.generator.predict(noise)

fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray')
        axs[i,j].axis('off')
        cnt += 1

fig.show()

In [None]:
import pickle as pkl

In [None]:
def save_gan():
    with open(join('./model', 'params.pkl'), 'wb') as f:
        pkl.dump([
            gan.input_dim,
            gan.discriminator_conv_filters,
            gan.discriminator_conv_kernel_size,
            gan.discriminator_conv_strides,
            gan.discriminator_batch_norm_momentum,
            gan.discriminator_activation,
            gan.discriminator_dropout_rate,
            gan.discriminator_learning_rate,
            gan.generator_initial_dense_layer_size,
            gan.generator_upsample,
            gan.generator_conv_filters,
            gan.generator_conv_kernel_size,
            gan.generator_conv_strides,
            gan.generator_batch_norm_momentum,
            gan.generator_activation,
            gan.generator_dropout_rate,
            gan.generator_learning_rate,
            gan.optimizer_str,
            gan.z_dim,
        ], f)
    
    gan.model.save(join('./model', 'model.h5'))
    gan.discriminator.save(join('./model', 'discriminator.h5'))
    gan.generator.save(join('./model', 'generator.h5'))

In [None]:
save_gan()