<a href="https://colab.research.google.com/github/kanru-wang/Distributed_Training_and_Quantization_Pruning_Distillation/blob/main/multi_gpu_mirrored_distributed_train_flower.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-GPU Mirrored Strategy Flower Image Classification

The code is useful for both single device and multi-device setup.

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import tensorflow_hub as hub

# Helper libraries
import numpy as np
import os
from tqdm import tqdm

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

## Data Preparation

In [2]:
splits = ['train[:80%]', 'train[80%:90%]', 'train[90%:]']

(train_examples, validation_examples, test_examples), info = tfds.load(
    'oxford_flowers102',
    with_info=True,
    as_supervised=True,
    split = splits,
    data_dir='data/'
)

num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes

Downloading and preparing dataset 328.90 MiB (download: 328.90 MiB, generated: 331.34 MiB, total: 660.25 MiB) to data/oxford_flowers102/2.1.1...
Dataset oxford_flowers102 downloaded and prepared to data/oxford_flowers102/2.1.1. Subsequent calls will reuse this data.


How does `tf.distribute.MirroredStrategy` strategy work?

*   All the variables and the model graph are replicated on the replicas.
*   Input is evenly distributed across the replicas.
*   Each replica calculates the loss and gradients for the input it received.
*   The gradients are synced across all the replicas by summing them.
*   After the sync, the same update is made to the copies of the variables on each replica.

In [3]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


Setup input pipeline

In [4]:
BUFFER_SIZE = num_examples
EPOCHS = 10
pixels = 224

MODULE_HANDLE='https://tfhub.dev/tensorflow/resnet_50/feature_vector/1'

IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

# Resizes the image and scales the pixel values to range from [0,1]
def format_image(image, label):
    image = tf.image.resize(image, IMAGE_SIZE) / 255.0
    return  image, label

Using https://tfhub.dev/tensorflow/resnet_50/feature_vector/1 with input size (224, 224)


Set the global batch size

In [5]:
def set_global_batch_size(batch_size_per_replica, strategy):
    '''
    Args:
        batch_size_per_replica (int) - batch size per replica
        strategy (tf.distribute.Strategy) - distribution strategy
    '''
    global_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
    return global_batch_size


BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = set_global_batch_size(BATCH_SIZE_PER_REPLICA, strategy)

Create the datasets using the global batch size and distribute the batches for training, validation and test batches

In [6]:
train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
test_batches = test_examples.map(format_image).batch(1)

Define the distributed datasets

In [7]:
def distribute_datasets(strategy, train_batches, validation_batches, test_batches):
    train_dist_dataset = strategy.experimental_distribute_dataset(train_batches)
    val_dist_dataset = strategy.experimental_distribute_dataset(validation_batches)
    test_dist_dataset = strategy.experimental_distribute_dataset(test_batches)
    return train_dist_dataset, val_dist_dataset, test_dist_dataset


train_dist_dataset, val_dist_dataset, test_dist_dataset = distribute_datasets(
    strategy,
    train_batches,
    validation_batches,
    test_batches
)

Each batch has 64 features and labels.

Take a look at a single batch from the train_dist_dataset:

In [8]:
x = iter(train_dist_dataset).get_next()

print(f"x is a tuple that contains {len(x)} values ")
print(f"x[0] contains the features, and has shape {x[0].shape}")
print(f"  so it has {x[0].shape[0]} examples in the batch, each is an image that is {x[0].shape[1:]}")
print(f"x[1] contains the labels, and has shape {x[1].shape}")

x is a tuple that contains 2 values 
x[0] contains the features, and has shape (64, 224, 224, 3)
  so it has 64 examples in the batch, each is an image that is (224, 224, 3)
x[1] contains the labels, and has shape (64,)


## Create the model

Use the Model Subclassing API to create model `ResNetModel` as a subclass of `tf.keras.Model`.

In [9]:
class ResNetModel(tf.keras.Model):
    def __init__(self, classes):
        super(ResNetModel, self).__init__()
        self._feature_extractor = hub.KerasLayer(MODULE_HANDLE,
                                                 trainable=False)
        self._classifier = tf.keras.layers.Dense(classes, activation='softmax')

    def call(self, inputs):
        x = self._feature_extractor(inputs)
        x = self._classifier(x)
        return x

Create a checkpoint directory to store the checkpoints (the model's weights during training).

In [10]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## Define the loss function

- `loss_object` for calculating the loss on the test set
- `compute_loss` for calculating the average loss on the training data

In [11]:
with strategy.scope():
    # Set reduction to `NONE` so we can do the reduction afterwards and divide by
    # the global batch size.
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        reduction=tf.keras.losses.Reduction.NONE
    )
    # or loss_fn = tf.keras.losses.sparse_categorical_crossentropy
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(
            per_example_loss,
            global_batch_size=GLOBAL_BATCH_SIZE
        )

    test_loss = tf.keras.metrics.Mean(name='test_loss')

## Define the metrics

Use `.result()` to get the accumulated statistics, e.g. `train_accuracy.result()`.

In [12]:
with strategy.scope():
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy'
    )
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='test_accuracy'
    )

## Instantiate the model, optimizer, and checkpoints

In [13]:
# Model and optimizer must be created under `strategy.scope`.
with strategy.scope():
    model = ResNetModel(classes=num_classes)
    optimizer = tf.keras.optimizers.Adam()
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

## Training loop

Define a regular training step and test step, which could work without a distributed strategy. Then use `strategy.run` to apply these functions in a distributed manner.

In [14]:
def train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        def train_step(inputs):
            images, labels = inputs
            with tf.GradientTape() as tape:
                predictions = model(images, training=True)
                loss = compute_loss(labels, predictions)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            train_accuracy.update_state(labels, predictions)
            return loss

        def test_step(inputs):
            images, labels = inputs
            predictions = model(images, training=False)
            t_loss = loss_object(labels, predictions)
            test_loss.update_state(t_loss)
            test_accuracy.update_state(labels, predictions)

        return train_step, test_step


train_step, test_step = train_test_step_fns(
    strategy, model, compute_loss, optimizer, train_accuracy, loss_object,
    test_loss, test_accuracy
)

## Distributed training and testing

In [15]:
def distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        @tf.function
        def distributed_train_step(dataset_inputs):
            per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
            return strategy.reduce(
                tf.distribute.ReduceOp.SUM, per_replica_losses,
                axis=None
            )

        @tf.function
        def distributed_test_step(dataset_inputs):
            return strategy.run(test_step, args=(dataset_inputs,))

        return distributed_train_step, distributed_test_step


distributed_train_step, distributed_test_step = distributed_train_test_step_fns(
    strategy, train_step, test_step, model, compute_loss, optimizer,
    train_accuracy, loss_object, test_loss, test_accuracy
)

## Training

The scaled loss is the return value of the `distributed_train_step`. This value is aggregated across replicas using the `tf.distribute.Strategy.reduce` call and then across batches by summing the return value of the `tf.distribute.Strategy.reduce` calls.

In [16]:
with strategy.scope():
    for epoch in range(EPOCHS):
        # TRAIN LOOP
        total_loss = 0.0
        num_batches = 0
        for x in tqdm(train_dist_dataset):
            total_loss += distributed_train_step(x)
            num_batches += 1
        train_loss = total_loss / num_batches

        # TEST LOOP
        for x in test_dist_dataset:
            distributed_test_step(x)

        template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}")
        print(
            template.format(
                epoch+1,
                train_loss,
                train_accuracy.result()*100,
                test_loss.result(),
                test_accuracy.result()*100
            )
        )

        test_loss.reset_states()
        train_accuracy.reset_states()
        test_accuracy.reset_states()

13it [00:21,  1.68s/it]


Epoch 1, Loss: 4.6524457931518555, Accuracy: 4.779411792755127, Test Loss: 3.83022403717041, Test Accuracy: 9.803921699523926


13it [00:05,  2.51it/s]


Epoch 2, Loss: 2.5824027061462402, Accuracy: 50.0, Test Loss: 2.7955212593078613, Test Accuracy: 40.19607925415039


13it [00:02,  4.81it/s]


Epoch 3, Loss: 1.4408522844314575, Accuracy: 80.88235473632812, Test Loss: 2.1607375144958496, Test Accuracy: 59.80392074584961


13it [00:02,  4.61it/s]


Epoch 4, Loss: 0.8598877191543579, Accuracy: 94.48529815673828, Test Loss: 1.8130377531051636, Test Accuracy: 62.74510192871094


13it [00:02,  4.50it/s]


Epoch 5, Loss: 0.5642362833023071, Accuracy: 97.05882263183594, Test Loss: 1.6388604640960693, Test Accuracy: 64.70588684082031


13it [00:02,  4.54it/s]


Epoch 6, Loss: 0.3913654088973999, Accuracy: 98.52941131591797, Test Loss: 1.4864661693572998, Test Accuracy: 66.66667175292969


13it [00:02,  4.55it/s]


Epoch 7, Loss: 0.2857825756072998, Accuracy: 99.26470184326172, Test Loss: 1.4010533094406128, Test Accuracy: 65.68627166748047


13it [00:02,  4.62it/s]


Epoch 8, Loss: 0.2221342921257019, Accuracy: 99.63235473632812, Test Loss: 1.3403692245483398, Test Accuracy: 65.68627166748047


13it [00:02,  4.56it/s]


Epoch 9, Loss: 0.1780683994293213, Accuracy: 99.75489807128906, Test Loss: 1.307769775390625, Test Accuracy: 65.68627166748047


13it [00:05,  2.51it/s]


Epoch 10, Loss: 0.14703938364982605, Accuracy: 99.87745666503906, Test Loss: 1.2656513452529907, Test Accuracy: 66.66667175292969
