# Kubeflow Trainer: Local Training with Multiple Backends

This notebook demonstrates how to run Kubeflow Trainer locally using **three different backends** and shows how easy it is to switch between them.

## Available Backends

| Backend | Container Runtime | Distributed Training | Best For |
|---------|------------------|---------------------|----------|
| **Local Process** | None (native Python) | ❌ Single process only | Quick testing, debugging |
| **Docker** | Docker required | ✅ Multi-node (PyTorch DDP) | Standard container workflow |
| **Podman** | Podman required | ✅ Multi-node (PyTorch DDP) | Rootless containers, security |

## Prerequisites

- **All backends**: Python 3.8+
- **Local Process**: No additional requirements
- **Docker**: Docker Desktop or Docker Engine
- **Podman**: Podman installed

This example trains a CNN on the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset using PyTorch.

## Install Kubeflow SDK

In [None]:
# Base installation (required)
%pip install -U kubeflow-trainer

# Optional extras for container backends:
# %pip install -U kubeflow[docker]  # For Docker
# %pip install -U kubeflow[podman]  # For Podman

## Define Training Functions

We define two training functions:

1. **Single-process** - For Local Process Backend (no distributed training)
2. **Distributed** - For Docker/Podman Backends (PyTorch DDP across multiple nodes)

### Single-Process Training Function

In [None]:
def train_single_process():
    """Train Fashion MNIST CNN - single process (no distributed)."""
    import torch
    import torch.nn.functional as F
    from torch import nn, optim
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms

    print(" Starting single-process training...")

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5, 1)
            self.conv2 = nn.Conv2d(20, 50, 5, 1)
            self.fc1 = nn.Linear(4 * 4 * 50, 500)
            self.fc2 = nn.Linear(500, 10)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2, 2)
            x = x.view(-1, 4 * 4 * 50)
            x = F.relu(self.fc1(x))
            return F.log_softmax(self.fc2(x), dim=1)

    model = Net()
    dataset = datasets.FashionMNIST(
        './data', train=True, download=True,
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    )
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    
    for epoch in range(1, 3):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            loss = F.nll_loss(model(data), target)
            loss.backward()
            optimizer.step()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

    torch.save(model.state_dict(), "fashion_mnist_cnn.pt")
    print("Training complete!")

### Distributed Training Function (PyTorch DDP)

In [None]:
def train_distributed():
    """Train Fashion MNIST CNN - distributed across multiple nodes with PyTorch DDP."""
    import os
    import torch
    import torch.distributed as dist
    import torch.nn.functional as F
    from torch import nn
    from torch.utils.data import DataLoader, DistributedSampler
    from torchvision import datasets, transforms

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5, 1)
            self.conv2 = nn.Conv2d(20, 50, 5, 1)
            self.fc1 = nn.Linear(4 * 4 * 50, 500)
            self.fc2 = nn.Linear(500, 10)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2, 2)
            x = x.view(-1, 4 * 4 * 50)
            x = F.relu(self.fc1(x))
            return F.log_softmax(self.fc2(x), dim=1)

    # Setup distributed training
    device, backend = ("cuda", "nccl") if torch.cuda.is_available() else ("cpu", "gloo")
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    dist.init_process_group(backend=backend)
    rank = dist.get_rank()
    
    print(f"Node {rank}/{dist.get_world_size()} - Device: {device}, Backend: {backend}")

    # Create model with DDP
    device = torch.device(f"{device}:{local_rank}")
    model = nn.parallel.DistributedDataParallel(Net().to(device))
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    # Load data with distributed sampler
    data_dir = f"/tmp/fashion-mnist-{rank}"
    os.makedirs(data_dir, exist_ok=True)
    dataset = datasets.FashionMNIST(
        data_dir, train=True, download=True,
        transform=transforms.Compose([transforms.ToTensor()])
    )
    train_loader = DataLoader(dataset, batch_size=100, sampler=DistributedSampler(dataset))

    # Training loop
    dist.barrier()
    for epoch in range(1, 3):
        model.train()
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            loss = F.nll_loss(model(inputs), labels)
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0 and rank == 0:
                print(f"Epoch {epoch} [{batch_idx * len(inputs)}/{len(train_loader.dataset)} "
                      f"({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

    dist.barrier()
    if rank == 0:
        print("Distributed training complete!")
    dist.destroy_process_group()

---

## Backend Selection

Choose which backend to use by running **ONE** of the following cells:

### Option 1: Local Process Backend

**Use when**: Quick testing, debugging, no containers needed  
**Note**: Only supports single-process training (no distributed)

In [None]:
from kubeflow.trainer import TrainerClient, LocalProcessBackendConfig

# Configure Local Process Backend
backend_config = LocalProcessBackendConfig(
    cleanup_venv=True  # Auto-cleanup virtual environments
)

# Set training parameters
backend = "Local Process"
training_function = train_single_process  # Single process only
num_nodes = None  # Not applicable for local process
packages = ["torch", "torchvision"]

print(f" Selected: {backend} Backend")
print(f"   Training mode: Single process (no distributed)")
print(f"   Container runtime: None")

### Option 2: Docker Backend

**Use when**: Multi-node distributed training, standard container workflow  
**Requires**: Docker Desktop or Docker Engine

In [None]:
from kubeflow.trainer import TrainerClient, ContainerBackendConfig
import os
# Configure Docker Backend
backend_config = ContainerBackendConfig()
# backend_config = ContainerBackendConfig(runtime="docker")  # Force Docker if both Docker/Podman installed

# Optional: For Colima on macOS
# backend_config = ContainerBackendConfig(
#     container_host=f"unix://{os.path.expanduser('~')}/.colima/default/docker.sock"
# )

# Set training parameters
backend = "Docker"
training_function = train_distributed  # Distributed training with PyTorch DDP
num_nodes = 3  # Number of Docker containers (training nodes)
packages = ["torchvision"]

print(f"   Selected: {backend} Backend")
print(f"   Training mode: Distributed (PyTorch DDP)")
print(f"   Number of nodes: {num_nodes}")
print(f"   Container runtime: Docker")

### Option 3: Podman Backend

**Use when**: Rootless containers, security-focused environments  
**Requires**: Podman installed

**Docker vs Podman**:
- Podman: Daemonless, rootless containers (better security)
- Docker: Daemon-based, typically requires root privileges

In [None]:
from kubeflow.trainer import TrainerClient, ContainerBackendConfig

# Configure Podman Backend
backend_config = ContainerBackendConfig(runtime="podman")  # Specify Podman

# Optional: Custom Podman socket
# backend_config = ContainerBackendConfig(
#     runtime="podman",
#     container_host="unix:///run/user/1000/podman/podman.sock"
# )

# Set training parameters
backend = "Podman"
training_function = train_distributed  # Distributed training with PyTorch DDP
num_nodes = 3  # Number of Podman containers (training nodes)
packages = ["torchvision"]

print(f"   Selected: {backend} Backend")
print(f"   Training mode: Distributed (PyTorch DDP)")
print(f"   Number of nodes: {num_nodes}")
print(f"   Container runtime: Podman (rootless)")


## Advanced: Custom Runtime Sources (Optional)

By default, container backends (Docker/Podman) use training runtimes from:
1. **GitHub** - `github://kubeflow/trainer` (official runtimes)
2. **Fallback** - Built-in default images (e.g., PyTorch official image)

You can customize where runtimes are loaded from by configuring `runtime_source`. This is useful for:
- Using custom container images
- Loading runtimes from private repositories
- Testing local runtime definitions

**To use custom sources**, run the cell below after selecting your backend:

In [None]:
# Uncomment and modify to use custom runtime sources
# from kubeflow.trainer import TrainingRuntimeSource

# backend_config.runtime_source = TrainingRuntimeSource(sources=[
#     "github://kubeflow/trainer",                    # Official Kubeflow runtimes (default)
#     "github://myorg/myrepo/path/to/runtimes",       # Custom GitHub repository
#     "https://example.com/custom-runtime.yaml",      # HTTP(S) endpoint
#     "file:///absolute/path/to/runtime.yaml",        # Local YAML file
#     "/absolute/path/to/runtime.yaml",               # Local YAML file (alternate syntax)
# ])

# Sources are checked in priority order. If no runtime is found in any source,
# the system falls back to the default image for the framework (e.g., pytorch/pytorch)

print("Using default runtime sources (no customization needed for this example)")

---

## Initialize Client

Initialize the TrainerClient with your selected backend configuration:

In [None]:
client = TrainerClient(backend_config=backend_config)

print(f"\n TrainerClient initialized with {backend} backend")

## List Available Runtimes

In [None]:
print("Available training runtimes:\n")
for runtime in client.list_runtimes():
    print(f"  • {runtime}")
    if runtime.name == "torch-distributed":
        torch_runtime = runtime

print(f"\n Selected runtime: {torch_runtime.name}")

## Start Training Job

Launch the training job with your selected backend:

In [None]:
from kubeflow.trainer import CustomTrainer

# Build trainer config based on selected backend
trainer_config = {
    "func": training_function,
    "packages_to_install": packages,
}

# Add num_nodes only for distributed backends
if num_nodes is not None:
    trainer_config["num_nodes"] = num_nodes

job_name = client.train(
    trainer=CustomTrainer(**trainer_config),
    runtime=torch_runtime,
)

print(f"\n Training job started: {job_name}")
print(f"   Backend: {backend}")
if num_nodes:
    print(f"   Nodes: {num_nodes}")

## Check Job Status

View the training job status and steps:

In [None]:
job = client.get_job(job_name)

print(f"\n Job Status:")
print(f"   Name: {job.name}")
print(f"   Status: {job.status}")
print(f"   Created: {job.creation_timestamp}")
print(f"\n   Steps:")
for step in job.steps:
    devices_info = f", Devices: {step.device} x {step.device_count}" if hasattr(step, 'device') else ""
    print(f"     • {step.name}: {step.status}{devices_info}")

## Stream Training Logs

Watch the training progress in real-time:

In [None]:
print(f"\n Streaming logs from {backend} backend (Ctrl+C to stop):\n")
print("="*80)

try:
    for log_line in client.get_job_logs(job_name, follow=True):
        print(log_line, end='')
except KeyboardInterrupt:
    print("\n\n⏸  Log streaming stopped by user")

## Wait for Completion

In [None]:
try:
    completed_job = client.wait_for_job_status(
        name=job_name,
        status={"Complete"},
        timeout=600,
        polling_interval=5
    )
    print(f"\n Training job completed successfully!")
except TimeoutError:
    print(f"\n  Job did not complete within timeout")
except RuntimeError as e:
    print(f"\n Job failed: {e}")

## Optional: Inspect Container Resources

For Docker/Podman backends, you can inspect the containers and networks:

**Docker:**
```bash
docker ps --filter label=trainer.kubeflow.ai/trainjob-name
docker network ls --filter label=trainer.kubeflow.org/trainjob-name
```

**Podman:**
```bash
podman ps --filter label=trainer.kubeflow.ai/trainjob-name
podman network ls --filter label=trainer.kubeflow.org/trainjob-name
```

## Clean Up

Delete the training job to free up resources:

In [None]:
client.delete_job(job_name)
print(f"\n Job deleted: {job_name}")
print(f"   Backend resources cleaned up: {backend}")

---

## Summary: Switching Backends

To switch backends, simply:

1. **Run a different backend selection cell** (Option 1, 2, or 3)
2. **Re-run all subsequent cells** starting from "Initialize Client"

### Quick Comparison:

```python
# Local Process Backend
backend_config = LocalProcessBackendConfig(cleanup_venv=True)
training_function = train_single_process  # Single process
num_nodes = None

# Docker Backend
backend_config = ContainerBackendConfig()
training_function = train_distributed  # Distributed
num_nodes = 3

# Podman Backend  
backend_config = ContainerBackendConfig(runtime="podman")
training_function = train_distributed  # Distributed
num_nodes = 3
```

The rest of the workflow remains **exactly the same** regardless of backend!