In [49]:
import numpy as np
import matplotlib.pyplot as plt

import idx2numpy

from keras.models import Model, Sequential
from keras.layers import Input, Dense, Flatten, Reshape, BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam

%matplotlib tk

In [50]:
noise_dims = 10
image_dims = 20, 20

In [51]:
images = idx2numpy.convert_from_file('circles.idx')/127 - 1

## Generator

In [61]:
model = Sequential()

model.add(Dense(32, activation='relu'))
model.add(BatchNormalization(momentum=.9))
model.add(Dense(512, activation='relu'))
model.add(BatchNormalization(momentum=.9))
model.add(Dense(1024, activation='relu'))
model.add(BatchNormalization(momentum=.9))
model.add(Dense(1024, activation='relu'))
model.add(Dense(np.prod(image_dims), activation='tanh'))
model.add(Reshape(image_dims))

noise = Input(shape=(noise_dims, ))
image_out = model(noise)

generator = Model(noise, image_out)

## Discriminator

In [62]:
model = Sequential()

model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

image_in = Input(shape=image_dims)
validity = model(image_in)
discriminator = Model(image_in, validity)

discriminator.compile(
    optimizer='Adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

## Stacked

In [63]:
noise_in = Input(shape=(noise_dims,))
image = generator(noise_in)

discriminator.trainable = False
for layer in discriminator.layers:
    layer.trainable = False

validity = discriminator(image)

stacked = Model(noise_in, validity)

stacked.compile(
    optimizer='Adam',
    loss='binary_crossentropy'
)

In [64]:
def show_samples(samples):
    n = len(samples)
    fig, axes = plt.subplots(1, n)
    
    for ax, samples in zip(axes, samples):
        ax.imshow(samples, cmap='Greys_r')
        ax.axis('off')

In [65]:
def train_generator(batch_size=32):
    noise = np.random.normal(0, 1, (batch_size, noise_dims))
    real = np.ones((batch_size, 1))

    return stacked.train_on_batch(noise, real)
    

In [66]:
def train_discriminator(batch_size=32):
    fake = np.zeros((batch_size, 1))
    real = np.ones((batch_size, 1))
    
    
    noise = np.random.normal(0, 1, (batch_size, noise_dims))
    random_indices = np.random.choice(images.shape[0], batch_size)
    
    real_images = images[random_indices, :, :]
    fake_images = generator.predict(noise)
    
    loss_fake, acc_fake = discriminator.train_on_batch(fake_images, fake)
    loss_real, acc_real = discriminator.train_on_batch(real_images, real)
    
    return loss_fake, loss_real, acc_fake, acc_real
    

In [67]:
history = []
stats = []

In [68]:
n_samples = 10

for epoch in range(8000):
    stacked_loss = train_generator(32)
    _, _, acc_fake, acc_real = train_discriminator(32)
    
    stats.append((stacked_loss, acc_fake, acc_real))
    
    
    if epoch % 100 == 0:
        print(f'epoch {epoch:10}    loss stacked: {stacked_loss:2.3f}    acc fake: {acc_fake:0.5f}    acc real: {acc_real:0.5f}')

        samples = generator.predict(np.random.normal(0, 1, (n_samples, noise_dims)))
        history.append(samples)

epoch          0    loss stacked: 0.807    acc fake: 0.00000    acc real: 0.00000
epoch        100    loss stacked: 8.852    acc fake: 1.00000    acc real: 1.00000
epoch        200    loss stacked: 10.845    acc fake: 1.00000    acc real: 1.00000
epoch        300    loss stacked: 10.139    acc fake: 1.00000    acc real: 1.00000
epoch        400    loss stacked: 12.871    acc fake: 1.00000    acc real: 1.00000
epoch        500    loss stacked: 13.864    acc fake: 1.00000    acc real: 1.00000
epoch        600    loss stacked: 14.227    acc fake: 1.00000    acc real: 1.00000
epoch        700    loss stacked: 14.522    acc fake: 1.00000    acc real: 1.00000
epoch        800    loss stacked: 14.621    acc fake: 1.00000    acc real: 1.00000
epoch        900    loss stacked: 14.849    acc fake: 1.00000    acc real: 1.00000
epoch       1000    loss stacked: 12.536    acc fake: 1.00000    acc real: 1.00000


KeyboardInterrupt: 

In [69]:
fig, axes = plt.subplots(10, 10)
for row, samples in zip(axes, history[:10]):
    for ax, sample in zip(row, samples):
        ax.imshow(sample, cmap='Greys_r')
        ax.axis('off')

fig.set_size_inches(20, 20)
plt.show()

In [None]:
fig, axes = plt.subplots(1, 10)
for ax, sample in zip(axes, images[:10]):
    ax.imshow(sample, cmap='Greys_r')
    ax.axis('off')

fig.set_size_inches(20, 1)
plt.show()