# Elastic Fault-Tolerant Distributed Training Harness

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mmcmanus1/elastic-training-harness/blob/main/notebooks/elastic_training_colab.ipynb)

This notebook demonstrates the `elastic_harness` package for fault-tolerant distributed training:

- **Multi-tier checkpointing** - Memory → NVMe → S3 fallback hierarchy
- **LR scaling** - Automatic learning rate adjustment for topology changes
- **Gradient accumulation** - Maintain constant global batch size
- **Fault tolerance** - Resume from checkpoint after crashes
- **Distributed training** - Using torchrun with elastic scaling

## 1. Setup

Clone the repository and install the package.

In [None]:
# Clone and install the elastic-training-harness package
!git clone https://github.com/mmcmanus1/elastic-training-harness.git
%cd elastic-training-harness
!pip install -e . --quiet
%cd /content

In [None]:
import torch
import torch.nn as nn
import time
import random
import os

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Import from elastic_harness Package

Import the core components from the installed package.

In [None]:
# Checkpoint management
from elastic_harness.checkpoint import (
    CheckpointConfig,
    CheckpointState,
    CheckpointManager,
    CheckpointTier,
    MemorySnapshotBackend,
    NVMeBackend,
)

# LR scaling and gradient accumulation
from elastic_harness.scaling import (
    ScalingRule,
    ScalingConfig,
    LRScalingManager,
    GradAccumulationConfig,
    GradientAccumulationManager,
    ElasticScalingManager,
)

# Distributed training utilities
from elastic_harness.agent import (
    setup_distributed_environment,
    get_world_info,
)

print("Successfully imported elastic_harness components!")

## 3. Define Model

A simple transformer language model for demonstration. Sized for T4 GPU (16GB VRAM).

In [None]:
class SimpleTransformerLM(nn.Module):
    """Simple transformer language model for demonstration."""

    def __init__(
        self,
        vocab_size: int = 50257,
        d_model: int = 256,
        nhead: int = 4,
        num_layers: int = 2,
        dim_feedforward: int = 512,
        max_seq_length: int = 128,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_length, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.output_proj = nn.Linear(d_model, vocab_size)
        self.output_proj.weight = self.embedding.weight  # Tie weights

    def forward(self, input_ids: torch.Tensor, labels=None):
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        x = self.embedding(input_ids) + self.pos_embedding(positions)

        mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=input_ids.device)
        x = self.transformer(x, mask=mask, is_causal=True)

        logits = self.output_proj(x)

        output = {"logits": logits}

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
            output["loss"] = loss

        return output


# Test model
model = SimpleTransformerLM()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Training with Fault Tolerance

Training loop using the `elastic_harness` checkpoint and scaling managers.

In [None]:
def train_with_fault_tolerance(
    model: nn.Module,
    max_steps: int = 100,
    batch_size: int = 4,
    seq_length: int = 128,
    vocab_size: int = 50257,
    checkpoint_interval: int = 25,
    memory_snapshot_interval: int = 5,
    chaos_enabled: bool = False,
    chaos_probability: float = 0.02,
    chaos_after_step: int = 30,
):
    """Training loop with fault tolerance using elastic_harness."""

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Setup optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

    # Setup LR scaling manager from elastic_harness
    lr_config = ScalingConfig(
        base_lr=1e-4,
        base_batch_size=batch_size,
        base_world_size=1,
        scaling_rule=ScalingRule.LINEAR,
        warmup_steps=10,
    )
    lr_manager = LRScalingManager(lr_config, optimizer)

    # Setup checkpoint backends from elastic_harness
    memory_backend = MemorySnapshotBackend(max_snapshots=2)
    nvme_backend = NVMeBackend(base_path='/tmp/elastic_checkpoints')

    # Try to resume from checkpoint
    start_step = 0
    latest_path = nvme_backend.get_latest_checkpoint()
    if latest_path:
        checkpoint = nvme_backend.load(latest_path)
        model.load_state_dict(checkpoint.model_state_dict)
        optimizer.load_state_dict(checkpoint.optimizer_state_dict)
        start_step = checkpoint.step + 1
        print(f"Resumed from checkpoint at step {checkpoint.step}")
    else:
        print("Starting fresh training run")

    # Synthetic data generator
    def get_batch():
        input_ids = torch.randint(0, vocab_size, (batch_size, seq_length), device=device)
        labels = torch.randint(0, vocab_size, (batch_size, seq_length), device=device)
        return input_ids, labels

    # Training loop
    model.train()
    losses = []
    step_times = []

    print(f"\nStarting training from step {start_step} to {max_steps}")
    print(f"Chaos mode: {'ENABLED' if chaos_enabled else 'disabled'}")
    print("-" * 50)

    for step in range(start_step, max_steps):
        step_start = time.time()

        # Chaos: maybe crash (for testing recovery)
        if chaos_enabled and step > chaos_after_step:
            if random.random() < chaos_probability:
                print(f"\nCHAOS: Simulated crash at step {step}!")
                print("Run this cell again to resume from checkpoint.")
                return losses, step

        # Memory snapshot (using elastic_harness backend)
        if step > 0 and step % memory_snapshot_interval == 0:
            state = CheckpointState(
                step=step,
                model_state_dict=model.state_dict(),
                optimizer_state_dict=optimizer.state_dict(),
            )
            memory_backend.save(state)

        # Get batch and forward pass
        input_ids, labels = get_batch()
        outputs = model(input_ids, labels=labels)
        loss = outputs['loss']

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Step LR warmup
        current_lr = lr_manager.step_warmup()

        losses.append(loss.item())
        step_times.append(time.time() - step_start)

        # Logging
        if step % 10 == 0:
            avg_loss = sum(losses[-10:]) / min(10, len(losses))
            avg_time = sum(step_times[-10:]) / min(10, len(step_times))
            tokens_per_sec = batch_size * seq_length / avg_time
            print(f"Step {step:4d} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e} | Tokens/s: {tokens_per_sec:.0f}")

        # NVMe checkpoint (using elastic_harness backend)
        if step > 0 and step % checkpoint_interval == 0:
            state = CheckpointState(
                step=step,
                model_state_dict=model.state_dict(),
                optimizer_state_dict=optimizer.state_dict(),
                metrics={'loss': loss.item()},
            )
            path = f"checkpoint_step_{step:08d}.pt"
            nvme_backend.save(state, path)
            nvme_backend.cleanup_old_checkpoints(keep_last=2)
            print(f"  [Checkpoint saved: {path}]")
            memory_backend.clear()  # Clear memory after persistent save

    print("-" * 50)
    print(f"Training complete! Final loss: {losses[-1]:.4f}")
    return losses, max_steps

## 5. Run Training

Run the training loop. If chaos mode is enabled and a crash occurs, re-run the cell to resume.

In [None]:
# Create model
model = SimpleTransformerLM(
    vocab_size=50257,
    d_model=256,
    nhead=4,
    num_layers=2,
    dim_feedforward=512,
    max_seq_length=128,
)

# Run training (set chaos_enabled=True to test recovery)
losses, final_step = train_with_fault_tolerance(
    model,
    max_steps=100,
    batch_size=4,
    seq_length=128,
    checkpoint_interval=25,
    memory_snapshot_interval=5,
    chaos_enabled=False,  # Set to True to test crash recovery
    chaos_probability=0.05,
    chaos_after_step=20,
)

## 6. Visualize Training

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

## 7. Topology Change Simulation

Demonstrate how `elastic_harness` handles dynamic world size changes.

In [None]:
print("Simulating Topology Changes")
print("=" * 50)

# Setup with elastic_harness scaling managers
dummy_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

lr_config = ScalingConfig(
    base_lr=1e-4,
    base_batch_size=8,
    base_world_size=4,
    scaling_rule=ScalingRule.LINEAR,
)
lr_manager = LRScalingManager(lr_config, dummy_optimizer)

accum_config = GradAccumulationConfig(
    target_global_batch_size=256,
    local_batch_size=8,
    base_world_size=4,
)
accum_manager = GradientAccumulationManager(accum_config)

print(f"Initial state (4 GPUs):")
print(f"  LR: {lr_manager.current_lr:.2e}")
print(f"  Accumulation steps: {accum_manager.accumulation_steps}")
print(f"  Global batch: {8 * 4 * accum_manager.accumulation_steps}")
print()

# Simulate losing a GPU
print("Worker failure: 4 GPUs -> 3 GPUs")
new_lr = lr_manager.on_topology_change(3)
new_accum = accum_manager.on_topology_change(3)
print(f"  New LR: {new_lr:.2e}")
print(f"  New accumulation steps: {new_accum}")
print(f"  Effective global batch: {8 * 3 * new_accum}")
print()

# Simulate adding GPUs
print("Workers joining: 3 GPUs -> 6 GPUs")
new_lr = lr_manager.on_topology_change(6)
new_accum = accum_manager.on_topology_change(6)
print(f"  New LR: {new_lr:.2e}")
print(f"  New accumulation steps: {new_accum}")
print(f"  Effective global batch: {8 * 6 * new_accum}")

## 8. Distributed Training Demo

Run a multi-process training script that imports from `elastic_harness`.

In [None]:
# Write a distributed training script that uses elastic_harness
training_script = '''
import os
import sys
sys.path.insert(0, '/content/elastic-training-harness/src')

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

from elastic_harness.checkpoint import MemorySnapshotBackend, CheckpointState
from elastic_harness.scaling import ScalingConfig, ScalingRule, LRScalingManager

def main():
    # Setup distributed
    dist.init_process_group(backend="gloo")
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    print(f"Worker {rank}/{world_size} started")

    # Simple model with DDP
    model = nn.Linear(10, 10)
    model = DDP(model)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    # Setup elastic_harness components
    lr_config = ScalingConfig(
        base_lr=0.01,
        base_batch_size=4,
        base_world_size=world_size,
        scaling_rule=ScalingRule.LINEAR,
    )
    lr_manager = LRScalingManager(lr_config, optimizer)
    memory_backend = MemorySnapshotBackend(max_snapshots=2)

    # Training loop
    for step in range(10):
        x = torch.randn(4, 10)
        loss = model(x).sum()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Save memory snapshot every 5 steps
        if step % 5 == 0 and step > 0:
            state = CheckpointState(
                step=step,
                model_state_dict=model.module.state_dict(),
                optimizer_state_dict=optimizer.state_dict(),
            )
            memory_backend.save(state)

        if rank == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")

    dist.destroy_process_group()
    print(f"Worker {rank} finished")

if __name__ == "__main__":
    main()
'''

with open('/tmp/distributed_demo.py', 'w') as f:
    f.write(training_script)

print("Distributed training script written to /tmp/distributed_demo.py")

In [None]:
# Run distributed training with 2 processes
!torchrun --standalone --nproc-per-node=2 /tmp/distributed_demo.py

## Summary

This notebook demonstrated the `elastic_harness` package:

1. **Package Installation** - Clone and install from GitHub
2. **LR Scaling** - `LRScalingManager` adjusts learning rate on topology changes
3. **Gradient Accumulation** - `GradientAccumulationManager` maintains constant global batch size
4. **Multi-tier Checkpointing** - `MemorySnapshotBackend` and `NVMeBackend` with automatic cleanup
5. **Fault Tolerance** - Resume from checkpoint after simulated crashes
6. **Distributed Training** - Using torchrun with `elastic_harness` components

For production use:
- Use `CheckpointManager` for coordinated multi-tier checkpointing
- Deploy etcd for multi-node rendezvous
- Configure S3Backend for durable checkpoints
- Use `ElasticScalingManager` for unified LR and accumulation management