# Context encoder

ref. PATHAK, Deepak et al. Context encoders: Feature learning by inpainting. In: Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. p. 2536-2544.

In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from keras.datasets import cifar10
from keras.layers import Input, Dense, Flatten, Dropout
from keras.layers import BatchNormalization, Activation
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

Using TensorFlow backend.


# Generator

In [5]:
def build_generator(img_shape, channels):
    """ the structure is hardcoded
    """

    model = Sequential()

    # Encoder
    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Conv2D(512, kernel_size=1, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.5))

    # Decoder
    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=3, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=3, padding="same"))
    model.add(Activation('relu'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(channels, kernel_size=3, padding="same"))
    model.add(Activation('tanh'))

    model.summary()

    masked_img = Input(shape=img_shape)
    gen_missing = model(masked_img)

    return Model(masked_img, gen_missing)

# Discriminator

In [6]:
def build_discriminator(missing_shape):

    model = Sequential()

    model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=missing_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(256, kernel_size=3, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=missing_shape)
    validity = model(img)

    return Model(img, validity)

# to mask the image

In [14]:
def mask_randomly(imgs, img_rows, mask_height, mask_width, channels):
    y1 = np.random.randint(0, img_rows - mask_height, imgs.shape[0])
    y2 = y1 + mask_height
    x1 = np.random.randint(0, img_rows - mask_width, imgs.shape[0])
    x2 = x1 + mask_width

    masked_imgs = np.empty_like(imgs)
    missing_parts = np.empty((imgs.shape[0], mask_height, mask_width, channels))
    for i, img in enumerate(imgs):
        masked_img = img.copy()
        _y1, _y2, _x1, _x2 = y1[i], y2[i], x1[i], x2[i]
        missing_parts[i] = masked_img[_y1:_y2, _x1:_x2, :].copy()
        masked_img[_y1:_y2, _x1:_x2, :] = 0
        masked_imgs[i] = masked_img

    return masked_imgs, missing_parts, (y1, y2, x1, x2)

# save images

In [19]:
def sample_images(G, epoch, 
                  imgs, img_rows, mask_height, mask_width, 
                  channels):
    r, c = 3, 6

    masked_imgs, missing_parts, (y1, y2, x1, x2) = mask_randomly(imgs, img_rows, mask_height, mask_width, channels)
    gen_missing = G.predict(masked_imgs)

    imgs = 0.5 * imgs + 0.5
    masked_imgs = 0.5 * masked_imgs + 0.5
    gen_missing = 0.5 * gen_missing + 0.5

    fig, axs = plt.subplots(r, c)
    for i in range(c):
        axs[0,i].imshow(imgs[i, :,:])
        axs[0,i].axis('off')
        axs[1,i].imshow(masked_imgs[i, :,:])
        axs[1,i].axis('off')
        filled_in = imgs[i].copy()
        filled_in[y1[i]:y2[i], x1[i]:x2[i], :] = gen_missing[i]
        axs[2,i].imshow(filled_in)
        axs[2,i].axis('off')
    fig.savefig("images/%d.png" % epoch)
    plt.close()


# train the model

In [17]:
def train(D, G, combined, 
          img_rows, mask_height, mask_width, channels,
          epochs, batch_size=128, sample_interval=50):

    # Load the dataset
    (X_train, y_train), (_, _) = cifar10.load_data()

    # Extract dogs and cats
    X_cats = X_train[(y_train == 3).flatten()]
    X_dogs = X_train[(y_train == 5).flatten()]
    X_train = np.vstack((X_cats, X_dogs))

    # Rescale -1 to 1
    X_train = X_train / 127.5 - 1.
    y_train = y_train.reshape(-1, 1)

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Select a random batch of images
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]

        masked_imgs, missing_parts, _ = mask_randomly(imgs, img_rows, mask_height, mask_width, channels)

        # Generate a batch of new images
        gen_missing = G.predict(masked_imgs)

        # Train the discriminator
        d_loss_real = D.train_on_batch(missing_parts, valid)
        d_loss_fake = D.train_on_batch(gen_missing, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = combined.train_on_batch(masked_imgs, [missing_parts, valid])

        # Plot the progress
        print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            idx = np.random.randint(0, X_train.shape[0], 6)
            imgs = X_train[idx]
            sample_images(G, epoch, imgs, img_rows, mask_height, mask_width, channels)

# main()

In [3]:
img_rows = 32
img_cols = 32
mask_height = 8
mask_width = 8
channels = 3
num_classes = 2
img_shape = (img_rows, img_cols, channels)
missing_shape = (mask_height, mask_width, channels)

In [4]:
# create optimizer
optimizer = Adam(0.0002, 0.5)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [8]:
# Build and compile the discriminator
D = build_discriminator(missing_shape)
D.compile(loss='binary_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy'])

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 4, 4, 64)          1792      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 4, 4, 64)          0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 4, 4, 64)          256       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 2, 2, 128)         73856     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 2, 2, 128)         0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 2, 2, 128)         512       
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 2, 2, 256)        

In [9]:
# Build the generator
G = build_generator(img_shape, channels)

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_4 (Conv2D)            (None, 16, 16, 32)        896       
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 16, 16, 32)        0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 16, 16, 32)        128       
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 8, 8, 64)          18496     
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 8, 8, 64)          0         
_________________________________________________________________
batch_normalization_5 (Batch (None, 8, 8, 64)          256       
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 4, 4, 128)        

In [10]:
# The generator takes noise as input and generates the missing
# part of the image
masked_img = Input(shape=img_shape)
gen_missing = G(masked_img)

In [11]:
# For the combined model we will only train the generator
D.trainable = False

In [12]:
# The discriminator takes generated images as input and determines
# if it is generated or if it is a real image
valid = D(gen_missing)

In [13]:
# The combined model  (stacked generator and discriminator)
# Trains generator to fool discriminator
combined = Model(masked_img , [gen_missing, valid])
combined.compile(loss=['mse', 'binary_crossentropy'],
    loss_weights=[0.999, 0.001],
    optimizer=optimizer)

# train

In [None]:
epochs=3000 
# epochs=30000
train(D, G, combined, 
      img_rows, mask_height, mask_width, channels,
      epochs=epochs, batch_size=64, sample_interval=50)

0 [D loss: 0.070684, acc: 98.44%] [G loss: 0.575176, mse: 0.575215]
1 [D loss: 0.063318, acc: 100.00%] [G loss: 0.557987, mse: 0.558184]
2 [D loss: 0.058033, acc: 100.00%] [G loss: 0.488121, mse: 0.488197]
3 [D loss: 0.085284, acc: 99.22%] [G loss: 0.526732, mse: 0.526857]
4 [D loss: 0.086994, acc: 100.00%] [G loss: 0.481686, mse: 0.481867]
5 [D loss: 0.071761, acc: 100.00%] [G loss: 0.542891, mse: 0.542966]
6 [D loss: 0.051698, acc: 100.00%] [G loss: 0.477299, mse: 0.477404]
7 [D loss: 0.098667, acc: 98.44%] [G loss: 0.465935, mse: 0.465790]
8 [D loss: 0.078381, acc: 98.44%] [G loss: 0.490100, mse: 0.489839]
9 [D loss: 0.059609, acc: 100.00%] [G loss: 0.445786, mse: 0.445683]
10 [D loss: 0.294534, acc: 88.28%] [G loss: 0.392258, mse: 0.391574]
11 [D loss: 0.187543, acc: 94.53%] [G loss: 0.434239, mse: 0.433696]
12 [D loss: 0.281152, acc: 88.28%] [G loss: 0.392676, mse: 0.391662]
13 [D loss: 0.302655, acc: 87.50%] [G loss: 0.397096, mse: 0.396132]
14 [D loss: 0.185801, acc: 92.97%] [G 