In [96]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import time

In [97]:
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

(60000, 784)
(60000,)
(10000, 784)
(10000,)


In [98]:
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

In [99]:
class MyModel(tf.keras.Model):
    def __init__(self, name=None, num_classes=10):
        super().__init__(name=name)

        self.dense_1 = tf.keras.layers.Dense(64, activation='relu')
        self.dense_2 = tf.keras.layers.Dense(64, activation='relu')
        self.dense_3 = tf.keras.layers.Dense(num_classes)

    def call(self, inputs):
        """前向传播"""
        x = self.dense_1(inputs)
        x = self.dense_2(x)
        x = self.dense_3(x)
        return x

In [100]:
model = MyModel(num_classes=10)

optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))

    print("Time taken: %.2fs" % (time.time() - start_time))


Start of epoch 0
Training loss (for one batch) at step 0: 104.1715
Seen so far: 64 samples
Training loss (for one batch) at step 200: 2.3563
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.4102
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.2952
Seen so far: 38464 samples
Training acc over epoch: 0.6892
Validation acc: 0.8110
Time taken: 6.23s

Start of epoch 1
Training loss (for one batch) at step 0: 0.7798
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.6245
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.6610
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.2747
Seen so far: 38464 samples
Training acc over epoch: 0.8138
Validation acc: 0.8434
Time taken: 6.26s


In [101]:
model = MyModel(num_classes=10)

optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()


@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)  # 前向传播
        loss_value = loss_fn(y, logits)  # 计算损失
    grads = tape.gradient(loss_value, model.trainable_weights)  # 梯度计算
    optimizer.apply_gradients(zip(grads, model.trainable_weights))  # 执行一次优化步骤
    train_acc_metric.update_state(y, logits)  # 计算评估指标
    return loss_value


@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)  # 前向传播
    val_acc_metric.update_state(y, val_logits)  # 计算评估指标


epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))

    print("Time taken: %.2fs" % (time.time() - start_time))

# 速度快得多


Start of epoch 0
Training loss (for one batch) at step 0: 112.7825
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.6522
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.6289
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.8838
Seen so far: 38464 samples
Training acc over epoch: 0.7054
Validation acc: 0.8164
Time taken: 1.71s

Start of epoch 1
Training loss (for one batch) at step 0: 0.5920
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.8249
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.3632
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.0091
Seen so far: 38464 samples
Training acc over epoch: 0.8366
Validation acc: 0.8572
Time taken: 1.34s
