In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from models.GenerativeAdversarialNetwork import GenerativeAdversarialNetwork
from utils.loaders import load_safari

In [None]:
# run params
SECTION = 'gan'
RUN_ID = '0001'
DATA_NAME = 'camel'
RUN_FOLDER = f'run/{SECTION}/'
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode = 'build' #'load' #

## Data

In [None]:
(x_train, y_train) = load_safari(DATA_NAME)

In [None]:
x_train.shape

In [None]:
plt.imshow(x_train[200,:,:,0], cmap = 'gray')

## Model

In [None]:
image_dim = (28, 28, 1)
latent_dim = 100

In [None]:
generator_initial_dim = (7, 7, 64)
generator_activation = 'relu'
discriminator_activation = 'relu'
use_batch_norm = True
discriminator_learning_rate = 0.0008,
generator_learning_rate = 0.0004,
generator_batch_norm_momentum = 0.9
discriminator_batch_norm_momentum = None
use_drop_out = True
discriminator_dense_dim = 0
generator_dropout_rate = None
discriminator_dropout_rate = .4
generator_convolutional_params = [
    {'strides': (1, 1), 'filters': 128, 'kernel_size': (5, 5), 'upsample': 2, },
    {'strides': (1, 1), 'filters': 64, 'kernel_size': (5, 5), 'upsample': 2, },
    {'strides': (1, 1), 'filters': 64, 'kernel_size': (5, 5), 'upsample': 1, },
    {'strides': (1, 1), 'filters': 1, 'kernel_size': (5, 5), 'upsample': 1, },
    ]
discriminator_convolutional_params = [
    {'strides': (2, 2), 'filters': 64, 'kernel_size': (5, 5),},
    {'strides': (2, 2), 'filters': 64, 'kernel_size': (5, 5),},
    {'strides': (2, 2), 'filters': 128, 'kernel_size': (5, 5),},
    {'strides': (1, 1), 'filters': 128, 'kernel_size': (5, 5),},
    ]

In [None]:
gan = GenerativeAdversarialNetwork(
    image_dim=image_dim,
    latent_dim=latent_dim,
    generator_initial_dim=generator_initial_dim,
    discriminator_dense_dim=discriminator_dense_dim,
    generator_activation=generator_activation,
    discriminator_activation=discriminator_activation,
    generator_convolutional_params=generator_convolutional_params,
    discriminator_learning_rate = discriminator_learning_rate,
    generator_learning_rate = generator_learning_rate,
    discriminator_convolutional_params=discriminator_convolutional_params,
    generator_batch_norm_momentum=generator_batch_norm_momentum,
    discriminator_batch_norm_momentum=discriminator_batch_norm_momentum,
    generator_dropout_rate=generator_dropout_rate,
    discriminator_dropout_rate=discriminator_dropout_rate,
    
    )

In [None]:
gan.generator_model.summary()

In [None]:
gan.discriminator_model.summary()

In [None]:
gan.adversarial_model.summary()

## Training

In [None]:
BATCH_SIZE = 64
EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5

In [None]:
gan.train(     
    x_train
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
)

In [None]:
latent_noise = np.random.normal(0., 1., (10, gan.latent_dim))

In [None]:
images = gan.generator_model.predict(latent_noise)

In [None]:
images.shape

In [None]:
for i in range(images.shape[0]):
    plt.imshow(images[i, :, :, 0], cmap = 'gray')
    plt.show()