In [1]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import sys

import numpy as np

import gan

Using CNTK backend


In [48]:
# Load the dataset
(X_train, y_train), (X_val, y_val) = mnist.load_data()

# Rescale -1 to 1
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)

X_val = (X_val.astype(np.float32) - 127.5) / 127.5
X_val = np.expand_dims(X_val, axis=3)

In [49]:
training = []
for i in range(y_train.shape[0]):
    if y_train[i] != 0:
        training.append(X_train[i])
X_train = np.asarray(training)        

In [51]:
X_train.shape

(54077, 28, 28, 1)

In [52]:
for i in range(y_val.shape[0]):
    if y_val[i] != 0:
        y_val[i]=1

In [53]:
xx=[]
yy=[]
xxx=[]
yyy=[]
for i in range(y_val.shape[0]):
    if y_val[i] == 0:
        xx.append(X_val[i])
        yy.append(y_val[i])
    else:
        xxx.append(X_val[i])
        yyy.append(y_val[i])
        
xx = np.asarray(xx) 
yy = np.asarray(yy)       
xxx = np.asarray(xxx)       
yyy = np.asarray(yyy)       

In [54]:
yy.shape

(980,)

In [56]:
y_val = np.concatenate((yy,yyy[0:yy.shape[0]]), axis=0)
X_val = np.concatenate((xx,xxx[0:yy.shape[0]]), axis=0)

In [57]:
X_val.shape

(1960, 28, 28, 1)

In [60]:
val_step = 10
val_data = (X_val, y_val, val_step)

In [21]:
 def build_generator():

    noise_shape = (100,)
        
    model = Sequential()

    model.add(Dense(256, input_shape=noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))
    model.add(Reshape((28, 28, 1)))

    model.summary()

    noise = Input(shape=noise_shape)
    img = model(noise)

    return Model(noise, img)
    
def build_discriminator():

    img_shape = (28, 28, 1)
        
    model = Sequential()

    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

In [22]:
my_gan=gan.GAN(discriminator=build_discriminator(),generator=build_generator())
my_gan.build_networks(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_8 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_10 (Dense)             (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
____

In [61]:
my_gan.train(X_train=X_train, epochs=30000, batch_size=32, save_interval=200, val_data=val_data)

  (sample.dtype, var.uid, str(var.dtype)))
  (sample.dtype, var.uid, str(var.dtype)))
  (sample.dtype, var.uid, str(var.dtype)))


0 [D loss: 0.628087, acc.: 65.62%] [D val auc: 0.83%] [G loss: 0.892793]
1 [D loss: 0.663567, acc.: 59.38%] [G loss: 0.888631]
2 [D loss: 0.664982, acc.: 59.38%] [G loss: 0.873811]
3 [D loss: 0.632690, acc.: 59.38%] [G loss: 0.843850]
4 [D loss: 0.620304, acc.: 62.50%] [G loss: 0.855652]
5 [D loss: 0.640553, acc.: 65.62%] [G loss: 0.848812]
6 [D loss: 0.614354, acc.: 62.50%] [G loss: 0.888004]
7 [D loss: 0.611507, acc.: 65.62%] [G loss: 0.859326]
8 [D loss: 0.686727, acc.: 46.88%] [G loss: 0.783785]
9 [D loss: 0.761136, acc.: 43.75%] [G loss: 0.830132]
10 [D loss: 0.691324, acc.: 59.38%] [D val auc: 0.82%] [G loss: 0.835948]
11 [D loss: 0.655331, acc.: 68.75%] [G loss: 0.812390]
12 [D loss: 0.719346, acc.: 53.12%] [G loss: 0.810189]
13 [D loss: 0.659222, acc.: 53.12%] [G loss: 0.799531]
14 [D loss: 0.717822, acc.: 43.75%] [G loss: 0.856246]
15 [D loss: 0.675643, acc.: 53.12%] [G loss: 0.860350]
16 [D loss: 0.711183, acc.: 46.88%] [G loss: 0.830892]
17 [D loss: 0.713557, acc.: 50.00%] [

KeyboardInterrupt: 