In [1]:
import utils
(train_x, train_y), (test_x, test_y) = utils.processed_data()

In [8]:
import keras
import keras.layers as layers
import keras.models as models
import keras.backend as K


def residual_block(x, filters, strides=1, **kwargs):
    """Constructs a residual convolutional block."""
    bn1 = layers.BatchNormalization(scale=False)(x)
    relu1 = layers.Activation("relu")(bn1)
    conv1 = layers.Conv2D(filters[0], 3, strides=strides, padding="same")(relu1)
    bn2 = layers.BatchNormalization(scale=False)(conv1)
    relu2 = layers.Activation("relu")(bn2)
    conv2 = layers.Conv2D(filters[1], 3, padding="same")(relu2)
    
    if strides != 1:
        x = layers.Conv2D(filters[1], 1, strides=strides, padding="same")(x)
        
    return layers.Add()([conv2, x])


def create_resnet():
    inputs = layers.Input((28,28))
    reshape = layers.Reshape((28,28,1))(inputs)
    
    # conv layer here?
    
    res1 = residual_block(reshape, (8, 8))
    res2 = residual_block(res1, (8, 8))
    res3 = residual_block(res2, (8, 16), strides=2)
    res4 = residual_block(res3, (16, 16))
    res5 = residual_block(res4, (16, 16))
    res6 = residual_block(res5, (16, 32), strides=2)
    res7 = residual_block(res6, (32, 32))
    
    pooling = layers.GlobalMaxPooling2D()(res4)
    outputs = layers.Dense(10, activation="softmax")(pooling)
    return models.Model(inputs=inputs, outputs=outputs)


model = create_resnet()

model.compile(
    optimizer=keras.optimizers.SGD(lr=1e-3),
    loss="categorical_crossentropy",
    metrics=[keras.metrics.categorical_accuracy]
)

history = model.fit(
        x=train_x,
        y=train_y,
        batch_size=256,
        epochs=20,
        validation_data=(test_x, test_y),
        verbose=1,
        shuffle=True
)

Train on 60000 samples, validate on 10000 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
