In [1]:
from keras.datasets import fashion_mnist
from keras.utils import np_utils
from keras.models import Sequential, Model
from keras.layers import Input, Dense, Activation, Flatten, Reshape
from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D, Convolution2D
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt2
import random
#from tqdm import tqdm_notebook
import math
from keras import layers

import scipy as sp
#from tqdm import tqdm_notebook
from math import floor
from numpy import ones
from numpy import expand_dims
from numpy import log
from numpy import mean
from numpy import std
from numpy import exp
from numpy import resize
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
 
# assumes images have the shape 299x299x3, pixels in [0,255]
def scale_images(images, new_shape):
    images_list = list()
    for image in images:
        # resize with nearest neighbor interpolation
        new_image = resize(image, new_shape)
        # store
        images_list.append(new_image)
    return np.asarray(images_list)
def calculate_inception_score(images, n_split=10, eps=1E-16):
    # load inception v3 model
    model = InceptionV3()
    # convert from uint8 to float32
    processed = images.astype('float32')
    # pre-process raw images for inception v3 model
    processed = preprocess_input(processed)
    # predict class probabilities for images
    yhat = model.predict(processed)
    # enumerate splits of images/predictions
    scores = list()
    n_part = floor(images.shape[0] / n_split)
    for i in range(n_split):
        # retrieve p(y|x)
        ix_start, ix_end = i * n_part, i * n_part + n_part
        p_yx = yhat[ix_start:ix_end]
        # calculate p(y)
        p_y = expand_dims(p_yx.mean(axis=0), 0)
        # calculate KL divergence using log probabilities
        kl_d = p_yx * (log(p_yx + eps) - log(p_y + eps))
        # sum over classes
        sum_kl_d = kl_d.sum(axis=1)
        # average over images
        avg_kl_d = mean(sum_kl_d)
        # undo the log
        is_score = exp(avg_kl_d)
        # store
        scores.append(is_score)
    # average across images
    is_avg, is_std = mean(scores), std(scores)
    return is_avg, is_std
# calculate frechet inception distance
def get_fid(act1, act2):
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sp.linalg.sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid
# Dataset of 60,000 28x28 grayscale images of the 10 digits, along with a test set of 10,000 images.
(X_train, Y_train), (X_test, Y_test) = fashion_mnist.load_data()
# print(X_train.shape)

z_dim = 100

X_train = X_train.reshape(60000, 28, 28, 1)
X_test = X_test.reshape(10000, 28, 28, 1)
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

X_train = X_train[0:30000]

nch = 20
g_input = Input(shape=[100])
H1 = Dense(nch * 14 * 14, init='glorot_normal')(g_input)
H = BatchNormalization()(H1)
H = Activation('relu')(H)
H = Reshape([nch, 14, 14])(H)
H = UpSampling2D(size=(2, 2))(H)
H = Convolution2D(int(nch / 2), 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization()(H)
H = Activation('relu')(H)
H = Convolution2D(int(nch / 4), 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization()(H)
H = Activation('relu')(H)
H = Convolution2D(1, 1, 1, border_mode='same', init='glorot_uniform')(H)
g_V = Activation('sigmoid')(H)
generator = Model(g_input, g_V)
generator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
generator.summary()

# Generator
adam = Adam(lr=0.0002, beta_1=0.5)

g = Sequential()
layer = g.add(Dense(7 * 7 * 112, input_dim=z_dim))
g.add(Reshape((7, 7, 112)))
g.add(BatchNormalization())
g.add(Activation(LeakyReLU(alpha=0.2)))
g.add(Conv2DTranspose(56, 5, strides=2, padding='same'))
g.add(BatchNormalization())
g.add(Activation(LeakyReLU(alpha=0.2)))
g.add(Conv2DTranspose(1, 5, strides=2, padding='same', activation='sigmoid'))
g.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
g.summary()

d = Sequential()
d.add(Conv2D(56, 5, strides=2, padding='same', input_shape=(28, 28, 1), activation=LeakyReLU(alpha=0.2)))
d.add(Conv2D(112, 5, strides=2, padding='same'))
g.add(BatchNormalization())
g.add(Activation(LeakyReLU(alpha=0.2)))
d.add(Conv2D(224, 5, strides=2, padding='same'))
g.add(Activation(LeakyReLU(alpha=0.2)))
d.add(Flatten())
d.add(Dense(112, activation=LeakyReLU(alpha=0.2)))
d.add(Dense(1, activation='sigmoid'))
d.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
d.summary()

d.trainable = False
inputs = Input(shape=(z_dim,))
hidden = g(inputs)
output = d(hidden)
gan = Model(inputs, output)
gan.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
gan.summary()


def plot_loss(losses):
    """
    @losses.keys():
        0: loss
        1: accuracy
    """
    d_loss = [v[0] for v in losses["D"]]
    g_loss = [v[0] for v in losses["G"]]

    plt.figure(figsize=(6.4, 4.8))
    plt.plot(d_loss, color='red', label="Discriminator loss")
    plt.plot(g_loss, color='green', label="Generator loss")
    plt.title("GAN : MNIST dataset")

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('loss.png')
    # plt.show()


def plot_generated(n_ex=20, dim=(2, 10), figsize=(48, 8)):
    noise = np.random.normal(0, 1, size=(n_ex, z_dim))
    generated_images = g.predict(noise)
    generated_images = generated_images.reshape(generated_images.shape[0], 28, 28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i + 1)
        # plt.imshow(generated_images[i, :, :], interpolation='nearest', cmap='gray_r')
        sss = str(i)

        plt.imsave(sss, generated_images[i, :, :], cmap='gray_r')

        plt.axis('off')
    plt.tight_layout()
    plt.plot()
    plt.show()


# Set up a vector (dict) to store the losses
losses = {"D": [], "G": []}
samples = []


mfid=[]
def train(d, epochs=1, plt_frq=1, BATCH_SIZE=128):
    batchCount = int(X_train.shape[0] / BATCH_SIZE)
    #batchCount=100
    print('Epochs:', epochs)
    print('Batch size:', BATCH_SIZE)
    print('Batches per epoch:', batchCount)

    d_v = []
    for e in range(1, epochs + 1):
        if e == 1 or e % plt_frq == 0:
            print('-' * 15, 'Epoch %d' % e, '-' * 15)
        for _ in range(batchCount):  # tqdm_notebook(range(batchCount), leave=False):
            # Create a batch by drawing random index numbers from the training set
            image_batch = X_train[np.random.randint(0, X_train.shape[0], size=BATCH_SIZE)]
            image_batch = image_batch.reshape(image_batch.shape[0], image_batch.shape[1], image_batch.shape[2], 1)
            # print(image_batch.shape)
            # Create noise vectors for the generator
            noise = np.random.normal(0, 1, size=(BATCH_SIZE, z_dim))

            # Generate the images from the noise
            generated_images = g.predict(noise)
            samples.append(generated_images)
            X = np.concatenate((image_batch, generated_images))
            # Create labels
            y = np.zeros(2 * BATCH_SIZE)
            y[:BATCH_SIZE] = 0.9  # One-sided label smoothing

            # Train discriminator on generated images
            d.trainable = True
            d_loss = d.train_on_batch(X, y)
            # Train generator
            noise = np.random.normal(0, 1, size=(BATCH_SIZE, z_dim))
            y2 = np.ones(BATCH_SIZE)
            d.trainable = False
            g_loss = gan.train_on_batch(noise, y2)
            weights = []
            ccc = 0

        #weights = g.layers[0].get_weights()[0]
        #w = 3
        #weights = weights.reshape(548800)
        #bin = [-0.08, -0.06, -0.04, -0.02, 0, 0.02, 0.04, 0.06, 0.08]

        #px,py,_ = plt2.hist(weights, bins=100)
        #print(str(np.argmax(px)) + '--' + str(np.max(py)))

        sss = str(e)
        # plt2.savefig("plt" +sss +".png")
        # plt2.clf()
        #print(len(weights))

        # Only store losses from final batch of epoch
        #images1 = scale_images(generated_images, (299,299,3))
        #print(calculate_inception_score(images1))
        
        
        image_batch = image_batch.reshape(BATCH_SIZE,784)
        generated_images = generated_images.reshape(BATCH_SIZE,784)
        temp=get_fid(generated_images, image_batch)
        print('fid : ' + str(temp))
        
        mfid.append(temp)
        losses["D"].append(d_loss)
        losses["G"].append(g_loss)

        # Update the plots
        if e == 1 or e % plt_frq == 0:
            plot_generated()
    # print((weights.shape ))
    

    # for fff in range(0,100):
    # print(losses["D"][fff][0])
    # print("-------")
    # for fff in range(0,100):
    # print(losses["G"][fff][0])


train(d, epochs=100, plt_frq=20, BATCH_SIZE=128)

for i in range(0,100):
    print(mfid[i])


Using TensorFlow backend.



Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz



















Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 100)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 3920)              395920    
_________________________________________________________________
batch_normalization_1 (Batch (None, 3920)              15680     
_________________________________________________________________
activation_1 (Activation)    (None, 3920)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 20, 14, 14)        0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 40, 28, 14)        0         
__________________________________________



Non-trainable params: 7,870
_________________________________________________________________


  identifier=identifier.__class__.__name__))


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_2 (Dense)              (None, 5488)              554288    
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 112)         0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 7, 7, 112)         448       
_________________________________________________________________
activation_5 (Activation)    (None, 7, 7, 112)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 56)        156856    
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 56)        224       
_________________________________________________________________
activation_6 (Activation)    (None, 14, 14, 56)        0         
__________

<Figure size 4800x800 with 20 Axes>

fid : 49.31038147388017
fid : 140.48658177101356
fid : 96.24495645962845
fid : 55.54836869271897
fid : 61.49900624756225
fid : 66.254260241344
fid : 64.45300340998041
fid : 54.313497471332056
fid : 53.57552667361824
fid : 52.66769545470994
fid : 42.376474924075
fid : 55.32951858606829
fid : 51.17456551515933
fid : 44.51603941235621
fid : 44.6562383670644
fid : 43.816391818020655
fid : 46.645549996089485
fid : 47.6680757621728
--------------- Epoch 20 ---------------
fid : 42.74172222389394


<Figure size 4800x800 with 20 Axes>

fid : 40.5924203641096
fid : 41.415494282729625
fid : 48.04118439104018
fid : 34.44761627721383
fid : 41.63747712204555
fid : 35.850189981227544
fid : 33.61948357412486
fid : 30.674172949744985
fid : 42.78381455536145
fid : 35.994330558296156
fid : 41.24999967302827
fid : 34.32275786191274
fid : 35.99559749248266
fid : 35.06223530165145
fid : 37.476567022218504
fid : 36.729341860975595
fid : 35.101209040977196
fid : 34.49942306853445
fid : 35.96304740648869
--------------- Epoch 40 ---------------
fid : 35.80706583743792


<Figure size 4800x800 with 20 Axes>

fid : 41.29266724817629
fid : 37.980734095891904
fid : 39.5242469104663
fid : 36.3070589146526
fid : 40.23440328351998
fid : 41.149862155418035
fid : 44.0691734365004
fid : 35.817870610050434
fid : 39.82277340986011
fid : 41.01447209119928
fid : 43.020695296260456
fid : 38.014468175361166
fid : 42.97671593502019
fid : 42.81808085746128
fid : 39.22385955436455
fid : 39.03927885063421
fid : 39.135212523973856
fid : 43.64437281451253
fid : 39.42083832306416
--------------- Epoch 60 ---------------
fid : 39.26793374125565


<Figure size 4800x800 with 20 Axes>

fid : 44.72814304826577
fid : 44.908020039606754
fid : 42.32088798543617
fid : 40.78883185586869
fid : 44.86624906211728
fid : 39.13181284990338
fid : 41.625184094251395
fid : 45.821278123056224
fid : 44.20087578043598
fid : 35.39398127309944
fid : 43.04230991157223
fid : 40.97321921942543
fid : 46.18694572848896
fid : 45.06152214322722
fid : 44.63014932409471
fid : 44.72835849132589
fid : 37.95373148534946
fid : 42.27449105849755
fid : 43.11147985917445
--------------- Epoch 80 ---------------
fid : 47.85596774757954


<Figure size 4800x800 with 20 Axes>

fid : 49.583527222581296
fid : 40.470584349181465
fid : 46.16010215407023
fid : 49.57121012473803
fid : 45.00366370314353
fid : 40.44852880352418
fid : 45.12050187841143
fid : 44.378789858220955
fid : 45.13547140693012
fid : 41.391197874513225
fid : 36.07813522803848
fid : 43.09047566258446
fid : 41.96968040522806
fid : 40.40367915452603
fid : 44.233731357022734
fid : 40.07388536674739
fid : 47.48998285285841
fid : 47.181223180990486
fid : 46.207284924957165
--------------- Epoch 100 ---------------
fid : 46.46744868399938


<Figure size 4800x800 with 20 Axes>

199.99079255901583
49.31038147388017
140.48658177101356
96.24495645962845
55.54836869271897
61.49900624756225
66.254260241344
64.45300340998041
54.313497471332056
53.57552667361824
52.66769545470994
42.376474924075
55.32951858606829
51.17456551515933
44.51603941235621
44.6562383670644
43.816391818020655
46.645549996089485
47.6680757621728
42.74172222389394
40.5924203641096
41.415494282729625
48.04118439104018
34.44761627721383
41.63747712204555
35.850189981227544
33.61948357412486
30.674172949744985
42.78381455536145
35.994330558296156
41.24999967302827
34.32275786191274
35.99559749248266
35.06223530165145
37.476567022218504
36.729341860975595
35.101209040977196
34.49942306853445
35.96304740648869
35.80706583743792
41.29266724817629
37.980734095891904
39.5242469104663
36.3070589146526
40.23440328351998
41.149862155418035
44.0691734365004
35.817870610050434
39.82277340986011
41.01447209119928
43.020695296260456
38.014468175361166
42.97671593502019
42.81808085746128
39.22385955436455
39.