In [65]:
import tensorflow as tf

from tensorflow.keras.layers import Flatten, Dense, Conv2D
from tensorflow.keras import Model

In [66]:
#dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0

# add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [67]:
# model
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        """
        :param x: todo 是什么

        """
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)


model = MyModel()

In [68]:
# loss and optimizer
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

In [69]:
# metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [70]:
# fit
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_fn(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)

In [71]:
# evaluate
@tf.function
def test_step(images, labels):
    predictions = model(images)
    loss = loss_fn(labels, predictions)
    test_loss(loss)
    test_accuracy(labels, predictions)

In [72]:
EPOCHS = 5

for epoch in range(EPOCHS):
    # 在下一个epoch开始时，重置评估指标
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)
    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels)

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch + 1,
                          train_loss.result(),
                          train_accuracy.result() * 100,
                          test_loss.result(),
                          test_accuracy.result() * 100))

2021-08-09 17:26:42.221842: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-08-09 17:26:42.226110: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2021-08-09 17:26:42.226187: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-08-09 17:26:58.613549: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-08-09 17:27:01.069550: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


Epoch 1, Loss: 0.1395365297794342, Accuracy: 95.94833374023438, Test Loss: 0.0636647567152977, Test Accuracy: 97.91000366210938
Epoch 2, Loss: 0.04483456164598465, Accuracy: 98.61500549316406, Test Loss: 0.04852450639009476, Test Accuracy: 98.36000061035156
Epoch 3, Loss: 0.0247790589928627, Accuracy: 99.16666412353516, Test Loss: 0.04836534708738327, Test Accuracy: 98.5300064086914
Epoch 4, Loss: 0.013457763008773327, Accuracy: 99.57500457763672, Test Loss: 0.05110668018460274, Test Accuracy: 98.52000427246094
Epoch 5, Loss: 0.008318557403981686, Accuracy: 99.73833465576172, Test Loss: 0.06929241120815277, Test Accuracy: 98.2800064086914
