# HPC Integration and Checkpointing

> Configure NLSQ for HPC clusters with fault-tolerant checkpointing

**30 minutes** | **Level: Advanced**

---

## What You'll Learn

By the end of this notebook, you will be able to:

- Use `ClusterDetector` and `ClusterInfo` to detect PBS Pro environments
- Configure `WorkflowTier.STREAMING_CHECKPOINT` for fault tolerance
- Set up checkpointing with `enable_checkpoints=True` and `checkpoint_dir`
- Use `create_checkpoint_directory()` for timestamp-based directories
- Implement checkpoint resume workflows
- Create PBS Pro job scripts for NLSQ

---

## Learning Path

**You are here:** Workflow System > **HPC and Checkpointing**

```
YAML Configuration --> [You are here: HPC & Checkpointing]
```

**Prerequisites:**
- [05_yaml_configuration.ipynb](05_yaml_configuration.ipynb) - YAML configuration basics

---

## Before You Begin

**Required knowledge:**
- Familiarity with HPC batch schedulers (PBS Pro)
- Understanding of NLSQ workflow tiers

**Required software:**
- NLSQ >= 0.3.4
- Python >= 3.12

**Note:** This tutorial demonstrates HPC features. Some examples require
a PBS Pro cluster environment to run fully.

---

## Why This Matters

HPC clusters enable fitting on massive datasets (100M+ points) but introduce challenges:

1. **Job time limits:** Cluster jobs have wall time limits (hours to days)
2. **Node failures:** Hardware failures can terminate jobs unexpectedly
3. **Preemption:** Higher-priority jobs may preempt your job
4. **Multi-GPU scaling:** Need to efficiently use available GPUs

**Checkpointing solves these by:**
- Saving optimization state periodically
- Enabling resume from last checkpoint
- Preventing loss of computation on failure

---

## Quick Start (30 seconds)

Enable checkpointing in a single line:

In [None]:
from nlsq.workflow import WorkflowConfig, WorkflowTier, create_checkpoint_directory

# Create checkpoint-enabled workflow
config = WorkflowConfig(
    tier=WorkflowTier.STREAMING_CHECKPOINT,
    enable_checkpoints=True,
    checkpoint_dir=create_checkpoint_directory(),
)

print(f"Tier: {config.tier}")
print(f"Checkpoints enabled: {config.enable_checkpoints}")
print(f"Checkpoint directory: {config.checkpoint_dir}")

---

## Setup

In [None]:
import os
import pickle
from pathlib import Path

import numpy as np
import jax.numpy as jnp

from nlsq.workflow import (
    WorkflowConfig,
    WorkflowTier,
    OptimizationGoal,
    ClusterDetector,
    ClusterInfo,
    create_checkpoint_directory,
    get_multi_gpu_config,
    create_distributed_config,
    WORKFLOW_PRESETS,
)

np.random.seed(42)

---

## Tutorial Content

### Section 1: ClusterDetector and ClusterInfo

The `ClusterDetector` automatically detects PBS Pro cluster environments
via the `$PBS_NODEFILE` environment variable.

In [None]:
# Create cluster detector
detector = ClusterDetector()

print("ClusterDetector methods:")
print("  - detect(): Auto-detect cluster environment")
print("  - is_pbs_environment(): Check for PBS Pro")
print("  - detect_pbs(): Get PBS-specific info")
print("  - detect_local_gpus(): Get local GPU count")

In [None]:
# Check current environment
print(f"PBS environment detected: {detector.is_pbs_environment()}")

# Try to detect cluster
cluster_info = detector.detect()

if cluster_info:
    print(f"\nCluster detected:")
    print(f"  Scheduler: {cluster_info.scheduler}")
    print(f"  Node count: {cluster_info.node_count}")
    print(f"  GPUs per node: {cluster_info.gpus_per_node}")
    print(f"  Total GPUs: {cluster_info.total_gpus}")
    if cluster_info.job_id:
        print(f"  Job ID: {cluster_info.job_id}")
else:
    print("\nNo cluster environment detected (running locally)")

In [None]:
# ClusterInfo structure (for reference)
print("ClusterInfo fields:")
print("  - node_count: int")
print("  - gpus_per_node: int")
print("  - total_gpus: int")
print("  - node_list: list[str]")
print("  - scheduler: str ('pbs', 'local', 'unknown')")
print("  - job_id: str | None")
print("  - interconnect: str | None ('infiniband', 'ethernet')")

In [None]:
# Simulated PBS environment for demonstration
simulated_cluster = ClusterInfo(
    node_count=4,
    gpus_per_node=8,
    total_gpus=32,
    node_list=["node01", "node02", "node03", "node04"],
    scheduler="pbs",
    job_id="12345.pbs_server",
    interconnect="infiniband",
)

print("Simulated PBS cluster:")
print(f"  Nodes: {simulated_cluster.node_count}")
print(f"  GPUs: {simulated_cluster.total_gpus}")
print(f"  Job ID: {simulated_cluster.job_id}")
print(f"  Interconnect: {simulated_cluster.interconnect}")

### Section 2: WorkflowTier.STREAMING_CHECKPOINT

The `STREAMING_CHECKPOINT` tier combines streaming optimization with
automatic checkpointing for fault tolerance on massive datasets.

In [None]:
# View available workflow tiers
print("Available WorkflowTiers:")
for tier in WorkflowTier:
    print(f"  - {tier.name}")

print("\nTier descriptions:")
print("  STANDARD: curve_fit() for small datasets")
print("  CHUNKED: LargeDatasetFitter with automatic chunking")
print("  STREAMING: AdaptiveHybridStreamingOptimizer for huge datasets")
print("  STREAMING_CHECKPOINT: Streaming + automatic checkpointing")

In [None]:
# Create STREAMING_CHECKPOINT configuration
hpc_config = WorkflowConfig(
    tier=WorkflowTier.STREAMING_CHECKPOINT,
    goal=OptimizationGoal.ROBUST,
    gtol=1e-7,
    ftol=1e-7,
    xtol=1e-7,
    enable_checkpoints=True,
    checkpoint_dir="./nlsq_checkpoints",
    enable_multistart=True,
    n_starts=10,
)

print("HPC Configuration:")
print(f"  tier: {hpc_config.tier}")
print(f"  goal: {hpc_config.goal}")
print(f"  tolerances: gtol={hpc_config.gtol}, ftol={hpc_config.ftol}")
print(f"  enable_checkpoints: {hpc_config.enable_checkpoints}")
print(f"  checkpoint_dir: {hpc_config.checkpoint_dir}")
print(f"  enable_multistart: {hpc_config.enable_multistart}")
print(f"  n_starts: {hpc_config.n_starts}")

### Section 3: Checkpointing Configuration

Configure checkpointing with these key parameters:

| Parameter | Description |
|-----------|-------------|
| `enable_checkpoints` | Enable automatic checkpointing |
| `checkpoint_dir` | Directory to save checkpoints |
| `checkpoint_interval` | Iterations between checkpoints (in HybridStreamingConfig) |

In [None]:
# create_checkpoint_directory() creates timestamped directories
checkpoint_dir = create_checkpoint_directory()

print(f"Created checkpoint directory: {checkpoint_dir}")
print(f"Directory exists: {Path(checkpoint_dir).exists()}")

In [None]:
# Custom base directory
custom_checkpoint_dir = create_checkpoint_directory(base_dir="./my_project_checkpoints")

print(f"Custom checkpoint directory: {custom_checkpoint_dir}")

In [None]:
# Full checkpoint configuration
checkpoint_config = WorkflowConfig(
    tier=WorkflowTier.STREAMING_CHECKPOINT,
    goal=OptimizationGoal.ROBUST,
    enable_checkpoints=True,
    checkpoint_dir=checkpoint_dir,
)

# Serialize config for saving
config_dict = checkpoint_config.to_dict()
print("Serialized config:")
for key, value in config_dict.items():
    if not key.startswith("_"):
        print(f"  {key}: {value}")

### Section 4: Checkpoint Resume Workflow

Implement checkpoint-resume logic for fault-tolerant optimization.

In [None]:
def save_checkpoint(checkpoint_dir, iteration, params, loss, metadata=None):
    """Save optimization checkpoint.
    
    Parameters
    ----------
    checkpoint_dir : str
        Directory to save checkpoint
    iteration : int
        Current iteration number
    params : np.ndarray
        Current parameter values
    loss : float
        Current loss value
    metadata : dict, optional
        Additional metadata to save
    """
    checkpoint_path = Path(checkpoint_dir) / f"checkpoint_{iteration:06d}.pkl"
    
    checkpoint_data = {
        "iteration": iteration,
        "params": np.array(params),
        "loss": float(loss),
        "metadata": metadata or {},
    }
    
    with open(checkpoint_path, "wb") as f:
        pickle.dump(checkpoint_data, f)
    
    print(f"Saved checkpoint: {checkpoint_path.name}")
    return checkpoint_path


def load_latest_checkpoint(checkpoint_dir):
    """Load the most recent checkpoint.
    
    Parameters
    ----------
    checkpoint_dir : str
        Directory containing checkpoints
    
    Returns
    -------
    dict or None
        Checkpoint data if found, None otherwise
    """
    checkpoint_dir = Path(checkpoint_dir)
    if not checkpoint_dir.exists():
        return None
    
    # Find all checkpoint files
    checkpoints = list(checkpoint_dir.glob("checkpoint_*.pkl"))
    if not checkpoints:
        return None
    
    # Sort by name (which includes iteration number)
    latest = sorted(checkpoints)[-1]
    
    with open(latest, "rb") as f:
        checkpoint_data = pickle.load(f)
    
    print(f"Loaded checkpoint: {latest.name}")
    return checkpoint_data


print("Checkpoint functions defined")

In [None]:
# Demonstrate checkpoint save/load
demo_dir = create_checkpoint_directory(base_dir="./demo_checkpoints")

# Simulate saving checkpoints during optimization
for i in range(0, 30, 10):
    params = np.array([2.0 + 0.01 * i, 1.0 - 0.005 * i, 0.5])
    loss = 0.1 / (1 + i * 0.1)
    save_checkpoint(demo_dir, i, params, loss, metadata={"epoch": i // 10})

In [None]:
# Load latest checkpoint for resume
latest = load_latest_checkpoint(demo_dir)

if latest:
    print(f"\nResuming from iteration {latest['iteration']}")
    print(f"  Parameters: {latest['params']}")
    print(f"  Loss: {latest['loss']:.6f}")
    print(f"  Metadata: {latest['metadata']}")

In [None]:
# Resume-aware optimization loop pattern
def optimization_with_checkpoints(checkpoint_dir, max_iterations=100, checkpoint_interval=10):
    """Example optimization loop with checkpoint support."""
    
    # Try to resume from checkpoint
    checkpoint = load_latest_checkpoint(checkpoint_dir)
    
    if checkpoint:
        start_iteration = checkpoint["iteration"] + 1
        params = checkpoint["params"]
        print(f"Resuming from iteration {start_iteration}")
    else:
        start_iteration = 0
        params = np.array([1.0, 1.0, 0.0])  # Initial guess
        print("Starting fresh optimization")
    
    # Optimization loop
    for iteration in range(start_iteration, max_iterations):
        # Simulate optimization step
        params = params + 0.001 * np.random.randn(3)
        loss = np.sum(params ** 2)  # Dummy loss
        
        # Checkpoint at intervals
        if iteration > 0 and iteration % checkpoint_interval == 0:
            save_checkpoint(checkpoint_dir, iteration, params, loss)
    
    # Final checkpoint
    save_checkpoint(checkpoint_dir, max_iterations - 1, params, loss)
    
    return params

# Run optimization (will checkpoint every 10 iterations)
final_params = optimization_with_checkpoints(demo_dir, max_iterations=50)

### Section 5: HPC Distributed Configuration

Use `create_distributed_config()` to generate HPC-optimized settings.

In [None]:
# Create distributed config from cluster info
dist_config = create_distributed_config(simulated_cluster)

print("Distributed configuration for PBS cluster:")
for key, value in dist_config.items():
    print(f"  {key}: {value}")

In [None]:
# Get multi-GPU configuration
gpu_config = get_multi_gpu_config(simulated_cluster)

if gpu_config:
    print("Multi-GPU configuration:")
    print(f"  n_devices: {gpu_config.n_devices}")
    print(f"  per_device_batch_size: {gpu_config.per_device_batch_size}")
    print(f"  total_batch_size: {gpu_config.total_batch_size}")
    print(f"  use_pmap: {gpu_config.use_pmap}")
    print(f"  use_pjit: {gpu_config.use_pjit}")

In [None]:
# View the hpc_distributed preset
hpc_preset = WORKFLOW_PRESETS["hpc_distributed"]

print("HPC Distributed Preset:")
for key, value in hpc_preset.items():
    print(f"  {key}: {value}")

### Section 6: PBS Pro Job Script

Here's an example PBS Pro job script for running NLSQ on an HPC cluster.

**Note:** This is for PBS Pro specifically (not SLURM) per the requirements.

In [None]:
# Generate example PBS job script
pbs_script = '''#!/bin/bash
#PBS -N nlsq_fit
#PBS -l select=4:ncpus=32:ngpus=8:mem=256gb
#PBS -l walltime=24:00:00
#PBS -q gpu
#PBS -j oe
#PBS -o nlsq_fit.log

# NLSQ Curve Fitting Job Script for PBS Pro
# ===========================================
#
# This script demonstrates running NLSQ on a multi-node GPU cluster.
# Adjust the resource requests based on your cluster configuration.
#
# Resources requested:
#   - 4 nodes with 32 CPUs each
#   - 8 GPUs per node (e.g., NVIDIA A100)
#   - 256GB RAM per node
#   - 24 hour walltime

# Change to submission directory
cd $PBS_O_WORKDIR

# Load required modules (adjust for your cluster)
module load python/3.12
module load cuda/12.0
module load cudnn/8.9

# Activate virtual environment
source ./venv/bin/activate

# Set NLSQ environment variables
export NLSQ_WORKFLOW_GOAL=robust
export NLSQ_MEMORY_LIMIT_GB=200
export NLSQ_CHECKPOINT_DIR=$PBS_O_WORKDIR/checkpoints

# Create checkpoint directory
mkdir -p $NLSQ_CHECKPOINT_DIR

# Display job information
echo "========================================"
echo "NLSQ Fitting Job Started"
echo "========================================"
echo "Job ID: $PBS_JOBID"
echo "Node list:"
cat $PBS_NODEFILE
echo "========================================"

# Run NLSQ fitting script
python fit_large_dataset.py \\
    --data-file ./data/large_dataset.h5 \\
    --output-dir ./results \\
    --checkpoint-dir $NLSQ_CHECKPOINT_DIR \\
    --enable-checkpoints \\
    --checkpoint-interval 50

echo "========================================"
echo "Job Completed: $(date)"
echo "========================================"
'''

# Save the PBS script
pbs_script_path = Path("nlsq_fit.pbs")
pbs_script_path.write_text(pbs_script)

print("Created PBS job script: nlsq_fit.pbs")
print()
print("Contents:")
print("=" * 50)
print(pbs_script)

### Section 7: Example Fitting Script for HPC

Here's an example Python script to use with the PBS job.

In [None]:
# Example HPC fitting script content
hpc_script = '''#!/usr/bin/env python
"""NLSQ HPC Fitting Script with Checkpointing.

This script demonstrates running NLSQ on HPC clusters with:
- Automatic cluster detection
- Checkpoint-based fault tolerance
- Environment variable configuration

Usage:
    python fit_large_dataset.py --data-file data.h5 --output-dir results
"""

import argparse
import os
from pathlib import Path

import numpy as np
import jax.numpy as jnp

from nlsq import curve_fit_large
from nlsq.workflow import (
    ClusterDetector,
    create_checkpoint_directory,
    load_config_with_overrides,
)


def model(x, a, b, c):
    """Model function."""
    return a * jnp.exp(-b * x) + c


def main():
    parser = argparse.ArgumentParser(description="NLSQ HPC Fitting")
    parser.add_argument("--data-file", required=True)
    parser.add_argument("--output-dir", default="./results")
    parser.add_argument("--checkpoint-dir", default=None)
    parser.add_argument("--enable-checkpoints", action="store_true")
    parser.add_argument("--checkpoint-interval", type=int, default=50)
    args = parser.parse_args()
    
    # Detect cluster environment
    detector = ClusterDetector()
    cluster_info = detector.detect()
    
    if cluster_info:
        print(f"Running on {cluster_info.scheduler} cluster")
        print(f"  Nodes: {cluster_info.node_count}")
        print(f"  Total GPUs: {cluster_info.total_gpus}")
    else:
        print("Running locally")
    
    # Load configuration with environment overrides
    config = load_config_with_overrides()
    memory_limit = config.get("memory_limit_gb", 16.0)
    
    # Setup checkpointing
    checkpoint_dir = args.checkpoint_dir
    if checkpoint_dir is None and args.enable_checkpoints:
        checkpoint_dir = create_checkpoint_directory()
    
    # Load data (example - replace with actual data loading)
    # In production, use h5py or similar for large datasets
    print(f"Loading data from: {args.data_file}")
    # x_data, y_data = load_hdf5_data(args.data_file)
    
    # For demonstration, generate synthetic data
    n_points = 10_000_000
    x_data = np.linspace(0, 10, n_points)
    y_data = 2.5 * np.exp(-1.3 * x_data) + 0.5 + 0.1 * np.random.randn(n_points)
    
    print(f"Dataset size: {n_points:,} points")
    
    # Run fitting
    popt, pcov = curve_fit_large(
        model,
        x_data,
        y_data,
        p0=[1.0, 1.0, 0.0],
        memory_limit_gb=memory_limit,
        multistart=True,
        n_starts=10,
    )
    
    # Save results
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    np.save(output_dir / "popt.npy", popt)
    np.save(output_dir / "pcov.npy", pcov)
    
    print(f"\\nResults saved to: {output_dir}")
    print(f"Fitted parameters: a={popt[0]:.4f}, b={popt[1]:.4f}, c={popt[2]:.4f}")


if __name__ == "__main__":
    main()
'''

# Save the example script
script_path = Path("fit_large_dataset.py")
script_path.write_text(hpc_script)

print("Created example fitting script: fit_large_dataset.py")

---

## Key Takeaways

After completing this notebook, remember:

1. **Cluster detection:**
   - `ClusterDetector().detect()` auto-detects PBS environments
   - `ClusterInfo` provides node count, GPU count, job ID

2. **Checkpointing configuration:**
   - Use `WorkflowTier.STREAMING_CHECKPOINT` for fault tolerance
   - Set `enable_checkpoints=True` and `checkpoint_dir`
   - Use `create_checkpoint_directory()` for timestamped directories

3. **PBS Pro integration:**
   - Set resources with `#PBS -l select=...:ngpus=...`
   - Use environment variables for configuration
   - Create checkpoint directory with `mkdir -p`

4. **Resume workflow:**
   - Check for existing checkpoints at job start
   - Resume from latest checkpoint iteration
   - Save checkpoints at regular intervals

---

## Common Questions

**Q: How often should I checkpoint?**

A: Balance checkpoint frequency against I/O overhead. For hour-long jobs, checkpoint every 10-15 minutes. For multi-day jobs, every 30-60 minutes.

**Q: Can I use SLURM instead of PBS?**

A: NLSQ currently auto-detects PBS via `$PBS_NODEFILE`. For SLURM, you can manually create `ClusterInfo` from `$SLURM_*` variables.

**Q: How much disk space do checkpoints use?**

A: Each checkpoint stores parameters (small) plus optimizer state. Typically a few MB per checkpoint.

---

## Related Resources

**Previous tutorials:**
- [05_yaml_configuration.ipynb](05_yaml_configuration.ipynb) - YAML configuration
- [06_auto_selection.ipynb](06_auto_selection.ipynb) - Automatic workflow selection

**Further reading:**
- [PBS Pro User Guide](https://www.pbsworks.com/documentation)
- [JAX Multi-GPU Documentation](https://jax.readthedocs.io/en/latest/multi_process.html)

---

## Glossary

**PBS Pro:** Portable Batch System Professional - an HPC job scheduler.

**Checkpoint:** A saved snapshot of optimization state that enables resume.

**pmap/pjit:** JAX primitives for multi-device parallelism.

In [None]:
# Cleanup
import shutil

# Clean up demo directories and files
for path in ["nlsq_checkpoints", "demo_checkpoints", "my_project_checkpoints"]:
    if Path(path).exists():
        shutil.rmtree(path)
        print(f"Cleaned up: {path}")

for path in ["nlsq_fit.pbs", "fit_large_dataset.py"]:
    if Path(path).exists():
        Path(path).unlink()
        print(f"Cleaned up: {path}")

# Final summary
print("\n" + "=" * 50)
print("Summary")
print("=" * 50)
print("\nHPC Integration:")
print("  - ClusterDetector for PBS Pro detection")
print("  - ClusterInfo for cluster metadata")
print("\nCheckpointing:")
print("  - WorkflowTier.STREAMING_CHECKPOINT")
print("  - enable_checkpoints=True")
print("  - create_checkpoint_directory()")
print("\nPBS Pro:")
print("  - #PBS -l select=N:ngpus=M")
print("  - Environment variable overrides")