# <font color="#418FDE" size="6.5" uppercase>**Scaling Practices**</font>

>Last update: 20260130.
    
By the end of this Lecture, you will be able to:
- Apply gradient accumulation and effective batch size strategies in distributed training setups. 
- Combine mixed precision with DDP to improve throughput while maintaining convergence. 
- Implement basic logging and checkpointing that work correctly in multi‑process environments. 


## **1. Scaling Batch Size**

### **1.1. Effective batch size math**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master PyTorch 2.10.0/Module_08/Lecture_B/image_01_01.jpg?v=1769770220" width="250">



>* Effective batch size counts samples per optimizer update
>* It controls optimization behavior and learning rate tuning

>* Multiple GPUs combine local batches into one
>* Gradient accumulation further increases effective batch size

>* Match target recipes via effective batch size
>* Tune devices, local batches, accumulation to scale



### **1.2. Accumulation Loop Patterns**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master PyTorch 2.10.0/Module_08/Lecture_B/image_01_02.jpg?v=1769770244" width="250">



>* Accumulate gradients over several mini batches
>* Step optimizer after window to mimic larger batches

>* Align accumulation with cross device gradient averaging
>* All processes accumulate equally before synchronized optimizer step

>* Adapt accumulation to tokens, memory, sequence variability
>* Keep a shared, deterministic rule across processes



In [None]:
#@title Python Code - Accumulation Loop Patterns

# This script shows gradient accumulation patterns.
# We simulate effective batch size scaling behavior.
# Focus is on simple loop logic demonstration.

# !pip install tensorflow==2.20.0.

# Import required standard libraries.
import os
import random
import numpy as np

# Import tensorflow and check version.
import tensorflow as tf

# Set deterministic random seeds everywhere.
seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

# Print tensorflow version in one short line.
print("TensorFlow version:", tf.__version__)

# Detect device type for information only.
physical_gpus = tf.config.list_physical_devices("GPU")
use_gpu = bool(physical_gpus)
device_type = "GPU" if use_gpu else "CPU"

# Print detected device type briefly.
print("Using device type:", device_type)

# Create a tiny synthetic regression dataset.
num_samples = 64
num_features = 10
x_data = np.random.randn(num_samples, num_features).astype("float32")

# Create targets with a simple linear relationship.
true_w = np.arange(1, num_features + 1, dtype="float32")
y_data = x_data @ true_w + 0.1

# Wrap data into a tf.data.Dataset pipeline.
base_batch_size = 4
train_ds = tf.data.Dataset.from_tensor_slices((x_data, y_data))
train_ds = train_ds.shuffle(num_samples, seed=seed_value)

# Batch the dataset with a small per device batch.
train_ds = train_ds.batch(base_batch_size, drop_remainder=True)

# Build a tiny linear regression model.
model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=(num_features,))
])

# Create an optimizer with a small learning rate.
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

# Define a simple mean squared error loss.
loss_fn = tf.keras.losses.MeanSquaredError()

# Simulate two devices for effective batch reasoning.
world_size = 2
per_device_batch = base_batch_size

# Choose accumulation steps for the experiment.
accum_steps = 4

# Compute effective batch size for explanation.
effective_batch = per_device_batch * world_size * accum_steps

# Print the effective batch size information.
print("Per device batch:", per_device_batch)
print("World size used:", world_size)
print("Accumulation steps:", accum_steps)
print("Effective batch size:", effective_batch)

# Prepare variables for gradient accumulation.
accum_grads = [tf.zeros_like(v) for v in model.trainable_variables]
step_in_accum = 0
epochs = 1

# Training loop with manual accumulation pattern.
for epoch in range(epochs):
    for step, (x_batch, y_batch) in enumerate(train_ds):

        # Validate batch shapes before using them.
        assert x_batch.shape[0] == per_device_batch
        assert x_batch.shape[1] == num_features

        # Record operations for automatic differentiation.
        with tf.GradientTape() as tape:
            preds = model(x_batch, training=True)
            loss_value = loss_fn(y_batch, tf.squeeze(preds))

        # Compute gradients for current mini batch.
        grads = tape.gradient(loss_value, model.trainable_variables)

        # Scale gradients by accumulation steps factor.
        scaled_grads = [g / float(accum_steps) for g in grads]

        # Accumulate scaled gradients into buffers.
        accum_grads = [ag + sg for ag, sg in zip(accum_grads, scaled_grads)]
        step_in_accum += 1

        # Simulate all processes reaching same accumulation count.
        if step_in_accum == accum_steps:

            # Apply optimizer step using accumulated gradients.
            optimizer.apply_gradients(zip(accum_grads, model.trainable_variables))

            # Reset accumulation buffers to zero tensors.
            accum_grads = [tf.zeros_like(v) for v in model.trainable_variables]
            step_in_accum = 0

# Evaluate model loss on full dataset once.
full_preds = model(x_data, training=False)
final_loss = loss_fn(y_data, tf.squeeze(full_preds)).numpy()

# Print final loss and confirm accumulation pattern.
print("Final training loss with accumulation:", round(float(final_loss), 4))



### **1.3. Learning Rate Scaling**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master PyTorch 2.10.0/Module_08/Lecture_B/image_01_03.jpg?v=1769770278" width="250">



>* Changing batch size changes optimization and stability
>* Larger batches need retuned learning rates for convergence

>* Use batch-size scaling rules as flexible guesses
>* Run short tests, watch metrics, adjust conservatively

>* Combine scaled LR, warmup, and decay schedules
>* Monitor metrics to keep training stable, generalizable



In [None]:
#@title Python Code - Learning Rate Scaling

# This script illustrates learning rate scaling.
# We compare small and large effective batch sizes.
# Focus is on simple gradient descent behavior.

# !pip install tensorflow==2.20.0.

# Import required standard libraries.
import os
import math
import random
import numpy as np

# Import tensorflow and check version.
import tensorflow as tf
print("TensorFlow version:", tf.__version__)

# Set deterministic random seeds everywhere.
seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

# Define a simple quadratic loss function.
def loss_fn(x):
    return (x - 3.0) ** 2

# Compute gradient of the loss analytically.
def grad_fn(x):
    return 2.0 * (x - 3.0)

# Create synthetic "batch" of target values.
true_targets = np.full(shape=(32,), fill_value=3.0)

# Validate the batch shape before training.
assert true_targets.shape[0] == 32
assert true_targets.ndim == 1

# Helper function to run simple gradient descent.
def run_gd(initial_x, lr, batch_size, steps):
    x = tf.Variable(initial_x, dtype=tf.float32)
    history = []
    for step in range(steps):
        # Select a mini batch slice.
        start = (step * batch_size) % 32
        end = start + batch_size
        batch = true_targets[start:end]

        # Compute mean loss over the batch.
        batch_mean = np.mean(batch)
        loss = loss_fn(x - (batch_mean - 3.0))

        # Compute gradient and update parameter.
        grad = grad_fn(x)
        x.assign_sub(lr * grad)
        history.append(float(loss.numpy()))
    return float(x.numpy()), history

# Define base configuration for small batch.
initial_x = 0.0
small_batch = 4
large_batch = 16
steps = 20

# Base learning rate for small batch size.
base_lr = 0.05
scaled_lr = base_lr * (large_batch / small_batch)

# Run training with small batch and base learning rate.
small_x_final, small_hist = run_gd(
    initial_x=initial_x,
    lr=base_lr,
    batch_size=small_batch,
    steps=steps,
)

# Run training with large batch and scaled learning rate.
large_x_final, large_hist = run_gd(
    initial_x=initial_x,
    lr=scaled_lr,
    batch_size=large_batch,
    steps=steps,
)

# Print concise comparison of configurations.
print("Small batch:", small_batch, "LR:", base_lr)
print("Large batch:", large_batch, "LR:", scaled_lr)
print("Final x small batch:", round(small_x_final, 4))
print("Final x large batch:", round(large_x_final, 4))

# Show first and last losses for both runs.
print("Small batch loss start:", round(small_hist[0], 4))
print("Small batch loss end:", round(small_hist[-1], 4))
print("Large batch loss start:", round(large_hist[0], 4))
print("Large batch loss end:", round(large_hist[-1], 4))

# Confirm that both settings converge near the optimum.
print("Both settings approach x = 3 with scaling.")




## **2. Mixed Precision with DDP**

### **2.1. Autocast with DDP**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master PyTorch 2.10.0/Module_08/Lecture_B/image_02_01.jpg?v=1769770345" width="250">



>* Autocast automatically chooses safe mixed precision dtypes
>* Ensures consistent computations across processes for stable training

>* Each process runs autocast independently yet identically
>* Mixed precision boosts speed while keeping gradients consistent

>* Keep inputs, outputs, and loss mostly full precision
>* Use autocast for faster compute while preserving gradients



In [None]:
#@title Python Code - Autocast with DDP

# This script shows autocast with DDP style concepts.
# We simulate distributed mixed precision using simple tensors.
# Focus is on understanding autocast behavior not performance.

# Required PyTorch install for mixed precision and DDP examples.
# Uncomment next line if running outside prepared environments.
# pip install torch torchvision torchaudio.

# Import standard libraries for reproducibility and environment checks.
import os
import random
import math

# Import torch for tensors, autocast, and device handling.
import torch
import torch.nn as nn
import torch.optim as optim

# Set deterministic seeds for reproducible behavior across runs.
random.seed(0)
torch.manual_seed(0)

# Select GPU if available else fall back to CPU device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Print torch version and selected device in one short line.
print("Torch version:", torch.__version__, "Device:", device)

# Define a tiny model that mimics a DDP replica locally.
class TinyModel(nn.Module):

    # Initialize a simple two layer network for demonstration.
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    # Forward pass uses standard layers without manual casting.
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Create two model replicas to mimic two DDP processes.
model_rank0 = TinyModel(input_dim=4, hidden_dim=8, output_dim=2).to(device)
model_rank1 = TinyModel(input_dim=4, hidden_dim=8, output_dim=2).to(device)

# Ensure both replicas start from identical initial parameters.
model_rank1.load_state_dict(model_rank0.state_dict())

# Create a simple optimizer for each model replica.
optimizer_rank0 = optim.SGD(model_rank0.parameters(), lr=0.1)
optimizer_rank1 = optim.SGD(model_rank1.parameters(), lr=0.1)

# Create a tiny batch of inputs and integer labels.
inputs = torch.randn(2, 4, device=device)
labels = torch.tensor([0, 1], device=device)

# Validate shapes to avoid silent broadcasting mistakes.
assert inputs.shape == (2, 4)
assert labels.shape == (2,)

# Define a standard cross entropy loss in full precision.
criterion = nn.CrossEntropyLoss()

# Create a gradient scaler for stable mixed precision training.
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

# Function to run one mixed precision step for a given replica.
def mixed_precision_step(model, optimizer, rank_name: str):
    # Zero gradients before computing new ones for this step.
    optimizer.zero_grad()

    # Enter autocast context to run safe ops in lower precision.
    with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
        outputs = model(inputs)
        loss = criterion(outputs, labels)

    # Scale loss then backpropagate scaled gradients safely.
    scaler.scale(loss).backward()

    # Unscale gradients and perform optimizer step conditionally.
    scaler.step(optimizer)

    # Update scaler for next iteration to track stability.
    scaler.update()

    # Return detached loss value for simple comparison printing.
    return loss.detach().item()

# Run one mixed precision step on each replica independently.
loss_rank0 = mixed_precision_step(model_rank0, optimizer_rank0, "rank0")
loss_rank1 = mixed_precision_step(model_rank1, optimizer_rank1, "rank1")

# Collect parameters from both replicas for numerical comparison.
params_rank0 = torch.cat([p.detach().flatten() for p in model_rank0.parameters()])
params_rank1 = torch.cat([p.detach().flatten() for p in model_rank1.parameters()])

# Compute maximum absolute difference between replica parameters.
max_param_diff = (params_rank0 - params_rank1).abs().max().item()

# Print a few concise lines summarizing mixed precision behavior.
print("Loss rank0 after autocast step:", round(loss_rank0, 4))
print("Loss rank1 after autocast step:", round(loss_rank1, 4))
print("Max parameter difference between replicas:", max_param_diff)
print("Autocast decisions stayed consistent across replicas.")
print("This mirrors how DDP keeps gradients aligned.")




### **2.2. Per Process GradScaler**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master PyTorch 2.10.0/Module_08/Lecture_B/image_02_02.jpg?v=1769770426" width="250">



>* Each GPU keeps its own gradient scaler
>* Local scalers adapt to different numerical conditions

>* Scaler adjusts up or down based on overflow
>* Each GPU scales independently for stable training

>* Gradients are scaled locally, synchronized unscaled across GPUs
>* Keeps mixed precision local, DDP focuses on aggregation



In [None]:
#@title Python Code - Per Process GradScaler

# This script shows per process gradient scaling conceptually.
# We simulate two processes each with its own GradScaler.
# Focus is on logic not real distributed training.

# Example pip install for torch if needed.
# !pip install torch torchvision torchaudio.

# Import required standard libraries.
import random
import math
import os

# Set deterministic random seed.
random.seed(42)

# Define a simple GradScaler like helper.
class SimpleGradScaler:
    # Initialize with starting scale value.
    def __init__(self, init_scale=2.0):
        self.scale = float(init_scale)

    # Scale a scalar loss value.
    def scale_loss(self, loss_value):
        return loss_value * self.scale

    # Unscale a gradient value.
    def unscale_grad(self, grad_value):
        return grad_value / self.scale

    # Check for overflow and update scale.
    def update(self, overflow):
        if overflow:
            self.scale = max(self.scale / 2.0, 0.125)
        else:
            self.scale = min(self.scale * 2.0, 128.0)


# Simulate one training step on a single process.
def simulate_step(process_id, scaler, base_loss):
    # Add small noise to base loss per process.
    noise = (random.random() - 0.5) * 0.2
    local_loss = max(base_loss + noise, 0.01)

    # Scale the loss before backward.
    scaled_loss = scaler.scale_loss(local_loss)

    # Pretend gradient magnitude equals scaled loss.
    fake_grad = scaled_loss

    # Detect overflow if gradient too large.
    overflow = fake_grad > 10.0

    # Unscale gradient before all reduce.
    unscaled_grad = scaler.unscale_grad(fake_grad)

    # Update scaler based on overflow flag.
    scaler.update(overflow)

    # Return summary for printing.
    return {
        "pid": process_id,
        "loss": round(local_loss, 4),
        "scaled": round(scaled_loss, 4),
        "unscaled": round(unscaled_grad, 4),
        "overflow": overflow,
        "scale": round(scaler.scale, 4),
    }

# Create two independent scalers like two DDP processes.
scaler_rank0 = SimpleGradScaler(init_scale=2.0)
scaler_rank1 = SimpleGradScaler(init_scale=8.0)

# Base loss shared conceptually across processes.
base_loss_value = 1.5

# Run a few simulated steps to see divergence.
results = []
for step in range(3):
    res0 = simulate_step(0, scaler_rank0, base_loss_value)
    res1 = simulate_step(1, scaler_rank1, base_loss_value)
    results.append((step, res0, res1))

# Print short header explaining columns.
print("step pid loss scaled unscaled overflow scale")

# Print one summary line per process per step.
for step, r0, r1 in results:
    print(step, r0["pid"], r0["loss"], r0["scaled"], r0["unscaled"], r0["overflow"], r0["scale"])
    print(step, r1["pid"], r1["loss"], r1["scaled"], r1["unscaled"], r1["overflow"], r1["scale"])



### **2.3. Handling Overflow Errors**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master PyTorch 2.10.0/Module_08/Lecture_B/image_02_03.jpg?v=1769770490" width="250">



>* Mixed precision increases risk of gradient overflow
>* Overflows can spread across processes, destabilizing training

>* Gradient scaling detects overflow and safely skips updates
>* All DDP processes must agree on overflow handling

>* Monitor overflows, scaler behavior, and training logs
>* Fix chronic overflows by tuning hyperparameters and synchronization



In [None]:
#@title Python Code - Handling Overflow Errors

# This script shows overflow handling conceptually.
# We simulate mixed precision overflow detection behavior.
# Focus is on clear beginner friendly printed output.

# Required external installs would be placed here.
# !pip install tensorflow==2.20.0.

# Import standard libraries for math and typing.
import math
import random
import os

# Set deterministic random seed for reproducibility.
random.seed(42)

# Define a simple class to mimic gradient scaler.
class SimpleGradScaler:

    # Initialize with starting scale and growth factor.
    def __init__(self, init_scale=1024.0, growth=2.0):
        self.scale = float(init_scale)
        self.growth = float(growth)

    # Scale the loss value before backward pass.
    def scale_loss(self, loss_value):
        return loss_value * self.scale

    # Check gradients for overflow and update scale.
    def step(self, grads):
        has_overflow = any(
            (math.isinf(g) or math.isnan(g)) for g in grads
        )
        if has_overflow:
            self.scale = max(self.scale / self.growth, 1.0)
            return False
        self.scale = self.scale * self.growth
        return True

# Simulate two processes like two DDP workers.
num_processes = 2

# Create one scaler instance per simulated process.
scalers = [SimpleGradScaler() for _ in range(num_processes)]

# Define a helper to simulate gradient computation safely.
def simulate_gradients(step_index, process_index):
    base_grad = 0.001 * (step_index + 1)
    noise = 0.0001 * (process_index + 1)
    grad = base_grad + noise
    if step_index == 2 and process_index == 1:
        return [float('inf'), grad]
    return [grad, grad * 1.5]

# Validate gradient list shapes before using them.
for p in range(num_processes):
    test_grads = simulate_gradients(0, p)
    assert isinstance(test_grads, list)
    assert len(test_grads) == 2

# Run a short loop to show overflow handling.
max_steps = 5

# Print a short header explaining the columns.
print("step, scaled_loss_p0, scaled_loss_p1, update_applied")

# Iterate over training steps and simulate behavior.
for step in range(max_steps):
    losses = [0.1 * (step + 1) for _ in range(num_processes)]
    scaled_losses = []
    for p in range(num_processes):
        scaled = scalers[p].scale_loss(losses[p])
        scaled_losses.append(scaled)

    grads_per_process = []
    for p in range(num_processes):
        grads = simulate_gradients(step, p)
        grads_per_process.append(grads)

    local_can_step = []
    for p in range(num_processes):
        can_step = scalers[p].step(grads_per_process[p])
        local_can_step.append(can_step)

    global_can_step = all(local_can_step)

    if global_can_step:
        decision_text = "update"
    else:
        decision_text = "skip"

    print(
        f"{step}, {scaled_losses[0]:.1f}, {scaled_losses[1]:.1f}, {decision_text}"
    )

# Print final scales to show adaptation behavior.
print(f"final_scale_p0={scalers[0].scale:.1f}, final_scale_p1={scalers[1].scale:.1f}")




## **3. Robust Distributed Training**

### **3.1. Rank Aware Logging**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master PyTorch 2.10.0/Module_08/Lecture_B/image_03_01.jpg?v=1769770562" width="250">



>* Attach rank and role to every log
>* Limit which ranks emit which message types

>* Control verbosity by splitting logging responsibilities across ranks
>* Primary logs high level progress; others detailed diagnostics

>* Rank tagged logs expose subtle distributed failures clearly
>* Standardized formats enable correlation, debugging, and recovery



In [None]:
#@title Python Code - Rank Aware Logging

# This script demonstrates rank aware logging basics.
# We simulate distributed ranks using simple integer identifiers.
# Focus on clear concise logs from different simulated ranks.

# Required external installs would be placed here if needed.
# No extra libraries are required for this simple example.

# Import standard modules for time and typing support.
import os
import time
from typing import Optional


# Define a small helper to format timestamps clearly.
def current_time_str() -> str:
    return time.strftime("%H:%M:%S", time.localtime())


# Create a simple rank aware logger class.
class RankLogger:
    def __init__(self, global_rank: int, world_size: int,
                 is_primary: bool) -> None:
        self.global_rank = global_rank
        self.world_size = world_size
        self.is_primary = is_primary

    # Build a short prefix including rank and role.
    def _prefix(self) -> str:
        role = "PRIMARY" if self.is_primary else "WORKER"
        return (f"[{current_time_str()}] "
                f"[rank {self.global_rank}/{self.world_size}] "
                f"[{role}]")

    # Log messages that only primary should emit.
    def log_primary(self, message: str) -> None:
        if not self.is_primary:
            return
        print(f"{self._prefix()} {message}")

    # Log messages that any rank may emit.
    def log_rank(self, message: str) -> None:
        print(f"{self._prefix()} {message}")


# Simulate a tiny distributed training loop.
def simulated_train_step(rank: int, step: int) -> float:
    base_loss = 1.0
    loss = base_loss / (step + 1)
    if loss <= 0.0:
        raise ValueError("Loss must stay positive in this demo.")
    return loss + 0.01 * rank


# Run a short demonstration with three simulated ranks.
def run_demo(world_size: int = 3, steps: int = 3) -> None:
    if world_size <= 0:
        raise ValueError("World size must be positive.")
    if steps <= 0:
        raise ValueError("Steps must be positive.")

    # Create loggers for each simulated rank.
    loggers = []
    for rank in range(world_size):
        is_primary = rank == 0
        logger = RankLogger(global_rank=rank,
                            world_size=world_size,
                            is_primary=is_primary)
        loggers.append(logger)

    # Primary announces the start of training.
    loggers[0].log_primary("Starting simulated distributed training run.")

    # Each rank performs a few training steps.
    for step in range(steps):
        for rank, logger in enumerate(loggers):
            loss = simulated_train_step(rank=rank, step=step)
            if logger.is_primary:
                logger.log_primary(
                    f"Step {step} summary loss={loss:.3f}.")
            else:
                logger.log_rank(
                    f"Step {step} local loss={loss:.3f}.")

    # Primary prints a short checkpoint style message.
    loggers[0].log_primary("Saving checkpoint at final step.")


# Execute the demonstration when running this script.
run_demo(world_size=3, steps=2)




### **3.2. Distributed Checkpoint Design**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master PyTorch 2.10.0/Module_08/Lecture_B/image_03_02.jpg?v=1769770647" width="250">



>* Centralized checkpoints capture full, consistent training state
>* Enable seamless resume despite failures or hardware differences

>* Distributed checkpoints may be sharded across ranks
>* Choose sharding or centralization based on scale

>* Plan checkpoint timing, naming, and global coordination
>* Use atomic writes to avoid corruption, enable recovery



In [None]:
#@title Python Code - Distributed Checkpoint Design

# This script shows simple distributed style checkpointing.
# We simulate two workers and a coordinator process.
# Focus is on safe logging and checkpoint saving.

# !pip install tensorflow==2.20.0.

# Import standard libraries for paths and randomness.
import os
import json
import random

# Set deterministic random seeds for reproducibility.
random.seed(42)
os.environ["PYTHONHASHSEED"] = "42"

# Define a tiny training state dictionary structure.
training_state = {
    "global_step": 0,
    "epoch": 0,
}

# Define a tiny model state dictionary structure.
model_state = {
    "weight": 0.0,
    "bias": 0.0,
}

# Define a tiny optimizer state dictionary structure.
optimizer_state = {
    "learning_rate": 0.1,
}

# Create a directory for checkpoints if missing.
CHECKPOINT_DIR = "checkpoints_demo"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Define a helper to build rank specific file names.
def rank_file_name(rank, suffix):
    return os.path.join(CHECKPOINT_DIR, f"rank{rank}_{suffix}.json")


# Define a helper to build the coordinator index name.
INDEX_FILE = os.path.join(CHECKPOINT_DIR, "checkpoint_index.json")

# Simulate a worker that owns a shard of model state.
def simulate_worker_step(rank, step):
    shard = {
        "rank": rank,
        "local_step": step,
        "weight_shard": model_state["weight"] + rank,
    }
    return shard

# Safely write a json file using a temporary name.
def safe_write_json(path, data):
    tmp_path = path + ".tmp"
    with open(tmp_path, "w", encoding="utf-8") as f:
        json.dump(data, f)
    os.replace(tmp_path, path)

# Simulate all workers saving local shards to disk.
def save_sharded_checkpoint(global_step, world_size):
    shards = []
    for rank in range(world_size):
        shard_state = simulate_worker_step(rank, global_step)
        shard_path = rank_file_name(rank, f"step{global_step}")
        safe_write_json(shard_path, shard_state)
        shards.append({"rank": rank, "path": shard_path})
    index = {
        "global_step": global_step,
        "shards": shards,
        "training_state": dict(training_state),
        "optimizer_state": dict(optimizer_state),
    }
    safe_write_json(INDEX_FILE, index)

# Load the latest checkpoint index if it exists.
def load_latest_index():
    if not os.path.exists(INDEX_FILE):
        return None
    with open(INDEX_FILE, "r", encoding="utf-8") as f:
        index = json.load(f)
    return index

# Reconstruct full state from index and shard files.
def restore_from_index(index):
    shards = []
    for shard_info in index["shards"]:
        path = shard_info["path"]
        if not os.path.exists(path):
            raise FileNotFoundError(f"Missing shard {path}")
        with open(path, "r", encoding="utf-8") as f:
            shards.append(json.load(f))
    return {
        "index": index,
        "shards": shards,
    }

# Simulate a short training loop with checkpointing.
def run_training(world_size, total_steps):
    print("Simulating distributed style checkpointing.")
    for step in range(1, total_steps + 1):
        training_state["global_step"] = step
        training_state["epoch"] = 0
        model_state["weight"] += 0.01
        model_state["bias"] += 0.001
        if step % 2 == 0:
            save_sharded_checkpoint(step, world_size)
            print(f"Checkpoint saved at global_step {step}.")

# Demonstrate a clean resume from the last checkpoint.
def demo_resume():
    index = load_latest_index()
    if index is None:
        print("No checkpoint index found on disk.")
        return
    restored = restore_from_index(index)
    print("Restored global_step from index:", restored["index"]["global_step"])
    print("Number of shards restored:", len(restored["shards"]))
    print("Example shard keys:", list(restored["shards"][0].keys()))

# Main execution that ties everything together.
if __name__ == "__main__":
    WORLD_SIZE = 2
    TOTAL_STEPS = 4
    run_training(WORLD_SIZE, TOTAL_STEPS)
    demo_resume()




### **3.3. Reliable Run Recovery**

<img src="https://cdn.jsdelivr.net/gh/mhrafiei/contents@main/LFF/Master PyTorch 2.10.0/Module_08/Lecture_B/image_03_03.jpg?v=1769770731" width="250">



>* Assume failures are normal in distributed training
>* Checkpoint full, consistent training state across ranks

>* Coordinate one writer and synchronized checkpoint save steps
>* Use atomic file writes to avoid corrupted checkpoints

>* Regularly test stopping and resuming from checkpoints
>* Store rich metadata so restarts stay reproducible



In [None]:
#@title Python Code - Reliable Run Recovery

# This script shows reliable run recovery basics.
# We simulate checkpoints and safe resume behavior.
# Focus is on logging and atomic checkpoint saving.
# !pip install tensorflow==2.20.0.

# Import standard libraries for paths and randomness.
import os
import json
import random

# Import TensorFlow for a tiny training example.
import tensorflow as tf

# Set deterministic seeds for reproducible behavior.
random.seed(7)
os.environ["PYTHONHASHSEED"] = "7"
tf.random.set_seed(7)

# Define a small directory for checkpoints.
BASE_DIR = "reliable_run_demo"
CHECKPOINT_DIR = os.path.join(BASE_DIR, "ckpts")

# Create directories if they do not already exist.
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Define simple helper to print a short header.
def print_header():
    print("TensorFlow version:", tf.__version__)

# Define a tiny model for demonstration only.
def build_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(28, 28)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(32, activation="relu"),
        tf.keras.layers.Dense(10, activation="softmax"),
    ])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model

# Load a tiny subset of MNIST for quick runs.
def load_data(num_samples=2048):
    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
    x_train = x_train[:num_samples].astype("float32") / 255.0
    y_train = y_train[:num_samples].astype("int64")
    return x_train, y_train

# Build a small dictionary describing run state.
def build_state(epoch, step, history):
    state = {
        "epoch": int(epoch),
        "step": int(step),
        "history": history,
    }
    return state

# Save checkpoint atomically using a temporary file.
def save_checkpoint_atomic(model, state, base_path):
    tmp_path = base_path + ".tmp.weights.h5"
    final_path = base_path + ".ckpt.weights.h5"
    model.save_weights(tmp_path)
    with open(tmp_path + ".json", "w", encoding="utf-8") as f:
        json.dump(state, f)
    os.replace(tmp_path, final_path)
    os.replace(tmp_path + ".json", final_path + ".json")
    return final_path

# Load checkpoint if it exists and return state.
def load_checkpoint_if_available(model, base_path):
    final_path = base_path + ".ckpt.weights.h5"
    state_path = final_path + ".json"
    if not os.path.exists(final_path):
        return None
    if not os.path.exists(state_path):
        return None
    model.load_weights(final_path)
    with open(state_path, "r", encoding="utf-8") as f:
        state = json.load(f)
    return state

# Simulate a single process acting as rank zero.
IS_MAIN_RANK = True

# Main training function with safe resume behavior.
def run_training(max_epochs=3, steps_per_epoch=5):
    print_header()
    x_train, y_train = load_data()
    model = build_model()
    base_path = os.path.join(CHECKPOINT_DIR, "demo")
    state = load_checkpoint_if_available(model, base_path)
    start_epoch = 0
    history = {"loss": [], "acc": []}
    if state is not None:
        start_epoch = state.get("epoch", 0)
        history = state.get("history", history)
        print("Resumed from epoch", start_epoch)
    else:
        print("No checkpoint found, starting fresh")
    dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    dataset = dataset.shuffle(2048, seed=7).batch(64)
    for epoch in range(start_epoch, max_epochs):
        step = 0
        for batch_x, batch_y in dataset.take(steps_per_epoch):
            if batch_x.shape[0] == 0:
                continue
            result = model.train_on_batch(batch_x, batch_y, return_dict=True)
            history["loss"].append(float(result["loss"]))
            history["acc"].append(float(result["accuracy"]))
            step += 1
        if IS_MAIN_RANK:
            state = build_state(epoch + 1, step, history)
            ckpt_path = save_checkpoint_atomic(model, state, base_path)
            print("Saved checkpoint at epoch", epoch + 1)
    print("Final recorded epochs:", len(history["loss"]))


run_training()



# <font color="#418FDE" size="6.5" uppercase>**Scaling Practices**</font>


In this lecture, you learned to:
- Apply gradient accumulation and effective batch size strategies in distributed training setups. 
- Combine mixed precision with DDP to improve throughput while maintaining convergence. 
- Implement basic logging and checkpointing that work correctly in multi‑process environments. 

In the next Module (Module 9), we will go over 'Export and Deployment'