# High-level Keras (TF) MNIST Example

In [1]:
import os
import sys
import numpy as np
os.environ['KERAS_BACKEND'] = "tensorflow"
import keras as K
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from common.params import *
from common.utils import *

Using TensorFlow backend.


In [2]:
print(K.__version__)
print(np.__version__)
print(K.backend.backend())
print(K.backend.image_data_format())

2.0.3
1.11.2
tensorflow
channels_last


In [3]:
def create_lenet():
    model = Sequential()
    model.add(Conv2D(20, kernel_size=(5, 5), activation='tanh', input_shape=(28, 28, 1)))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
    model.add(Conv2D(50, (5, 5), activation='tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
    model.add(Flatten())
    model.add(Dense(500, activation='tanh'))
    model.add(Dense(N_CLASSES, activation='softmax'))
    return model

In [4]:
def init_model():
    model = create_lenet()
    model.compile(loss = "categorical_crossentropy",
                  optimizer = K.optimizers.SGD(lr=LR, momentum=MOMENTUM, decay=0.0, nesterov=False),
                  metrics = ['accuracy'])
    return model

In [5]:
%%time
# Data into format for library
x_train, x_test, y_train, y_test = mnist_for_library(channel_first=False, one_hot=True)

CPU times: user 248 ms, sys: 200 ms, total: 448 ms
Wall time: 446 ms


In [6]:
%%time
# Initialise model
model = init_model()

CPU times: user 120 ms, sys: 8 ms, total: 128 ms
Wall time: 126 ms


In [7]:
%%time
# Train model
model.fit(x_train,
          y_train,
          batch_size=BATCHSIZE,
          epochs=EPOCHS,
          verbose=1)

Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12
CPU times: user 2min 19s, sys: 46.6 s, total: 3min 6s
Wall time: 3min 11s


<keras.callbacks.History at 0x7f8aa8e31390>

In [8]:
%%time
# Test model
acc = model.evaluate(x_test, y_test, verbose=0)
print("Accuracy ", acc[-1])

Accuracy  0.9893
CPU times: user 1.43 s, sys: 296 ms, total: 1.73 s
Wall time: 1.32 s
