# Model

## Architecture

For "Gender swap" we have used CycleGAN. The architecture is comprised of four models, two discriminator models, and two generator models.

![Diagram](pictures/GanDiagram.png)

![Diagram](pictures/GanDiagram2.png)

## Implementation

The <b>discriminator</b> is a deep convolutional neural network that performs image classification. It takes a source image as input and predicts the likelihood of whether the target image is a real or fake image. Two discriminator models are used, one for Domain-A and one for Domain-B.<p>
The output of the model depends on the size of the input image but may be one value or a square activation map of values. Each value is a probability for the likelihood that a patch in the input image is real. These values can be averaged to give an overall likelihood or classification score if needed.

In [None]:
from keras.initializers import RandomNormal
from keras.models import Input
from keras.models import Model
from keras.layers import Conv2D
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import Conv2DTranspose
from keras.optimizers import Adam
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization

def discriminator(image_shape):
    init = RandomNormal(stddev=0.02)
    in_image = Input(shape=image_shape)
    
    layer = Conv2D(64, (4,4), strides=(2,2),padding='same',kernel_initializer=init)(in_image)
    layer = LeakyReLU(alpha=0.2)(layer)

    layer = Conv2D(128, (4,4), strides=(2,2),padding='same',kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = LeakyReLU(alpha=0.2)(layer)

    layer = Conv2D(256, (4,4), strides=(2,2),padding='same',kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = LeakyReLU(alpha=0.2)(layer)

    layer = Conv2D(512, (4,4), strides=(2,2),padding='same',kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = LeakyReLU(alpha=0.2)(layer)

    layer = Conv2D(512, (4,4), padding='same',kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = LeakyReLU(alpha=0.2)(layer)

    patch_out = Conv2D(1,(4,4), padding='same', kernel_initializer=init)(layer)

    model = Model(in_image, patch_out)

    model.compile(loss='mse', optimizer=Adam(lr=0.0002,beta_1=0.5),loss_weights=[0.5])
    
    return model

The <b>generator</b> is an encoder-decoder model architecture. The model takes a source image (e.g. male photo) and generates a target image (e.g. female photo). It does this by first downsampling or encoding the input image down to a bottleneck layer, then interpreting the encoding with a number of ResNet layers that use skip connections, followed by a series of layers that upsample or decode the representation to the size of the output image.

In [None]:
def resnet_block(n_filters, input_layer):
    init = RandomNormal(stddev=0.02)

    layer = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = Activation('relu')(layer)

    layer = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)

    layer = Concatenate()([layer, input_layer])

    return layer

def generator(image_shape, n_resnet=9):
    init = RandomNormal(stddev=0.02)
    in_image = Input(shape=image_shape)

    layer = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = Activation('relu')(layer)

    layer = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = Activation('relu')(layer)

    layer = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = Activation('relu')(layer)

    for _ in range(n_resnet):
        layer = resnet_block(256,layer)
    
    layer = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = Activation('relu')(layer)

    layer = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)
    layer = Activation('relu')(layer)

    layer = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(layer)
    layer = InstanceNormalization(axis=-1)(layer)
    out_image = Activation('tanh')(layer)

    model = Model(in_image, out_image)
    return model


Each generator model is optimized via the combination of four outputs with four loss functions:

- Adversarial loss (L2 or mean squared error).
- Identity loss (L1 or mean absolute error).
- Forward cycle loss (L1 or mean absolute error).
- Backward cycle loss (L1 or mean absolute error).

We have achieved this by defining a composite model used to train each generator model that is responsible for only updating the weights of that generator model, although it is required to share the weights with the related discriminator model and the other generator model.



In [None]:
def composite_model(g_model_1,d_model,g_model_2,image_shape):
    g_model_1.trainable = True
    d_model.trainable = False
    g_model_2.trainable = False

    input_gen = Input(shape=image_shape)
    gen1_out = g_model_1(input_gen)
    output_d = d_model(gen1_out)

    input_id = Input(shape=image_shape)
    output_id = g_model_1(input_id)

    output_f = g_model_2(gen1_out)

    gen2_out = g_model_2(input_id)
    output_b = g_model_1(gen2_out)

    model = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])

    opt = Adam(lr=0.0002, beta_1=0.5)

    model.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)
    return model

In [None]:
from numpy import load
from numpy import ones
from numpy import zeros
from numpy.random import randint
from matplotlib import pyplot

def load_real_samples(filename):
    data = load(filename)
    X1, X2 = data['arr_0'], data['arr_1']

    X1 = (X1 - 127.5) / 127.5
    X2 = (X2 - 127.5) / 127.5
    return [X1,X2]

def generate_real_samples(dataset,n_samples, patch_shape):
    ix = randint(0, dataset.shape[0],n_samples)
    X = dataset[ix]
    Y = ones((n_samples,patch_shape,patch_shape,1))
    return X, Y

def generate_fake_samples(g_model, dataset, patch_shape):
    X = g_model.predict(dataset)
    Y = zeros((len(X),patch_shape,patch_shape,1))
    return X, Y

# This function will save each generator model to the current directory in H5 format,
# including the training iteration number in the filename
def save_models(step, g_model_AtoB, g_model_BtoA):
    path = '../../models/'
    filename1 = 'g_model_AtoB_%06d.h5' % (step+1)
    g_model_AtoB.save(path+filename1)
    filename2 = 'g_model_BtoA_%06d.h5' % (step+1)
    g_model_BtoA.save(path+filename2)

# This function uses a given generator model to generate translated versions of
# a few randomly selected source photographs and saves the plot to file.
def summarize_performance(step, g_model, trainX, name, n_samples=5):
    path = '../../models/'
    X_in, _ = generate_real_samples(trainX, n_samples, 0)
    X_out, _ = generate_fake_samples(g_model, X_in, 0)
    X_in = (X_in + 1) / 2.0
    X_out = (X_out + 1) / 2.0
    for i in range(n_samples):
        pyplot.subplot(2, n_samples, 1 + i)
        pyplot.axis('off')
        pyplot.imshow(X_in[i])
    for i in range(n_samples):
        pyplot.subplot(2, n_samples, 1 + n_samples + i)
        pyplot.axis('off')
        pyplot.imshow(X_out[i])
    filename1 = '%s_generated_plot_%06d.png' % (name, (step+1))
    pyplot.savefig(path+filename1)
    pyplot.close()

In [None]:
def update_image_pool(pool, images, max_size=50):
    selected = list()
    for image in images:
        if len(pool) < max_size:
            pool.append(image)
            selected.append(image)
        elif random() < 0.5:
            selected.append(image)
        else:
            ix = randint(0, len(pool))
            selected.append(pool[ix])
            pool[ix] = image
    return asarray(selected)

The train() function takes all six models (two discriminator, two generator, and two composite models) as arguments along with the dataset and trains the models.

The order of model updates is implemented to match the official Torch implementation. First, a batch of real images from each domain is selected, then a batch of fake images for each domain is generated. The fake images are then used to update each discriminator’s fake image pool.

Next, the Generator-A model (male to female) is updated via the composite model, followed by the Discriminator-A model (female). Then the Generator-B (female to male) composite model and Discriminator-B (male) models are updated.

Loss for each of the updated models is then reported at the end of the training iteration. Importantly, only the weighted average loss used to update each generator is reported.

In [None]:
def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset):
    n_epochs, n_batch, = 100, 1
    n_patch = d_model_A.output_shape[1]
    trainA, trainB = dataset
    poolA, poolB = list(), list()
    bat_per_epo = int(len(trainA) / n_batch)
    n_steps = bat_per_epo * n_epochs

    for i in range(n_steps):

        X_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch)
        X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)

        X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch)
        X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch)

        X_fakeA = update_image_pool(poolA, X_fakeA)
        X_fakeB = update_image_pool(poolB, X_fakeB)

        g_loss2, _, _, _, _  = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])

        dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)
        dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)

        g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])

        dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)
        dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)

        print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))

        if (i+1) % (bat_per_epo * 1) == 0:

            summarize_performance(i, g_model_AtoB, trainA, 'AtoB')

            summarize_performance(i, g_model_BtoA, trainB, 'BtoA')
        if (i+1) % (bat_per_epo / 10) == 0:

            save_models(i, g_model_AtoB, g_model_BtoA)


dataset = load_real_samples('../data/genderswaptest.npz')
print('Loaded', dataset[0].shape, dataset[1].shape)

image_shape = dataset[0].shape[1:]

g_model_AtoB = generator(image_shape)

g_model_BtoA = generator(image_shape)

d_model_A = discriminator(image_shape)

d_model_B = discriminator(image_shape)

c_model_AtoB = composite_model(g_model_AtoB, d_model_B, g_model_BtoA, image_shape)

c_model_BtoA = composite_model(g_model_BtoA, d_model_A, g_model_AtoB, image_shape)

train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset)

### First results 

We have trained these models on 3000 male and 3000 female photos, through 5 epochs. It run approximately 8 hours on Google Cloud Virtual Machine. For now, we think it is too less epochs and too small dataset for such a complicated structure of neural networks, if we want to see some usable transformations.

![](models/AtoB_generated_plot_015080.png)
![](models/BtoA_generated_plot_015080.png)