# Distributed Training Concepts with tf.distribute
- As datasets and models grow, single-device training becomes too slow.
- Production ML systems often train on multiple GPUs and systems (or clusters of systems)
- TensorFlow provides a unified API to support this: tf.distribute

Even if you only have a CPU today, you can still write code that scales tomorrow - we'll:
- Introduce data parallelism
- Use MirroredStrategy
- Train inside a distribution scope
- Compare batch size scaling
- Understand how distribution changes system behavior

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import time

print("TensorFlow version:", tf.__version__)

print("Available devices:")
for device in tf.config.list_physical_devices():
    print(" -", device)


TensorFlow version: 2.9.1
Available devices:
 - PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')


## In data parallel training:
- Each device gets a copy of the model
- Each device processes a different batch slice
- Gradients are averaged
- Weights are synchronized

Conceptually: Batch → Split → Parallel Compute → Aggregate Gradients → Update

TensorFlow handles this automatically.

In [2]:
# Initialize Strategy (On 1 device → 1 replica, on 2 GPUs → 2 replicas)

strategy = tf.distribute.MirroredStrategy()

print("Number of replicas in sync:", strategy.num_replicas_in_sync)

# Load Dataset

IMG_SIZE = 128
BATCH_SIZE = 32

(ds_train, ds_val), ds_info = tfds.load(
    "tf_flowers",
    split=["train[:80%]", "train[80%:]"],
    as_supervised=True,
    with_info=True
)

NUM_CLASSES = ds_info.features["label"].num_classes



INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Number of replicas in sync: 1


In [3]:
# Preprocess

def preprocess(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

ds_train = (
    ds_train
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

ds_val = (
    ds_val
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)


In [4]:
# Model WITHOUT Distribution (Baseline)

# Build Baseline Model

baseline_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(IMG_SIZE, IMG_SIZE, 3)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(NUM_CLASSES, activation="softmax")
])

baseline_model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

# Train Baseline

start = time.time()

baseline_model.fit(ds_train, epochs=2)

baseline_time = time.time() - start
print("Baseline training time:", baseline_time)


Epoch 1/2
Epoch 2/2
Baseline training time: 16.231674671173096


In [5]:
# Model WITH Distribution Strategy
# To do this, we build model inside strategy.scope()

# Build Distributed Model

with strategy.scope():
    dist_model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(IMG_SIZE, IMG_SIZE, 3)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(NUM_CLASSES, activation="softmax")
    ])

    dist_model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

# Train Distributed Model

start = time.time()

dist_model.fit(ds_train, epochs=2)

dist_time = time.time() - start
print("Distributed training time:", dist_time)


Epoch 1/2
Epoch 2/2
Distributed training time: 23.187735557556152


## Batch Size Scaling Concept

When using multiple replicas, effective global batch size = per_replica_batch_size × num_replicas

So if: Batch = 32 and Replicas = 2, then Global batch = 64

This often requires:
- Learning rate adjustment
- Monitoring convergence behavior

In [6]:
print("Replicas:", strategy.num_replicas_in_sync)
print("Per-replica batch size:", BATCH_SIZE)
print("Effective global batch size:", BATCH_SIZE * strategy.num_replicas_in_sync)


Replicas: 1
Per-replica batch size: 32
Effective global batch size: 32
