# Distributed Collective Operations

This notebook demonstrates HPX collective communication patterns running across multiple localities (processes). Each collective is executed by a worker script that runs on every locality simultaneously (SPMD pattern).

**Collectives demonstrated:**
1. **all_reduce** - combine local values, result available on all localities
2. **broadcast** - send data from root to all localities
3. **gather** - collect data from all localities to root
4. **scatter** - distribute chunks from root to all localities
5. **barrier** - synchronize all localities

In [None]:
%%writefile _distributed_reduction_worker.py
"""SPMD collective operations demo.

Each locality runs this script. Collectives coordinate across all localities.
"""
import sys
import numpy as np
import hpxpy as hpx
from hpxpy.launcher import init_from_args

init_from_args()

my_id = hpx.locality_id()
num_locs = hpx.num_localities()
print(f"[Locality {my_id}/{num_locs}] Started with {hpx.num_threads()} threads")

# ============================================================
# 1. ALL-REDUCE: each locality contributes, all get the result
# ============================================================
print(f"\n{'='*50}")
print("Phase 1: All-Reduce")
print(f"{'='*50}")

# Each locality has different local data
np.random.seed(100 + my_id)
local_values = np.random.randn(5)
local_arr = hpx.from_numpy(local_values)

local_sum = float(hpx.sum(local_arr))
print(f"[Locality {my_id}] Local values: {local_values.round(2)}")
print(f"[Locality {my_id}] Local sum: {local_sum:.4f}")

# All-reduce combines across localities; every locality gets the result
global_arr = hpx.all_reduce(local_arr, op='sum')
global_sum = float(hpx.sum(global_arr))
print(f"[Locality {my_id}] Global sum after all_reduce: {global_sum:.4f}")

hpx.barrier("after_allreduce")

# ============================================================
# 2. BROADCAST: root sends data to all localities
# ============================================================
print(f"\n{'='*50}")
print("Phase 2: Broadcast")
print(f"{'='*50}")

if my_id == 0:
    params = np.array([0.01, 0.9, 256.0])  # learning_rate, momentum, batch_size
    print(f"[Locality 0] Broadcasting parameters: {params}")
else:
    params = np.zeros(3)

params_arr = hpx.from_numpy(params)
received = hpx.broadcast(params_arr, root=0)
print(f"[Locality {my_id}] Received: {received.to_numpy()}")

hpx.barrier("after_broadcast")

# ============================================================
# 3. GATHER: all localities send to root
# ============================================================
print(f"\n{'='*50}")
print("Phase 3: Gather")
print(f"{'='*50}")

# Each locality computes local statistics
local_mean = float(np.mean(local_values))
local_std = float(np.std(local_values))
my_stats = hpx.from_numpy(np.array([local_mean, local_std, float(len(local_values))]))
print(f"[Locality {my_id}] Sending stats: mean={local_mean:.4f}, std={local_std:.4f}, n={len(local_values)}")

gathered = hpx.gather(my_stats, root=0)

if my_id == 0:
    print(f"[Locality 0] Gathered from {len(gathered)} localities:")
    for i, stats in enumerate(gathered):
        print(f"  Locality {i}: mean={stats[0]:.4f}, std={stats[1]:.4f}, n={int(stats[2])}")

hpx.barrier("after_gather")

# ============================================================
# 4. SCATTER: root distributes chunks to each locality
# ============================================================
print(f"\n{'='*50}")
print("Phase 4: Scatter")
print(f"{'='*50}")

if my_id == 0:
    # Root creates work assignments for each locality
    all_work = np.arange(num_locs * 4, dtype=np.float64)
    print(f"[Locality 0] Scattering work array: {all_work}")
else:
    all_work = np.empty(0)

work_arr = hpx.from_numpy(all_work)
my_chunk = hpx.scatter(work_arr, root=0)
print(f"[Locality {my_id}] Received chunk: {my_chunk.to_numpy()}")

hpx.barrier("after_scatter")

# ============================================================
# 5. BARRIER: synchronize all localities
# ============================================================
print(f"\n{'='*50}")
print("Phase 5: Barrier")
print(f"{'='*50}")

import time
# Simulate different amounts of work per locality
time.sleep(0.1 * my_id)
print(f"[Locality {my_id}] Finished work, waiting at barrier...")
hpx.barrier("final_sync")
print(f"[Locality {my_id}] All localities synchronized!")

hpx.barrier("cleanup")
if my_id == 0:
    print("\nAll phases completed successfully.")
hpx.finalize()

## Launch Distributed Execution

Each locality runs the same worker script (SPMD). The collectives synchronize data between them via HPX's TCP parcelport.

In [None]:
from hpxpy.launcher import launch_localities

launch_localities(
    "_distributed_reduction_worker.py",
    num_localities=2,
    threads_per_locality=2,
    verbose=True,
)

## Collective Operation Patterns

### All-Reduce
Every locality contributes data. The reduction (sum, min, max, prod) is computed and the result is available on **all** localities. Communication: O(N * log(P)) where N is data size, P is localities.

### Broadcast
One locality (root) sends data to all others. Communication: O(N * log(P)).

### Gather
All localities send their data to the root. Root receives a list of arrays. Communication: O(N * P) at root.

### Scatter
Root distributes equal-sized chunks to each locality. Each locality receives its portion. Communication: O(N) total.

### Barrier
Pure synchronization point. No data is exchanged. All localities block until every locality has reached the barrier.


In [None]:
import os
os.remove("_distributed_reduction_worker.py")
print("Cleaned up worker script.")