<a href="https://colab.research.google.com/github/mridul-sahu/jax-sharding-tutorials/blob/main/Chapter_1_2_Single_Host_Sharding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [Series 1, Chapter 1.2: Single-Host Sharding with `jax.pmap` - Foundations of Data Parallelism] - The Aurora Project 🌌

## Introduction

Welcome back, Aurora Architect! In our previous briefing (Chapter 1.1), we established a deep understanding of JAX data primitives—`Array`s, host versus device memory, and the art of explicit data placement using `jax.device_put` and `jax.device_get`. We learned to command individual data elements with precision.

However, Project Aurora's ambitions require us to process datasets and train models far too large for any single accelerator core. We must harness the combined power of multiple devices working in concert. Our first major step into this parallel universe is **`jax.pmap`** (parallel map). This transformation is a cornerstone for single-host, multi-device parallelism, allowing us to replicate a computation across multiple local devices (like GPUs or TPU cores on a single machine) and process different chunks of data simultaneously—the essence of **data parallelism**.

**Chapter Goal:** This chapter will equip you to use `jax.pmap` to distribute computations and data across all available devices on a single Aurora node. You will learn to control how data is split (scattered) to devices and how results are combined (gathered), and to perform essential inter-device communication using JAX collectives.

**Topic Introduction:** We'll delve into the mechanics of `jax.pmap`, including function replication and the critical `in_axes` and `out_axes` arguments for controlling data mapping. We will explore how to prepare data for parallel processing, perform collective operations like `psum` for aggregating results (e.g., gradients), and understand the memory implications and limitations of this powerful, yet foundational, parallelism strategy.

**Outcome Statement:** By the end of this chapter, you will be proficient in implementing basic data parallelism using `jax.pmap`, enabling significant speedups for many of Aurora's training tasks on multi-accelerator single hosts. You'll also understand its operational model, laying the groundwork for more advanced sharding techniques to come.

### Learning Objectives for This Phase:

* Implement data parallelism on multiple devices on a single host using `jax.pmap`.
* Master the usage of `in_axes` and `out_axes` to control data distribution and collection.
* Utilize `jax.lax` collectives (e.g., `psum`, `pmean`) within `pmap` for inter-device communication.
* Analyze memory usage patterns and limitations of the `pmap` model.

### Chapter Outline:

1.  **The Essence of `pmap`: Function Replication & "SPMD Lite"**
2.  **Preparing Data for Parallel Execution: Physical Splitting** (`numpy.split`)
3.  **Mapping Data to Devices: The `in_axes` Argument**
4.  **Consolidating Results: The `out_axes` Argument**
5.  **Handling Constants Efficiently: `static_broadcast_argnums`**
6.  **Targeting Specific Devices with `pmap`'s `devices` Argument**
7.  **Inter-Device Alchemy: Collectives within `pmap`** (`jax.lax.psum`, `pmean`, `all_gather`)
8.  **The Memory Footprint of `pmap`: Considerations for Aurora**
9.  **`pmap` in Practice: Use Cases, Strengths, and Limitations**

## Core Concepts Refresher

Before diving into `pmap`, let's recall from Chapter 1.1:
* **JAX Arrays & Device Affinity:** JAX arrays (`DeviceArray`s) reside on specific devices.
* **Explicit Placement:** We can use `jax.device_put` for precise control, though `pmap` will also manage placement.
* **Asynchronous Dispatch:** JAX operations, especially on accelerators, are often asynchronous. `block_until_ready()` is key for synchronization when needed.

`pmap` builds upon these concepts by taking a function and running it concurrently across multiple devices, each operating on its designated piece of data. It's a form of SPMD (Single Program, Multiple Data) programming, though JAX handles much of the complexity.


First, let's set up our JAX environment and check available devices. The effectiveness of `pmap` shines when multiple devices are present.

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import os

# Forcing 4 CPU devices for pmap demonstration on CPU runtime.
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

print(f"JAX version: {jax.__version__}")
print(f"JAX default backend: {jax.default_backend()}")

# After potential restart (if on CPU and setting XLA_FLAGS), re-initialize JAX context implicitly by using JAX
num_devices = jax.local_device_count()
devices = jax.local_devices()

print(f"Number of local JAX devices available: {num_devices}")
print(f"Available devices: {devices}")

# Ensure we have at least one device to proceed
if num_devices == 0:
    raise RuntimeError("No JAX devices found. pmap requires at least one device.")

JAX version: 0.5.2
JAX default backend: cpu
Number of local JAX devices available: 4
Available devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]


### 1. The Essence of `pmap`: Function Replication & "SPMD Lite"

`jax.pmap` (parallel map) is a JAX transformation that takes a function and compiles it to run in parallel on multiple XLA devices (e.g., GPUs, TPU cores).

**Core Mechanics:**
* **Function Replication:** The function you pass to `pmap` is replicated across all specified devices (or all available local devices by default). Each device executes the *same* function.
* **Data Parallelism (SPMD style):** While the program (function) is the same, each replica typically operates on a different *slice* of the input data. This is the essence of Single Program, Multiple Data (SPMD).
* **Input/Output Handling:** `pmap` needs to know how the input data should be distributed (sharded/split) across devices and how the outputs from each device should be combined or returned. This is primarily controlled by the `in_axes` and `out_axes` arguments.

Think of it as an advanced version of Python's `map`, but supercharged for parallel execution on hardware accelerators.

In [8]:
# A simple function we want to run in parallel
def scale(x_shard, scale_factor_on_device):
  return (x_shard) * scale_factor_on_device

# Prepare data: ensure the leading dimension matches num_devices for sharding.
# Each device will get one row from 'sharded_x' and one element from 'sharded_scale_factors'.
example_elements_per_device = 4
sharded_x_data = jnp.arange(num_devices * example_elements_per_device, dtype=jnp.float32).reshape(num_devices, example_elements_per_device)

# Example: different scale factor for each device to demonstrate sharded arguments.
sharded_scale_factors_data = jnp.array([10.0 * (i + 1) for i in range(num_devices)], dtype=jnp.float32)

# pmap the function
# in_axes=(0, 0) means:
# - For 'x_shard' (first arg): its 0-th axis is the device axis (sharded).
# - For 'scale_factor_on_device' (second arg): its 0-th axis is the device axis (sharded).
# axis_name='i' names the mapped axis, used by collectives like jax.lax.axis_index.
pmapped_scaled_square_fn = jax.pmap(scale, in_axes=(0, 0), axis_name='i')

# Run the pmapped function
result_from_pmap = pmapped_scaled_square_fn(sharded_x_data, sharded_scale_factors_data)

print(f"\nInput sharded_x_data:\n{sharded_x_data}")
print(f"Input sharded_scale_factors_data:\n{sharded_scale_factors_data}")
print(f"Pmap result (output from all devices stacked):\n{result_from_pmap}")
print(f"Result shape: {result_from_pmap.shape} (Should be: num_devices, elements_per_device)")


Input sharded_x_data:
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]
 [12. 13. 14. 15.]]
Input sharded_scale_factors_data:
[10. 20. 30. 40.]
Pmap result (output from all devices stacked):
[[  0.  10.  20.  30.]
 [ 80. 100. 120. 140.]
 [240. 270. 300. 330.]
 [480. 520. 560. 600.]]
Result shape: (4, 4) (Should be: num_devices, elements_per_device)


### 2. Preparing Data for Parallel Execution: Physical Splitting (`numpy.split`)

For `pmap` to distribute data effectively using `in_axes` (covered next), the input arrays typically need a leading dimension that matches the number of devices we are `pmap`ping over. Each slice along this leading dimension goes to one device.

A common way to prepare data is to use `numpy.split` (or `jnp.split`) if your data isn't already in this shape.

In [9]:
# Global data for Aurora's simulations (e.g., a batch of inputs)
# Ensure batch_size is a multiple of num_devices for easy splitting.
total_batch_size = 8 * num_devices
feature_size = 3
global_aurora_data = np.random.rand(total_batch_size, feature_size).astype(np.float32)
print(f"Global Aurora data shape: {global_aurora_data.shape}")

# We need to split the data into num_devices chunks along the batch axis (axis 0).
# Each chunk will be processed by one device.
sharded_data_list_np = np.split(global_aurora_data, num_devices, axis=0)

# sharded_data_list_np is a list of NumPy arrays.
# For pmap, we typically stack them back into a single NumPy array
# where the first dimension is the device dimension.
prepared_data_for_pmap_stacked = np.stack(sharded_data_list_np, axis=0)

# Alternatively, if global_data's first dimension is already divisible by num_devices,
# we can often just reshape it directly.
per_device_batch_size = total_batch_size // num_devices
prepared_data_for_pmap_reshaped = global_aurora_data.reshape(num_devices, per_device_batch_size, feature_size)

print(f"\nShape after np.stack(np.split(...)): {prepared_data_for_pmap_stacked.shape}")
print(f"Shape after direct reshape: {prepared_data_for_pmap_reshaped.shape}")
assert prepared_data_for_pmap_stacked.shape == prepared_data_for_pmap_reshaped.shape
# Both shapes should be: (num_devices, per_device_batch_size, feature_size)

# For subsequent examples, we'll use the reshaped version.
current_prepared_data = prepared_data_for_pmap_reshaped

Global Aurora data shape: (32, 3)

Shape after np.stack(np.split(...)): (4, 8, 3)
Shape after direct reshape: (4, 8, 3)


This `current_prepared_data` now has a leading axis that `pmap` can map to devices.

### 3. Mapping Data to Devices: The `in_axes` Argument

The `in_axes` argument of `pmap` is crucial. It tells `pmap` how to distribute the input arguments of the mapped function across the devices.
* `in_axes` is a tuple/list specifying, for each positional argument of the function, which axis of that argument should be mapped to the devices.
* An integer `d` means the `d`-th axis of the corresponding input array is the "device axis." Data along this axis is split and distributed.
* `None` means the corresponding input argument is *replicated* (broadcasted) to all devices. This is for data that should be identical for all parallel executions.

In [10]:
def process_data_shard_example(data_shard, replicated_scalar_param):
  # data_shard is the piece of data specific to this device.
  # replicated_scalar_param is the same for all devices.
  return jnp.sum(data_shard * replicated_scalar_param, axis=0) # Example: sum features per device

# Our current_prepared_data has shape (num_devices, per_device_batch_size, feature_size)
# We want to map axis 0 of current_prepared_data to devices.
# The replicated_scalar_param should be the same for all devices.
aurora_scalar_param = jnp.array(5.0, dtype=jnp.float32) # A scalar JAX array

# pmap the function
# in_axes=(0, None) means:
# - For 'data_shard' (first arg): map its 0-th axis to devices.
# - For 'replicated_scalar_param' (second arg): replicate it to all devices.
pmapped_process_fn = jax.pmap(process_data_shard_example,
                              in_axes=(0, None),
                              axis_name='data_axis')

# Run the pmapped function
result_in_axes_example = pmapped_process_fn(current_prepared_data, aurora_scalar_param)

print(f"\n--- pmap with in_axes=(0, None) ---")
print(f"Input data shape for pmap: {current_prepared_data.shape}")
print(f"Replicated param: {aurora_scalar_param}")
print(f"Result from pmap (raw): \n{result_in_axes_example}")
print(f"Result shape: {result_in_axes_example.shape}")
# The result will have a leading dimension equal to num_devices,
# each element being the output from one device.
# Shape: (num_devices, feature_size) because we summed over per_device_batch_size.


--- pmap with in_axes=(0, None) ---
Input data shape for pmap: (4, 8, 3)
Replicated param: 5.0
Result from pmap (raw): 
[[23.570044 17.356874 21.184027]
 [18.736385 23.411463 19.134218]
 [18.431837 19.805702 25.691887]
 [17.981094 19.538786 15.820544]]
Result shape: (4, 3)


Here, `current_prepared_data` (shape `(N, B, F)`) is passed. `in_axes=(0, None)` means:
* The first argument (`data_shard`) receives slices from `current_prepared_data` along its 0-th axis. So, device `i` gets `current_prepared_data[i]`, which has shape `(B, F)`.
* The second argument (`replicated_scalar_param`) receives the `aurora_scalar_param` value replicated on all devices.

### 4. Consolidating Results: The `out_axes` Argument

Just as `in_axes` controls input distribution, `out_axes` controls how results from each device are combined (or kept separate) in the final output array.
* If `out_axes` is `d` (an integer), the outputs from each device (which are expected to be arrays) are stacked along a new axis `d` in the output. The default for `out_axes` is `0`.
* If `out_axes` is `None`, the output from each device must be identical, and `pmap` returns just one copy of this identical output. This is useful if all devices are expected to compute the exact same final scalar after some collective, for example.

In [11]:
print(f"\n--- Understanding out_axes ---")
print(f"Previous pmap result (default out_axes=0): \n{result_in_axes_example}")
print(f"Shape (default out_axes=0): {result_in_axes_example.shape}") # (num_devices, feature_size)

# Example: What if the function returned a scalar from each device,
# and we still want them stacked (default out_axes=0)?
def process_to_scalar_sum_example(data_shard, replicated_param):
    return jnp.sum(data_shard * replicated_param) # Returns a scalar

pmapped_scalar_output_fn = jax.pmap(process_to_scalar_sum_example, in_axes=(0, None)) # out_axes defaults to 0
scalar_results_stacked = pmapped_scalar_output_fn(current_prepared_data, aurora_scalar_param) # uses data from previous cell
print(f"\nScalar results from each device (stacked by default out_axes=0):\n{scalar_results_stacked}")
print(f"Shape: {scalar_results_stacked.shape}") # (num_devices,)

# Example: If out_axes=None, all devices must return the same value.
# This is usually used when a collective operation (like psum over all results)
# already produces an identical result on all devices.
def process_and_return_fixed_value(data_shard_ignored, param_ignored):
    # For out_axes=None to be valid, the values returned by each device MUST be identical.
    # This function simulates that by returning a constant.
    # A more realistic scenario involves a collective (see later section).
    return jnp.array(42.0)

# All inputs are effectively ignored or could be None if the function doesn't use them.
# For pmap to trace correctly and know the number of devices,
# it usually infers from sharded inputs. If all in_axes are None,
# pmap applies to all jax.local_devices().
# We provide a dummy sharded input to make the number of devices explicit for pmap.
dummy_sharded_input_for_out_axes_none = jnp.zeros((num_devices, 1))

pmapped_identical_output_fn = jax.pmap(process_and_return_fixed_value,
                                      in_axes=(0, None), # Dummy sharded, dummy replicated
                                      out_axes=None) # Key part for this example

identical_result = pmapped_identical_output_fn(dummy_sharded_input_for_out_axes_none, None)
print(f"\nIdentical result (out_axes=None):\n{identical_result}")
print(f"Shape: {identical_result.shape}") # Scalar, no device dimension


--- Understanding out_axes ---
Previous pmap result (default out_axes=0): 
[[23.570044 17.356874 21.184027]
 [18.736385 23.411463 19.134218]
 [18.431837 19.805702 25.691887]
 [17.981094 19.538786 15.820544]]
Shape (default out_axes=0): (4, 3)

Scalar results from each device (stacked by default out_axes=0):
[62.110943 61.282063 63.929436 53.340424]
Shape: (4,)

Identical result (out_axes=None):
42.0
Shape: ()


Typically, `out_axes=0` is the most common behavior, giving you an array where the 0-th axis represents the devices.

### 5. Handling Constants Efficiently: `static_broadcasted_argnums`

Sometimes, arguments to your `pmap`ped function are Python scalars/strings or JAX arrays that should be treated as compile-time constants and broadcasted efficiently, rather than being processed as per-device data or runtime replicated JAX arrays. This is where `static_broadcasted_argnums` comes in.

* Arguments listed in `static_broadcasted_argnums` are "baked into" the compiled function for each device.
* They must be hashable and define `__eq__` (standard Python types like int, str, bool, or JAX arrays that are compile-time constants).
* This avoids unnecessary device transfers or replication of data that is truly static.
* If an argument is static, it is NOT included in `in_axes` corresponding to its position. `in_axes` entries are for non-static arguments.

In [15]:
def my_static_op_example(x_dynamic_shard, static_python_val: int, another_dynamic_val_replicated):
  # static_python_val will be a compile-time constant here.
  # JAX traces the function for each unique value of static_python_val.
  print(f"my_static_op_example TRACED/RUN with static_python_val = {static_python_val}")
  if static_python_val > 5:
    return x_dynamic_shard * 100 + another_dynamic_val_replicated
  else:
    return x_dynamic_shard * static_python_val + another_dynamic_val_replicated

# static_python_val (arg index 1) will be static.
# `in_axes` should only specify axes for dynamic arguments.
# arg0 (x_dynamic_shard) is sharded -> in_axes[0] = 0
# arg2 (another_dynamic_val_replicated) is replicated -> in_axes[1] = None
pmapped_static_op_fn = jax.pmap(my_static_op_example,
                              static_broadcasted_argnums=(1,), # Index of 'static_python_val'
                              in_axes=(0, None, None)
                            )

# Take a slice for simplicity, e.g., first feature from each device's batch.
# Shape: (num_devices, per_device_batch_size)
dynamic_input_shards = current_prepared_data[:, :, 0]
another_dynamic_param_replicated = jnp.array(0.5, dtype=jnp.float32) # Replicated via in_axes=None

print(f"\n--- pmap with static_broadcast_argnums ---")
# Call with static_python_val = 7
# JAX will compile a version of my_static_op_example specialized for static_python_val=7.
print("Calling with static_val=7 (expect compilation if first time with this value)...")
result_static_7 = pmapped_static_op_fn(dynamic_input_shards, 7, another_dynamic_param_replicated)
print(f"Result with static_val=7 (first element from each device):\n{result_static_7[:, 0]}")

# Call with static_python_val = 3
# This will trigger a re-compilation because 'static_python_val' changed,
# unless a version for 3 was already compiled and cached.
print("\nCalling with a different static_val (static_val=3), expect a JAX compilation message (if not cached)...")
result_static_3 = pmapped_static_op_fn(dynamic_input_shards, 3, another_dynamic_param_replicated)
print(f"Result with static_val=3 (first element from each device):\n{result_static_3[:, 0]}")


--- pmap with static_broadcast_argnums ---
Calling with static_val=7 (expect compilation if first time with this value)...
my_static_op_example TRACED/RUN with static_python_val = 7
Result with static_val=7 (first element from each device):
[65.99184  89.81062  68.202095 53.22373 ]

Calling with a different static_val (static_val=3), expect a JAX compilation message (if not cached)...
my_static_op_example TRACED/RUN with static_python_val = 3
Result with static_val=3 (first element from each device):
[2.464755  3.1793187 2.5310628 2.0817118]


**Pro-Tip:** Using `static_broadcasted_argnums` is crucial for arguments that control the structure of the computation (e.g., boolean flags, dimensions used in reshapes inside the function). Changing a static argument forces recompilation if that specific version isn't cached.

### 6. Targeting Specific Devices with `pmap`'s `devices` Argument

By default, `pmap` uses all `jax.local_devices()`. However, you can explicitly specify a list of devices for `pmap` to use. This is useful for more complex scenarios, like manually assigning parts of a model to specific subsets of devices on a host (though `jax.sharding` APIs are generally preferred for more advanced device management today).

In [17]:
# Use the first half of available devices, or just the first device if only 1 would result.
num_subset_devices = max(1, num_devices // 2)
subset_target_devices = devices[:num_subset_devices]

print(f"\n--- pmap with explicit 'devices' argument ---")
print(f"All available devices: {devices}")
print(f"Targeting subset of devices for pmap: {subset_target_devices}")

# Prepare data specifically for this subset of devices.
# Ensure total_batch_size_for_subset is a multiple of num_subset_devices.
per_device_batch_for_subset = current_prepared_data.shape[1] # Use same per-device batch as before
total_batch_size_for_subset = num_subset_devices * per_device_batch_for_subset

# For simplicity, we'll reshape a slice of the previously prepared data.
subset_global_data_slice = global_aurora_data[:total_batch_size_for_subset, :]

subset_prepared_data = subset_global_data_slice.reshape(num_subset_devices, per_device_batch_for_subset, feature_size)

pmapped_on_subset_fn = jax.pmap(process_data_shard_example, # Using the function from in_axes example
                              in_axes=(0, None), # data sharded, param replicated
                              devices=subset_target_devices,
                              axis_name='data_axis_subset')

result_on_subset = pmapped_on_subset_fn(subset_prepared_data, aurora_scalar_param)
print(f"Input data shape for subset pmap: {subset_prepared_data.shape}")
print(f"Result from pmap on subset of devices:\n{result_on_subset}")
print(f"Result shape: {result_on_subset.shape}") # Leading dim will be num_subset_devices


--- pmap with explicit 'devices' argument ---
All available devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
Targeting subset of devices for pmap: [CpuDevice(id=0), CpuDevice(id=1)]
Input data shape for subset pmap: (2, 8, 3)
Result from pmap on subset of devices:
[[23.570044 17.356874 21.184027]
 [18.736385 23.411463 19.134218]]
Result shape: (2, 3)


### 7. Inter-Device Alchemy: Collectives within `pmap`

When each device processes its data shard, we often need to combine or share information *between* these parallel executions. This is where **collective operations** come in. `jax.lax` provides several collectives that work seamlessly inside `pmap`.

Inside a `pmap`ped function:
* An "axis name" is implicitly defined by `pmap` (internally often named 'i', but you can specify one using `axis_name` argument in `pmap`). You refer to this axis in collectives.
* **`jax.lax.psum(x, axis_name)`**: Computes the sum of `x` across all devices participating in the `pmap` along the named axis. The result is broadcast back to all devices, so each device receives the same total sum.
* **`jax.lax.pmean(x, axis_name)`**: Computes the mean of `x` across all devices. Result is broadcast.
* **`jax.lax.all_gather(x, axis_name, tiled=False)`**: Gathers the value of `x` from all devices and concatenates them along a new leading axis (if `x` is a vector/matrix) or creates a vector (if `x` is scalar). The result is broadcast. If `tiled=True`, it's used when `x` is a slice of a larger array that was split across devices, and `all_gather` reconstructs the full array.

In [18]:
# Define the axis name for pmap collectives, matching pmap's axis_name argument.
PMAP_COLLECTIVES_AXIS_NAME = 'aurora_data_parallel_mesh'

def process_with_collectives_example(data_shard_input, model_weights_replicated):
    # 1. Per-device computation (e.g., local "predictions" or "gradients")
    # data_shard_input shape: (per_device_batch_size, in_features)
    # model_weights_replicated shape: (in_features, out_features)
    local_output = jnp.dot(data_shard_input, model_weights_replicated) # Shape: (per_device_batch_size, out_features)

    # 2. Sum outputs across all devices using psum.
    # Each device will receive the same 'global_sum_output'.
    global_sum_output = jax.lax.psum(local_output, axis_name=PMAP_COLLECTIVES_AXIS_NAME)

    # 3. Calculate the mean of a specific value (e.g., first element of the first batch item's output) across devices.
    mean_of_specific_value = jax.lax.pmean(local_output[0,0], axis_name=PMAP_COLLECTIVES_AXIS_NAME)

    # 4. Gather all 'local_output' first rows from all devices using all_gather.
    # Each device's local_output[0,:] is a vector of shape (out_features,).
    # all_gather will stack these vectors, resulting in an array of shape (num_devices, out_features) on each device.
    gathered_first_rows_from_all_devices = jax.lax.all_gather(local_output[0,:], axis_name=PMAP_COLLECTIVES_AXIS_NAME)
    # If tiled=True, all_gather(local_output, ...) would try to reconstruct a global tensor assuming local_output is a tile.

    return global_sum_output, mean_of_specific_value, gathered_first_rows_from_all_devices


# Dummy model weights for the example
in_features = current_prepared_data.shape[2] # feature_size
out_feature_size_example = 2
# These weights will be replicated to all devices via in_axes=None.
aurora_model_weights = jnp.array(np.random.rand(in_features, out_feature_size_example).astype(np.float32))

pmapped_collective_op_fn = jax.pmap(process_with_collectives_example,
                                  in_axes=(0, None), # data sharded, weights replicated
                                  axis_name=PMAP_COLLECTIVES_AXIS_NAME) # Crucial for collectives

# Run the pmapped function
# current_prepared_data shape: (num_devices, per_device_batch_size, in_features)
summed_res, mean_res, gathered_res = pmapped_collective_op_fn(current_prepared_data, aurora_model_weights)

print(f"\n--- pmap with Collectives (axis_name='{PMAP_COLLECTIVES_AXIS_NAME}') ---")
print(f"Model weights (replicated) shape: {aurora_model_weights.shape}")

# --- Analyzing psum result ---
# `summed_res` itself will be sharded by pmap's default out_axes=0.
# So, summed_res has shape (num_devices, per_device_batch_size, out_feature_size_example).
# The key is that summed_res[0], summed_res[1], etc., are ALL IDENTICAL, containing the global sum.
print(f"\nGlobal Summed Result (via psum):")
print(f"  Shape of 'summed_res' (outer device dim from pmap): {summed_res.shape}")
# Verify all devices have the same sum by comparing the slice from device 0 with device 1 (if available).
print(f"  summed_res[0] and summed_res[1] are identical: {np.allclose(summed_res[0], summed_res[1])}")
print(f"  Example value from one device (actual global sum, e.g., first batch item):\n{summed_res[0,0,:]}")

# --- Analyzing pmean result ---
# `mean_res` will have shape (num_devices,) because pmap's default out_axes=0 stacks the scalar mean from each device.
# Again, mean_res[0], mean_res[1], etc., are all identical.
print(f"\nGlobal Mean of Specific Value (via pmean):")
print(f"  Shape of 'mean_res': {mean_res.shape}") # (num_devices,)
print(f"  Value on first device (actual global mean): {mean_res[0]}")
print(f"  mean_res[0] and mean_res[1] are identical: {np.allclose(mean_res[0], mean_res[1])}")

# --- Analyzing all_gather result ---
# `gathered_res` will have shape (num_devices, num_devices_again, out_feature_size_example).
# The outer num_devices is from pmap's default out_axes=0.
# The inner num_devices_again is because all_gather(vector_from_each_device) produces a matrix of shape (num_devices, out_feature_size_example) on EACH device.
# So, gathered_res[i] contains this (num_devices, out_feature_size_example) matrix as seen by device i.
# And gathered_res[i,j,:] is the local_output[0,:] from original device j.
print(f"\nGathered First Rows (via all_gather):")
print(f"  Shape of 'gathered_res': {gathered_res.shape}")
print(f"  Example value (all gathered rows as seen by device 0):\n{gathered_res[0]}")
print(f"  gathered_res[0] and gathered_res[1] are identical: {np.allclose(gathered_res[0], gathered_res[1])}")


--- pmap with Collectives (axis_name='aurora_data_parallel_mesh') ---
Model weights (replicated) shape: (3, 2)

Global Summed Result (via psum):
  Shape of 'summed_res' (outer device dim from pmap): (4, 8, 2)
  summed_res[0] and summed_res[1] are identical: True
  Example value from one device (actual global sum, e.g., first batch item):
[3.1516733 2.5336175]

Global Mean of Specific Value (via pmean):
  Shape of 'mean_res': (4,)
  Value on first device (actual global mean): 0.7879183292388916
  mean_res[0] and mean_res[1] are identical: True

Gathered First Rows (via all_gather):
  Shape of 'gathered_res': (4, 4, 2)
  Example value (all gathered rows as seen by device 0):
[[0.6746523  0.5609971 ]
 [0.699285   0.61584073]
 [1.0759616  0.7655081 ]
 [0.7017747  0.5912713 ]]
  gathered_res[0] and gathered_res[1] are identical: True


Collectives are the backbone of distributed training, allowing gradients to be summed, metrics to be averaged, or parameters to be synchronized across all parallel workers.

### 8. The Memory Footprint of `pmap`: Considerations for Aurora

While `pmap` enables parallelism, it's crucial to understand its memory implications for Project Aurora:
* **Data Replication (for `in_axes=None`):** Arguments with `in_axes=None` are replicated. If these are large arrays (e.g., model parameters in simple data parallelism), a full copy resides on *each* device. This can severely limit model size.
* **Per-Device Data Shards:** The sharded data still consumes memory on each device.
* **Intermediate Activations:** Each device computes its own intermediate activations during the forward/backward pass. These also consume memory per device.
* **Function Code:** The compiled function code is replicated on each device, usually a smaller concern.
* **Output Buffers:** Buffers for outputs are allocated on each device.

If your model parameters are too large to fit on a single device alongside activations and data shards, simple `pmap` data parallelism (where parameters are replicated) won't work. This limitation is a key motivator for more advanced sharding strategies (model parallelism, Fully Sharded Data Parallel - FSDP) that we'll explore later in Project Aurora.

### 9. `pmap` in Practice: Use Cases, Strengths, and Limitations

**Ideal Use Cases & Strengths for Aurora:**
* **Data Parallel Training (Small to Medium Models):** When model parameters and optimizer states fit comfortably on each device, `pmap` is excellent for distributing the data batch and speeding up training.
* **Parallel Inference/Evaluation:** Evaluating a model on large batches of data by splitting the batch across devices.
* **Simple Embarrassingly Parallel Tasks:** Any task where you can split work into independent chunks (e.g., running multiple simulations with different initial seeds, if seeds are part of sharded input).
* **Ease of Use:** For basic data parallelism, `pmap` is relatively straightforward to implement compared to more manual sharding approaches.

**Shortcomings & Limitations:**
* **Single Host Only:** `pmap` is designed for devices connected to a single host machine. It doesn't inherently scale across multiple nodes in a cluster (that's for `jax.distributed` and global meshes, covered later).
* **Memory Constraints (Replicated Parameters):** As mentioned, replicating large model parameters across all devices is a major bottleneck for very large models.
* **Limited Parallelism Types:** `pmap` is primarily for data parallelism. Implementing complex model parallelism or pipeline parallelism with `pmap` alone is cumbersome and often inefficient.
* **Collectives Scope:** Collectives operate over all devices `pmap` is running on. Finer-grained control over communication subgroups is not directly supported by `pmap` (requiring `shard_map` or custom mesh setups explored in later chapters).

For Project Aurora, `pmap` is a vital initial tool for leveraging multi-accelerator nodes. However, to build truly colossal models, we will need to transcend its limitations with the more advanced sharding APIs introduced in subsequent chapters.

## Chapter Debrief: `pmap` Milestones Achieved!

**Summary:**
Excellent work, Architect! You've successfully unlocked single-host, multi-device parallelism for Project Aurora using `jax.pmap`. You've learned to replicate computations, strategically distribute data shards using `in_axes`, manage output aggregation with `out_axes`, and leverage powerful inter-device communication via collectives like `psum`, `pmean`, and `all_gather`. You also now appreciate the memory model of `pmap` and its ideal use cases and limitations.

**Key Takeaways:**

* `jax.pmap` enables SPMD-style data parallelism by replicating a function across local devices.
* `in_axes` dictates how input array axes are mapped to devices (sharded) or replicated.
* `out_axes` controls how results from devices are combined or returned. `static_broadcasted_argnums` handles compile-time constants efficiently.
* `jax.lax` collectives (`psum`, `pmean`, `all_gather`) are essential for inter-device communication within `pmap`, using a specified `axis_name`.
* `pmap` is powerful for data parallelism when models fit in single-device memory, but parameter replication limits its use for extremely large models. Data must typically be pre-sharded with a leading device axis.

**Transition:**
While `pmap` significantly boosts our processing power on a single Aurora node, its view of devices is somewhat "flat"—a simple list. For more sophisticated parallelism strategies, especially when coordinating many devices in potentially multi-dimensional topologies (even on a single host initially), we need a more structured way to view and address our hardware. This leads us to **Chapter 1.3: Global Device `Mesh` - Abstracting Hardware Topology for Advanced Sharding**, where we'll learn to define logical meshes that map to our physical devices, paving the way for the modern `jax.sharding` API.

## Further Reading & Resources

* **JAX API Documentation - `jax.pmap`**: [https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
* **JAX Lax Parallel Operators**: [https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)