In [None]:
# generative adversarial network

import io
import os

from keras.optimizers import RMSprop

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
import keras
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Conv2D, BatchNormalization, Dropout, Flatten
from keras.layers import Activation, Reshape, Conv2DTranspose, UpSampling2D

import pandas as pd
import matplotlib.pyplot as plt

%matplotlib inline

Using TensorFlow backend.


In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
data = np.load(io.BytesIO(uploaded['underwear.npy']))
data = data / 255

# make 3d
data = np.reshape(data, [data.shape[0], 28, 28, 1])
img_h, img_w = data.shape[1:3]

plt.imshow(data[4343, :, :, 0], cmap='Greys')
plt.show()

In [None]:
def discriminator_builder(width=64, dropout=0.4):

    # define inputs
    inputs = Input((img_h, img_w, 1))  # 1 == color depth

    # convolutional layers
    conv1 = Conv2D(width * 1, 5, strides=2, padding='same', activation='relu')(inputs)
    conv1 = Dropout(dropout)(conv1)

    conv2 = Conv2D(width * 2, 5, strides=2, padding='same', activation='relu')(conv1)
    conv2 = Dropout(dropout)(conv2)

    conv3 = Conv2D(width * 4, 5, strides=2, padding='same', activation='relu')(conv2)
    conv3 = Dropout(dropout)(conv3)

    conv4 = Conv2D(width * 8, 5, strides=1, padding='same', activation='relu')(conv3)
    conv4 = Flatten()(Dropout(dropout)(conv4))

    output = Dense(1, activation='sigmoid')(conv4)

    model = Model(inputs=inputs, outputs=output)
    print(model.summary())

    return model

In [None]:
def generator_builder(latent_space=100, width=64, dropout=0.4):

    # define inputs
    inputs = Input((latent_space,))

    # first dense layer
    dense1 = Dense(7 * 7 * 64)(inputs)
    dense1 = BatchNormalization(momentum=0.9)(dense1)
    dense1 = Activation(activation='relu')(dense1)
    dense1 = Reshape((7, 7, 64))(dense1)
    dense1 = Dropout(dropout)(dense1)

    # deconvolutional layers
    conv1 = UpSampling2D()(dense1)
    conv1 = Conv2DTranspose(int(width / 2), kernel_size=5, padding='same', activation=None)(conv1)
    conv1 = BatchNormalization(momentum=0.9)(conv1)
    conv1 = Activation(activation='relu')(conv1)

    conv2 = UpSampling2D()(conv1)
    conv2 = Conv2DTranspose(int(width / 4), kernel_size=5, padding='same', activation=None)(conv2)
    conv2 = BatchNormalization(momentum=0.9)(conv2)
    conv2 = Activation(activation='relu')(conv2)

    conv3 = Conv2DTranspose(int(width / 8), kernel_size=5, padding='same', activation=None)(conv2)
    conv3 = BatchNormalization(momentum=0.9)(conv3)
    conv3 = Activation(activation='relu')(conv3)

    # output layer; activation = sigmoid to give us the color for each pixel
    output = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(conv3)

    model = Model(inputs=inputs, outputs=output)
    print(model.summary())

    return model

In [None]:
discriminator = discriminator_builder()
discriminator.compile(loss='binary_crossentropy',
                      optimizer=RMSprop(lr=0.0002, decay=2e-8, clipvalue=1.0),
                      metrics=['accuracy'])

generator = generator_builder()

In [None]:
# create adversarial network
def adversarial_builder(latent_space=100):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)

    model.compile(loss='binary_crossentropy',
                  optimizer=RMSprop(lr=0.0001, decay=1e-8, clipvalue=1.0),
                  metrics=['accuracy'])
    print(model.summary())
    return model

In [None]:
adversarial_model = adversarial_builder()

In [None]:
def make_trainable(net, val):
    net.trainable = val
    for layer in net.layers:
        layer.trainable = val

In [None]:
def train(epochs=2000, batch=128):
    d_metrics = []
    a_metrics = []

    running_d_loss = 0
    running_d_acc = 0
    running_a_loss = 0
    running_a_acc = 0

    for i in range(epochs):

        if i % 100 == 0:
            print(i)

        real_imgs = np.reshape(data[np.random.choice(data.shape[0], batch, replace=False)], (batch, 28, 28, 1))
        fake_imgs = generator.predict(np.random.uniform(-1.0, 1.0, size=[batch, 100]))

        x = np.concatenate((real_imgs, fake_imgs))
        y = np.ones([2 * batch, 1])
        y[batch:, :] = 0

        make_trainable(discriminator, True)

        d_metrics.append(discriminator.train_on_batch(x, y))
        running_d_loss += d_metrics[-1][0]
        running_d_acc += d_metrics[-1][1]

        make_trainable(discriminator, False)

        noise = np.random.uniform(-1.0, 1.0, size=[batch, 100])
        y = np.ones([batch, 1])

        a_metrics.append(adversarial_model.train_on_batch(noise, y))
        running_a_loss += a_metrics[-1][0]
        running_a_acc += a_metrics[-1][1]

        if (i + 1) % 300 == 0:

            print('Epoch #{}'.format(i + 1))
            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, running_d_loss / i, running_d_acc / i)
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, running_a_loss / i, running_a_acc / i)
            print(log_mesg)

            noise = np.random.uniform(-1.0, 1.0, size=[16, 100])
            gen_imgs = generator.predict(noise)

            plt.figure(figsize=(5, 5))

            for k in range(gen_imgs.shape[0]):
                plt.subplot(4, 4, k + 1)
                plt.imshow(gen_imgs[k, :, :, 0], cmap='gray')
                plt.axis('off')

            plt.tight_layout()
            plt.show()

    return a_metrics, d_metrics

In [None]:
a_metrics_complete, d_metrics_complete = train(epochs=3000)

In [None]:
ax = pd.DataFrame(
    {
        'Generator': [metric[0] for metric in a_metrics_complete],
        'Discriminator': [metric[0] for metric in d_metrics_complete],
    }
).plot(title='Training Loss', logy=True)
ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")

ax = pd.DataFrame(
    {
        'Generator': [metric[1] for metric in a_metrics_complete],
        'Discriminator': [metric[1] for metric in d_metrics_complete],
    }
).plot(title='Training Accuracy')
ax.set_xlabel("Epochs")
ax.set_ylabel("Accuracy")
