In [1]:
import tensorflow as tf
import numpy as np

## Define the distribution strategy ##

In [2]:
strategy = tf.distribute.MirroredStrategy()
NUM_WORKERS = strategy.num_replicas_in_sync

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


## Download and Preprocess the Data ##

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train/np.float32(255), x_test/ np.float32(255) 

#add one dimension to the end of the data so that Conv2D can work with it. 
x_train = tf.expand_dims(x_train, axis=-1)
x_test = tf.expand_dims(x_test, axis=-1)

print(x_train.shape, x_test.shape)
print(y_train.shape, y_test.shape)

(60000, 28, 28, 1) (10000, 28, 28, 1)
(60000,) (10000,)


In [4]:
BUFFER_SIZE = len(x_train) // 10 
BATCH_SIZE_PER_WORKER = 64 
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_WORKER * NUM_WORKERS

In [5]:
#convert into tf.data.Datasets, shuffle and batch. 
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(GLOBAL_BATCH_SIZE)

**When working with multiple GPUs, it is important to cast the dataset into a disributed dataset. This way, Tensorflow can perform data parallelism with this dataset on multiple workers.**

In [6]:
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

## Define Simple Sequential Model (Nothing New Here) ##  

In [7]:
#This function has to be called from within the scope of the strategy object.

def build_model():
  model = tf.keras.Sequential([
                             tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation="relu", padding="same", input_shape=(28,28,1)),
                             tf.keras.layers.MaxPooling2D(2), 
                             tf.keras.layers.Flatten(), 
                             tf.keras.layers.Dense(128, activation="relu"), 
                             tf.keras.layers.Dense(10, activation="softmax")
  ])
  return model 

## Custom Training Loop ## 

-instead of simply calling model.compile(...)

In [32]:
with strategy.scope():
  #for sake of in-depth learning, I will manually reduce the losses (sum) across all workers myself. 
  #if I didn't pass this parameter to the loss object, it would have done it for me automatically.
  criterion = tf.keras.losses.SparseCategoricalCrossentropy(reduction= tf.keras.losses.Reduction.NONE) 

  def compute_loss(y_true, y_pred):
    # since explicitly told it not to reduce, it will return a loss from each example. 
    per_example_loss = criterion(y_true, y_pred) 

    #will sum up the losses from each example.
    
    '''
    WHY ARE WE DOING THIS THOUGH? 
     -- BECAUSE WE NEED TO SCALE THESE PER-EXAMPLE LOSSES BY GLOBAL_BATCH_SIZE SO THAT WHEN distributed_train_step ADDS THEM UP,
     THEY WILL BE SCALED BY THE GLOBAL BATCH SIZE, WHICH IS THE TOTAL NUMBER OF EXAMPLES, PER BATCH, ACROSS ALL WORKERS.

    IF WE DIDN'T DO THIS:
    --Then by default, this loss would be scaled by the number of examples that was sent to a given worker, but these numbers may not
    be equal, so that the loss would be pulled towards the worker whose batch size is the smallest.
    '''
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    #Reduce by getting the average of the losses.  
    test_loss = tf.keras.metrics.Mean(name="test_loss")

    #Metrics. 
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy") 
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="test_accuracy")

    #Optimizer
    optimizer = tf.keras.optimizers.Adam()

    #Create the model within the scope.
    model = build_model()

## Train and Test Step Functions ##

-Notice that these functions are decorated with @tf.function. This way, autograd will generate graph mode code for them and they'll be ran on graph mode, potentially providing dramatic performance increase. 

In [33]:
'''
Notice that this function is doing nothing but calling the train_step function in the strategy.run()
function. strategy.run() allows for distributing a workload according to a given strategy (of which 'strategy' is an object)
'''

@tf.function
def distributed_train_step(dataset_inputs):
  #using strategy.run() function to distribute the execution of a given function according to the strategy.
  per_replica_losses = strategy.run(train_step, args = (dataset_inputs, )) 
  #tf.print(per_replica_losses)
  
  #using strategy.reduce(), we will be adding up the loss from each 
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)


'''
Notice that there's no need to decorate this function with @tf.function.
This is because it will be called in a function that is decorated (right above).
Thus, train_step itself will be executed in graph mode.
'''
def train_step(inputs, LR=0.001): 
  images, labels = inputs 
  with tf.GradientTape() as tape:
    y_preds = model(images)
    batch_loss = compute_loss(labels, y_preds)
    
  gradients = tape.gradient(batch_loss, model.trainable_weights)
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))
  train_accuracy.update_state(labels, y_preds)
  return batch_loss 


@tf.function
def distributed_test_step(dataset_inputs):
  per_replica_losses = strategy.run(test_step, args=(dataset_inputs,))


def test_step(inputs):
  images, labels = inputs 

  ''' make damn sure that you're setting training to False here.
  otherwise, layers like Dropout and BatchNorm that should only be active during training 
  might be active during testing, contaminating the evaluation. ''' 

  y_preds = model(images, training=False) 
  loss = compute_loss(labels, y_preds)
  test_loss.update_state(loss)
  test_accuracy.update_state(labels, y_preds)




## Finally, Training Loop ## 

In [None]:
EPOCHS =10 
template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, " "Test Accuracy: {}")



for epoch in range(EPOCHS):
  #training 
  epoch_loss = 0.0
  step = 0 
  #step = number of batches in the distributed dataset.
  for batch in train_dist_dataset:
    epoch_loss += distributed_train_step(batch)
    step += 1
  train_loss = epoch_loss / step

  #testing 
  for batch in test_dist_dataset:
    distributed_test_step(batch)

  print(template.format(epoch+1, train_loss, train_accuracy.result()*100, test_loss.result(), test_accuracy.result()*100))
  
  #reset the metrics' accumulators.
  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()

Epoch 1, Loss: 0.006500512361526489, Accuracy: 99.59720611572266, Test Loss: 0.49756261706352234, Test Accuracy: 91.91999816894531
Epoch 2, Loss: 0.003369492245838046, Accuracy: 99.9566650390625, Test Loss: 0.5027914643287659, Test Accuracy: 92.11000061035156
Epoch 3, Loss: 0.00244527799077332, Accuracy: 99.98333740234375, Test Loss: 0.5107318162918091, Test Accuracy: 92.20999908447266
Epoch 4, Loss: 0.0019254875369369984, Accuracy: 99.99500274658203, Test Loss: 0.5217400789260864, Test Accuracy: 92.25
