# Elastic Fault-Tolerant Distributed Training Harness

This notebook demonstrates the elastic training harness that:
- Automatically detects worker failures and re-balances workload
- Resumes training from latest checkpoint with RTO < 30 seconds
- Handles dynamic world_size changes without data duplication/skipping

## Features
- **Token-based data sharding** - Correct resumption even when topology changes
- **Multi-tier checkpointing** - Memory → NVMe → S3 fallback hierarchy
- **LR scaling** - Automatic learning rate adjustment for topology changes
- **Gradient accumulation** - Maintain constant global batch size
- **Chaos testing** - Built-in random crash simulation

## 1. Setup

In [None]:
# Install dependencies
!pip install torch torchvision torchaudio --quiet
!pip install omegaconf pyyaml boto3 psutil --quiet

In [None]:
# Clone the repository (or upload files)
# Uncomment and modify the URL if using a remote repository
# !git clone https://github.com/your-org/elastic-training-harness.git
# %cd elastic-training-harness

# For this demo, we'll create the necessary files inline

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.distributed as dist
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Dict, Any
import time
import random

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Core Components

### 2.1 Learning Rate Scaling

When the world size changes (workers join/leave), we need to adjust the learning rate to maintain training stability.

In [None]:
class ScalingRule(Enum):
    """Learning rate scaling rules for topology changes."""
    LINEAR = "linear"  # lr *= batch_new / batch_old
    SQRT = "sqrt"      # lr *= sqrt(batch_new / batch_old)
    NONE = "none"      # No scaling


@dataclass
class ScalingConfig:
    """Configuration for LR scaling."""
    base_lr: float
    base_batch_size: int
    base_world_size: int
    scaling_rule: ScalingRule = ScalingRule.LINEAR
    warmup_steps: int = 100


@dataclass
class GradAccumulationConfig:
    """Configuration for gradient accumulation."""
    target_global_batch_size: int
    local_batch_size: int
    base_world_size: int


class LRScalingManager:
    """Manages learning rate scaling during topology changes."""

    def __init__(self, config: ScalingConfig, optimizer: torch.optim.Optimizer):
        self.config = config
        self.optimizer = optimizer
        self.current_world_size = config.base_world_size
        self.current_lr = config.base_lr
        self._warmup_step = 0
        self._warmup_active = False
        self._warmup_start_lr = config.base_lr
        self._warmup_target_lr = config.base_lr

    def on_topology_change(self, new_world_size: int) -> float:
        """Adjust learning rate when world size changes."""
        if new_world_size == self.current_world_size:
            return self.current_lr

        old_effective_batch = self.config.base_batch_size * self.current_world_size
        new_effective_batch = self.config.base_batch_size * new_world_size

        # Calculate new LR based on scaling rule
        if self.config.scaling_rule == ScalingRule.LINEAR:
            scale = new_effective_batch / old_effective_batch
        elif self.config.scaling_rule == ScalingRule.SQRT:
            scale = (new_effective_batch / old_effective_batch) ** 0.5
        else:
            scale = 1.0

        new_lr = self.current_lr * scale

        # Start warmup from current LR to new LR
        self._warmup_start_lr = self.current_lr
        self._warmup_target_lr = new_lr
        self._warmup_step = 0
        self._warmup_active = True

        self.current_world_size = new_world_size
        self.current_lr = new_lr

        return new_lr

    def step_warmup(self) -> float:
        """Step the warmup schedule."""
        if not self._warmup_active:
            return self.current_lr

        self._warmup_step += 1
        progress = min(1.0, self._warmup_step / self.config.warmup_steps)

        # Linear interpolation
        current = self._warmup_start_lr + progress * (self._warmup_target_lr - self._warmup_start_lr)

        # Update optimizer
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = current

        if self._warmup_step >= self.config.warmup_steps:
            self._warmup_active = False

        return current


class GradientAccumulationManager:
    """Manages gradient accumulation to maintain constant global batch size."""

    def __init__(self, config: GradAccumulationConfig):
        self.config = config
        self.current_world_size = config.base_world_size
        self._micro_step = 0
        self._accumulation_steps = self._calculate_accumulation(config.base_world_size)

    def _calculate_accumulation(self, world_size: int) -> int:
        """Calculate accumulation steps for given world size."""
        per_step_batch = self.config.local_batch_size * world_size
        return max(1, round(self.config.target_global_batch_size / per_step_batch))

    def on_topology_change(self, new_world_size: int) -> int:
        """Adjust accumulation steps when world size changes."""
        self.current_world_size = new_world_size
        self._accumulation_steps = self._calculate_accumulation(new_world_size)
        self._micro_step = 0
        return self._accumulation_steps

    @property
    def accumulation_steps(self) -> int:
        return self._accumulation_steps

    def should_step(self) -> bool:
        """Check if optimizer should step."""
        self._micro_step += 1
        if self._micro_step >= self._accumulation_steps:
            self._micro_step = 0
            return True
        return False


print("LR Scaling and Gradient Accumulation managers defined.")

### 2.2 Checkpoint Manager

Multi-tier checkpointing with memory, NVMe, and S3 backends.

In [None]:
class CheckpointTier(Enum):
    """Checkpoint storage tiers in order of speed."""
    MEMORY = "memory"  # Fastest, volatile
    NVME = "nvme"      # Fast, persistent
    S3 = "s3"          # Slow, durable


@dataclass
class CheckpointState:
    """Complete training state for checkpointing."""
    step: int
    model_state: Dict[str, torch.Tensor]
    optimizer_state: Dict[str, Any]
    dataset_state: Dict[str, Any] = field(default_factory=dict)
    world_size: int = 1
    metrics: Dict[str, float] = field(default_factory=dict)
    timestamp: float = field(default_factory=time.time)


class MemorySnapshotBackend:
    """In-memory checkpoint storage with circular buffer."""

    def __init__(self, max_snapshots: int = 2):
        self.max_snapshots = max_snapshots
        self._snapshots: list = []

    def save(self, state: CheckpointState) -> None:
        """Save snapshot to memory."""
        # Deep copy tensors to CPU
        cpu_state = CheckpointState(
            step=state.step,
            model_state={k: v.cpu().clone() for k, v in state.model_state.items()},
            optimizer_state=state.optimizer_state,
            dataset_state=state.dataset_state.copy(),
            world_size=state.world_size,
            metrics=state.metrics.copy(),
            timestamp=state.timestamp,
        )

        self._snapshots.append(cpu_state)
        if len(self._snapshots) > self.max_snapshots:
            self._snapshots.pop(0)

    def load_latest(self) -> Optional[CheckpointState]:
        """Load most recent snapshot."""
        if not self._snapshots:
            return None
        return self._snapshots[-1]

    def clear(self) -> None:
        """Clear all snapshots."""
        self._snapshots.clear()


class NVMeBackend:
    """Local NVMe/SSD checkpoint storage."""

    def __init__(self, base_path: str, keep_last_n: int = 3):
        self.base_path = base_path
        self.keep_last_n = keep_last_n
        os.makedirs(base_path, exist_ok=True)

    def save(self, state: CheckpointState) -> str:
        """Save checkpoint to disk."""
        filename = f"checkpoint_step_{state.step}.pt"
        path = os.path.join(self.base_path, filename)

        torch.save({
            'step': state.step,
            'model_state': state.model_state,
            'optimizer_state': state.optimizer_state,
            'dataset_state': state.dataset_state,
            'world_size': state.world_size,
            'metrics': state.metrics,
            'timestamp': state.timestamp,
        }, path)

        self._cleanup_old_checkpoints()
        return path

    def load_latest(self) -> Optional[CheckpointState]:
        """Load most recent checkpoint."""
        checkpoints = sorted(
            [f for f in os.listdir(self.base_path) if f.startswith('checkpoint_')],
            key=lambda x: int(x.split('_')[-1].replace('.pt', '')),
            reverse=True
        )

        if not checkpoints:
            return None

        path = os.path.join(self.base_path, checkpoints[0])
        data = torch.load(path, map_location='cpu')

        return CheckpointState(
            step=data['step'],
            model_state=data['model_state'],
            optimizer_state=data['optimizer_state'],
            dataset_state=data.get('dataset_state', {}),
            world_size=data.get('world_size', 1),
            metrics=data.get('metrics', {}),
            timestamp=data.get('timestamp', 0),
        )

    def _cleanup_old_checkpoints(self) -> None:
        """Remove old checkpoints beyond keep_last_n."""
        checkpoints = sorted(
            [f for f in os.listdir(self.base_path) if f.startswith('checkpoint_')],
            key=lambda x: int(x.split('_')[-1].replace('.pt', '')),
            reverse=True
        )

        for old_ckpt in checkpoints[self.keep_last_n:]:
            os.remove(os.path.join(self.base_path, old_ckpt))


print("Checkpoint backends defined.")

### 2.3 Simple Transformer Model

In [None]:
class SimpleTransformerLM(nn.Module):
    """Simple transformer language model for demonstration."""

    def __init__(
        self,
        vocab_size: int = 50257,
        d_model: int = 256,
        nhead: int = 4,
        num_layers: int = 2,
        dim_feedforward: int = 512,
        max_seq_length: int = 128,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_length, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.output_proj = nn.Linear(d_model, vocab_size)
        self.output_proj.weight = self.embedding.weight  # Tie weights

    def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        x = self.embedding(input_ids) + self.pos_embedding(positions)

        mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=input_ids.device)
        x = self.transformer(x, mask=mask, is_causal=True)

        logits = self.output_proj(x)

        output = {"logits": logits}

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
            output["loss"] = loss

        return output


# Test model
model = SimpleTransformerLM()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Training Loop with Fault Tolerance

This demonstrates a single-process training loop with:
- Multi-tier checkpointing
- LR scaling simulation
- Chaos mode (random crashes)

In [None]:
def train_with_fault_tolerance(
    model: nn.Module,
    max_steps: int = 200,
    batch_size: int = 4,
    seq_length: int = 128,
    vocab_size: int = 50257,
    checkpoint_interval: int = 50,
    memory_snapshot_interval: int = 10,
    chaos_enabled: bool = False,
    chaos_probability: float = 0.02,
    chaos_after_step: int = 30,
):
    """Training loop with fault tolerance features."""

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Setup optimizer and scaling
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

    lr_config = ScalingConfig(
        base_lr=1e-4,
        base_batch_size=batch_size,
        base_world_size=1,
        scaling_rule=ScalingRule.LINEAR,
        warmup_steps=10,
    )
    lr_manager = LRScalingManager(lr_config, optimizer)

    # Setup checkpointing
    memory_backend = MemorySnapshotBackend(max_snapshots=2)
    nvme_backend = NVMeBackend('/tmp/elastic_checkpoints', keep_last_n=2)

    # Try to resume from checkpoint
    start_step = 0
    checkpoint = nvme_backend.load_latest()
    if checkpoint:
        model.load_state_dict(checkpoint.model_state)
        optimizer.load_state_dict(checkpoint.optimizer_state)
        start_step = checkpoint.step + 1
        print(f"Resumed from checkpoint at step {checkpoint.step}")
    else:
        print("Starting fresh training run")

    # Synthetic data generator
    def get_batch():
        input_ids = torch.randint(0, vocab_size, (batch_size, seq_length), device=device)
        labels = torch.randint(0, vocab_size, (batch_size, seq_length), device=device)
        return input_ids, labels

    # Training loop
    model.train()
    losses = []
    step_times = []

    print(f"\nStarting training from step {start_step} to {max_steps}")
    print(f"Chaos mode: {'ENABLED' if chaos_enabled else 'disabled'}")
    print("-" * 50)

    for step in range(start_step, max_steps):
        step_start = time.time()

        # Chaos: maybe crash (for testing recovery)
        if chaos_enabled and step > chaos_after_step:
            if random.random() < chaos_probability:
                print(f"\nCHAOS: Simulated crash at step {step}!")
                print("Run this cell again to resume from checkpoint.")
                return losses, step

        # Memory snapshot
        if step > 0 and step % memory_snapshot_interval == 0:
            state = CheckpointState(
                step=step,
                model_state=model.state_dict(),
                optimizer_state=optimizer.state_dict(),
            )
            memory_backend.save(state)

        # Get batch and forward pass
        input_ids, labels = get_batch()
        outputs = model(input_ids, labels=labels)
        loss = outputs['loss']

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Step LR warmup
        current_lr = lr_manager.step_warmup()

        losses.append(loss.item())
        step_times.append(time.time() - step_start)

        # Logging
        if step % 10 == 0:
            avg_loss = sum(losses[-10:]) / min(10, len(losses))
            avg_time = sum(step_times[-10:]) / min(10, len(step_times))
            tokens_per_sec = batch_size * seq_length / avg_time
            print(f"Step {step:4d} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e} | Tokens/s: {tokens_per_sec:.0f}")

        # NVMe checkpoint
        if step > 0 and step % checkpoint_interval == 0:
            state = CheckpointState(
                step=step,
                model_state=model.state_dict(),
                optimizer_state=optimizer.state_dict(),
                metrics={'loss': loss.item()},
            )
            path = nvme_backend.save(state)
            print(f"  [Checkpoint saved: {path}]")
            memory_backend.clear()  # Clear memory after persistent save

    print("-" * 50)
    print(f"Training complete! Final loss: {losses[-1]:.4f}")
    return losses, max_steps

## 4. Run Training

Run the training loop. If chaos mode is enabled and a crash occurs, simply run the cell again to resume from the last checkpoint.

In [None]:
# Create model
model = SimpleTransformerLM(
    vocab_size=50257,
    d_model=256,
    nhead=4,
    num_layers=2,
    dim_feedforward=512,
    max_seq_length=128,
)

# Run training (set chaos_enabled=True to test recovery)
losses, final_step = train_with_fault_tolerance(
    model,
    max_steps=100,
    batch_size=4,
    seq_length=128,
    checkpoint_interval=25,
    memory_snapshot_interval=5,
    chaos_enabled=False,  # Set to True to test crash recovery
    chaos_probability=0.05,
    chaos_after_step=20,
)

## 5. Visualize Training

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

## 6. Demonstrate Topology Change Handling

This simulates what happens when workers join or leave the training cluster.

In [None]:
print("Simulating Topology Changes")
print("=" * 50)

# Setup
dummy_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

lr_config = ScalingConfig(
    base_lr=1e-4,
    base_batch_size=8,
    base_world_size=4,
    scaling_rule=ScalingRule.LINEAR,
)
lr_manager = LRScalingManager(lr_config, dummy_optimizer)

accum_config = GradAccumulationConfig(
    target_global_batch_size=256,
    local_batch_size=8,
    base_world_size=4,
)
accum_manager = GradientAccumulationManager(accum_config)

print(f"Initial state (4 GPUs):")
print(f"  LR: {lr_manager.current_lr:.2e}")
print(f"  Accumulation steps: {accum_manager.accumulation_steps}")
print(f"  Global batch: {8 * 4 * accum_manager.accumulation_steps}")
print()

# Simulate losing a GPU
print("Worker failure: 4 GPUs -> 3 GPUs")
new_lr = lr_manager.on_topology_change(3)
new_accum = accum_manager.on_topology_change(3)
print(f"  New LR: {new_lr:.2e}")
print(f"  New accumulation steps: {new_accum}")
print(f"  Effective global batch: {8 * 3 * new_accum}")
print()

# Simulate adding GPUs
print("Workers joining: 3 GPUs -> 6 GPUs")
new_lr = lr_manager.on_topology_change(6)
new_accum = accum_manager.on_topology_change(6)
print(f"  New LR: {new_lr:.2e}")
print(f"  New accumulation steps: {new_accum}")
print(f"  Effective global batch: {8 * 6 * new_accum}")

## 7. Multi-Process Training (Local)

For actual distributed training, use the launch script. This cell shows how to run it in Colab.

In [None]:
# Write a minimal training script
training_script = '''
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

def main():
    # Setup
    dist.init_process_group(backend="gloo")
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    print(f"Worker {rank}/{world_size} started")

    # Simple model
    model = nn.Linear(10, 10)
    model = DDP(model)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    # Training
    for step in range(10):
        x = torch.randn(4, 10)
        loss = model(x).sum()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if rank == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")

    dist.destroy_process_group()
    print(f"Worker {rank} finished")

if __name__ == "__main__":
    main()
'''

with open('/tmp/distributed_demo.py', 'w') as f:
    f.write(training_script)

print("Distributed training script written to /tmp/distributed_demo.py")

In [None]:
# Run distributed training with 2 processes
# Note: In Colab, we use torchrun with c10d backend (no etcd needed)
!torchrun --standalone --nproc-per-node=2 /tmp/distributed_demo.py

## Summary

This notebook demonstrated:

1. **LR Scaling** - Automatically adjust learning rate when topology changes
2. **Gradient Accumulation** - Maintain constant global batch size
3. **Multi-tier Checkpointing** - Memory and NVMe backends with automatic cleanup
4. **Fault Tolerance** - Resume from checkpoint after simulated crashes
5. **Distributed Training** - Using torchrun with c10d backend

For production use:
- Use the full `elastic_harness` package
- Deploy etcd for multi-node rendezvous
- Configure S3 backend for durable checkpoints
- Enable chaos testing to validate fault tolerance