In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, UpSampling2D, BatchNormalization, Activation, Dropout, Reshape
from tensorflow.keras import Model
from tensorflow.keras.initializers import RandomNormal
import numpy as np
import matplotlib.pyplot as plt

# Check GPUs:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            # Prevent TensorFlow from allocating all memory of all GPUs:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

In [None]:
input_data = np.load(INPUT_PATH)
input_data = input_data.reshape(input_data.shape[0], 28, 28, 1)
input_data.shape

In [None]:
plt.imshow(input_data[300,:,:,0], cmap='gray')

# Discriminator

In [None]:
INPUT_PATH = 'D:/oystein/quick-draw/full_numpy_bitmap_camel.npy'
DISCR_BLOCKS = 4
DISCR_CONV_FILTERS = [64,64,128,128]
DISCR_CONV_KERNEL_SIZE = [5,5,5,5]
DISCR_CONV_STRIDES = [2,2,2,1]
DISCR_BATCH_NORM_MOMENTUM = None
DISCR_ACTIVATION = 'relu'
DISCR_DROPOUT_RATE = 0.4
DISCR_LEARNING_RATE = 0.0008

In [None]:
discriminator_input = Input((28,28,1), name = 'discriminator_input')
x = discriminator_input

for i in range(DISCR_BLOCKS):
    x = Conv2D(filters = DISCR_CONV_FILTERS[i], 
               kernel_size = DISCR_CONV_KERNEL_SIZE[i], 
               strides = DISCR_CONV_STRIDES[i], 
               padding = 'same', 
               name = 'discriminator_conv_' + str(i))(x)
    if DISCR_BATCH_NORM_MOMENTUM and i > 0:
        x = BatchNormalization(momentum = DISCR_BATCH_NORM_MOMENTUM)(x)
        
    x = Activation(DISCR_ACTIVATION)(x)
    
    if DISCR_DROPOUT_RATE:
        x = Dropout(rate = DISCR_DROPOUT_RATE)(x)

x = Flatten()(x)
discriminator_output = Dense(1, 
                             activation = 'sigmoid', 
                             kernel_initializer = RandomNormal(mean=0., stddev=0.02))(x)
discriminator = Model(discriminator_input, discriminator_output)

discriminator.summary()

# Generator

In [None]:
GEN_BLOCKS = 4
GEN_INPUT_SIZE = 100
GEN_INITIAL_DENSE_LAYER_SIZE = [7,7,64]
GEN_UPSAMPLE = [2,2,1,1]
GEN_CONV_FILTERS = [128,64, 64,1]
GEN_CONV_KERNEL_SIZE = [5,5,5,5]
GEN_CONV_STRIDES = [1,1,1,1]
GEN_BATCH_NORM_MOMENTUM = 0.9
GEN_ACTIVATION = 'relu'
GEN_DROPOUT_RATE = None
GEN_LEARNING_RATE = 0.0004

In [None]:
generator_input = Input(shape = (GEN_INPUT_SIZE,), name = 'generator_input')
x = generator_input

x = Dense(np.prod(GEN_INITIAL_DENSE_LAYER_SIZE), 
          kernel_initializer = RandomNormal(mean=0., stddev=0.02))(x)

if GEN_BATCH_NORM_MOMENTUM:
    x = BatchNormalization(momentum = GEN_BATCH_NORM_MOMENTUM)(x)
    
x = Activation(GEN_ACTIVATION)(x)
x = Reshape(GEN_INITIAL_DENSE_LAYER_SIZE)(x)

if GEN_DROPOUT_RATE:
    x = Dropout(rate = GEN_DROPOUT_RATE)(x)

for i in range(GEN_BLOCKS):
    
    if i < GEN_BLOCKS - 2:
        x = UpSampling2D()(x)
        
    x = Conv2D(filters = GEN_CONV_FILTERS[i], 
               kernel_size = GEN_CONV_KERNEL_SIZE[i],
               padding = 'same',
               name = 'generator_conv_' + str(i))(x)
    
    if i < GEN_BLOCKS - 1:
        if GEN_BATCH_NORM_MOMENTUM:
            x = BatchNormalization(momentum = GEN_BATCH_NORM_MOMENTUM)(x)
        x = Activation(GEN_ACTIVATION)(x)
    else:
        x = Activation('tanh')(x)

generator_output = x
generator = Model(generator_input, generator_output)
generator.summary()