In [5]:

import cv2
import matplotlib.pyplot as plt
import numpy as np

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers import LeakyReLU, Dropout
from keras.layers import BatchNormalization
from keras.optimizers import Adam, RMSprop

In [6]:
class ProgressBar(object):
    """Class that measures the time and reports duration in s/m/h format"""
    def __init__(self):
        self.start_time = time.time()
        
    def elapsed(self, seconds):
        if seconds < 60:
            return str(seconds) + " sec"
        elif seconds < (60 * 60):
            return str(seconds / 60) + " min"
        else:
            return str(seconds / (60 * 60)) + " hr"
    def elapsed_time(self):
        print("Elapsed: %s " % self.elapsed(time.time() - self.start_time) )


# Getting Started with generating anime images

    We start by defining a class that expects to take in images of 400 x 400 size in RGB format.
    We will build a utility later on with OpenCV to ensure all images in our training data are converted to the correct format.
    We go ahead and initilize all the model parts and set them to None to be added at a later time.

In [7]:
class AnimeGAN(object):
    def __init__(self, img_rows=400, img_cols=400, channel=3):

        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.discriminator = None   # discriminator
        self.generator = None   # generator
        self.adv_model = None  # adversarial model
        self.disc_model = None  # discriminator model
        
   

## Building the discriminator

In [9]:
 def discriminator(self):
        if self.discriminator:
            return self.discriminator
        
        self.discriminator = Sequential()
        dropout = 0.2
        
        # In: 400 x 400 x 1, depth = 1
        # Out: 14 x 14 x 1, depth=64
        
        input_shape = (self.img_rows, self.img_cols, self.channel)
        
        self.discriminator.add(Conv2D(filters = 256, kernel_size=10, strides=(2, 2), padding='same', input_shape=input_shape))
        self.discriminator.add(LeakyReLU(alpha=0.3))
        self.discriminator.add(Dropout(dropout))

        self.discriminator.add(Conv2D(512, 10, strides=2, padding='same'))
        self.discriminator.add(LeakyReLU(alpha=0.3))
        self.discriminator.add(Dropout(dropout))

        self.discriminator.add(Conv2D(1024, 10, strides=2, padding='same'))
        self.discriminator.add(LeakyReLU(alpha=0.3))
        self.discriminator.add(Dropout(dropout))

        self.discriminator.add(Conv2D(2048, 10, strides=1, padding='same'))
        self.discriminator.add(LeakyReLU(alpha=0.3))
        self.discriminator.add(Dropout(dropout))

        # Out: 1-dim probability
        self.discriminator.add(Flatten())
        self.discriminator.add(Dense(1))
        self.discriminator.add(Activation('sigmoid'))
        self.discriminator.summary()
        return self.discriminator


## create the generator

In [10]:
def generator(self):
        if self.generator:
            return self.generator
        self.generator = Sequential()
        dropout = 0.2
        depth = 256*4
        dim = 7
        # In: 100
        # Out: dim x dim x depth
        self.generator.add(Dense(dim*dim*depth, input_dim=100))
        self.generator.add(BatchNormalization(momentum=0.9))
        self.generator.add(Activation('relu'))
        self.generator.add(Reshape((dim, dim, depth)))
        self.generator.add(Dropout(dropout))

        # In: dim x dim x depth
        # Out: 2*dim x 2*dim x depth/2
        self.generator.add(UpSampling2D())
        self.generator.add(Conv2DTranspose(int(depth / 2), 10, padding='same'))
        self.generator.add(BatchNormalization(momentum=0.9))
        self.generator.add(Activation('relu'))

        self.generator.add(UpSampling2D())
        self.generator.add(Conv2DTranspose(int(depth / 4), 10, padding='same'))
        self.generator.add(BatchNormalization(momentum=0.9))
        self.generator.add(Activation('relu'))

        self.generator.add(Conv2DTranspose(int(depth / 8), 10, padding='same'))
        self.generator.add(BatchNormalization(momentum=0.9))
        self.generator.add(Activation('relu'))

        # Out: 28 x 28 x 1 grayscale image [0.0,1.0] per pix
        self.generator.add(Conv2DTranspose(1, 10, padding='same'))
        self.generator.add(Activation('sigmoid'))
        self.generator.summary()
        return self.generator

# Create the models

In [13]:
def discriminator_model(self):
        if self.discriminator_model:
            return self.discriminator_model
        optimizer = RMSprop(lr=0.0002, decay=6e-8)
        self.discriminator_model = Sequential()
        self.discriminator_model.add(self.discriminator())
        self.discriminator_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        return self.discriminator_model

def adversarial_model(self):
    if self.adversarial_model:
        return self.adversarial_model
    optimizer = RMSprop(lr=0.0001, decay=3e-8)
    self.adversarial_model = Sequential()
    self.adversarial_model.add(self.generator())
    self.adversarial_model.add(self.discriminator())
    self.adversarial_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return self.adversarial_model

In [None]:

class TrainAnimeGAN(object):
    def __init__(self):
        self.img_rows = 400
        self.img_cols = 400
        self.channel = 3

        self.x_train = input_data.read_data_sets("mnist",\
        	one_hot=True).train.images
        self.x_train = self.x_train.reshape(-1, self.img_rows,\
        	self.img_cols, 1).astype(np.float32)

        self.DCGAN = DCGAN()
        self.discriminator =  self.DCGAN.discriminator_model()
        self.adversarial = self.DCGAN.adversarial_model()
        self.generator = self.DCGAN.generator()

    def train(self, train_steps=2000, batch_size=256, save_interval=0):
        noise_input = None
        if save_interval>0:
            noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])
        for i in range(train_steps):
            images_train = self.x_train[np.random.randint(0,
                self.x_train.shape[0], size=batch_size), :, :, :]
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            images_fake = self.generator.predict(noise)
            x = np.concatenate((images_train, images_fake))
            y = np.ones([2*batch_size, 1])
            y[batch_size:, :] = 0
            d_loss = self.discriminator.train_on_batch(x, y)

            y = np.ones([batch_size, 1])
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            a_loss = self.adversarial.train_on_batch(noise, y)
            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
            print(log_mesg)
            if save_interval>0:
                if (i+1)%save_interval==0:
                    self.plot_images(save2file=True, samples=noise_input.shape[0],\
                        noise=noise_input, step=(i+1))

    def plot_images(self, save2file=False, fake=True, samples=16, noise=None, step=0):
        filename = 'mnist.png'
        if fake:
            if noise is None:
                noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
            else:
                filename = "mnist_%d.png" % step
            images = self.generator.predict(noise)
        else:
            i = np.random.randint(0, self.x_train.shape[0], samples)
            images = self.x_train[i, :, :, :]

        plt.figure(figsize=(10,10))
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [self.img_rows, self.img_cols])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()
        if save2file:
            plt.savefig(filename)
            plt.close('all')
        else:
            plt.show()