In [1]:
from __future__ import print_function

from keras import backend as K
K.set_image_dim_ordering('th') # ensure our dimension notation matches

from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Convolution2D, AveragePooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD, Adam
from keras.datasets import mnist
from keras import utils
import numpy as np
from PIL import Image, ImageOps
import argparse
import math

import os
import os.path

import glob

Using TensorFlow backend.


In [29]:
def generator_model():
    model = Sequential()
    model.add(Dense(input_dim=100, output_dim=512))
    model.add(Activation('softsign'))
    model.add(Dense(128*8*8))
    model.add(BatchNormalization())
    model.add(Activation('softsign'))
    model.add(Reshape((128, 8, 8), input_shape=(128*8*8,)))
    model.add(UpSampling2D(size=(4, 4)))
    model.add(Convolution2D(64, 5, 5, border_mode='same'))
    model.add(Activation('softsign'))
    model.add(UpSampling2D(size=(4, 4)))
    model.add(Convolution2D(1, 5, 5, border_mode='same'))
    model.add(Activation('softsign'))
    return model


def discriminator_model():
    model = Sequential()
    model.add(Convolution2D(
                        64, 5, 5,
                        border_mode='same',
                        input_shape=(1, 128, 128)))
    model.add(Activation('softsign'))
    model.add(AveragePooling2D(pool_size=(4, 4)))
    model.add(Convolution2D(128, 5, 5))
    model.add(Activation('softsign'))
    model.add(AveragePooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(256))
    model.add(Activation('softsign'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model


def generator_containing_discriminator(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model


def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[2:]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[0, :, :]
    return image

In [30]:
model = generator_model()
print(model.summary())

  This is separate from the ipykernel package so we can avoid doing imports until
  # Remove the CWD from sys.path while we load stuff.
  del sys.path[0]


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_59 (Dense)             (None, 512)               51712     
_________________________________________________________________
activation_117 (Activation)  (None, 512)               0         
_________________________________________________________________
dense_60 (Dense)             (None, 8192)              4202496   
_________________________________________________________________
batch_normalization_19 (Batc (None, 8192)              32768     
_________________________________________________________________
activation_118 (Activation)  (None, 8192)              0         
_________________________________________________________________
reshape_19 (Reshape)         (None, 128, 8, 8)         0         
_________________________________________________________________
up_sampling2d_37 (UpSampling (None, 128, 32, 32)       0         
__________

In [31]:
def load_data(pixels=128, verbose=False):
    print("Loading data")
    X_train = []
    paths = glob.glob(os.path.normpath(os.getcwd() + '/jaffe/*.tiff'))
    for path in paths:
        if verbose: print(path)
        im = Image.open(path)
        im = ImageOps.fit(im, (pixels, pixels), Image.ANTIALIAS)
        im = ImageOps.grayscale(im)
        #im.show()
        im = np.asarray(im)
        X_train.append(im)
    print("Finished loading data")
    return np.array(X_train)

def train(epochs, BATCH_SIZE, weights=False):
    """
    :param epochs: Train for this many epochs
    :param BATCH_SIZE: Size of minibatch
    :param weights: If True, load weights from file, otherwise train the model from scratch. 
    Use this if you have already saved state of the network and want to train it further.
    """
    X_train = load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
    discriminator = discriminator_model()
    generator = generator_model()
    if weights:
        generator.load_weights('goodgenerator.h5')
        discriminator.load_weights('gooddiscriminator.h5')
    discriminator_on_generator = \
        generator_containing_discriminator(generator, discriminator)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    generator.compile(loss='binary_crossentropy', optimizer="SGD")
    discriminator_on_generator.compile(
        loss='binary_crossentropy', optimizer=g_optim)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)
    noise = np.zeros((BATCH_SIZE, 100))
    for epoch in range(epochs):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            generated_images = generator.predict(noise, verbose=0)
            #print(generated_images.shape)
            if index % 20 == 0 and epoch % 10 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                destpath = os.path.normpath(os.getcwd()+ "/jaffe-generated-images/"+str(epoch)+"_"+str(index)+".png")
                Image.fromarray(image.astype(np.uint8)).save(destpath)
            X = np.concatenate((image_batch, generated_images))
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            discriminator.trainable = False
            g_loss = discriminator_on_generator.train_on_batch(
                noise, [1] * BATCH_SIZE)
            discriminator.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss))
            if epoch % 10 == 9:
                generator.save_weights('goodgenerator.h5', True)
                discriminator.save_weights('gooddiscriminator.h5', True)

def clean(image):
    for i in range(1, image.shape[0] - 1):
        for j in range(1, image.shape[1] - 1):
            if image[i][j] + image[i+1][j] + image[i][j+1] + image[i-1][j] + image[i][j-1] > 127 * 5:
                image[i][j] = 255
    return image
def generate(BATCH_SIZE):
    generator = generator_model()
    generator.compile(loss='binary_crossentropy', optimizer="SGD")
    generator.load_weights('goodgenerator.h5')
    noise = np.zeros((BATCH_SIZE, 100))
    a = np.random.uniform(-1, 1, 100)
    b = np.random.uniform(-1, 1, 100)
    grad = (b - a) / BATCH_SIZE
    for i in range(BATCH_SIZE):
        noise[i, :] = np.random.uniform(-1, 1, 100)
    generated_images = generator.predict(noise, verbose=1)
    #image = combine_images(generated_images)
    print(generated_images.shape)
    for image in generated_images:
        image = image[0]
        image = image*127.5+127.5
        Image.fromarray(image.astype(np.uint8)).save("dirty.png")
        Image.fromarray(image.astype(np.uint8)).show()
        clean(image)
        image = Image.fromarray(image.astype(np.uint8))
        image.show()        
        image.save("clean.png")


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--nice", dest="nice", action="store_true")
    parser.set_defaults(nice=False)
    args = parser.parse_args()
    return args


In [34]:
train(40,10,False)

Loading data
Finished loading data


  This is separate from the ipykernel package so we can avoid doing imports until
  # Remove the CWD from sys.path while we load stuff.
  del sys.path[0]


Epoch is 0
Number of batches 21
batch 0 d_loss : 0.685145
batch 0 g_loss : 0.690927
batch 1 d_loss : 0.674899
batch 1 g_loss : 0.693193
batch 2 d_loss : 0.674612
batch 2 g_loss : 0.691920
batch 3 d_loss : 0.673626
batch 3 g_loss : 0.696336
batch 4 d_loss : 0.647582
batch 4 g_loss : 0.697387
batch 5 d_loss : 0.629557
batch 5 g_loss : 0.700939
batch 6 d_loss : 0.617320
batch 6 g_loss : 0.695952
batch 7 d_loss : 0.625635
batch 7 g_loss : 0.698420
batch 8 d_loss : 0.616364
batch 8 g_loss : 0.695661
batch 9 d_loss : 0.594169
batch 9 g_loss : 0.711157
batch 10 d_loss : 0.574421
batch 10 g_loss : 0.700526
batch 11 d_loss : 0.562472
batch 11 g_loss : 0.703684
batch 12 d_loss : 0.545179
batch 12 g_loss : 0.709139
batch 13 d_loss : 0.537994
batch 13 g_loss : 0.713015
batch 14 d_loss : 0.524559
batch 14 g_loss : 0.715892
batch 15 d_loss : 0.503633
batch 15 g_loss : 0.707399
batch 16 d_loss : 0.490070
batch 16 g_loss : 0.710030
batch 17 d_loss : 0.482451
batch 17 g_loss : 0.708531
batch 18 d_loss 

batch 3 d_loss : 0.135968
batch 3 g_loss : 2.573383
batch 4 d_loss : 0.101697
batch 4 g_loss : 2.770319
batch 5 d_loss : 0.085295
batch 5 g_loss : 2.845725
batch 6 d_loss : 0.112030
batch 6 g_loss : 2.852346
batch 7 d_loss : 0.262572
batch 7 g_loss : 3.196014
batch 8 d_loss : 0.227890
batch 8 g_loss : 3.249223
batch 9 d_loss : 0.125260
batch 9 g_loss : 3.223820
batch 10 d_loss : 0.115015
batch 10 g_loss : 3.256944
batch 11 d_loss : 0.117754
batch 11 g_loss : 3.465155
batch 12 d_loss : 0.100023
batch 12 g_loss : 3.388020
batch 13 d_loss : 0.090388
batch 13 g_loss : 3.219933
batch 14 d_loss : 0.060092
batch 14 g_loss : 3.352074
batch 15 d_loss : 0.066594
batch 15 g_loss : 3.486783
batch 16 d_loss : 0.063537
batch 16 g_loss : 3.449914
batch 17 d_loss : 0.098069
batch 17 g_loss : 3.280754
batch 18 d_loss : 0.069415
batch 18 g_loss : 3.432607
batch 19 d_loss : 0.042505
batch 19 g_loss : 3.438237
batch 20 d_loss : 0.037244
batch 20 g_loss : 3.227122
Epoch is 8
Number of batches 21
batch 0 d_

batch 6 g_loss : 3.066665
batch 7 d_loss : 0.045399
batch 7 g_loss : 2.876821
batch 8 d_loss : 0.069618
batch 8 g_loss : 3.345383
batch 9 d_loss : 0.065584
batch 9 g_loss : 3.200198
batch 10 d_loss : 0.067961
batch 10 g_loss : 3.108772
batch 11 d_loss : 0.071993
batch 11 g_loss : 2.971527
batch 12 d_loss : 0.055393
batch 12 g_loss : 2.944953
batch 13 d_loss : 0.041920
batch 13 g_loss : 3.105916
batch 14 d_loss : 0.039301
batch 14 g_loss : 3.190125
batch 15 d_loss : 0.034472
batch 15 g_loss : 2.489513
batch 16 d_loss : 0.035991
batch 16 g_loss : 2.847138
batch 17 d_loss : 0.097699
batch 17 g_loss : 3.108871
batch 18 d_loss : 0.061417
batch 18 g_loss : 2.970307
batch 19 d_loss : 0.029285
batch 19 g_loss : 3.339906
batch 20 d_loss : 0.023535
batch 20 g_loss : 2.978132
Epoch is 15
Number of batches 21
batch 0 d_loss : 0.066378
batch 0 g_loss : 2.713145
batch 1 d_loss : 0.051719
batch 1 g_loss : 3.025420
batch 2 d_loss : 0.069604
batch 2 g_loss : 2.833320
batch 3 d_loss : 0.079663
batch 3 g

batch 10 d_loss : 0.126830
batch 10 g_loss : 2.349583
batch 11 d_loss : 0.187480
batch 11 g_loss : 2.737910
batch 12 d_loss : 0.130521
batch 12 g_loss : 1.895429
batch 13 d_loss : 0.262095
batch 13 g_loss : 1.926732
batch 14 d_loss : 0.064954
batch 14 g_loss : 1.759877
batch 15 d_loss : 0.052325
batch 15 g_loss : 2.308745
batch 16 d_loss : 0.111801
batch 16 g_loss : 2.250788
batch 17 d_loss : 0.106684
batch 17 g_loss : 2.219728
batch 18 d_loss : 0.078567
batch 18 g_loss : 2.795258
batch 19 d_loss : 0.086028
batch 19 g_loss : 3.247577
batch 20 d_loss : 0.080594
batch 20 g_loss : 2.591826
Epoch is 22
Number of batches 21
batch 0 d_loss : 0.211883
batch 0 g_loss : 2.242903
batch 1 d_loss : 0.075835
batch 1 g_loss : 2.311716
batch 2 d_loss : 0.232151
batch 2 g_loss : 2.049137
batch 3 d_loss : 0.154910
batch 3 g_loss : 2.651480
batch 4 d_loss : 0.094079
batch 4 g_loss : 2.599940
batch 5 d_loss : 0.067444
batch 5 g_loss : 2.257005
batch 6 d_loss : 0.069108
batch 6 g_loss : 2.583817
batch 7 d

batch 13 g_loss : 2.555515
batch 14 d_loss : 0.114756
batch 14 g_loss : 1.700557
batch 15 d_loss : 0.144039
batch 15 g_loss : 2.608939
batch 16 d_loss : 0.099484
batch 16 g_loss : 2.180799
batch 17 d_loss : 0.092807
batch 17 g_loss : 2.102007
batch 18 d_loss : 0.104039
batch 18 g_loss : 2.255781
batch 19 d_loss : 0.133798
batch 19 g_loss : 2.674006
batch 20 d_loss : 0.081096
batch 20 g_loss : 2.902783
Epoch is 29
Number of batches 21
batch 0 d_loss : 0.287160
batch 0 g_loss : 2.633706
batch 1 d_loss : 0.071384
batch 1 g_loss : 1.960870
batch 2 d_loss : 0.252204
batch 2 g_loss : 2.244709
batch 3 d_loss : 0.251613
batch 3 g_loss : 2.143317
batch 4 d_loss : 0.164761
batch 4 g_loss : 1.822076
batch 5 d_loss : 0.238356
batch 5 g_loss : 2.494422
batch 6 d_loss : 0.101031
batch 6 g_loss : 2.442804
batch 7 d_loss : 0.132979
batch 7 g_loss : 1.733963
batch 8 d_loss : 0.146895
batch 8 g_loss : 2.103508
batch 9 d_loss : 0.267755
batch 9 g_loss : 2.177410
batch 10 d_loss : 0.138162
batch 10 g_loss

batch 17 d_loss : 0.115459
batch 17 g_loss : 1.832267
batch 18 d_loss : 0.128846
batch 18 g_loss : 2.943428
batch 19 d_loss : 0.138254
batch 19 g_loss : 3.274588
batch 20 d_loss : 0.104658
batch 20 g_loss : 2.613083
Epoch is 36
Number of batches 21
batch 0 d_loss : 0.254215
batch 0 g_loss : 2.302601
batch 1 d_loss : 0.047602
batch 1 g_loss : 2.201018
batch 2 d_loss : 0.279822
batch 2 g_loss : 2.426237
batch 3 d_loss : 0.131265
batch 3 g_loss : 1.979366
batch 4 d_loss : 0.193815
batch 4 g_loss : 1.450736
batch 5 d_loss : 0.252733
batch 5 g_loss : 2.169759
batch 6 d_loss : 0.198566
batch 6 g_loss : 2.656620
batch 7 d_loss : 0.074610
batch 7 g_loss : 2.241168
batch 8 d_loss : 0.083027
batch 8 g_loss : 1.663472
batch 9 d_loss : 0.116924
batch 9 g_loss : 2.716230
batch 10 d_loss : 0.146196
batch 10 g_loss : 2.070404
batch 11 d_loss : 0.202434
batch 11 g_loss : 2.158445
batch 12 d_loss : 0.199852
batch 12 g_loss : 2.441461
batch 13 d_loss : 0.261617
batch 13 g_loss : 2.587444
batch 14 d_loss

In [33]:
generate(1)

  This is separate from the ipykernel package so we can avoid doing imports until
  # Remove the CWD from sys.path while we load stuff.
  del sys.path[0]


(1, 1, 128, 128)
