# Distributed TensorFlow: Scale-up your model training
## Jeongkyu Shin

Prepared for ML GDE talk. (2021 Apr.)

## Jeongkyu Shin

* Lablup Inc.: Make AI Accessible
     * Making Backend.AI.
     * Incredibly convenient ```open source``` machine learning cluster platform

* Google Developers Experts
     * ML / DL GDE (Context retrieval)
   
* I like to play with open source. For a long time.
     * Hobby becomes Research, Research becomes Job.


![image](slides/Picture2.png)

![image](slides_en/Slide3.png)

![image](slides_en/Slide10.png)

![image](slides_en/Slide11.png)

# Super simple example

Below is a very, very simple MNIST example.

Today, we're going to refine this example and look at the process of converting it to distributed training code.

First, let's build a basic data load and model.


In [2]:
%%writefile mnist.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
  x_train, x_test = x_train / np.float32(255), x_test / np.float32(255)
  y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)

  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  test_dataset = tf.data.Dataset.from_tensor_slices(
      (x_test, y_test)).batch(batch_size)
  return train_dataset, test_dataset

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ]) 
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model

Overwriting mnist.py


# Then, let's train simplest fashion MNIST model


In [3]:
import os
import json

import tensorflow as tf
import mnist

batch_size = 64

## GPU allocation part (important!)
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)


train_dataset, test_dataset = mnist.mnist_dataset(batch_size)

model = mnist.build_and_compile_cnn_model()
model.fit(train_dataset, epochs=3, steps_per_epoch=70)

eval_loss, eval_acc = model.evaluate(test_dataset)

Epoch 1/3
Epoch 2/3
Epoch 3/3


We ran the basic Fashion MNIST model through tf.keras.

Let's dive into distributed training world!

Why distributed training / processing?

# Because training talks sooooo long time

* BERT
* T5
* GPT-3



![image](slides/Slide7.png)

![image](slides/Slide8.png)

![image](slides/Slide9.png)

![image](slides_en/Slide15.png)

 * Training needs to be faster.
    * The more distributing, the faster training

 * Less extra effort is required.
    * Code modifications should be minimized.

 * It should be reproducible.
    * System dependencies should be low.



![image](slides_en/Slide16.png)

![image](slides/Slide10.png)

## Support for distributed processing of TensorFlow

* Strategy-based support (1.12~)
   * Distributed processing is supported by calling a predefined strategy.

* tf.data based pipeline parallelization
   * Data pipeline parallelization
   * Useful when there are traffic restrictions by data source

* Mesh TensorFlow
   * In case of distributed model training with multiple nodes
 

## tf.distribute.strategy

* goal
   * Convenient use
   * Powerful performance
   * Easy strategy replacement

* General decentralization policy
   * MirroredStrategy
   * CentralStorageStrategy

* Special case
   * OneDeviceStrategy
   * TPUStrategy
   * ParameterServerStrategy



## Tutorial: distributed processing on multi-node / multi-GPU systems


* MirroredStrategy
    * TensorFlow's default distributed processing policy
    * All-reduce based
    * Optimized for multi-GPU
    * “Simple is the best”

* MultiWorkerMirroredStrategy
    * Apply MirrorStrategy in a multi-node environment
    * It's a little slow, but simple is the best!


![image](slides/Picture1.png)

## Time is running out

* We'll start right away with MultiWorkerMirroredStrategy.

## MultiWorkerMirroredStrategy  
  
* Multi-node version of MirroredStrategy
   * Internode communication becomes a bottleneck.
   * Still, if the model is large, it will accelerate.
   * It's hard to see the enhancement in MNIST, but it's obvious from ResNet.
  
* Implementation
   * One of the nodes becomes the master node.
   * Manage node information through cluster_resolver.
   * gRPC, NCCL are used for communication.

## MultiWorkerMirroredStrategy code execution

* Action
    * Determine node role based on the value set in TF_CONFIG in advance.
    * Starting with TensorFlow 2, you can designate all nodes as Workers. (In the past, a separate master was designated and should be pointed out.)
    * If all nodes are designated as Workers, the first worker becomes the master node.
   
* Master node
    * The master node waits for all worker nodes to contact you in the Strategy Scope of MultiWorker MirroredStrategy.
* Worker node
    * Worker nodes register themselves to the master node in the scope during execution.

In [4]:
%%writefile simple_worker_mnist.py
import os
import json

import tensorflow as tf
import mnist

per_worker_batch_size = 64

## This is a necessary part only for manual cluster setup.
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers ## 수동 클러스터 설정시 필요한 부분입니다.
multi_worker_dataset, test_dataset = mnist.mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
    multi_worker_model = mnist.build_and_compile_cnn_model()
    multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
eval_loss, eval_acc = multi_worker_model.evaluate(test_dataset)

Overwriting simple_worker_mnist.py


# Create practical code for actually use

* So far, these are some example codes.
* Now let's look at the actual code.
* Below is the code to put in a typical pipeline.


In [5]:
%%writefile practical_mnist.py
import tensorflow_datasets as tfds
import tensorflow as tf

import os

## Initialize GPUs to be incremental memory reference mode
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)

## Set the multi worker mirrored strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()

## Data load
datasets, info = tfds.load(name='fashion_mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

#multi_worker_dataset = mnist.mnist_dataset(global_batch_size)

## Set  input pipeline
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
print('\nNumber of replicas in sync: {}'.format(strategy.num_replicas_in_sync))

### Normalize
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255
  return image, label

train_dataset = mnist_train.map(scale).take(num_train_examples).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
train_dataset.options().experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
eval_dataset.options().experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

## Setup functions
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

# Callback for printing the Learning Rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
callbacks = [
  tf.keras.callbacks.TensorBoard(log_dir='/home/work/logs'),
  tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                    save_weights_only=True),
  tf.keras.callbacks.LearningRateScheduler(decay),
  PrintLR()
]
model.fit(train_dataset, epochs=12, callbacks=callbacks)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
# Model saving
path = 'saved_model/'
model.save(path, save_format='tf')

Overwriting practical_mnist.py


```python
%%writefile practical_mnist.py
import tensorflow_datasets as tfds
import tensorflow as tf

import os

## Initialize GPUs to be incremental memory reference mode
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)

## Set the multi worker mirrored strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()

## Data load
datasets, info = tfds.load(name='fashion_mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

#multi_worker_dataset = mnist.mnist_dataset(global_batch_size)

## Set  input pipeline
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
print('\nNumber of replicas in sync: {}'.format(strategy.num_replicas_in_sync))

### Normalize
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255
  return image, label

train_dataset = mnist_train.map(scale).take(num_train_examples).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
train_dataset.options().experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
eval_dataset.options().experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

## Setup functions
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

# Callback for printing the Learning Rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
callbacks = [
  tf.keras.callbacks.TensorBoard(log_dir='/home/work/logs'),
  tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                    save_weights_only=True),
  tf.keras.callbacks.LearningRateScheduler(decay),
  PrintLR()
]
model.fit(train_dataset, epochs=12, callbacks=callbacks)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
# Model saving
path = 'saved_model/'
model.save(path, save_format='tf')
```

# Pipeline

More serious for distributed training. Very simple pipeline from data download to training to serving.

In [None]:
# This part saves checkpoints
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

In [None]:
# Change learning rate dynamically.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

In [None]:
# Load checkpoint and check the accuracy,
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

# Save mode as savedmodel format.
path = 'saved_model/'
model.save(path, save_format='tf')

# Note: Replica size

* When determining the batch size, you must pass the desired batch size multiplied by the number of workers.
    * Otherwise, each node may be trained with overlapping datasets, and due to the nature of the all-reduce algorithm, you may see side effects such as overfitting.

```python
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
print('\nNumber of replicas in sync: {}'.format(strategy.num_replicas_in_sync))
```

# Guides

* https://www.tensorflow.org/tutorials/distribute/custom_training
* https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
* https://www.tensorflow.org/tutorials/distribute/input

    

# Thank you.

It was so fast. Was it?

 * inureyes@gmail.com
 * Facebook: jeongkyu.shin
 * GitHub: inureyes

 * For more information, 
   * Lablup Inc: https://www.lablup.com
   * Backend.AI: https://www.backend.ai
   * GitHub    : https://github.com/lablup/backend.ai
   * Backend.AI Cloud: https://cloud.backend.ai
