# <font color="#418FDE" size="6.5" uppercase>**Implementing Strategies**</font>

>Last update: 20260126.
    
By the end of this Lecture, you will be able to:
- Configure and run model.fit under a tf.distribute strategy for multi-GPU or multi-worker setups. 
- Implement a custom training loop that uses strategy.run and distributed datasets. 
- Troubleshoot common distributed training errors related to shapes, batch sizes, and variable placement. 


## **1. Using fit with Strategies**

### **1.1. Working With strategy scope**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master TensorFlow 2.20.0/Module_08/Lecture_B/image_01_01.jpg?v=1769459157" width="250">



>* Strategy scope controls model and variable creation
>* Enables automatic multi-device placement using model.fit

>* Create strategy, enter scope, then build components
>* Building outside scope breaks distribution and devices

>* Strategy scope controls variable placement and communication
>* Ensures consistent training, checkpoints, and scaling behavior



In [None]:
#@title Python Code - Working With strategy scope

# This script shows using strategy scope correctly.
# It demonstrates model creation inside strategy scope.
# It keeps training simple with minimal printed output.

# Install TensorFlow if needed in some environments.
# pip install tensorflow==2.20.0.

# Import required standard libraries.
import os
import random
import numpy as np

# Import TensorFlow and Keras utilities.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Set deterministic seeds for reproducibility.
seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

# Print TensorFlow version in one short line.
print("TensorFlow version:", tf.__version__)

# Detect available GPUs for potential distribution.
physical_gpus = tf.config.list_physical_devices("GPU")
print("GPUs detected:", len(physical_gpus))

# Choose a simple strategy based on available devices.
if len(physical_gpus) > 1:
    strategy = tf.distribute.MirroredStrategy()
else:
    strategy = tf.distribute.get_strategy()

# Print strategy type to confirm selection.
print("Using strategy:", type(strategy).__name__)

# Load a small subset of MNIST dataset.
(x_train, y_train), _ = keras.datasets.mnist.load_data()

# Reduce dataset size for quick demonstration.
x_train = x_train[:4000]
y_train = y_train[:4000]

# Normalize images and add channel dimension.
x_train = x_train.astype("float32") / 255.0
x_train = np.expand_dims(x_train, axis=-1)

# Validate shapes before building the model.
print("Train shape:", x_train.shape, "Labels:", y_train.shape)

# Define global batch size for distributed training.
num_replicas = strategy.num_replicas_in_sync
global_batch_size = 128 * max(1, num_replicas)

# Create a tf.data dataset with batching.
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(4000, seed=seed_value)
train_ds = train_ds.batch(global_batch_size)

# Prefetch for better pipeline performance.
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

# Enter strategy scope before creating model and optimizer.
with strategy.scope():
    # Define a simple sequential CNN model.
    model = keras.Sequential([
        layers.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, 3, activation="relu"),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(32, activation="relu"),
        layers.Dense(10, activation="softmax"),
    ])

    # Compile model with optimizer and loss.
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

# Show that variables are created under the strategy.
print("Replicas in sync:", strategy.num_replicas_in_sync)

# Train the model briefly with silent logs.
history = model.fit(
    train_ds,
    epochs=2,
    verbose=0,
)

# Extract final loss and accuracy from history.
final_loss = history.history["loss"][-1]
final_acc = history.history["accuracy"][-1]

# Print concise training results for inspection.
print("Final loss:", round(float(final_loss), 4))
print("Final accuracy:", round(float(final_acc), 4))

# Confirm that training finished without distribution errors.
print("Training completed successfully under strategy scope.")



### **1.2. Global vs per replica**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master TensorFlow 2.20.0/Module_08/Lecture_B/image_01_02.jpg?v=1769459244" width="250">



>* Per replica values come from each individual device
>* Global values aggregate results across all replicas

>* Global batch splits into per-replica mini-batches
>* Metrics aggregate per-replica results; misreading causes confusion

>* Design datasets using global, not per-replica batches
>* Know which tensors are local versus globally synced



In [None]:
#@title Python Code - Global vs per replica

# This script shows global versus per replica concepts.
# It uses TensorFlow distribution strategies with small data.
# Run in Colab to explore multi device batch behavior.

# !pip install tensorflow==2.20.0.

# Import required modules safely.
import os
import random
import numpy as np
import tensorflow as tf

# Set deterministic seeds for reproducibility.
seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

# Print TensorFlow version in one short line.
print("TensorFlow version:", tf.__version__)

# Choose a distribution strategy based on available devices.
if len(tf.config.list_logical_devices("GPU")) > 1:
    strategy = tf.distribute.MirroredStrategy()
else:
    strategy = tf.distribute.OneDeviceStrategy("/cpu:0")

# Show how many replicas are in sync.
num_replicas = strategy.num_replicas_in_sync
print("Replicas in sync:", num_replicas)

# Define a small global batch size for the example.
global_batch_size = 8
per_replica_batch_size = max(global_batch_size // num_replicas, 1)
print("Global batch size:", global_batch_size)

# Print the computed per replica batch size.
print("Per replica batch size:", per_replica_batch_size)

# Create a tiny synthetic dataset with known size.
num_samples = 32
features = np.random.randn(num_samples, 4).astype("float32")
labels = np.random.randint(0, 2, size=(num_samples, 1)).astype("float32")

# Build a tf.data.Dataset with global batch size.
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.batch(global_batch_size, drop_remainder=True)

# Distribute the dataset using the chosen strategy.
dist_ds = strategy.experimental_distribute_dataset(ds)

# Define a simple model building function.
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(4,)),
        tf.keras.layers.Dense(4, activation="relu"),
        tf.keras.layers.Dense(1, activation="sigmoid"),
    ])
    return model

# Build and compile the model inside strategy scope.
with strategy.scope():
    model = create_model()
    model.compile(
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )

# Run a single epoch with silent training logs.
history = model.fit(
    dist_ds,
    epochs=1,
    steps_per_epoch=2,
    verbose=0,
)

# Print global metrics from model.fit history.
print("Global loss after epoch:", float(history.history["loss"][0]))
print("Global accuracy after epoch:", float(history.history["accuracy"][0]))

# Take one distributed batch to inspect shapes.
for batch_features, batch_labels in iter(dist_ds):
    per_replica_x = batch_features
    per_replica_y = batch_labels
    break

# Show the type of per replica objects.
print("Type of per_replica_x:", type(per_replica_x).__name__)

# Define a function to inspect per replica shapes.
def inspect_per_replica(tensor_per_replica, name):
    if isinstance(tensor_per_replica, tf.distribute.DistributedValues):
        shapes = []
        for replica_id in range(num_replicas):
            value = tensor_per_replica.values[replica_id]
            shapes.append(tuple(value.shape.as_list()))
        print(name, "per replica shapes:", shapes)
    else:
        print(name, "shape:", tuple(tensor_per_replica.shape.as_list()))

# Call the inspection function for features and labels.
inspect_per_replica(per_replica_x, "Features")
inspect_per_replica(per_replica_y, "Labels")

# Define a simple step that returns per replica loss.
loss_obj = tf.keras.losses.BinaryCrossentropy(
    reduction=tf.keras.losses.Reduction.NONE
)

# Create a distributed train step using strategy.run.
@tf.function
def distributed_step(inputs):
    def replica_step(x, y):
        logits = model(x, training=True)
        per_example_loss = loss_obj(y, logits)
        return per_example_loss

    per_replica_losses = strategy.run(replica_step, args=inputs)
    return per_replica_losses

# Prepare one batch tuple for the distributed step.
inputs = (per_replica_x, per_replica_y)
per_replica_losses = distributed_step(inputs)

# Inspect per replica loss shapes before reduction.
inspect_per_replica(per_replica_losses, "Per example loss")

# Reduce per replica losses to a single global mean.
global_loss = strategy.reduce(
    tf.distribute.ReduceOp.MEAN,
    per_replica_losses,
    axis=None,
)

# Print the final global loss value for this batch.
print("Global mean loss for one batch:", float(global_loss.numpy().mean()))



### **1.3. Distributed Checkpointing Basics**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master TensorFlow 2.20.0/Module_08/Lecture_B/image_01_03.jpg?v=1769459380" width="250">



>* Distributed checkpointing coordinates state across all workers
>* Strategy syncs model, optimizer, schedules, and variables

>* Checkpoints store logical model state, not replicas
>* Same checkpoint works across changing devices and workers

>* Choose smart save intervals and responsible writer
>* Use portable checkpoints to build robust pipelines



In [None]:
#@title Python Code - Distributed Checkpointing Basics

# This script shows distributed checkpointing basics.
# It uses MirroredStrategy with model.fit safely.
# It keeps output short and beginner friendly.

# !pip install tensorflow==2.20.0.

# Import required modules from TensorFlow.
import os
import pathlib
import numpy as np
import tensorflow as tf

# Set deterministic seeds for reproducibility.
np.random.seed(7)
tf.random.set_seed(7)

# Print TensorFlow version in one short line.
print("TensorFlow version:", tf.__version__)

# Detect available GPUs for potential distribution.
physical_gpus = tf.config.list_physical_devices("GPU")
print("GPUs detected:", len(physical_gpus))

# Choose strategy based on available GPUs.
if len(physical_gpus) > 1:
    strategy = tf.distribute.MirroredStrategy()
else:
    strategy = tf.distribute.get_strategy()

# Show which strategy class is being used.
print("Using strategy:", strategy.__class__.__name__)

# Create a small directory for checkpoints.
base_dir = pathlib.Path("distributed_ckpt_demo")
base_dir.mkdir(exist_ok=True)
ckpt_dir = base_dir / "checkpoints"
ckpt_dir.mkdir(exist_ok=True)

# Prepare a tiny synthetic dataset for classification.
num_samples = 256
num_features = 20
num_classes = 3

# Create random features and integer labels.
features = np.random.randn(num_samples, num_features).astype("float32")
labels = np.random.randint(num_classes, size=(num_samples,)).astype("int32")

# Validate shapes before building the dataset.
assert features.shape[0] == labels.shape[0]
assert features.shape[1] == num_features

# Build a tf.data.Dataset with small batch size.
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

# Shuffle and batch the dataset deterministically.
dataset = dataset.shuffle(buffer_size=num_samples, seed=7)
dataset = dataset.batch(batch_size)

# Define a simple model inside the strategy scope.
with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(num_features,)),
        tf.keras.layers.Dense(16, activation="relu"),
        tf.keras.layers.Dense(num_classes, activation="softmax"),
    ])

# Compile the model with optimizer and loss.
with strategy.scope():
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=["accuracy"],
    )

# Create a checkpoint callback that saves every epoch.
ckpt_path = str(ckpt_dir / "weights.epoch{epoch:02d}.keras")
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=ckpt_path,
    save_weights_only=False,
    save_freq="epoch",
    monitor="loss",
    save_best_only=False,
    verbose=0,
)

# Train briefly with model.fit under the strategy.
history = model.fit(
    dataset,
    epochs=2,
    verbose=0,
    callbacks=[checkpoint_cb],
)

# List checkpoint files to show what was saved.
ckpt_files = sorted(ckpt_dir.glob("*.keras"))
print("Checkpoint files saved:")
for path in ckpt_files:
    print("-", path.name)

# Load the last checkpoint into a new model instance.
with strategy.scope():
    restored_model = tf.keras.models.load_model(ckpt_files[-1])

# Evaluate original and restored models on one batch.
for batch_features, batch_labels in dataset.take(1):
    original_eval = model.evaluate(
        batch_features,
        batch_labels,
        verbose=0,
    )
    restored_eval = restored_model.evaluate(
        batch_features,
        batch_labels,
        verbose=0,
    )

# Print a short comparison of evaluation results.
print("Original model loss, accuracy:", original_eval)
print("Restored model loss, accuracy:", restored_eval)




## **2. Custom Distributed Loops**

### **2.1. Writing Strategy Run Steps**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master TensorFlow 2.20.0/Module_08/Lecture_B/image_02_01.jpg?v=1769459468" width="250">



>* Step function runs per replica, processing batch slice
>* You define computation; strategy runs and aggregates updates

>* Put per-replica forward pass and loss inside
>* Share variables, then compute, reduce, apply gradients

>* Make step functions efficient, deterministic, and side‑effect free
>* Scale losses consistently so aggregated gradients train stably



In [None]:
#@title Python Code - Writing Strategy Run Steps

# This script shows a simple distributed training step.
# It focuses on strategy.run with a custom step.
# Use it to understand replica step functions.

# !pip install tensorflow==2.20.0.

# Import required modules from TensorFlow.
import tensorflow as tf

# Set deterministic seeds for reproducibility.
tf.random.set_seed(7)

# Detect available GPUs and choose a strategy.
physical_gpus = tf.config.list_physical_devices("GPU")

# Select MirroredStrategy for multi GPU or fallback CPU.
if physical_gpus:
    strategy = tf.distribute.MirroredStrategy()
else:
    strategy = tf.distribute.OneDeviceStrategy("/cpu:0")

# Print TensorFlow version and strategy type.
print("TF", tf.__version__, "Strategy", type(strategy).__name__)

# Create a tiny synthetic dataset for demonstration.
features = tf.random.normal(shape=(64, 4))

# Create simple labels as a linear function.
labels = tf.reduce_sum(features, axis=1, keepdims=True)

# Build a tf.data.Dataset from tensors.
base_ds = tf.data.Dataset.from_tensor_slices((features, labels))

# Batch the dataset with a small global batch size.
base_ds = base_ds.batch(8, drop_remainder=True)

# Distribute the dataset using the chosen strategy.
train_ds = strategy.experimental_distribute_dataset(base_ds)

# Define global batch size for loss scaling.
GLOBAL_BATCH_SIZE = 8

# Create model and optimizer inside strategy scope.
with strategy.scope():
    # Build a tiny sequential regression model.
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(8, activation="relu", input_shape=(4,)),
        tf.keras.layers.Dense(1)
    ])

    # Use a simple optimizer for gradient updates.
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)

    # Define a mean squared error loss object.
    loss_obj = tf.keras.losses.MeanSquaredError(
        reduction=tf.keras.losses.Reduction.NONE
    )

# Define a function to compute per replica loss.
def compute_loss(labels, predictions):
    # Compute unscaled per example loss values.
    per_example_loss = loss_obj(labels, predictions)

    # Scale loss by global batch size for correctness.
    return tf.nn.compute_average_loss(
        per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE
    )

# Define one training step to run on each replica.
@tf.function
def train_step(dist_inputs):
    # Unpack distributed features and labels.
    dist_features, dist_labels = dist_inputs

    # Record operations for automatic differentiation.
    with tf.GradientTape() as tape:
        # Forward pass through the model on each replica.
        predictions = model(dist_features, training=True)

        # Compute scaled loss for this replica.
        loss = compute_loss(dist_labels, predictions)

    # Compute gradients of loss with respect to variables.
    gradients = tape.gradient(loss, model.trainable_variables)

    # Apply gradients to update shared model weights.
    optimizer.apply_gradients(
        zip(gradients, model.trainable_variables)
    )

    # Return the replica loss for later reduction.
    return loss

# Define a function that calls strategy.run on the step.
@tf.function
def distributed_train_step(dist_inputs):
    # Run train_step on each replica in parallel.
    per_replica_losses = strategy.run(train_step, args=(dist_inputs,))

    # Reduce losses to get a single scalar value.
    return strategy.reduce(
        tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None
    )

# Run a few epochs over the small dataset.
for epoch in range(3):
    # Initialize metric to track average loss.
    epoch_loss = 0.0

    # Initialize counter for number of batches.
    num_batches = 0

    # Iterate over distributed batches from dataset.
    for batch in train_ds:
        # Call the distributed training step.
        loss_value = distributed_train_step(batch)

        # Accumulate loss and batch count.
        epoch_loss += loss_value
        num_batches += 1

    # Compute mean loss for this epoch.
    mean_loss = epoch_loss / tf.cast(num_batches, tf.float32)

    # Print a short summary line for the epoch.
    print("Epoch", epoch, "mean loss", float(mean_loss))

# Run one final batch to inspect shapes and loss.
for batch in iter(train_ds):
    final_loss = distributed_train_step(batch)
    break

# Print final loss value to confirm training behavior.
print("Final distributed step loss", float(final_loss))



### **2.2. Building Distributed Datasets**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master TensorFlow 2.20.0/Module_08/Lecture_B/image_02_02.jpg?v=1769459541" width="250">



>* Think in global batches split across replicas
>* Design pipelines for consistent, non-overlapping replica data

>* Build a simple pipeline that outputs global batches
>* Strategy splits batches so replicas get balanced shards

>* Shard data so each worker gets unique subsets
>* Combine sharding, shuffling, batching for efficient pipelines



In [None]:
#@title Python Code - Building Distributed Datasets

# This script shows distributed dataset basics.
# It uses TensorFlow strategy with simple data.
# Focus is on global and per replica batches.

# !pip install tensorflow==2.20.0.

# Import required modules for TensorFlow training.
import os
import random
import numpy as np
import tensorflow as tf

# Set deterministic seeds for reproducible behavior.
seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

# Print TensorFlow version in one short line.
print("TensorFlow version:", tf.__version__)

# Choose a distribution strategy based on hardware.
if len(tf.config.list_logical_devices("GPU")) > 1:
    strategy = tf.distribute.MirroredStrategy()
else:
    strategy = tf.distribute.OneDeviceStrategy("/cpu:0")

# Show which strategy class is being used.
print("Using strategy:", strategy.__class__.__name__)

# Define small synthetic dataset parameters clearly.
num_examples = 32
feature_dim = 4
num_classes = 3

# Create simple numeric features and integer labels.
features = np.random.randn(num_examples, feature_dim).astype("float32")
labels = np.random.randint(num_classes, size=(num_examples,))

# Wrap arrays in a tf.data Dataset object.
base_ds = tf.data.Dataset.from_tensor_slices((features, labels))

# Define a small global batch size for training.
global_batch_size = 8

# Shuffle, repeat, and batch to form global batches.
train_ds = (base_ds.shuffle(buffer_size=num_examples)
            .repeat(1)
            .batch(global_batch_size))

# Validate that batched shapes match expectations.
for batch_x, batch_y in train_ds.take(1):
    print("Global batch shape:", batch_x.shape)

# Create a distributed dataset from the global dataset.
dist_train_ds = strategy.experimental_distribute_dataset(train_ds)

# Define a simple model building function.
def build_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(feature_dim,)),
        tf.keras.layers.Dense(8, activation="relu"),
        tf.keras.layers.Dense(num_classes)
    ])
    return model

# Build model, optimizer, and loss inside strategy scope.
with strategy.scope():
    model = build_model()
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
    loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True,
        reduction=tf.keras.losses.Reduction.NONE
    )

# Define a function to compute per example loss.
def compute_loss(labels, logits):
    per_example_loss = loss_obj(labels, logits)
    return tf.nn.compute_average_loss(
        per_example_loss,
        global_batch_size=global_batch_size
    )

# Define one training step run on each replica.
@tf.function
def train_step(dist_inputs):
    def replica_step(inputs):
        x, y = inputs
        with tf.GradientTape() as tape:
            logits = model(x, training=True)
            loss = compute_loss(y, logits)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return loss

    per_replica_losses = strategy.run(replica_step, args=(dist_inputs,))
    mean_loss = strategy.reduce(
        tf.distribute.ReduceOp.SUM,
        per_replica_losses,
        axis=None
    )
    return mean_loss

# Run a short custom loop over distributed dataset.
step = 0
for dist_batch in dist_train_ds:
    step += 1
    loss_value = train_step(dist_batch)
    print("Step", step, "loss:", float(loss_value))

# Confirm that training finished without shape issues.
print("Finished custom distributed loop successfully.")



### **2.3. Replica Metrics Reduction**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master TensorFlow 2.20.0/Module_08/Lecture_B/image_02_03.jpg?v=1769459665" width="250">



>* Each replica computes its own local metrics
>* We must aggregate replicas’ metrics for trustworthy monitoring

>* Compute per-replica metrics, then reduce centrally
>* Use mean for normalized losses, sum for counts

>* Consistent reduction keeps metrics stable when scaling
>* Correct reductions ensure trustworthy gradients and alerts



In [None]:
#@title Python Code - Replica Metrics Reduction

# This script shows replica metrics reduction simply.
# It uses TensorFlow strategy with a tiny dataset.
# Focus is on custom loop and metric aggregation.

# !pip install tensorflow==2.20.0.

# Import required standard libraries safely.
import os
import random
import numpy as np

# Import TensorFlow and check version.
import tensorflow as tf

# Set deterministic seeds for reproducibility.
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Print TensorFlow version in one short line.
print("TensorFlow version:", tf.__version__)

# Choose strategy based on available GPUs.
if len(tf.config.list_physical_devices("GPU")) > 0:
    strategy = tf.distribute.MirroredStrategy()
else:
    strategy = tf.distribute.OneDeviceStrategy("/cpu:0")

# Create a tiny synthetic classification dataset.
num_samples = 64
num_features = 8
num_classes = 3

# Generate random features and integer labels.
x_data = np.random.randn(num_samples, num_features).astype("float32")
y_data = np.random.randint(num_classes, size=(num_samples,)).astype("int32")

# Validate shapes before building dataset.
assert x_data.shape[0] == y_data.shape[0]

# Create a tf.data.Dataset with small batch size.
batch_size = 16
dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))

# Shuffle and batch the dataset deterministically.
dataset = dataset.shuffle(num_samples, seed=SEED).batch(batch_size)

# Distribute the dataset using the chosen strategy.
dist_dataset = strategy.experimental_distribute_dataset(dataset)

# Build a simple model inside strategy scope.
with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(num_features,)),
        tf.keras.layers.Dense(16, activation="relu"),
        tf.keras.layers.Dense(num_classes)
    ])

# Define loss object and optimizer inside scope.
with strategy.scope():
    loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE
    )
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

# Define a function to compute per replica loss.
@tf.function
def compute_loss(labels, predictions):
    per_example_loss = loss_obj(labels, predictions)
    return tf.nn.compute_average_loss(
        per_example_loss, global_batch_size=batch_size
    )

# Define one training step run on each replica.
@tf.function
def train_step(dist_inputs):
    def replica_step(inputs):
        features, labels = inputs
        with tf.GradientTape() as tape:
            logits = model(features, training=True)
            loss = compute_loss(labels, logits)
            preds = tf.argmax(logits, axis=1, output_type=tf.int32)
            correct = tf.cast(tf.equal(preds, labels), tf.float32)
            correct_count = tf.reduce_sum(correct)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return loss, correct_count

    per_replica_loss, per_replica_correct = strategy.run(
        replica_step, args=(dist_inputs,)
    )

    mean_loss = strategy.reduce(
        tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None
    )
    total_correct = strategy.reduce(
        tf.distribute.ReduceOp.SUM, per_replica_correct, axis=None
    )
    return mean_loss, total_correct

# Run a single epoch over the tiny distributed dataset.
num_epochs = 1
for epoch in range(num_epochs):
    epoch_loss = 0.0
    epoch_correct = 0.0
    num_batches = 0

    for batch in dist_dataset:
        mean_loss, total_correct = train_step(batch)
        epoch_loss += mean_loss.numpy()
        epoch_correct += total_correct.numpy()
        num_batches += 1

    avg_loss = epoch_loss / float(num_batches)
    total_examples = float(num_batches * batch_size)
    accuracy = epoch_correct / total_examples

# Print final reduced metrics from all replicas.
print("Reduced mean loss over epoch:", float(avg_loss))
print("Reduced accuracy over epoch:", float(accuracy))



## **3. Distributed Debugging Essentials**

### **3.1. Replica Shape Mismatches**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master TensorFlow 2.20.0/Module_08/Lecture_B/image_03_01.jpg?v=1769459765" width="250">



>* Replicas must see tensors with identical shapes
>* Uneven batches or preprocessing differences cause mismatches

>* Align global and per-replica batch sizes
>* Normalize input shapes to avoid replica mismatches

>* Data-dependent conditionals can desynchronize replica tensor shapes
>* Use deterministic logic, preprocessing, and shape logging



In [None]:
#@title Python Code - Replica Shape Mismatches

# This script demonstrates replica shape mismatches.
# It uses TensorFlow distribution strategy with simple data.
# Focus is on debugging shapes across replicas.

# !pip install tensorflow==2.20.0.

# Import required modules safely.
import os
import random
import numpy as np
import tensorflow as tf

# Set deterministic seeds for reproducibility.
seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

# Print TensorFlow version and device information.
print("TensorFlow version:", tf.__version__)
print("GPUs:", len(tf.config.list_physical_devices("GPU")))

# Choose strategy based on available GPUs.
if len(tf.config.list_physical_devices("GPU")) > 1:
    strategy = tf.distribute.MirroredStrategy()
else:
    strategy = tf.distribute.OneDeviceStrategy("/cpu:0")

# Show chosen strategy for clarity.
print("Using strategy:", type(strategy).__name__)

# Define global and per replica batch sizes.
num_replicas = strategy.num_replicas_in_sync
global_batch_size = 8
per_replica_batch = global_batch_size // num_replicas

# Print basic batch configuration.
print("Replicas:", num_replicas, "Global batch:", global_batch_size)
print("Per replica batch:", per_replica_batch)

# Create toy features and labels with odd size.
num_samples = global_batch_size + 3
x_data = tf.random.normal((num_samples, 4))
y_data = tf.random.normal((num_samples, 1))

# Build dataset without dropping remainder first.
ds_bad = tf.data.Dataset.from_tensor_slices((x_data, y_data))
ds_bad = ds_bad.batch(global_batch_size, drop_remainder=False)

# Build dataset with drop_remainder to fix shapes.
ds_good = tf.data.Dataset.from_tensor_slices((x_data, y_data))
ds_good = ds_good.batch(global_batch_size, drop_remainder=True)

# Distribute both datasets using the strategy.
dist_bad = strategy.experimental_distribute_dataset(ds_bad)
dist_good = strategy.experimental_distribute_dataset(ds_good)

# Define a simple model inside strategy scope.
with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(4,)),
        tf.keras.layers.Dense(4, activation="relu"),
        tf.keras.layers.Dense(1),
    ])
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
    loss_fn = tf.keras.losses.MeanSquaredError(
        reduction=tf.keras.losses.Reduction.NONE
    )


# Define one training step using strategy.run.
@tf.function
def train_step(dist_inputs):
    def step_fn(inputs):
        x_batch, y_batch = inputs
        with tf.GradientTape() as tape:
            preds = model(x_batch, training=True)
            per_example_loss = loss_fn(y_batch, preds)
            loss = tf.nn.compute_average_loss(
                per_example_loss, global_batch_size=global_batch_size
            )
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return loss, tf.shape(x_batch)[0]

    per_replica_loss, per_replica_size = strategy.run(step_fn, (dist_inputs,))
    mean_loss = strategy.reduce(
        tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None
    )
    total_size = strategy.reduce(
        tf.distribute.ReduceOp.SUM, per_replica_size, axis=None
    )
    return mean_loss, total_size


# Helper to inspect one distributed batch shapes.
def inspect_batch(dist_batch, label):
    print("\nInspecting", label)

    def show_fn(inputs):
        x_batch, y_batch = inputs
        print("Replica x shape:", x_batch.shape)
        print("Replica y shape:", y_batch.shape)
        return 0

    _ = strategy.run(show_fn, (dist_batch,))


# Take first batch from bad and good datasets.
for bad_batch in iter(dist_bad):
    inspect_batch(bad_batch, "batch without drop_remainder")
    break

for good_batch in iter(dist_good):
    inspect_batch(good_batch, "batch with drop_remainder")
    break

# Run one safe training step on the good batch.
loss_value, seen_size = train_step(good_batch)
print("\nTrain step finished. Loss:", float(loss_value))
print("Examples seen in step:", int(seen_size))



### **3.2. Managing Local Variables**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master TensorFlow 2.20.0/Module_08/Lecture_B/image_03_02.jpg?v=1769459896" width="250">



>* Local variables keep separate values on each replica
>* Confusing local versus global variables breaks training

>* Create variables in correct strategy and replica scopes
>* Avoid unintended per-replica copies and metric counters

>* Clarify which variables must be local or shared
>* Trace creation, updates, aggregation to spot issues



In [None]:
#@title Python Code - Managing Local Variables

# This script shows managing local variables.
# It uses TensorFlow distribution strategies safely.
# Focus on debugging variable placement issues.

# Install TensorFlow in some environments if needed.
# !pip install tensorflow==2.20.0.

# Import required standard libraries.
import os
import random
import numpy as np

# Import TensorFlow and distribution strategies.
import tensorflow as tf
from tensorflow import keras

# Set deterministic seeds for reproducibility.
seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)

# Set TensorFlow random seed deterministically.
tf.random.set_seed(seed_value)

# Print TensorFlow version in one short line.
print("TensorFlow version:", tf.__version__)

# Choose a simple distribution strategy for demo.
strategy = tf.distribute.MirroredStrategy()

# Create tiny synthetic dataset for quick training.
features = np.random.randn(64, 4).astype("float32")
labels = np.random.randint(0, 2, size=(64, 1)).astype("float32")

# Validate shapes before building dataset.
print("Features shape:", features.shape)
print("Labels shape:", labels.shape)

# Build a tf.data dataset with global batch size.
batch_size = 8
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

# Shuffle and batch the dataset safely.
dataset = dataset.shuffle(64, seed=seed_value).batch(batch_size)

# Distribute the dataset using the chosen strategy.
dist_dataset = strategy.experimental_distribute_dataset(dataset)

# Define a simple model building function.
def build_model():
    # Create a tiny sequential model.
    model = keras.Sequential([
        keras.layers.Input(shape=(4,)),
        keras.layers.Dense(4, activation="relu"),
        keras.layers.Dense(1, activation="sigmoid"),
    ])

    # Compile model with simple optimizer.
    model.compile(
        optimizer=keras.optimizers.SGD(learning_rate=0.1),
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )
    return model

# Create model inside strategy scope for correct placement.
with strategy.scope():
    model = build_model()

# Define a metric created once, not per replica.
with strategy.scope():
    train_loss = keras.metrics.Mean(name="train_loss")

# Show that metric is a MirroredVariable backed object.
print("Metric variable type:", type(train_loss.variables[0]).__name__)

# Define loss object for custom training loop.
loss_obj = keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)

# Define per replica loss computation function.
def compute_loss(labels_replica, preds_replica):
    # Compute unreduced per example loss.
    per_example_loss = loss_obj(labels_replica, preds_replica)

    # Scale loss by global batch size.
    return tf.nn.compute_average_loss(
        per_example_loss,
        global_batch_size=batch_size,
    )

# Create optimizer inside strategy scope for safety.
with strategy.scope():
    optimizer = keras.optimizers.SGD(learning_rate=0.1)

# Define one training step run on each replica.
@tf.function
def train_step(dist_inputs):
    # Unpack distributed features and labels.
    def step_fn(inputs):
        x_batch, y_batch = inputs

        # Use GradientTape for custom training.
        with tf.GradientTape() as tape:
            preds = model(x_batch, training=True)
            loss = compute_loss(y_batch, preds)

        # Compute gradients and apply updates.
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # Update global metric once per replica.
        train_loss.update_state(loss)
        return loss

    # Run step function on each replica.
    per_replica_losses = strategy.run(step_fn, args=(dist_inputs,))

    # Reduce losses to get mean across replicas.
    return strategy.reduce(
        tf.distribute.ReduceOp.SUM,
        per_replica_losses,
        axis=None,
    )

# Run a short custom training loop.
for epoch in range(2):
    # Reset metric at the start of each epoch.
    train_loss.reset_state()

    # Iterate over distributed batches.
    for batch_inputs in dist_dataset:
        loss_value = train_step(batch_inputs)

    # Print epoch summary with metric result.
    print(
        "Epoch",
        epoch,
        "mean loss:",
        float(train_loss.result()),
    )




### **3.3. Monitoring Distributed Training**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master TensorFlow 2.20.0/Module_08/Lecture_B/image_03_03.jpg?v=1769459946" width="250">



>* Monitor replica behavior and resource use continuously
>* Track per-replica metrics to catch hidden issues

>* Track both global and per-worker training metrics
>* Use detailed logs to reveal hidden replica issues

>* Log key steps to catch rare errors
>* Use dashboards and profiling to trace replicas



In [None]:
#@title Python Code - Monitoring Distributed Training

# This script shows basic distributed monitoring.
# It uses MirroredStrategy with simple metrics.
# Focus is on safe concise debugging ideas.

# !pip install tensorflow==2.20.0.

# Import required standard libraries.
import os
import random
import numpy as np

# Import TensorFlow and check version.
import tensorflow as tf

# Set deterministic seeds for reproducibility.
seed_value = 42
random.seed(seed_value)

# Set numpy and tensorflow seeds deterministically.
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

# Print TensorFlow version in one short line.
print("TensorFlow version:", tf.__version__)

# Choose devices automatically for strategy.
if tf.config.list_logical_devices("GPU"):
    strategy = tf.distribute.MirroredStrategy()
else:
    strategy = tf.distribute.OneDeviceStrategy("/cpu:0")

# Print number of replicas for quick check.
print("Replicas in sync:", strategy.num_replicas_in_sync)

# Create a tiny synthetic dataset safely.
num_samples = 256
features = np.random.randn(num_samples, 8).astype("float32")

# Create simple binary labels from features.
labels = (np.sum(features, axis=1) > 0).astype("float32")

# Validate shapes before building dataset.
assert features.shape[0] == labels.shape[0]

# Define global batch size divisible by replicas.
per_replica_batch = 8
global_batch_size = per_replica_batch * strategy.num_replicas_in_sync

# Build tf.data dataset with batching.
base_ds = tf.data.Dataset.from_tensor_slices((features, labels))
base_ds = base_ds.shuffle(256, seed=seed_value)

# Batch and repeat for a few steps.
base_ds = base_ds.batch(global_batch_size).repeat(3)

# Distribute dataset for the chosen strategy.
dist_ds = strategy.experimental_distribute_dataset(base_ds)

# Define a simple model building function.
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(8,)),
        tf.keras.layers.Dense(16, activation="relu"),
        tf.keras.layers.Dense(1, activation="sigmoid"),
    ])
    return model

# Create optimizer and loss objects.
with strategy.scope():
    model = create_model()
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

# Use BinaryCrossentropy with reduction NONE.
with strategy.scope():
    loss_obj = tf.keras.losses.BinaryCrossentropy(
        from_logits=False,
        reduction=tf.keras.losses.Reduction.NONE,
    )

# Define a function to compute per replica loss.
def compute_loss(labels, predictions):
    per_example_loss = loss_obj(labels, predictions)
    per_example_loss = tf.reshape(per_example_loss, [-1])
    return tf.nn.compute_average_loss(
        per_example_loss,
        global_batch_size=global_batch_size,
    )

# Create metrics for monitoring training.
with strategy.scope():
    train_loss = tf.keras.metrics.Mean(name="train_loss")
    train_acc = tf.keras.metrics.BinaryAccuracy(name="train_accuracy")

# Define the per replica train step function.
def train_step(inputs):
    features_batch, labels_batch = inputs
    tf.debugging.assert_shapes(
        [(features_batch, (None, 8)), (labels_batch, (None,))]
    )
    with tf.GradientTape() as tape:
        predictions = model(features_batch, training=True)
        predictions = tf.squeeze(predictions, axis=-1)
        loss = compute_loss(labels_batch, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss.update_state(loss)
    train_acc.update_state(labels_batch, predictions)
    return loss

# Define one distributed training step with monitoring.
@tf.function
def distributed_train_step(dist_inputs):
    per_replica_losses = strategy.run(train_step, args=(dist_inputs,))
    mean_loss = strategy.reduce(
        tf.distribute.ReduceOp.SUM,
        per_replica_losses,
        axis=None,
    )
    return mean_loss

# Run a short custom training loop with logs.
step_times = []
for step, batch in enumerate(dist_ds):
    if step >= 5:
        break
    start = tf.timestamp()
    mean_loss = distributed_train_step(batch)
    end = tf.timestamp()
    step_times.append(float(end - start))
    if (step + 1) % 2 == 0:
        print(
            "Step",
            step + 1,
            "loss=",
            float(train_loss.result()),
            "acc=",
            float(train_acc.result()),
        )

# Print simple monitoring summary at the end.
avg_step_time = sum(step_times) / len(step_times)
print("Average step time (seconds):", round(avg_step_time, 4))



# <font color="#418FDE" size="6.5" uppercase>**Implementing Strategies**</font>


In this lecture, you learned to:
- Configure and run model.fit under a tf.distribute strategy for multi-GPU or multi-worker setups. 
- Implement a custom training loop that uses strategy.run and distributed datasets. 
- Troubleshoot common distributed training errors related to shapes, batch sizes, and variable placement. 

In the next Module (Module 9), we will go over 'Production and Serving'