In [13]:
#adapted from https://github.com/eriklindernoren/Keras-GAN

from __future__ import print_function, division

import os
from keras.preprocessing.image import load_img, img_to_array
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt

import sys

import numpy as np

In [14]:
class GAN():
    
    def __init__(self):
        #initialize size of target images, image channels, and input shape
        self.img_rows = 256
        self.img_cols = 256
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        
        #define what kind of optimizer to use
        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        #model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=3, sample_interval=50):

        # Load the dataset from saved image directory
        images = []
        for image_path in os.listdir('xray/data/val/pneumonia'):
            img = load_img('xray/data/val/pneumonia' + '/' + image_path, grayscale=True, 
                                target_size = (self.img_rows, self.img_cols))
            img = img_to_array(img)
            images.append(img)
        X_train = np.array(images)
        
        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        
        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5 
        noise = np.random.normal(0, 1, (r*c,self.latent_dim))
        gen_imgs = self.generator.predict(noise)
        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5
        
        for j in range(c):
            plt.imsave('gen_pneumonia/%d.png' % epoch, gen_imgs[j, :,:,0], cmap='gray')
        #fig, axs = plt.subplots(r, c)
        #cnt = 0
        #for i in range(r):
            #for j in range(c):
                #axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                #axs[i,j].axis('off')
                #cnt += 1
        #fig.savefig("images/%d.png" % epoch)
        plt.close()

In [15]:
if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=1200, batch_size=32, sample_interval=1)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_7 (Flatten)          (None, 65536)             0         
_________________________________________________________________
dense_43 (Dense)             (None, 512)               33554944  
_________________________________________________________________
leaky_re_lu_31 (LeakyReLU)   (None, 512)               0         
_________________________________________________________________
dense_44 (Dense)             (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_32 (LeakyReLU)   (None, 256)               0         
_________________________________________________________________
dense_45 (Dense)             (None, 1)                 257       
Total params: 33,686,529
Trainable params: 33,686,529
Non-trainable params: 0
________________________________________________________________

96 [D loss: 1.451041, acc.: 90.62%] [G loss: 14.479944]
97 [D loss: 2.994485, acc.: 71.88%] [G loss: 8.345007]
98 [D loss: 3.238767, acc.: 79.69%] [G loss: 8.044483]
99 [D loss: 2.772537, acc.: 81.25%] [G loss: 9.070171]
100 [D loss: 2.632624, acc.: 82.81%] [G loss: 8.508446]
101 [D loss: 2.525341, acc.: 82.81%] [G loss: 9.378815]
102 [D loss: 2.936045, acc.: 79.69%] [G loss: 8.157928]
103 [D loss: 1.487056, acc.: 87.50%] [G loss: 8.060826]
104 [D loss: 1.368590, acc.: 89.06%] [G loss: 12.391882]
105 [D loss: 0.608966, acc.: 95.31%] [G loss: 14.096930]
106 [D loss: 1.446750, acc.: 87.50%] [G loss: 14.611010]
107 [D loss: 0.382156, acc.: 96.88%] [G loss: 13.932178]
108 [D loss: 0.475317, acc.: 92.19%] [G loss: 12.593159]
109 [D loss: 1.320806, acc.: 87.50%] [G loss: 13.114674]
110 [D loss: 0.387377, acc.: 95.31%] [G loss: 15.624297]
111 [D loss: 2.036500, acc.: 71.88%] [G loss: 13.038494]
112 [D loss: 0.860162, acc.: 93.75%] [G loss: 11.620022]
113 [D loss: 1.262814, acc.: 90.62%] [G lo

240 [D loss: 0.996405, acc.: 93.75%] [G loss: 14.037325]
241 [D loss: 0.778781, acc.: 93.75%] [G loss: 15.110748]
242 [D loss: 0.937359, acc.: 85.94%] [G loss: 14.103333]
243 [D loss: 0.595234, acc.: 95.31%] [G loss: 12.981807]
244 [D loss: 0.000000, acc.: 100.00%] [G loss: 14.607023]
245 [D loss: 0.249100, acc.: 98.44%] [G loss: 14.103333]
246 [D loss: 0.695234, acc.: 95.31%] [G loss: 13.343109]
247 [D loss: 0.498200, acc.: 96.88%] [G loss: 14.103333]
248 [D loss: 0.363413, acc.: 96.88%] [G loss: 14.565953]
249 [D loss: 0.463868, acc.: 96.88%] [G loss: 14.400276]
250 [D loss: 0.498318, acc.: 95.31%] [G loss: 15.614405]
251 [D loss: 0.666181, acc.: 95.31%] [G loss: 15.110747]
252 [D loss: 0.004942, acc.: 100.00%] [G loss: 15.297179]
253 [D loss: 0.068668, acc.: 98.44%] [G loss: 15.322525]
254 [D loss: 0.112988, acc.: 96.88%] [G loss: 14.608152]
255 [D loss: 0.000002, acc.: 100.00%] [G loss: 14.722887]
256 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.472911]
257 [D loss: 0.000001, acc.

383 [D loss: 0.011084, acc.: 98.44%] [G loss: 15.614405]
384 [D loss: 0.078888, acc.: 98.44%] [G loss: 15.043044]
385 [D loss: 0.249100, acc.: 98.44%] [G loss: 16.118095]
386 [D loss: 0.249100, acc.: 98.44%] [G loss: 15.614405]
387 [D loss: 0.000006, acc.: 100.00%] [G loss: 15.614405]
388 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.614405]
389 [D loss: 0.473055, acc.: 96.88%] [G loss: 15.177720]
390 [D loss: 0.001785, acc.: 100.00%] [G loss: 15.110859]
391 [D loss: 0.019755, acc.: 100.00%] [G loss: 16.118095]
392 [D loss: 0.247104, acc.: 96.88%] [G loss: 16.118095]
393 [D loss: 0.957541, acc.: 79.69%] [G loss: 14.607023]
394 [D loss: 1.245516, acc.: 92.19%] [G loss: 14.103333]
395 [D loss: 1.135025, acc.: 92.19%] [G loss: 14.560685]
396 [D loss: 0.000034, acc.: 100.00%] [G loss: 15.614405]
397 [D loss: 0.251270, acc.: 98.44%] [G loss: 13.729662]
398 [D loss: 0.249100, acc.: 98.44%] [G loss: 15.614405]
399 [D loss: 1.004433, acc.: 92.19%] [G loss: 14.260523]
400 [D loss: 0.498202, acc

526 [D loss: 0.084741, acc.: 98.44%] [G loss: 16.118095]
527 [D loss: 0.000026, acc.: 100.00%] [G loss: 15.654880]
528 [D loss: 0.095068, acc.: 96.88%] [G loss: 16.118095]
529 [D loss: 0.794367, acc.: 82.81%] [G loss: 14.607023]
530 [D loss: 0.504196, acc.: 96.88%] [G loss: 15.110714]
531 [D loss: 0.498200, acc.: 96.88%] [G loss: 14.981442]
532 [D loss: 0.996399, acc.: 93.75%] [G loss: 14.607024]
533 [D loss: 0.747299, acc.: 95.31%] [G loss: 15.110714]
534 [D loss: 0.000242, acc.: 100.00%] [G loss: 14.787120]
535 [D loss: 0.498200, acc.: 96.88%] [G loss: 15.110714]
536 [D loss: 0.498200, acc.: 96.88%] [G loss: 16.118095]
537 [D loss: 0.249100, acc.: 98.44%] [G loss: 14.166262]
538 [D loss: 0.249100, acc.: 98.44%] [G loss: 15.111763]
539 [D loss: 0.000000, acc.: 100.00%] [G loss: 16.118095]
540 [D loss: 0.498200, acc.: 96.88%] [G loss: 15.614407]
541 [D loss: 0.569280, acc.: 95.31%] [G loss: 15.614405]
542 [D loss: 0.249637, acc.: 98.44%] [G loss: 15.110714]
543 [D loss: 0.498200, acc.:

669 [D loss: 0.000191, acc.: 100.00%] [G loss: 16.118095]
670 [D loss: 0.000933, acc.: 100.00%] [G loss: 16.118095]
671 [D loss: 0.000064, acc.: 100.00%] [G loss: 15.518837]
672 [D loss: 0.000109, acc.: 100.00%] [G loss: 15.614435]
673 [D loss: 0.028033, acc.: 98.44%] [G loss: 15.903028]
674 [D loss: 0.001378, acc.: 100.00%] [G loss: 16.118095]
675 [D loss: 0.007369, acc.: 100.00%] [G loss: 16.118095]
676 [D loss: 0.000138, acc.: 100.00%] [G loss: 15.915470]
677 [D loss: 0.000075, acc.: 100.00%] [G loss: 15.580343]
678 [D loss: 0.000053, acc.: 100.00%] [G loss: 15.663387]
679 [D loss: 0.000031, acc.: 100.00%] [G loss: 14.623150]
680 [D loss: 0.209157, acc.: 98.44%] [G loss: 15.617109]
681 [D loss: 0.347751, acc.: 95.31%] [G loss: 15.614779]
682 [D loss: 0.000005, acc.: 100.00%] [G loss: 15.112358]
683 [D loss: 0.000408, acc.: 100.00%] [G loss: 15.844849]
684 [D loss: 0.001813, acc.: 100.00%] [G loss: 16.118095]
685 [D loss: 0.000009, acc.: 100.00%] [G loss: 15.923180]
686 [D loss: 0.00

811 [D loss: 0.249100, acc.: 98.44%] [G loss: 16.118095]
812 [D loss: 0.157000, acc.: 98.44%] [G loss: 15.614405]
813 [D loss: 0.000806, acc.: 100.00%] [G loss: 16.118095]
814 [D loss: 0.296486, acc.: 96.88%] [G loss: 15.614405]
815 [D loss: 0.249100, acc.: 98.44%] [G loss: 15.110722]
816 [D loss: 0.000000, acc.: 100.00%] [G loss: 16.118095]
817 [D loss: 0.096915, acc.: 98.44%] [G loss: 16.118095]
818 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.614613]
819 [D loss: 0.000000, acc.: 100.00%] [G loss: 14.112394]
820 [D loss: 0.249100, acc.: 98.44%] [G loss: 15.614405]
821 [D loss: 0.000003, acc.: 100.00%] [G loss: 15.755081]
822 [D loss: 0.305779, acc.: 95.31%] [G loss: 16.118095]
823 [D loss: 0.274763, acc.: 96.88%] [G loss: 16.118095]
824 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.477009]
825 [D loss: 0.000000, acc.: 100.00%] [G loss: 16.118095]
826 [D loss: 0.000003, acc.: 100.00%] [G loss: 15.832184]
827 [D loss: 0.249100, acc.: 98.44%] [G loss: 15.614405]
828 [D loss: 0.192263, 

954 [D loss: 0.249100, acc.: 98.44%] [G loss: 14.103333]
955 [D loss: 1.743698, acc.: 89.06%] [G loss: 13.599644]
956 [D loss: 1.028568, acc.: 92.19%] [G loss: 14.103333]
957 [D loss: 0.747299, acc.: 95.31%] [G loss: 15.110714]
958 [D loss: 0.996399, acc.: 93.75%] [G loss: 14.103333]
959 [D loss: 1.021552, acc.: 92.19%] [G loss: 16.118095]
960 [D loss: 0.323703, acc.: 96.88%] [G loss: 15.614406]
961 [D loss: 0.000000, acc.: 100.00%] [G loss: 16.118095]
962 [D loss: 0.127805, acc.: 93.75%] [G loss: 15.110714]
963 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.110714]
964 [D loss: 0.498200, acc.: 96.88%] [G loss: 16.118095]
965 [D loss: 0.481034, acc.: 96.88%] [G loss: 16.118095]
966 [D loss: 0.249102, acc.: 98.44%] [G loss: 16.118095]
967 [D loss: 1.355927, acc.: 89.06%] [G loss: 15.113728]
968 [D loss: 0.000000, acc.: 100.00%] [G loss: 15.614405]
969 [D loss: 0.159365, acc.: 98.44%] [G loss: 16.118095]
970 [D loss: 0.000000, acc.: 100.00%] [G loss: 14.694711]
971 [D loss: 0.251845, acc.

1095 [D loss: 0.009343, acc.: 100.00%] [G loss: 16.118095]
1096 [D loss: 0.249100, acc.: 98.44%] [G loss: 16.118095]
1097 [D loss: 0.009943, acc.: 100.00%] [G loss: 16.118095]
1098 [D loss: 0.000000, acc.: 100.00%] [G loss: 16.118095]
1099 [D loss: 0.249982, acc.: 98.44%] [G loss: 16.118095]
1100 [D loss: 0.249100, acc.: 98.44%] [G loss: 16.118095]
1101 [D loss: 0.330366, acc.: 96.88%] [G loss: 15.614405]
1102 [D loss: 0.249100, acc.: 98.44%] [G loss: 16.118095]
1103 [D loss: 0.249100, acc.: 98.44%] [G loss: 15.762725]
1104 [D loss: 0.498200, acc.: 96.88%] [G loss: 15.110715]
1105 [D loss: 0.435304, acc.: 96.88%] [G loss: 15.273331]
1106 [D loss: 0.000003, acc.: 100.00%] [G loss: 15.614405]
1107 [D loss: 0.736581, acc.: 93.75%] [G loss: 15.614405]
1108 [D loss: 0.249100, acc.: 98.44%] [G loss: 16.118095]
1109 [D loss: 0.747299, acc.: 95.31%] [G loss: 16.118095]
1110 [D loss: 0.249100, acc.: 98.44%] [G loss: 15.614405]
1111 [D loss: 1.195842, acc.: 92.19%] [G loss: 15.110714]
1112 [D lo