In [25]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [12]:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

In [13]:
train_images, test_images = train_images / 255.0, test_images / 255.0

# DNN 模型

In [14]:
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),  # 将28x28的图像展平为784个像素
    keras.layers.Dense(128, activation='relu'),   # 全连接层1，使用ReLU激活函数
    keras.layers.Dropout(0.2),                   # 防止过拟合，添加Dropout层
    keras.layers.Dense(10, activation='softmax') # 全连接层2，输出10个类别的概率分布
])

In [15]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [16]:
model.fit(train_images, train_labels, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7ababfcd6bc0>

In [17]:
test_loss, test_accuracy = model.evaluate(test_images, test_labels)
print("Test accuracy:", test_accuracy)

Test accuracy: 0.9796000123023987


In [18]:
predictions = model.predict(test_images)



# 2. CNN 模型

In [19]:
model2 = keras.Sequential([
    keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPooling2D((2, 2)),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

In [20]:
model2.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [21]:
model2.fit(train_images, train_labels, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7abaa00231f0>

In [22]:
test_loss, test_accuracy = model2.evaluate(test_images, test_labels)
print("Test accuracy:", test_accuracy)

Test accuracy: 0.9890999794006348


In [23]:
predictions = model2.predict(test_images)



In [26]:
predicted_labels = np.argmax(predictions, axis=1)

In [27]:
for i in range(10):
    print("True Label:", test_labels[i])
    print("Predicted Label:", predicted_labels[i])
    print()

True Label: 7
Predicted Label: 7

True Label: 2
Predicted Label: 2

True Label: 1
Predicted Label: 1

True Label: 0
Predicted Label: 0

True Label: 4
Predicted Label: 4

True Label: 1
Predicted Label: 1

True Label: 4
Predicted Label: 4

True Label: 9
Predicted Label: 9

True Label: 5
Predicted Label: 5

True Label: 9
Predicted Label: 9

