In [1]:
from keras.models import Model
from keras.layers import (
    Input,
    Activation,
    Dense,
    Flatten
)
from keras.layers.convolutional import (
    Conv2D,
    MaxPooling2D,
    AveragePooling2D
)
from keras.layers.merge import add
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras import backend as K

Using TensorFlow backend.


In [22]:
if K.image_dim_ordering() == 'tf':
    ROW_AXIS = 1
    COL_AXIS = 2
    CHANNEL_AXIS = 3
input_shape = (32,32,3)
output_shape = 100

In [32]:
def _conv_bn_relu(filters, kernel_size=(3,3), strides=(1, 1)):
    """
    conv -> BN -> relu
    """
    def f(inputs):
        conv = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
                      padding='same', kernel_initializer="he_normal", kernel_regularizer=l2(1e-4))(inputs)
        norm = BatchNormalization(axis=CHANNEL_AXIS)(conv)
        return Activation("relu")(norm)
    return f

def _bn_relu_conv(filters, kernel_size=(3,3), strides=(1, 1)):
    """
    BN -> relu -> conv
    """
    def f(inputs):
        norm = BatchNormalization(mode=0, axis=CHANNEL_AXIS)(inputs)
        activation = Activation("relu")(norm)
        return Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
                      padding='same', kernel_initializer="he_normal", kernel_regularizer=l2(1e-4))(activation)
    return f

def _basic_block(filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
    """
    basic residual block : 3*3 kernel
    """
    def f(inputs):
        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv1 = Conv2D(filters=filters, kernel_size=(3, 3),
                           strides=init_strides,
                           padding="same",
                           kernel_initializer="he_normal",
                           kernel_regularizer=l2(1e-4))(inputs)
        else:
            conv1 = _bn_relu_conv(filters, (3, 3), strides=init_strides)(inputs)
        residual = _bn_relu_conv(filters, (3, 3))(conv1)
        return _shortcut(inputs, residual)
    return f

def bottleneck(filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
    """
    bottleneck : 1*1 filters -> 3*3 filters -> 1*1 4*filters
    """
    def f(inputs):
        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv_1_1 = Conv2D(filters=filters, kernel_size=(1, 1),
                              strides=init_strides,
                              padding="same",
                              kernel_initializer="he_normal",
                              kernel_regularizer=l2(1e-4))(inputs)
        else:
            conv_1_1 = _bn_relu_conv(filters=filters, kernel_size=(1, 1),
                                     strides=init_strides)(inputs)

        conv_3_3 = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv_1_1)
        residual = _bn_relu_conv(filters=filters * 4, kernel_size=(1, 1))(conv_3_3)
        return _shortcut(inputs, residual)
    return f

def _shortcut(inputs, residual):
    input_shape = K.int_shape(inputs)
    residual_shape = K.int_shape(residual)
    stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS]))
    stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS]))
    equal_channels = residual._keras_shape[CHANNEL_AXIS] == inputs._keras_shape[CHANNEL_AXIS]

    shortcut = inputs
    if stride_width > 1 or stride_height > 1 or not equal_channels:
        shortcut = Conv2D(filters=residual._keras_shape[CHANNEL_AXIS], kernel_size=(1, 1), strides=(stride_width, stride_height),
                          padding="valid", kernel_initializer="he_normal")(inputs)

    return add([shortcut, residual])

def _residual_block(block_function, filters, repetitions, is_first_layer=False):
    def f(inputs):
        for i in range(repetitions):
            init_strides = (1, 1)
            if i == 0 and not is_first_layer:
                init_strides = (2, 2)
            inputs = block_function(filters=filters, init_strides=init_strides,
                                    is_first_block_of_first_layer=(is_first_layer and i == 0))(inputs)
        return inputs
    return f

In [33]:
inputs = Input(shape=input_shape)
conv1 = _conv_bn_relu(64, (7, 7), (2, 2))(inputs)
pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(conv1)

# Residual blocks
block_fn = _basic_block
block1 = _residual_block(block_fn, filters=64, repetitions=2, is_first_layer=True)(pool1)
block2 = _residual_block(block_fn, filters=128, repetitions=2)(block1)
block3 = _residual_block(block_fn, filters=256, repetitions=2)(block2)
block4 = _residual_block(block_fn, filters=512, repetitions=2)(block3)

# Classifier block
pool2 = AveragePooling2D(pool_size=(7, 7), strides=(1, 1), padding="same")(block4)
flatten1 = Flatten()(pool2)
dense = Dense(units=output_shape, kernel_initializer="he_normal", activation="softmax")(flatten1)

model = Model(inputs=inputs, outputs=dense)

