In [None]:
from IPython import display
import matplotlib.pyplot as plt
import cPickle as pickle
import numpy as np
from keras.models import load_model
from keras.datasets import cifar10

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
def plotRecord(pickle_file):
    with open(pickle_file, 'rb') as f:
        losses = pickle.load(f)

    display.clear_output(wait=True)
    display.display(plt.gcf())
    plt.figure(figsize=(10,8))
    plt.plot(losses["disc"], label='discriminitive loss')
    plt.plot(losses["gen"], label='generative loss')
    plt.plot(losses["acc_real"], label='discriminator accuracy on real images')
    plt.plot(losses["acc_gen"], label='discriminator accuracy on generated images')
    plt.plot(losses["acc_unl"], label='discriminator accuracy on unlabeled images')
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    # plt.ylim([0, 4])
    plt.show()

In [None]:
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

def plotOriginalImages(image_class):
    samples = 0
    sample_indices = []
    while samples < 16:
        sample_index = np.random.randint(0, X_train.shape[0])
        if y_train[sample_index] == image_class:
            sample_indices.append(sample_index)
            samples += 1

    plt.figure(figsize=(10, 3))
    f = 0
    for index in sample_indices:
        f += 1
        plt.subplot(2, 8, f)
        img = X_train[index,:,:,:]
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def plotGeneratedImages(generator_file):
    generator = load_model(generator_file)
    
    noise = np.random.uniform(0,1,size=[16,100])
    generated_images = generator.predict(noise)

    plt.figure(figsize=(10, 3))
    for i in range(generated_images.shape[0]):
        plt.subplot(2, 8, i+1)
        img = generated_images[i,:,:,:]
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
plotRecord('record0-1200.pickle')

In [None]:
plotOriginalImages(0) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen0-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen0-1200.h5') # After 1200 epochs

In [None]:
plotRecord('record1-1200.pickle')

In [None]:
plotOriginalImages(1) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen1-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen1-1200.h5') # After 1200 epochs

In [None]:
plotRecord('record2-1200.pickle')

In [None]:
plotOriginalImages(2) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen2-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen2-1200.h5') # After 1200 epochs

In [None]:
plotRecord('record3-1200.pickle')

In [None]:
plotOriginalImages(3) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen3-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen3-1200.h5') # After 1200 epochs

In [None]:
plotRecord('record4-1200.pickle')

In [None]:
plotOriginalImages(4) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen4-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen4-1200.h5') # After 1200 epochs

In [None]:
plotRecord('record5-1200.pickle')

In [None]:
plotOriginalImages(5) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen5-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen5-1200.h5') # After 1200 epochs

In [None]:
plotRecord('record6-1200.pickle')

In [None]:
plotOriginalImages(6) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen6-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen6-1200.h5') # After 1200 epochs

In [None]:
plotRecord('record7-1200.pickle')

In [None]:
plotOriginalImages(7) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen7-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen7-1200.h5') # After 1200 epochs

In [None]:
plotRecord('record8-1200.pickle')

In [None]:
plotOriginalImages(8) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen8-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen8-1200.h5') # After 1200 epochs

In [None]:
plotRecord('record9-1200.pickle')

In [None]:
plotOriginalImages(9) # Original CIFAR-10 images

In [None]:
plotGeneratedImages('gen9-120.h5') # After 120 epochs

In [None]:
plotGeneratedImages('gen9-1200.h5') # After 1200 epochs