# Distributed Computing Basics

This guide introduces the distributed computing capabilities in PtDAlgorithms, showing how to scale computations from a single machine to 100+ node clusters with minimal code changes.

## Overview

The new distributed computing interface provides:

- **One-line initialization** - Replace 200+ lines of SLURM boilerplate
- **Automatic environment detection** - Works locally and on SLURM clusters
- **JAX integration** - Full support for pmap/vmap parallelization
- **Configuration management** - YAML-based cluster configs

## The Problem: SLURM Boilerplate

### Before (200+ lines)

```python
# Detect SLURM environment
if 'SLURM_JOB_ID' in os.environ:
    num_processes = int(os.environ['SLURM_NTASKS'])
    process_id = int(os.environ['SLURM_PROCID'])
    cpus_per_task = int(os.environ.get('SLURM_CPUS_PER_TASK', '1'))

    # Get coordinator node
    nodelist = os.environ['SLURM_JOB_NODELIST']
    result = subprocess.run(['scontrol', 'show', 'hostnames', nodelist], ...)
    coordinator_node = result.stdout.strip().split('\n')[0]

    # Setup environment
    os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={cpus_per_task}'

    # Initialize JAX distributed
    coordinator_address = f"{coordinator_node}:{coordinator_port}"
    jax.distributed.initialize(...)

# ... 150+ more lines ...
```

### After (1 line)

```python
from ptdalgorithms import initialize_distributed

dist_info = initialize_distributed()
```

That's it! All SLURM detection, coordinator setup, and JAX initialization happens automatically.

## Quick Start Example

Let's see how easy distributed computing can be:

In [1]:
from ptdalgorithms import initialize_distributed

# Initialize distributed computing (handles everything automatically)
dist_info = initialize_distributed(
    coordinator_port=12345,
    platform="cpu",
    enable_x64=True
)

# Check configuration
print(f"Process ID: {dist_info.process_id}")
print(f"Total processes: {dist_info.num_processes}")
print(f"Local devices: {dist_info.local_device_count}")
print(f"Global devices: {dist_info.global_device_count}")
print(f"Is coordinator: {dist_info.is_coordinator}")
print(f"\nFull configuration:")
print(dist_info)

[INFO] Not running under SLURM - using single-node setup
[INFO] Configured JAX for 1 CPU devices
[INFO] JAX x64 precision enabled
[INFO] Single-node setup - no distributed initialization needed
[INFO] 
Distributed Configuration:
  Job ID: N/A
  Process: 0/1
  Coordinator: localhost:12345 (this node)
  Local devices: 1
  Global devices: 1
  Platform: cpu


Process ID: 0
Total processes: 1
Local devices: 1
Global devices: 1
Is coordinator: True

Full configuration:
Distributed Configuration:
  Job ID: N/A
  Process: 0/1
  Coordinator: localhost:12345 (this node)
  Local devices: 1
  Global devices: 1
  Platform: cpu


## DistributedConfig Object

The `initialize_distributed()` function returns a `DistributedConfig` object with all the information you need:

| Attribute | Description |
|-----------|-------------|
| `num_processes` | Total number of processes (nodes) |
| `process_id` | This process's rank (0 to num_processes-1) |
| `local_device_count` | Number of devices on this node |
| `global_device_count` | Total devices across all nodes |
| `is_coordinator` | True if this is the coordinator (rank 0) |
| `coordinator_address` | Address of coordinator ("host:port") |
| `job_id` | SLURM job ID (if running under SLURM) |

## Simple Distributed Computation

Let's build a complete example that evaluates a phase-type distribution across multiple devices:

In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from ptdalgorithms import Graph

def build_erlang_model(num_stages=5):
    """
    Build an Erlang distribution (sum of exponentials).
    
    This represents the time until the num_stages'th event
    in a Poisson process.
    """
    g = Graph(1)
    start = g.starting_vertex()
    
    # Create chain of states
    vertices = [start]
    for i in range(num_stages):
        v = g.find_or_create_vertex([i + 1])
        vertices.append(v)
    
    # Add edges with rate 1.0
    for i in range(num_stages):
        vertices[i].add_edge(vertices[i + 1], 1.0)
    
    return g

# Build model
graph = build_erlang_model(num_stages=5)
print(f"Built Erlang(5) distribution with {graph.vertices_length()} states")

Built Erlang(5) distribution with 6 states


### Parallel PDF Evaluation

Now let's evaluate the PDF at multiple time points in parallel using JAX:

In [3]:
def evaluate_pdf_parallel(graph, time_points, dist_info):
    """
    Evaluate PDF at multiple time points using parallel computation.
    """
    # Determine how many points per device
    n_points = len(time_points)
    points_per_device = n_points // dist_info.local_device_count
    
    # Reshape for pmap: (n_devices, points_per_device)
    time_points_reshaped = time_points[:dist_info.local_device_count * points_per_device]
    time_points_reshaped = time_points_reshaped.reshape(
        (dist_info.local_device_count, points_per_device)
    )
    
    # Define PDF evaluation function
    def eval_pdf(t):
        return graph.pdf(float(t))
    
    # Vectorize over time points on each device
    vmap_pdf = jax.vmap(eval_pdf)
    
    # Parallelize across devices
    pmap_pdf = jax.pmap(vmap_pdf)
    
    # Evaluate
    time_points_jax = jnp.array(time_points_reshaped)
    pdf_values = pmap_pdf(time_points_jax)
    
    # Flatten results
    return time_points_reshaped.flatten(), pdf_values.flatten()

# Generate time points
time_points = np.linspace(0.1, 10.0, 32)

# Evaluate in parallel
times, pdf_vals = evaluate_pdf_parallel(graph, time_points, dist_info)

# Display results
if dist_info.is_coordinator:
    print(f"\nEvaluated {len(times)} time points in parallel")
    print(f"Distributed across {dist_info.global_device_count} devices")
    print(f"\nSample values:")
    for i in range(min(5, len(times))):
        print(f"  t={times[i]:.2f} -> PDF={pdf_vals[i]:.6f}")

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[]
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
This BatchTracer with object id 5450445360 was created on line:
  /var/folders/s6/srs8qkh52w1_h32d65z95tth0000gn/T/ipykernel_72389/2297681017.py:27:17 (evaluate_pdf_parallel)

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError

### Visualization

Let's visualize the PDF we just computed:

In [None]:
import matplotlib.pyplot as plt

if dist_info.is_coordinator:
    plt.figure(figsize=(10, 6))
    plt.plot(times, pdf_vals, 'b-', linewidth=2, label='Erlang(5) PDF')
    plt.xlabel('Time', fontsize=12)
    plt.ylabel('Probability Density', fontsize=12)
    plt.title('Erlang Distribution PDF (Computed in Parallel)', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=11)
    plt.tight_layout()
    plt.show()
    
    print(f"\nPDF computed using {dist_info.global_device_count} parallel devices")

## Best Practices

### 1. Use Coordinator Check for Output

Only the coordinator (rank 0) should print summary information:

In [None]:
if dist_info.is_coordinator:
    print(f"Starting computation with {dist_info.global_device_count} devices")
    # ... other logging ...

### 2. Distribute Work Evenly

Ensure work is divisible by device count:

In [None]:
# Good: evenly divisible
n_particles = dist_info.global_device_count * 4  # Exactly 4 per device

# Bad: not evenly divisible
# n_particles = 37  # Won't divide evenly

print(f"Using {n_particles} particles across {dist_info.global_device_count} devices")
print(f"= {n_particles // dist_info.global_device_count} particles per device")

### 3. Use Different Seeds per Process

Avoid identical random numbers across processes:

In [None]:
# Set unique seed for each process
np.random.seed(42 + dist_info.process_id)

# Generate some random numbers (different on each process)
random_vals = np.random.randn(5)
print(f"Process {dist_info.process_id} random values: {random_vals[:3]}")

## Running on SLURM Clusters

The same code works seamlessly on SLURM clusters! No code changes needed.

### Local Testing

```bash
# Test on your laptop
python my_script.py
```

### SLURM Submission

```bash
# Submit to cluster (no code changes!)
sbatch <(python generate_slurm_script.py --profile medium --script my_script.py)
```

### What Happens Automatically

When running on SLURM, `initialize_distributed()` automatically:

1. Detects SLURM environment variables
2. Identifies coordinator node
3. Configures JAX devices
4. Initializes distributed JAX
5. Sets up inter-node communication

When running locally, it:

1. Creates multiple local devices
2. Enables CPU parallelization
3. Works just like SLURM mode (but single node)

## Performance Scaling

Let's demonstrate how computation scales with device count:

In [None]:
import time

def benchmark_parallel_computation(graph, n_evaluations, dist_info):
    """
    Benchmark parallel PDF evaluation.
    """
    # Generate time points
    time_points = np.linspace(0.1, 10.0, n_evaluations)
    
    # Warm up (JIT compilation)
    _, _ = evaluate_pdf_parallel(graph, time_points[:16], dist_info)
    
    # Benchmark
    start = time.time()
    times, pdf_vals = evaluate_pdf_parallel(graph, time_points, dist_info)
    elapsed = time.time() - start
    
    return elapsed, len(times)

if dist_info.is_coordinator:
    print("\nPerformance Benchmark")
    print("=" * 50)
    
    # Test different workload sizes
    for n in [32, 64, 128]:
        elapsed, n_computed = benchmark_parallel_computation(graph, n, dist_info)
        throughput = n_computed / elapsed
        
        print(f"\nEvaluations: {n_computed}")
        print(f"  Time: {elapsed:.4f}s")
        print(f"  Throughput: {throughput:.1f} evals/sec")
        print(f"  Devices: {dist_info.global_device_count}")
    
    print("\n" + "=" * 50)

## Next Steps

Now that you understand the basics, check out:

1. **[Distributed SVGD Inference](distributed_svgd_inference.ipynb)** - Bayesian inference across multiple nodes
2. **[SLURM Cluster Setup](slurm_cluster_setup.ipynb)** - Configure and manage cluster resources
3. **[API Reference](../api/index.html)** - Complete API documentation

## Summary

**One-line initialization** replaces 200+ lines of boilerplate

**Automatic environment detection** works locally and on SLURM

**Full JAX integration** with pmap/vmap/jit support

**Same code everywhere** - develop locally, deploy to cluster

```python
# That's all you need!
dist_info = initialize_distributed()
```