# Training a simple GAN

In [3]:
%matplotlib inline
import sys
sys.path.append("/home/ubuntu/part2")
%matplotlib inline
import importlib
import utils2; importlib.reload(utils2)
from utils2 import *
import PIL
from scipy.optimize import fmin_l_bfgs_b
from scipy.misc import imsave
from keras import metrics
from vgg16_avg import VGG16_Avg
from IPython.display import SVG
from keras.utils.visualize_util import model_to_dot
from keras_tqdm import TQDMNotebookCallback
from keras.datasets import mnist
from tqdm import tqdm
from IPython import display

Using TensorFlow backend.


In [4]:
(X_train, y_train), (X_test, y_test) = mnist.load_data(path="mnist.pkl.gz")

In [5]:

img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

print(np.min(X_train), np.max(X_train))

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')

0.0 1.0
X_train shape: (60000, 28, 28, 1)
60000 train samples


In [6]:
y_train_labels = keras.utils.np_utils.to_categorical(y_train)
y_test = keras.utils.np_utils.to_categorical(y_test)

## Generator

In [7]:
nch = 200
shp = X_train.shape[1:]
dropout_rate = 0.25
opt = Adam(lr=1e-4)
dopt = Adam(lr=1e-3)

g_input = Input(shape=[100])
H = Dense(nch*14*14, init='glorot_normal')(g_input)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Reshape( [14, 14, nch] )(H)
H = UpSampling2D(size=(2, 2))(H)
H = Convolution2D(nch//2, 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Convolution2D(nch//4, 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Convolution2D(1, 1, 1, border_mode='same', init='glorot_uniform')(H)
g_V = Activation('sigmoid')(H)
#the input image is made up of sigmoid pixels as well, thus the loss makes sense.
generator = Model(g_input, g_V)
generator.compile(loss='binary_crossentropy', optimizer=opt)
generator.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (None, 100)           0                                            
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 39200)         3959200     input_1[0][0]                    
____________________________________________________________________________________________________
batchnormalization_1 (BatchNorma (None, 39200)         156800      dense_1[0][0]                    
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 39200)         0           batchnormalization_1[0][0]       
___________________________________________________________________________________________

# Discriminator

In [8]:

def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val

ip_node = Input(shape=((28,28,1)))
res = Reshape([28 * 28 * 1])(ip_node)
l1 = Dense(512, activation='relu')(res)
#disc_output = Dense(1, activation='sigmoid')(l1)
disc_output = Dense(1)(l1)
discriminator = Model(ip_node, disc_output)
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
discriminator.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_2 (InputLayer)             (None, 28, 28, 1)     0                                            
____________________________________________________________________________________________________
reshape_2 (Reshape)              (None, 784)           0           input_2[0][0]                    
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 512)           401920      reshape_2[0][0]                  
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 1)             513         dense_2[0][0]                    
Total params: 402,433
Trainable params: 402,433
Non-trainable params: 0
___________________

In [9]:
gan_input = Input(shape=[100])
generated_output = generator(gan_input)
discriminators_thinks = discriminator(generated_output)
gan = Model(gan_input, discriminators_thinks)
gan.compile(loss='binary_crossentropy', optimizer='adam')
gan.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_3 (InputLayer)             (None, 100)           0                                            
____________________________________________________________________________________________________
model_1 (Model)                  (None, 28, 28, 1)     4341801     input_3[0][0]                    
____________________________________________________________________________________________________
model_2 (Model)                  (None, 1)             402433      model_1[1][0]                    
Total params: 4,744,234
Trainable params: 4,665,534
Non-trainable params: 78,700
____________________________________________________________________________________________________


In [10]:
def plot_loss(losses):
    display.clear_output(wait=True)
    display.display(plt.gcf())
    plt.figure(figsize=(10,8))
    plt.plot(losses["d"], label='discriminitive loss')
    plt.plot(losses["g"], label='generative loss')
    plt.legend()
    plt.show()

In [11]:
def plot_gen(n_ex=6,dim=(4, 4), figsize=(10,10) ):
    noise = np.random.uniform(0,1,size=[n_ex,100])
    generated_images = generator.predict(noise)
    print(generated_images.shape)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0],dim[1], i + 1)
        img = generated_images[i,:,:, 0]
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [12]:

ntrain = 10000
trainidx = random.sample(range(0,X_train.shape[0]), ntrain)
XT = X_train[trainidx,:,:,:]

# Pre-train the discriminator network ...
noise_gen = np.random.uniform(0,1,size=[XT.shape[0],100])
generated_images = generator.predict(noise_gen)
print(XT.shape)
print(generated_images.shape)
X = np.concatenate((XT, generated_images))
n = XT.shape[0]
y = np.vstack([np.ones((n,1)), np.zeros((n,1))])
#y = np.zeros([2*n,2])
#y[:n,1] = 1
#y[n:,0] = 1
print(y.shape)
make_trainable(discriminator,True)
discriminator.fit(X,y, nb_epoch=1, batch_size=128)
y_hat = discriminator.predict(X)

(10000, 28, 28, 1)
(10000, 28, 28, 1)
(20000, 1)
Epoch 1/1


In [13]:
y_hat_idx = np.argmax(y_hat,axis=1)
y_idx = np.argmax(y,axis=1)
diff = y_idx-y_hat_idx
n_tot = y.shape[0]
n_rig = (diff==0).sum()
acc = n_rig*100.0/n_tot
print("Accuracy: %0.02f pct (%d of %d) right"%(acc, n_rig, n_tot))

Accuracy: 100.00 pct (20000 of 20000) right


In [14]:
# set up loss storage vector
losses = {"d":[], "g":[]}

In [15]:
def train_for_n(nb_epoch=5000, plt_frq=25,BATCH_SIZE=32):

    for e in tqdm(range(nb_epoch)):  
        
        # Make generative images
        image_batch = X_train[np.random.randint(0,X_train.shape[0],size=BATCH_SIZE),:,:,:]    
        noise_gen = np.random.uniform(0,1,size=[BATCH_SIZE,100])
        generated_images = generator.predict(noise_gen)
        
        # Train discriminator on generated images
        X = np.concatenate((image_batch, generated_images))
        #y = np.zeros([2*BATCH_SIZE,2])
        #y[0:BATCH_SIZE,1] = 1
        #y[BATCH_SIZE:,0] = 1
        y = np.vstack([np.ones((BATCH_SIZE,1)), np.zeros((BATCH_SIZE,1))])

        make_trainable(discriminator,True)
        d_loss  = discriminator.train_on_batch(X, y)
        losses["d"].append(d_loss)
        # train Generator-Discriminator stack on input noise to non-generated output class
        noise_tr = np.random.uniform(0,1,size=[BATCH_SIZE,100])
        #y2 = np.zeros([BATCH_SIZE,2])
        #y2[:,1] = 1
        y2 = np.vstack([np.ones((BATCH_SIZE,1))])

        make_trainable(discriminator,False)
        g_loss = gan.train_on_batch(noise_tr, y2 )
        losses["g"].append(g_loss)
        
        # Updates plots
        if e%plt_frq==plt_frq-1:
            #plot_loss(losses)
            plot_gen()


In [16]:
%%html
<style>
.output_wrapper, .output {
    height:auto !important;
    max-height:1000px;  /* your desired max-height here */
}
.output_scroll {
    box-shadow:none !important;
    webkit-box-shadow:none !important;
}
</style>

In [21]:
train_for_n(nb_epoch=6000, plt_frq=1000,BATCH_SIZE=128)



  0%|          | 0/6000 [00:00<?, ?it/s][A
  0%|          | 1/6000 [00:01<2:40:26,  1.60s/it][A
  0%|          | 2/6000 [00:02<2:05:48,  1.26s/it][A
  0%|          | 3/6000 [00:02<1:33:17,  1.07it/s][A
  0%|          | 4/6000 [00:02<1:10:24,  1.42it/s][A
  0%|          | 5/6000 [00:02<54:26,  1.84it/s]  [A
  0%|          | 6/6000 [00:02<43:13,  2.31it/s][A
  0%|          | 7/6000 [00:02<35:23,  2.82it/s][A
  0%|          | 8/6000 [00:03<29:54,  3.34it/s][A
  0%|          | 9/6000 [00:03<26:02,  3.84it/s][A
  0%|          | 10/6000 [00:03<23:14,  4.30it/s][A
  0%|          | 11/6000 [00:03<21:18,  4.68it/s][A
  0%|          | 12/6000 [00:03<19:58,  5.00it/s][A
  0%|          | 13/6000 [00:03<19:03,  5.24it/s][A
  0%|          | 14/6000 [00:04<22:44,  4.39it/s][A
  7%|▋         | 437/6000 [02:43<53:11,  1.74it/s]

KeyboardInterrupt: 

In [18]:
print("this")

this
