In [1]:
# Import Libraries
from keras.models import Sequential, clone_model, Model
from keras.layers import Dense, Dropout, Input
from keras.layers.merge import concatenate
from keras.datasets import mnist
from keras.utils import np_utils
from keras import backend as K
import matplotlib.pyplot as plt
import numpy as np

Using TensorFlow backend.


In [7]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], -1)/255
X_test = X_test.reshape(X_test.shape[0], -1)/255

X_train = X_train[y_train == 0]
X_test = X_test[y_test == 0]
# y_train = np_utils.to_categorical(y_train)
# y_test = np_utils.to_categorical(y_test)
# print(X_train.shape)
# print(X_test.shape)
# print(y_train.shape)
# print(y_test.shape)
y_train = y_train[y_train == 0]
y_test = y_test[y_test == 0]

In [8]:
def baseline_model(num_pixels, num_classes):
    inpt1 = Input(shape=(num_pixels,))
    dense1 = Dense(num_pixels,kernel_initializer='normal',activation='relu')(inpt1)
    
    inpt2 = Input(shape=(num_pixels,))
    dense2 = Dense(num_pixels,kernel_initializer='normal',activation='relu')(inpt2)
    
    concat_layer = concatenate([dense1, dense2], axis=-1)
    logits = Dense(num_classes, kernel_initializer='normal',activation='softmax')(concat_layer)
    model = Model([inpt1, inpt2], logits)
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

In [13]:
num_pixels = X_train.shape[1]
num_classes = 2

model_D = baseline_model(num_pixels=num_pixels, num_classes=num_classes)

In [10]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 784)          0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 784)          0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 784)          615440      input_1[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 784)          615440      input_2[0][0]                    
__________________________________________________________________________________________________
concatenat

In [25]:
def generator_model(input_dimension, hidden_size, discriminator):
    inpt1 = Input(shape=input_dimension)
    l = inpt1
    for hid in hidden_size:
        l = Dense(units=hid, kernel_initializer='normal',activation='relu')(l)
    model_d = clone_model(model=discriminator)
    for layer in model_d.layers:
        layer.trainable = False
        
    inpt2 = Input(shape=(784,))
    out = model_d([l, inpt2])
    model_gan = Model([inpt1, inpt2], out)
    model_gan.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
    
    model_G = Model(inpt1, l)
    return model_gan, model_G

In [None]:
def plot_images(self, save2file=False, samples=16, step=0):
    ''' Plot and generated images '''
    filename = "./images/mnist_%d.png" % step
    noise = np.random.uniform(high=1,low=0, size=(X_train.shape[0], 10))

    images = model_G.predict(noise)

    plt.figure(figsize=(10, 10))

    for i in range(images.shape[0]):
        plt.subplot(4, 4, i+1)
        image = images[i, :, :, :]
        image = np.reshape(image, [self.height, self.width])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.tight_layout()

    if save2file:
        plt.savefig(filename)
        plt.close('all')
    else:
        plt.show()

In [63]:
num_pixels = X_train.shape[1]
num_classes = 2

model_D = baseline_model(num_pixels=num_pixels, num_classes=num_classes)
model_GAN, model_G = generator_model((10,), [32, 784], model)
for e in range(40):
    print("Epoch: {}".format(e))
    Noise_x = np.random.uniform(high=1,low=0, size=(X_train.shape[0], 10))
    fake_image = model_G.predict(Noise_x)
    X_train_sync = np.concatenate([fake_image, X_train], axis=0)
    y_train_sync = np_utils.to_categorical(np.array([0]*fake_image.shape[0] + [1] * X_train.shape[0]))

    idx = np.random.choice(np.arange(X_train.shape[0]), replace=False, size=X_train.shape[0])
    X_train_ori = np.concatenate([X_train, X_train[idx]], axis=0)
    
    idx = np.random.choice(np.arange(X_train_sync.shape[0]), replace=False, size=X_train_sync.shape[0])
    X_train_sync = X_train_sync[idx, :]
    X_train_ori =  X_train_ori[idx, :]
    
    y_train_sync = y_train_sync[idx, :]
    
    for i in range(0, X_train_sync.shape[0], 64):
        loss, acc = model_D.train_on_batch(x=[X_train_sync[i:i + 64,:], X_train_ori[i:i+64, :]], y=y_train_sync[i:i+64])
        print(acc, end=", ")
    print()
    model_GAN.layers[4].set_weights(model_D.get_weights())

    for i in range(0, Noise_x.shape[0], 64):
        loss, acc = model_GAN.train_on_batch(x = [Noise_x[i:i+64,:], X_train[i:i+64,:]], y=np.array([[0,1]]*Noise_x[i:i+64,:].shape[0]))
        print(acc, end=", ")
    print()
    
    

Epoch: 0
0.40625, 0.5625, 0.875, 0.984375, 0.984375, 0.984375, 0.9375, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 
0.0, 0.0, 0.0, 0.0, 0.0, 0.0

0.484375, 0.53125, 0.578125, 0.9375, 0.953125, 0.890625, 0.796875, 0.953125, 0.90625, 0.875, 0.953125, 0.96875, 0.9375, 0.96875, 0.96875, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 0.984375, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.

0.4375, 0.578125, 0.53125, 0.5625, 0.671875, 0.984375, 0.96875, 0.953125, 0.9375, 0.96875, 0.859375, 0.84375, 0.828125, 0.859375, 0.90625, 0.890625, 0.921875, 0.953125, 0.953125, 0.984375, 0.984375, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1

0.484375, 0.46875, 0.515625, 0.734375, 0.96875, 0.953125, 0.953125, 0.875, 0.953125, 0.828125, 0.84375, 0.859375, 0.875, 0.9375, 0.84375, 0.890625, 0.921875, 0.9375, 0.984375, 0.96875, 0.96875, 0.984375, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1

0.609375, 0.53125, 0.421875, 0.609375, 0.953125, 1.0, 0.96875, 0.96875, 0.859375, 0.953125, 0.875, 0.84375, 0.9375, 0.921875, 0.96875, 0.875, 0.875, 0.90625, 0.984375, 0.984375, 0.890625, 0.96875, 0.984375, 0.921875, 0.890625, 0.9375, 0.96875, 0.953125, 0.984375, 0.96875, 0.984375, 1.0, 1.0, 0.984375, 0.96875, 1.0, 0.953125, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 0.984375, 1.0, 0.984375, 0.984375, 1.0, 0.953125, 1.0, 0.984375, 0.96875, 1.0, 1.0, 1.0, 1.0, 0.96875, 1.0, 1.0, 1.0, 0.984375, 1.0, 0.984375, 1.0, 1.0, 1.0, 0.984375, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 0.96875, 1.0, 0.984375, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1

0.5, 0.484375, 0.34375, 0.46875, 0.515625, 0.578125, 0.984375, 0.953125, 1.0, 0.9375, 0.890625, 0.953125, 0.859375, 0.890625, 0.953125, 0.953125, 0.890625, 0.875, 0.90625, 0.953125, 0.96875, 0.953125, 0.9375, 0.953125, 0.90625, 0.9375, 0.9375, 0.890625, 0.9375, 0.90625, 0.984375, 0.953125, 0.984375, 0.9375, 0.984375, 0.9375, 0.953125, 0.96875, 0.953125, 0.953125, 0.984375, 0.96875, 0.984375, 0.953125, 0.96875, 1.0, 1.0, 1.0, 0.984375, 1.0, 0.984375, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 0.984375, 0.96875, 1.0, 1.0, 1.0, 0.984375, 1.0, 0.953125, 1.0, 1.0, 0.984375, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.

0.5625, 0.453125, 0.546875, 0.4375, 0.515625, 0.53125, 0.484375, 0.453125, 0.53125, 0.53125, 0.546875, 0.484375, 0.390625, 0.640625, 0.59375, 0.65625, 0.796875, 0.984375, 0.875, 0.859375, 0.859375, 0.875, 0.859375, 0.875, 0.859375, 0.90625, 0.84375, 0.921875, 0.96875, 0.9375, 0.921875, 0.96875, 0.953125, 1.0, 0.984375, 0.9375, 0.984375, 0.921875, 0.953125, 0.921875, 0.9375, 0.921875, 0.984375, 0.984375, 0.96875, 0.984375, 1.0, 1.0, 1.0, 0.984375, 0.953125, 1.0, 0.96875, 1.0, 0.984375, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.984375, 1.0, 1.0, 1.0, 

0.5625, 0.46875, 0.53125, 0.421875, 0.515625, 0.5, 0.578125, 0.53125, 0.453125, 0.453125, 0.46875, 0.453125, 0.421875, 0.515625, 0.40625, 0.53125, 0.5, 0.46875, 0.484375, 0.484375, 0.421875, 0.5, 0.453125, 0.375, 0.5, 0.5625, 0.515625, 0.59375, 0.59375, 0.5, 0.515625, 0.5, 0.46875, 0.53125, 0.46875, 0.5, 0.390625, 0.453125, 0.484375, 0.515625, 0.53125, 0.5, 0.53125, 0.4375, 0.515625, 0.484375, 0.5625, 0.53125, 0.421875, 0.53125, 0.609375, 0.46875, 0.453125, 0.46875, 0.515625, 0.5625, 0.5, 0.578125, 0.515625, 0.46875, 0.59375, 0.625, 0.5, 0.5625, 0.453125, 0.609375, 0.484375, 0.46875, 0.578125, 0.546875, 0.546875, 0.46875, 0.390625, 0.46875, 0.4375, 0.421875, 0.484375, 0.53125, 0.40625, 0.53125, 0.375, 0.5625, 0.5625, 0.453125, 0.5, 0.515625, 0.484375, 0.5, 0.5, 0.484375, 0.484375, 0.46875, 0.59375, 0.5, 0.53125, 0.53125, 0.5625, 0.609375, 0.4375, 0.515625, 0.484375, 0.546875, 0.484375, 0.453125, 0.515625, 0.359375, 0.4375, 0.515625, 0.5, 0.4375, 0.5625, 0.484375, 0.453125, 0.546875, 0.

1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 
Epoch: 38
0.40625, 0.46875, 0.421875, 0.453125, 0.3125, 0.5625, 0.421875, 0.46875, 0.53125, 0.609375, 0.515625, 0.546875, 0.46875, 0.40625, 0.515625, 0.609375, 0.46875, 0.4375, 0.4375, 0.5625, 0.546875, 0.4375, 0.5, 0.5, 0.421875, 0.421875, 0.53125, 0.5625, 0.484375, 0.515625, 0.546875, 0.53125, 0.609375, 0.546875, 0.421875, 0.578125, 0.40625, 0.546875, 0.40625, 0.453125, 0.5, 0.546875, 0.40625, 0.453125, 0.421875, 0.453125, 0.46875, 0.390625, 0.5625, 0.546875, 0.421875, 0.59375, 0.359375, 0.578125, 0.59375, 0.421875, 0.53125, 0

In [42]:
tmp = model_D.get_weights()

In [48]:
model_GAN.layers[4].set_weights(tmp)

In [52]:
np.array([[0,1]]*Noise_x.shape[0])

array([[0, 1],
       [0, 1],
       [0, 1],
       ..., 
       [0, 1],
       [0, 1],
       [0, 1]])