In [1]:
import os

import numpy as np
import tensorflow as tf

print(tf.__version__)

2.13.0


In [2]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

In [3]:
train_images = train_images[..., None]  #60000，28，28 --》 60000，28，28，1
test_images = test_images[..., None]

In [4]:
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

In [5]:
strategy = tf.distribute.MirroredStrategy()
print('Number of devices:{}'.format(strategy.num_replicas_in_sync))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Number of devices:1


In [6]:
#设置输入管道
BUFFER_SIZE = len(train_images)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10

In [7]:
#创建数据集，并进行分发
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(
    GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

In [8]:
#创建模型
def create_model():
    regularizer = tf.keras.regularizers.L2(1e-5)  #正则化器
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3,
                               activation='relu',
                               kernel_regularizer=regularizer),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(64, 3,
                               activation='relu',
                               kernel_regularizer=regularizer),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64,
                              activation='relu',
                              kernel_regularizer=regularizer),
        tf.keras.layers.Dense(10, kernel_regularizer=regularizer)
    ])
    return model

In [9]:
#create a checkpoint directory to store the checkpoints
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')

In [10]:
#定义损失函数
with strategy.scope():
    # Set reduction to `NONE` so you can do the reduction yourself.
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True,
        reduction=tf.keras.losses.Reduction.NONE)


    #定义计算损失函数
    def compute_loss(labels, predictions, model_losses):
        per_example_loss = loss_object(labels, predictions)
        loss = tf.nn.compute_average_loss(per_example_loss)
        if model_losses:
            loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))
        return loss

In [11]:
#定义跟踪损失以及准确度指标，定义三个指标，test_loss,train_acc,test_acc
with strategy.scope():
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='test_accuracy')

In [12]:
#循环训练
# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
    #创建模型
    model = create_model()
    #创建优化函数
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    #创建checkpoint回调函数
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [13]:
def train_step(inputs):
    images, labels = inputs

    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = compute_loss(labels, predictions, model.losses)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_accuracy.update_state(labels, predictions)
    return loss


def test_step(inputs):
    images, labels = inputs

    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)

In [0]:
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                           axis=None)

In [14]:
@tf.function
def distributed_test_step(dataset_inputs):
    return strategy.run(test_step, args=(dataset_inputs,))


for epoch in range(EPOCHS):
    # TRAIN LOOP
    total_loss = 0.0
    num_batches = 0
    for x in train_dist_dataset:
        total_loss += distributed_train_step(x)#损失计算函数
        num_batches += 1
    train_loss = total_loss / num_batches

    # TEST LOOP
    for x in test_dist_dataset:
        distributed_test_step(x)

    if epoch % 2 == 0:
        checkpoint.save(checkpoint_prefix)

    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                "Test Accuracy: {}")
    print(template.format(epoch + 1, train_loss,
                          train_accuracy.result() * 100, test_loss.result(),
                          test_accuracy.result() * 100))

    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()

Epoch 1, Loss: 0.5092597603797913, Accuracy: 81.71333312988281, Test Loss: 0.3729873597621918, Test Accuracy: 86.44000244140625
Epoch 2, Loss: 0.34039872884750366, Accuracy: 87.66999816894531, Test Loss: 0.32193049788475037, Test Accuracy: 88.37000274658203
Epoch 3, Loss: 0.29635900259017944, Accuracy: 89.38833618164062, Test Loss: 0.3093103766441345, Test Accuracy: 89.16000366210938
Epoch 4, Loss: 0.2683040201663971, Accuracy: 90.45333099365234, Test Loss: 0.2826783359050751, Test Accuracy: 89.84000396728516
Epoch 5, Loss: 0.24595151841640472, Accuracy: 91.22666931152344, Test Loss: 0.28471988439559937, Test Accuracy: 89.45999908447266
Epoch 6, Loss: 0.22732500731945038, Accuracy: 91.90666961669922, Test Loss: 0.2745018005371094, Test Accuracy: 89.99000549316406
Epoch 7, Loss: 0.21033954620361328, Accuracy: 92.61500549316406, Test Loss: 0.25763633847236633, Test Accuracy: 90.55999755859375
Epoch 8, Loss: 0.19666792452335358, Accuracy: 93.08999633789062, Test Loss: 0.25555938482284546,