In [19]:

from __future__ import print_function, division

from keras.datasets import mnist
from keras.datasets import cifar10
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, merge
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.layers import Concatenate, Dense, LSTM, Input, concatenate
from keras.optimizers import RMSprop
import keras.backend as K
 

import tensorflow as tf
from scipy.misc import imread, imsave

import matplotlib.pyplot as plt

import sys
import os
from PIL import Image
from glob import glob

import numpy as np
from keras.models import load_model


In [20]:
class WGAN():
    
    
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 3
        self.img_shape = (1,self.img_rows*self.img_cols*self.channels)

        # Following parameter and optimizer set as recommended in paper
        self.n_critic = 5
        self.clip_value = 0.01
        optimizer = RMSprop(lr=0.00005)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=self.wasserstein_loss,
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()
       

        # The generator takes noise as input and generated imgs
        z = Input(shape=(4596,))
        img = self.generator(z)
        
        img_merge = concatenate([img,z], axis=1)
        print ("img_merge")
        print (img_merge)
        print (img_merge.shape)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The valid takes generated images as input and determines validity
        valid = self.discriminator(img_merge)

        # The combined model  (stacked generator and discriminator) takes
        # noise as input => generates images => determines validity
        self.combined = Model(z, valid)
        self.combined.compile(loss=self.wasserstein_loss, optimizer=optimizer, metrics=['accuracy'])
        
    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)


    def build_generator(self):

        noise_shape = (4596,)

        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation="relu", input_shape=noise_shape))
        model.add(Reshape((7, 7, 128)))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=4, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=4, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
        model.add(Activation("tanh"))
        
        #model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        #img = Concatenate([img,noise])
        
        return Model(noise, img)
 

    def build_discriminator(self):

        #img_shape = (self.img_rows, self.img_cols, self.channels)
        img_shape = (self.img_rows*self.img_cols*self.channels+4596,)

        model = Sequential()

        model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        
        
        model.summary()

        img = Input(shape=img_shape)
        features = model(img)
        valid = Dense(1, activation="linear")(features)

        return Model(img, valid)

    def get_image(self, image_path, width, height, mode):

        image = Image.open(image_path)
        # image = image.resize([width, height], Image.BILINEAR)
        if image.size != (width, height):
        # Remove most pixels that aren't part of a face
            face_width = face_height = 108
            j = (image.size[0] - face_width) // 2
            i = (image.size[1] - face_height) // 2
            image = image.crop([j, i, j + face_width, i + face_height])
            image = image.resize([width, height])

        return np.array(image.convert(mode))

    def get_batch(self, image_files, width, height, mode):
        data_batch = np.array(
            [self.get_image(sample_file, width, height, mode) for sample_file in image_files])

        return data_batch

    def train(self, epochs, batch_size=128, save_interval=50):

        X_train_green_fr = self.get_batch(glob(os.path.join('data_final/green_fruit', '*'))[:490], 28, 28, 'RGB')
        X_train_purple_fr = self.get_batch(glob(os.path.join('data_final/purple_fruit', '*'))[:490], 28, 28, 'RGB')
        X_train_red_fl = self.get_batch(glob(os.path.join('data_final/red_flower', '*'))[:322], 28, 28, 'RGB')
        X_train_yellow_fl = self.get_batch(glob(os.path.join('data_final/yellow_flower', '*'))[:490], 28, 28, 'RGB')
        
        
        
        
        
        
        
        y_train_green_fr = np.genfromtxt('embeddings/greenpear.csv',delimiter=",").reshape(1,4096);
        y_train_purple_fr = np.genfromtxt('embeddings/purplepear.csv',delimiter=",").reshape(1,4096); 
        y_train_red_fl = np.genfromtxt('embeddings/redbanana.csv',delimiter=",").reshape(1,4096);
        y_train_yellow_fl = np.genfromtxt('embeddings/yellowbanana.csv',delimiter=",").reshape(1,4096); 
        
     
        
        
        
        
        
        #print(X_train_blue_fl.shape)
        # Rescale -1 to 1
        X_train_green_fr = (X_train_green_fr.astype(np.float32) - 127.5) / 127.5
        X_train_purple_fr = (X_train_purple_fr.astype(np.float32) - 127.5) / 127.5
        X_train_red_fl = (X_train_red_fl.astype(np.float32) - 127.5) / 127.5
        X_train_yellow_fl = (X_train_yellow_fl.astype(np.float32) - 127.5) / 127.5
        # Adversarial ground truth
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            
            for _ in range(self.n_critic):
            
            
            
                # ---------------------
                #  Train Discriminator
                # ---------------------
            
            
            
            
                G_repeated = np.asarray([y_train_green_fr,]*int(batch_size/4))
                P_repeated = np.asarray([y_train_purple_fr,]*int(batch_size/4))
                R_repeated = np.asarray([y_train_red_fl,]*int(batch_size/4))
                Y_repeated = np.asarray([y_train_yellow_fl,]*int(batch_size/4))
            
            
            
            
            

                # Sample a random batch of images
                idx = np.random.randint(0, X_train_green_fr.shape[0], int(batch_size/4))
                imgs_green_fr = X_train_green_fr[idx]
            
                #appending with embedding blue flowers
                print (imgs_green_fr.shape)
                print (G_repeated.shape)
            
                #reshaping
                imgs_green_fr = imgs_green_fr.reshape(int(batch_size/4),28*28*3)
                G_repeated = G_repeated.reshape(int(batch_size/4),4096)

                imgs_green_fr_with_emb = np.concatenate((imgs_green_fr, G_repeated),axis=1)
            
            
            
            
            
                # Sample a random batch of images
                idx = np.random.randint(0, X_train_purple_fr.shape[0], int(batch_size/4))
                imgs_purple_fr = X_train_purple_fr[idx]
            
                #reshaping
                imgs_purple_fr = imgs_purple_fr.reshape(int(batch_size/4),28*28*3)
                P_repeated = P_repeated.reshape(int(batch_size/4),4096)
            
                #appending with embedding purple fruit
                imgs_purple_fr_with_emb = np.concatenate((imgs_purple_fr, P_repeated),axis=1)
            
            
            
            
            
            
            
                # Sample a random batch of images
                idx = np.random.randint(0, X_train_red_fl.shape[0], int(batch_size/4))
                imgs_red_fl = X_train_red_fl[idx]
            
                #appending with embedding blue flowers
            
                #reshaping
                imgs_red_fl = imgs_red_fl.reshape(int(batch_size/4),28*28*3)
                R_repeated = R_repeated.reshape(int(batch_size/4),4096)

                imgs_red_fl_with_emb = np.concatenate((imgs_red_fl, R_repeated),axis=1)
            
            
            
            
            
            
            
                # Sample a random batch of images
                idx = np.random.randint(0, X_train_yellow_fl.shape[0], int(batch_size/4))
                imgs_yellow_fl = X_train_yellow_fl[idx]
            
                #reshaping
                imgs_yellow_fl = imgs_yellow_fl.reshape(int(batch_size/4),28*28*3)
                Y_repeated = Y_repeated.reshape(int(batch_size/4),4096)
            
                #appending with embedding yellow fruit
                imgs_yellow_fl_with_emb = np.concatenate((imgs_yellow_fl, Y_repeated),axis=1)
            
            
            
                imgs = np.concatenate((imgs_green_fr_with_emb, imgs_purple_fr_with_emb),axis=0)
                imgs = np.concatenate((imgs, imgs_red_fl_with_emb),axis=0)
                imgs = np.concatenate((imgs, imgs_yellow_fl_with_emb),axis=0)
            
            
                #Generating noise
                noise = np.random.normal(0, 1, (batch_size, 500))
                noise = noise.reshape(batch_size,500)
            
            
                imgs = np.concatenate((imgs, noise),axis=1)
            
            
                # Sample generator input
                #noise = np.random.normal(0, 1, (batch_size, 100))
            
            
                #Generating noise
                noise = np.random.normal(0, 1, (batch_size, 500))
                noise = noise.reshape(batch_size,500)

                text_embed = np.concatenate((G_repeated,P_repeated),axis=0)
                text_embed = np.concatenate((text_embed,R_repeated),axis=0)
                text_embed = np.concatenate((text_embed,Y_repeated),axis=0)  
            
            
                text_embed = np.concatenate((text_embed, noise),axis=1)
            
            
                # Generate a batch of new images
                gen_imgs = self.generator.predict(text_embed)
            
            
                print (gen_imgs.shape)
                print (text_embed.shape)
            
                gen_imgs = gen_imgs.reshape(batch_size, 28*28*3)
                print (text_embed.shape)
            
                #concatenate embeddings wit the generated images
                gen_imgs = np.concatenate((gen_imgs, text_embed),axis=1)
            

                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch(imgs, -np.ones((batch_size, 1)))
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.ones((batch_size, 1)))
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                
                
                # Clip discriminator weights
                for l in self.discriminator.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
                    l.set_weights(weights)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Sample generator input
            #noise = np.random.normal(0, 1, (batch_size, 100))
            # Train the generator
            g_loss = self.combined.train_on_batch(text_embed, -np.ones((batch_size, 1)))

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, 1-d_loss[0], 100*d_loss[1], 1-g_loss[0]))
            file = open ("128_batch","w+")
            file.write ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, 1-d_loss[0], 100*d_loss[1], 1-g_loss[0]))
            

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch, text_embed)
                
            

    def save_imgs(self, epoch, text_embed):
        r, c = 4, 8
        #noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(text_embed)
        
        #dimensions
        print(gen_imgs.shape)
        gen_imgs = gen_imgs.reshape(32,28,28,3) #hardcoded batch-size 32

        # Rescale images 0 - 1
        gen_imgs = (1/2.5) * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,:])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("WGAN_output_banana_pear/%d.png" % epoch)
        plt.close()

In [21]:
if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=10000, batch_size=32, save_interval=50)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_32 (Dense)             (None, 512)               3557888   
_________________________________________________________________
leaky_re_lu_23 (LeakyReLU)   (None, 512)               0         
_________________________________________________________________
dense_33 (Dense)             (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_24 (LeakyReLU)   (None, 256)               0         
_________________________________________________________________
dense_34 (Dense)             (None, 1)                 257       
Total params: 3,689,473
Trainable params: 3,689,473
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   


  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.966238, acc.: 45.31%] [G loss: 1.520115]
(32, 2352)
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
1 [D loss: 1.196330, acc.: 25.00%] [G loss: 1.440716]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
2 [D loss: 1.232466, acc.: 20.31%] [G loss: 1.389605]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
3 [D loss: 1.268457, acc.: 12.50%] [G loss: 1.351772]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
4 [D loss: 1.289591, acc.: 3.12%] [G loss: 1.313580]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
5 [D loss: 1.298710, acc.: 1.56%] [G loss: 1.290181]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
6 [D loss: 1.312957, acc.: 4.69%] [G loss: 1.279313]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
7 [D loss: 1.334416, acc.: 0.00%] [G loss: 1.215967]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
8 [D loss: 1.342258, acc.: 0.00%] [G loss: 1.228833]
(8, 28, 28, 3)
(8, 1

72 [D loss: 1.499873, acc.: 0.00%] [G loss: 1.000192]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
73 [D loss: 1.499860, acc.: 0.00%] [G loss: 1.000205]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
74 [D loss: 1.499899, acc.: 0.00%] [G loss: 1.000145]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
75 [D loss: 1.499875, acc.: 0.00%] [G loss: 1.000154]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
76 [D loss: 1.499915, acc.: 0.00%] [G loss: 1.000146]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
77 [D loss: 1.499914, acc.: 0.00%] [G loss: 1.000127]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
78 [D loss: 1.499936, acc.: 0.00%] [G loss: 1.000100]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
79 [D loss: 1.499930, acc.: 0.00%] [G loss: 1.000097]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
80 [D loss: 1.499940, acc.: 0.00%] [G loss: 1.000100]
(8, 28, 28, 3)
(8, 1, 4096

143 [D loss: 1.499999, acc.: 0.00%] [G loss: 1.000001]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
144 [D loss: 1.499999, acc.: 0.00%] [G loss: 1.000001]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
145 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000001]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
146 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000001]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
147 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000001]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
148 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000001]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
149 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000001]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
150 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000001]
(32, 2352)
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
151 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000001]
(8, 28

214 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000000]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
215 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000000]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
216 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000000]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
217 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000000]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
218 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000000]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
219 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000000]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
220 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000000]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
221 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000000]
(8, 28, 28, 3)
(8, 1, 4096)
(32, 2352)
(32, 4596)
(32, 4596)
222 [D loss: 1.500000, acc.: 0.00%] [G loss: 1.000000]
(8, 28, 28, 3)
(8

KeyboardInterrupt: 