In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import os
import sys

import pandas as pd

from keras.models import Model
import keras.backend as K
from keras.layers import Input, Dense, Conv2D, Add, Activation
from keras.callbacks import LearningRateScheduler

from keras.datasets import mnist, cifar10

Using TensorFlow backend.


In [3]:
def obtain(dir_path):
    """
    Downloads the dataset to ``dir_path``.
    """

    dir_path = os.path.expanduser(dir_path)
    print('Downloading the dataset')
    import urllib
    urllib.urlretrieve('http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_train.amat',os.path.join(dir_path,'binarized_mnist_train.amat'))
    urllib.urlretrieve('http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_valid.amat',os.path.join(dir_path,'binarized_mnist_valid.amat'))
    urllib.urlretrieve('http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_test.amat',os.path.join(dir_path,'binarized_mnist_test.amat'))

    print('Done                     ')

# Model specifications

## Custom Layers

In [4]:
class MaskedConv2D(Conv2D):
    """
    Masked Convolution from [1]. Contains the same implementation of Conv2D from keras, but
    allows one to specify whether the mask type is 'A', 'B', or None.
    
    Params
    ------
    
    mask_type: string, default=None
        Determines the masking type for the convolution from [1].
    
    References
    ----------
    
    [1] https://arxiv.org/pdf/1601.06759.pdf
    """
    
    def __init__(self, filters, kernel_size, padding='same', mask_type=None, mask_rgb=True, **kwargs):
        super(MaskedConv2D, self).__init__(filters, kernel_size, padding=padding, **kwargs)
        self.mask_type = mask_type
        self.mask_rgb = mask_rgb
        
    def build(self, input_shape):
        super(MaskedConv2D, self).build(input_shape)
        
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (input_dim, self.filters)
        
        # assert that the kernel size is odd
        assert self.kernel_size[0] % 2 == 1
        assert self.kernel_size[1] % 2 == 1
        
        center = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
        self.mask = np.ones(kernel_shape)
        
        # mask out values right of center
        self.mask[center[0]:, center[1]+1:, :, :] = 0
        
        # mask out values below center
        self.mask[center[0]+1:, :, :, :] = 0
        
        # mask out center if masking type is 'A'
        if self.mask_type == 'A':
            self.mask[center[0], center[0], :, :] = 0
        
        # mask RGB channels
        if self.mask_rgb:
            if input_dim >= 1:
                self.mask[center[0], center[0], 1:, 0] = 0
            if input_dim >= 2:
                self.mask[center[0], center[0], 2:, 1] = 0
            if input_dim >= 3:
                self.mask[center[0], center[0], 3:, 2] = 0
        
        self.mask = K.variable(self.mask)
        
    def call(self, inputs):
        if self.mask_type is None:
            return super(MaskedConv2D, self).call(inputs)
        outputs = K.conv2d(
            inputs,
            self.kernel * self.mask,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate
        )
        if self.use_bias:
            outputs = K.bias_add(
            outputs,
            self.bias,
            data_format=self.data_format
            )
        if self.activation is not None:
            return self.activation(outputs)
        return outputs

In [5]:
class ResidualBlock(object):
    
    def __init__(self, filters, mask_rgb=True):
        self.filters = filters
        self.mask_rgb = mask_rgb
        
    def __call__(self, model):
        block = Activation('relu')(model)
        block = Conv2D(self.filters // 2, 1, activation='relu')(block)
        block = MaskedConv2D(self.filters // 2, 3, mask_type='B', mask_rgb=self.mask_rgb, activation='relu')(block)
        block = Conv2D(self.filters, 1)(block)
        
        return Add()([model, block])

## Models

In [16]:
def pixel_cnn(input_shape, nb_features, nb_blocks, mask_rgb=True):
    x = Input(shape=input_shape)
    y = MaskedConv2D(nb_features, 7, mask_type='A', mask_rgb=mask_rgb)(x)

    for i in range(nb_blocks):
        y = ResidualBlock(nb_features, mask_rgb=mask_rgb)(y)

    for i in range(2):
        y = Activation('relu')(y)
        y = Conv2D(input_shape[-1], 1)(y)
    
    y = Activation('sigmoid')(y)    

    model = Model(x, y)
    
    return model

# Data

In [17]:
def load_binary_mnist():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    x_train = x_train/255.
    x_test = x_test/255.

    x_train = (x_train > np.random.rand(*x_train.shape)).astype(np.int32)
    x_test = (x_test > np.random.rand(*x_test.shape)).astype(np.int32)
    
    return (x_train, y_train), (x_test, y_test)

def load_color_mnist():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = np.repeat(np.expand_dims(x_train, -1), 3, axis=-1)
    x_test = np.repeat(np.expand_dims(x_test, -1), 3, axis=-1)
    x_train = x_train/255.
    x_test = x_test/255.

    x_train = (x_train > np.random.rand(*x_train.shape)).astype(np.int32)
    x_test = (x_test > np.random.rand(*x_test.shape)).astype(np.int32)
    
    return (x_train, y_train), (x_test, y_test)

# Training

In [18]:
(x_train, y_train), (x_test, y_test) = load_color_mnist()

input_shape = x_train.shape[1:]

model = pixel_cnn(input_shape, 16, 12)
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 28, 28, 3)    0                                            
__________________________________________________________________________________________________
masked_conv2d_27 (MaskedConv2D) (None, 28, 28, 16)   2368        input_3[0][0]                    
__________________________________________________________________________________________________
activation_31 (Activation)      (None, 28, 28, 16)   0           masked_conv2d_27[0][0]           
__________________________________________________________________________________________________
conv2d_53 (Conv2D)              (None, 28, 28, 8)    136         activation_31[0][0]              
__________________________________________________________________________________________________
masked_con

In [None]:
def pixel_cnn_loss(y_true, y_pred):
    return K.sum(K.binary_crossentropy(y_true, y_pred), axis=(-1, -2, -3))

nb_regions = 5
regions = cumsum([3**i for i in range(nb_regions+1)]) + 300
schedule = lambda i, lr: 0.001 * 10**(-sum(regions <= i)/nb_regions)
lr_schedule = LearningRateScheduler(schedule)

epochs = regions[-1]
batch_size = 1000

model.compile(loss=pixel_cnn_loss, optimizer='adam')
model.fit(x_train, x_train, batch_size=batch_size, epochs=epochs, callbacks=[lr_schedule])

In [None]:
x_half = x_train[:100].copy()
x_half[:, 14:] = 0
imshow(x_half[1][:, :, 0])

In [None]:
x_pred = np.repeat(np.expand_dims(x_half[1], 0), 100, axis=0)
for i in range(14, 28):
    for j in range(0, 28):
        pred = model.predict(x_pred)
        x_pred[:, i, j] = pred[:, i, j] > np.random.rand(*pred[:, i, j].shape)

In [None]:
fig, axes = plt.subplots(5, 5, dpi=300, figsize=(3, 3), constrained_layout=True)

for i, ax in enumerate(axes):
    for j, a in enumerate(ax):
        a.imshow(x_pred[i*10 + j, :, :, 0])
        a.set_xticks(())
        a.set_yticks(())

In [130]:
np.repeat?