# 数据预处理

In [67]:
from keras.datasets import mnist
from keras.utils import np_utils

In [68]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [69]:
X_train.shape

(60000, 28, 28)

In [70]:
y_train.shape

(60000,)

In [71]:
X_train = X_train.reshape(-1, 1,28, 28)/255.
X_test = X_test.reshape(-1, 1,28, 28)/255.
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)

In [72]:
X_train[0].shape

(1, 28, 28)

In [73]:
y_train[1]

array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

# 建立神经网络

In [74]:
from keras.models import Sequential
from keras.layers import Dense, Activation, Convolution2D, MaxPooling2D, Flatten
from keras.optimizers import Adam

## 卷积

In [75]:
model = Sequential()
# output shape:(32, 28, 28)
model.add(Convolution2D(
        batch_input_shape=(None, 1, 28, 28),
        filters=32,
        kernel_size=5,
        strides=1,
        padding="same",
        data_format="channels_first"
    ))
model.add(Activation("relu"))

## 池化

In [76]:
# output shape:(32, 14, 14)
model.add(MaxPooling2D(
        pool_size=2,
        strides=2,
        padding="same",
        data_format="channels_first"
    )) # 默认为channels_last ，即 (batch, height, width, channels)

## 卷积

In [77]:
# output shape: (64, 14, 14)
model.add(Convolution2D(64, 5, strides=1, padding="same", data_format="channels_first"))
model.add(Activation("relu"))

## 池化

In [78]:
# output shipe: (64, 7, 7)
model.add(MaxPooling2D(2, 2, "same"))

## 全连接

In [79]:
# output shape (64 * 7 * 7)
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation("relu"))

## 分类

In [80]:
model.add(Dense(10))
model.add(Activation("softmax"))

## 优化器

In [81]:
adam = Adam(lr=1e-4)

# 激活模型

In [82]:
model.compile(optimizer=adam, loss="categorical_crossentropy", metrics=["accuracy"])

# 训练

In [83]:
model.fit(X_train, y_train, epochs=1, batch_size=32)

Epoch 1/1


<keras.callbacks.History at 0xef147f0>

# 测试模型

In [84]:
loss, accuracy = model.evaluate(X_test, y_test)

loss, accuracy



(0.08725842625834047, 0.971)