# TPU Demo Notebook

This notebook validates the Iris cluster's TPU functionality by running real JAX
computations on TPU hardware.

**Prerequisites:**
- SSH tunnel to controller: `uv run python scripts/cluster-tools.py --zone europe-west4-b --project hai-gcp-models tunnel`
- Environment variables:
  - `IRIS_CONTROLLER_ADDRESS=http://localhost:10000`
  - `IRIS_WORKSPACE=/path/to/lib/iris`

**Tests:**
1. Basic TPU job - verify TPU provisioning
2. JAX matrix multiplication - verify JAX execution
3. Multi-device pmap - verify data parallelism
4. Coscheduled multi-host job - verify distributed execution

## Setup and Connection

Connect to the cluster via the SSH tunnel.

In [None]:
import os
from pathlib import Path

from iris.client import IrisClient
from iris.cluster.types import (
    CoschedulingConfig,
    Entrypoint,
    EnvironmentSpec,
    ResourceSpec,
    tpu_device,
)
from iris.rpc import cluster_pb2

# Configuration
TPU_TYPE = "v5litepod-16"  # 4 VMs per slice
JOB_TIMEOUT = 600  # 10 minutes for TPU provisioning

# Connect to the cluster
controller_address = os.environ.get("IRIS_CONTROLLER_ADDRESS", "http://localhost:10000")
workspace_str = os.environ.get("IRIS_WORKSPACE")
if workspace_str is None:
    raise RuntimeError(
        "IRIS_WORKSPACE not set. Set it to the iris project root:\n"
        "  export IRIS_WORKSPACE=/path/to/lib/iris"
    )
workspace = Path(workspace_str)

client = IrisClient.remote(controller_address, workspace=workspace)
print(f"Connected to cluster at {controller_address}")
print(f"Workspace: {workspace}")
print(f"TPU type: {TPU_TYPE}")

## Test 1: Basic TPU Job

Submit a simple job that runs on TPU and reports device info. This validates
that the autoscaler can provision TPU slices and jobs execute correctly.

In [None]:
def tpu_device_info():
    """Report TPU device information."""
    import jax
    
    devices = jax.devices()
    print(f"JAX devices: {len(devices)}")
    for i, device in enumerate(devices):
        print(f"  Device {i}: {device.device_kind} @ {device.platform}")
    
    print(f"\nDefault backend: {jax.default_backend()}")
    print(f"Local device count: {jax.local_device_count()}")
    
    return {
        "device_count": len(devices),
        "backend": jax.default_backend(),
        "device_kinds": [d.device_kind for d in devices],
    }


print(f"Submitting TPU device info job (may trigger TPU provisioning)...")
print(f"Note: First job can take 5-10 minutes while TPU slice provisions.\n")

job = client.submit(
    entrypoint=Entrypoint.from_callable(tpu_device_info),
    name="tpu-device-info",
    resources=ResourceSpec(device=tpu_device(TPU_TYPE)),
    environment=EnvironmentSpec(workspace="/app"),
)
print(f"Job submitted: {job.job_id}")

status = job.wait(timeout=JOB_TIMEOUT, stream_logs=True)
state_name = cluster_pb2.JobState.Name(status.state)
print(f"\nJob completed: {state_name}")

if status.state != cluster_pb2.JOB_STATE_SUCCEEDED:
    raise RuntimeError(f"Job failed: {status.error}")

print("\nTest 1 PASSED: Basic TPU job executed successfully")

## Test 2: JAX Matrix Multiplication

Run a real computation on TPU: matrix multiplication using JAX. This validates
that JAX operations execute correctly on TPU hardware.

In [None]:
def jax_matmul():
    """Perform matrix multiplication on TPU."""
    import jax
    import jax.numpy as jnp
    import time
    
    print(f"Running on: {jax.default_backend()}")
    print(f"Devices: {jax.device_count()}")
    
    # Create large matrices for meaningful TPU utilization
    n = 4096
    print(f"\nCreating {n}x{n} matrices...")
    
    key = jax.random.PRNGKey(42)
    key1, key2 = jax.random.split(key)
    a = jax.random.normal(key1, (n, n))
    b = jax.random.normal(key2, (n, n))
    
    # JIT-compile the matmul
    @jax.jit
    def matmul(x, y):
        return jnp.dot(x, y)
    
    # Warmup (compilation)
    print("Compiling...")
    _ = matmul(a, b).block_until_ready()
    
    # Benchmark
    print("Running benchmark (10 iterations)...")
    times = []
    for i in range(10):
        start = time.perf_counter()
        c = matmul(a, b).block_until_ready()
        elapsed = time.perf_counter() - start
        times.append(elapsed)
    
    avg_time = sum(times) / len(times)
    flops = 2 * n**3 / avg_time / 1e12  # TFLOPS
    
    print(f"\nResults:")
    print(f"  Matrix shape: ({n}, {n})")
    print(f"  Avg time: {avg_time*1000:.2f} ms")
    print(f"  Throughput: {flops:.2f} TFLOPS")
    print(f"  Output shape: {c.shape}")
    print(f"  Output sum: {float(jnp.sum(c)):.4e}")
    
    return {"tflops": flops, "avg_ms": avg_time * 1000}


print("Submitting JAX matmul job...\n")

job = client.submit(
    entrypoint=Entrypoint.from_callable(jax_matmul),
    name="tpu-matmul",
    resources=ResourceSpec(device=tpu_device(TPU_TYPE)),
    environment=EnvironmentSpec(workspace="/app"),
)
print(f"Job submitted: {job.job_id}")

status = job.wait(timeout=JOB_TIMEOUT, stream_logs=True)
state_name = cluster_pb2.JobState.Name(status.state)
print(f"\nJob completed: {state_name}")

if status.state != cluster_pb2.JOB_STATE_SUCCEEDED:
    raise RuntimeError(f"Job failed: {status.error}")

print("\nTest 2 PASSED: JAX matrix multiplication executed successfully")

## Test 3: Multi-Device pmap

Use JAX's `pmap` for data-parallel computation across multiple TPU chips.
This validates that JAX can utilize all TPU cores on a single worker.

In [None]:
def jax_pmap_demo():
    """Demonstrate pmap across TPU devices."""
    import jax
    import jax.numpy as jnp
    
    n_devices = jax.local_device_count()
    print(f"Local devices: {n_devices}")
    print(f"Devices: {[d.device_kind for d in jax.local_devices()]}")
    
    # Define a simple parallel computation
    @jax.pmap
    def parallel_square(x):
        return x ** 2
    
    # Create input data - one batch per device
    batch_size = 1024
    x = jnp.arange(n_devices * batch_size).reshape(n_devices, batch_size)
    print(f"\nInput shape: {x.shape}")
    print(f"Input (first 5 per device): {x[:, :5]}")
    
    # Run parallel computation
    y = parallel_square(x)
    print(f"\nOutput shape: {y.shape}")
    print(f"Output (first 5 per device): {y[:, :5]}")
    
    # Verify results
    expected = x ** 2
    assert jnp.allclose(y, expected), "pmap result mismatch!"
    print("\nVerification: Results match expected values")
    
    # Demonstrate psum (parallel sum)
    @jax.pmap
    def parallel_sum_reduce(x):
        return jax.lax.psum(jnp.sum(x), axis_name='i')
    
    # Wrap in named axis for psum
    @jax.pmap(axis_name='i')
    def parallel_global_sum(x):
        local_sum = jnp.sum(x)
        global_sum = jax.lax.psum(local_sum, 'i')
        return global_sum
    
    global_sums = parallel_global_sum(x)
    print(f"\nGlobal sum via psum (each device sees): {global_sums[0]}")
    print(f"Expected global sum: {jnp.sum(x)}")
    
    return {
        "n_devices": n_devices,
        "batch_size": batch_size,
        "global_sum": float(global_sums[0]),
    }


print("Submitting pmap demo job...\n")

job = client.submit(
    entrypoint=Entrypoint.from_callable(jax_pmap_demo),
    name="tpu-pmap-demo",
    resources=ResourceSpec(device=tpu_device(TPU_TYPE)),
    environment=EnvironmentSpec(workspace="/app"),
)
print(f"Job submitted: {job.job_id}")

status = job.wait(timeout=JOB_TIMEOUT, stream_logs=True)
state_name = cluster_pb2.JobState.Name(status.state)
print(f"\nJob completed: {state_name}")

if status.state != cluster_pb2.JOB_STATE_SUCCEEDED:
    raise RuntimeError(f"Job failed: {status.error}")

print("\nTest 3 PASSED: Multi-device pmap executed successfully")

## Test 4: Coscheduled Multi-Host Job

Submit a coscheduled job with multiple replicas. In a v5litepod-16 configuration,
there are 4 VMs per slice. This test validates that:
- All 4 tasks are scheduled on workers from the same TPU slice
- Task indices are correctly assigned (0, 1, 2, 3)
- All tasks can communicate their status

In [None]:
from iris.cluster.types import get_tpu_topology

def multi_host_task():
    """Task that runs as part of a coscheduled multi-host job."""
    import os
    import jax
    from iris.cluster.client import get_job_info
    
    info = get_job_info()
    if info is None:
        raise RuntimeError("Not running in Iris job context")
    
    print(f"=== Task {info.task_index} of {info.num_tasks} ===")
    print(f"Worker ID: {info.worker_id}")
    print(f"TPU worker ID: {os.environ.get('TPU_WORKER_ID', 'N/A')}")
    
    # Report JAX device info
    print(f"\nJAX backend: {jax.default_backend()}")
    print(f"Local devices: {jax.local_device_count()}")
    
    # Simple computation to verify TPU works
    import jax.numpy as jnp
    key = jax.random.PRNGKey(info.task_index)
    x = jax.random.normal(key, (1000,))
    result = float(jnp.sum(x ** 2))
    print(f"Computation result: {result:.4f}")
    
    return {
        "task_index": info.task_index,
        "num_tasks": info.num_tasks,
        "worker_id": info.worker_id,
        "result": result,
    }


# Get VM count for the TPU topology
topo = get_tpu_topology(TPU_TYPE)
replicas = topo.vm_count

print(f"TPU topology: {TPU_TYPE}")
print(f"VMs per slice: {replicas}")
print(f"Chips per VM: {topo.chips_per_vm}")
print(f"\nSubmitting coscheduled job with {replicas} replicas...\n")

job = client.submit(
    entrypoint=Entrypoint.from_callable(multi_host_task),
    name="tpu-multi-host",
    resources=ResourceSpec(
        device=tpu_device(TPU_TYPE),
        replicas=replicas,
    ),
    environment=EnvironmentSpec(workspace="/app"),
    coscheduling=CoschedulingConfig(group_by="tpu-name"),
)
print(f"Job submitted: {job.job_id}")

status = job.wait(timeout=JOB_TIMEOUT, stream_logs=True)
state_name = cluster_pb2.JobState.Name(status.state)
print(f"\nJob completed: {state_name}")

if status.state != cluster_pb2.JOB_STATE_SUCCEEDED:
    raise RuntimeError(f"Job failed: {status.error}")

# Verify task placement
print("\nTask placement:")
for task in job.tasks():
    task_status = task.status()
    print(f"  Task {task.task_index}: worker={task_status.worker_id}")

print(f"\nTest 4 PASSED: Coscheduled multi-host job ({replicas} tasks) completed")

## Summary

All TPU tests completed. The cluster is functioning correctly for:
- TPU slice provisioning via autoscaler
- JAX execution on TPU hardware
- Multi-device data parallelism (pmap)
- Coscheduled multi-host distributed jobs

In [None]:
print("="*60)
print("TPU DEMO COMPLETED SUCCESSFULLY")
print("="*60)
print(f"\nCluster: {controller_address}")
print(f"TPU type: {TPU_TYPE}")
print(f"\nAll 4 tests passed:")
print("  1. Basic TPU job")
print("  2. JAX matrix multiplication")
print("  3. Multi-device pmap")
print("  4. Coscheduled multi-host job")