In [1]:
import tensorflow as tf

import numpy as np
import os

print(tf.__version__)

2.2.0


In [2]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

train_images = train_images[..., None]
test_images = test_images[..., None]

train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [3]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [4]:
print(f"Number of devices: {strategy.num_replicas_in_sync}")

Number of devices: 1


In [5]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

In [6]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

In [7]:
def create_model():
  model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation="relu"),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64, 3, activation="relu"),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(10)                             
  ])

  return model

In [8]:
checkpoint_dir = "./training/checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

In [9]:
with strategy.scope():
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE
  )
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

In [10]:
with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name="test_loss")

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy")
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="test_accuracy")

In [11]:
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [12]:
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)

In [13]:
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  for x in test_dist_dataset:
    distributed_test_step(x)
  
  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)
  
  print(f"Epoch {epoch+1}, Loss: {train_loss}, Accuracy: {train_accuracy.result()*100}, Test Loss: {test_loss.result()}, Test Accuracy: {test_accuracy.result()*100}")

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1, Loss: 0.5160505771636963, Accuracy: 81.47333526611328, Test Loss: 0.40085989236831665, Test Accuracy: 85.12999725341797
INFO:tensorflow:Reduce to /job:

In [14]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="eval_accuracy")

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

In [15]:
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)

In [16]:
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels, in test_dataset:
  eval_step(images, labels)

print(f"Accuracy after restoring the saved model without strategy: {eval_accuracy.result()*100}")

Accuracy after restoring the saved model without strategy: 90.97000122070312


In [19]:
for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  print(f"Epoch {epoch+1}, Loss: {train_loss}, Accuracy: {train_accuracy.result()*100}")
  train_accuracy.reset_states()

Epoch 10, Loss: 0.1558409035205841, Accuracy: 96.09375
Epoch 10, Loss: 0.1558409035205841, Accuracy: 95.15625
Epoch 10, Loss: 0.1558409035205841, Accuracy: 93.59375
Epoch 10, Loss: 0.1558409035205841, Accuracy: 96.25
Epoch 10, Loss: 0.1558409035205841, Accuracy: 95.15625
Epoch 10, Loss: 0.1558409035205841, Accuracy: 93.75
Epoch 10, Loss: 0.1558409035205841, Accuracy: 94.21875
Epoch 10, Loss: 0.1558409035205841, Accuracy: 94.21875
Epoch 10, Loss: 0.1558409035205841, Accuracy: 95.46875
Epoch 10, Loss: 0.1558409035205841, Accuracy: 95.625


In [21]:
@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
        tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None
    )
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  print(f"Epoch {epoch+1}, Loss: {train_loss}, Accuracy: {train_accuracy.result()*100}")
  train_accuracy.reset_states()

Epoch 1, Loss: 0.14259059727191925, Accuracy: 94.74666595458984
Epoch 2, Loss: 0.13192428648471832, Accuracy: 95.11000061035156
Epoch 3, Loss: 0.11616019904613495, Accuracy: 95.67833709716797
Epoch 4, Loss: 0.10837100446224213, Accuracy: 95.99833679199219
Epoch 5, Loss: 0.10041636973619461, Accuracy: 96.27833557128906
Epoch 6, Loss: 0.0895528495311737, Accuracy: 96.69499969482422
Epoch 7, Loss: 0.08159716427326202, Accuracy: 97.0
Epoch 8, Loss: 0.07712753862142563, Accuracy: 97.1433334350586
Epoch 9, Loss: 0.07018831372261047, Accuracy: 97.40666961669922
Epoch 10, Loss: 0.06551890075206757, Accuracy: 97.51333618164062
