# DualGAN

ref. YI, Zili et al.  
     Dualgan: Unsupervised dual learning for image-to-image translation.  
     In: Proceedings of the IEEE international conference on computer vision. 2017. p. 2849-2857.

In [1]:
import os
import numpy as np
import scipy
import matplotlib.pyplot as plt

In [2]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Dropout
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import keras.backend as K

Using TensorFlow backend.


## Generator

In [None]:
def build_generator(img_dim):

    X = Input(shape=(img_dim,))

    model = Sequential()
    model.add(Dense(256, input_dim=img_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dropout(0.4))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dropout(0.4))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dropout(0.4))
    model.add(Dense(img_dim, activation='tanh'))

    X_translated = model(X)

    return Model(X, X_translated)

## Discriminator

In [None]:
def build_discriminator(img_dim):

    img = Input(shape=(img_dim,))

    model = Sequential()
    model.add(Dense(512, input_dim=img_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1))

    validity = model(img)

    return Model(img, validity)

# Helper

In [None]:
def sample_generator_input(X, batch_size):
    # Sample random batch of images from X
    idx = np.random.randint(0, X.shape[0], batch_size)
    return X[idx]

In [None]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

In [None]:
def save_imgs(G_AB, G_BA, epoch, 
              X_A, X_B, 
              img_rows, img_cols):
    r, c = 4, 4

    # Sample generator inputs
    imgs_A = sample_generator_input(X_A, c)
    imgs_B = sample_generator_input(X_B, c)

    # Images translated to their opposite domain
    fake_B = G_AB.predict(imgs_A)
    fake_A = G_BA.predict(imgs_B)

    gen_imgs = np.concatenate([imgs_A, fake_B, imgs_B, fake_A])
    gen_imgs = gen_imgs.reshape((r, c, img_rows, img_cols, 1))

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[i, j, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()

In [None]:
def train(self, epochs, batch_size=128, sample_interval=50):

    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5

    # Domain A and B (rotated)
    X_A = X_train[:int(X_train.shape[0]/2)]
    X_B = scipy.ndimage.interpolation.rotate(X_train[int(X_train.shape[0]/2):], 90, axes=(1, 2))

    X_A = X_A.reshape(X_A.shape[0], self.img_dim)
    X_B = X_B.reshape(X_B.shape[0], self.img_dim)

    clip_value = 0.01
    n_critic = 4

    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake = np.ones((batch_size, 1))

    for epoch in range(epochs):

        # Train the discriminator for n_critic iterations
        for _ in range(n_critic):

            # ----------------------
            #  Train Discriminators
            # ----------------------

            # Sample generator inputs
            imgs_A = self.sample_generator_input(X_A, batch_size)
            imgs_B = self.sample_generator_input(X_B, batch_size)

            # Translate images to their opposite domain
            fake_B = self.G_AB.predict(imgs_A)
            fake_A = self.G_BA.predict(imgs_B)

            # Train the discriminators
            D_A_loss_real = self.D_A.train_on_batch(imgs_A, valid)
            D_A_loss_fake = self.D_A.train_on_batch(fake_A, fake)

            D_B_loss_real = self.D_B.train_on_batch(imgs_B, valid)
            D_B_loss_fake = self.D_B.train_on_batch(fake_B, fake)

            D_A_loss = 0.5 * np.add(D_A_loss_real, D_A_loss_fake)
            D_B_loss = 0.5 * np.add(D_B_loss_real, D_B_loss_fake)

            # Clip discriminator weights
            for d in [self.D_A, self.D_B]:
                for l in d.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -clip_value, clip_value) for w in weights]
                    l.set_weights(weights)

        # ------------------
        #  Train Generators
        # ------------------

        # Train the generators
        g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, valid, imgs_A, imgs_B])

        # Plot the progress
        print ("%d [D1 loss: %f] [D2 loss: %f] [G loss: %f]" \
            % (epoch, D_A_loss[0], D_B_loss[0], g_loss[0]))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            self.save_imgs(epoch, X_A, X_B)


## main()

In [3]:
img_rows = 28
img_cols = 28
channels = 1
img_dim = img_rows * img_cols

In [4]:
# create optimizer
optimizer = Adam(0.0002, 0.5)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [None]:
# Build and compile the discriminators
D_A = build_discriminator(img_dim)
D_A.compile(loss=wasserstein_loss,
            optimizer=optimizer,
            metrics=['accuracy'])
D_B = build_discriminator(img_dim)
D_B.compile(loss=wasserstein_loss,
            optimizer=optimizer,
            metrics=['accuracy'])

In [None]:
#-------------------------
# Construct Computational
#   Graph of Generators
#-------------------------

In [None]:
# Build the generators
G_AB = build_generator(img_dim)
G_BA = build_generator(img_dim)

In [None]:
# For the combined model we will only train the generators
D_A.trainable = False
D_B.trainable = False

In [None]:
# The generator takes images from their respective domains as inputs
imgs_A = Input(shape=(img_dim,))
imgs_B = Input(shape=(img_dim,))

In [None]:
# Generators translates the images to the opposite domain
fake_B = G_AB(imgs_A)
fake_A = G_BA(imgs_B)

In [None]:
# The discriminators determines validity of translated images
valid_A = D_A(fake_A)
valid_B = D_B(fake_B)

In [None]:
# Generators translate the images back to their original domain
recov_A = G_BA(fake_B)
recov_B = G_AB(fake_A)

In [None]:
# The combined model  (stacked generators and discriminators)
combined = Model(inputs=[imgs_A, imgs_B], outputs=[valid_A, valid_B, recov_A, recov_B])
combined.compile(loss=[wasserstein_loss, wasserstein_loss, 'mae', 'mae'],
                 optimizer=optimizer,
                 loss_weights=[1, 1, 100, 100])

In [None]:
epochs=30000
epochs=30
train(D_A, D_B, G_AB, G_BA, combined,
      epochs=epochs, batch_size=32, sample_interval=200)