<a href="https://colab.research.google.com/github/mridul-sahu/jax-sharding-tutorials/blob/main/Shard_Map_Detailed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [77]:
!pip install --upgrade jax



# The Grand Shard_map Tutorial: A Wizard's Guide to Parallel Spells 🧙‍♂️📜✨

Welcome, young Data Wizards! Archmage Jaxus here, your guide on this quest to master the **Scroll of Sharding**, known in the common tongue as `jax.shard_map`. You've learned individual spells (`jax.jit`, basic functions), but some Magic Scrolls (datasets, models) are too immense for a single wizard. `shard_map` is the ancient art of **manual parallelism**, where a coven of wizards (your devices) work in concert, each tackling a fragment of the Great Scroll, and communicating through **explicit collective spells**.

It's more hands-on than the "Auto-Sort Spell" (`jit` with automatic partitioning), which often does a good job but doesn't always know the specific enchantments you wish to cast. And it's an evolution of the "Elder Scroll of Mapping" (`pmap`), offering more power, flexibility, and even the ability to debug your spells eagerly before committing them to the aether (compiling them).

Let's begin by preparing our School of Magic (Python environment) and summoning our apprentice wizards (simulating 8 devices for this tutorial).

In [2]:
import os
# Ensure we have 8 "apprentice wizards" (CPU devices) for our spells.
# For this tutorial, we'll force JAX to see 8 CPU devices.
# In a real scenario, these would be your actual hardware devices (CPUs, GPUs, TPUs).
# IMPORTANT: Run this cell *before* importing JAX for the first time in a session if you're running locally.
# If you've already imported JAX, you might need to restart your kernel.
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax
import jax.numpy as jnp
import numpy as np # Standard numpy for some array creations
from functools import partial
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.tree_util import tree_map, tree_all
from jax import lax # For collective operations and other utilities

# A little helper spell to check if our distributed magic matches single-wizard results
def allclose(a, b, atol=1e-2, rtol=1e-2): # Using a slightly looser tolerance for tutorial examples
  """Checks if all elements in two pytrees are close."""
  return tree_all(tree_map(partial(jnp.allclose, atol=atol, rtol=rtol), a, b))

print(f"Welcome, Wizards! We have {len(jax.devices())} apprentices ready for their first lesson.")
if len(jax.devices()) != 8:
    print("\n⚠️ Warning: Could not simulate 8 devices as expected.")
    print("The Archmage's incantation for device simulation might have been resisted, or you might be running in an environment where this flag is overridden.")
    print("You can still follow the tutorial, but visualizations and sharding specifics might differ slightly.")
    print("The core principles of the shard_map spell remain the same!")



Welcome, Wizards! We have 8 apprentices ready for their first lesson.


## Chapter 1: The First Sharding Spell - `matmul_basic` 💥

Our first group spell will be a fundamental one: multiplying two Magic Scrolls (matrices). To do this, Archmage Jaxus explains, "We must first define how our wizards are arranged for collaborative spellcasting. This arrangement is known as a **`Mesh`**. Think of it as assigning each wizard a specific workstation in a magical grid."

We'll tell JAX how many wizards are in each dimension of this grid and give names to these dimensions. For instance, we could have a 2D grid.

In [3]:
# Let's arrange our 8 wizards in a 4x2 formation.
# We give names to these dimensions: 'row_of_wizards' and 'col_of_wizards'.
mesh_shape = (4, 2)  # 4 rows of wizards, 2 columns of wizards
mesh_axis_names = ('row_of_wizards', 'col_of_wizards') # Naming the dimensions of our magical grid

# Create the mesh
wizard_mesh = jax.make_mesh(mesh_shape, mesh_axis_names)

print(f"Our wizarding mesh configuration: {wizard_mesh}")
print(f"Total wizards in this mesh: {wizard_mesh.size}")
print(f"Wizardry is organized along these axes: {wizard_mesh.axis_names}")
print(f"Number of wizards in the '{wizard_mesh.axis_names[0]}' dimension: {wizard_mesh.shape[wizard_mesh.axis_names[0]]}")
print(f"Number of wizards in the '{wizard_mesh.axis_names[1]}' dimension: {wizard_mesh.shape[wizard_mesh.axis_names[1]]}")

Our wizarding mesh configuration: Mesh('row_of_wizards': 4, 'col_of_wizards': 2, axis_types=(Auto, Auto))
Total wizards in this mesh: 8
Wizardry is organized along these axes: ('row_of_wizards', 'col_of_wizards')
Number of wizards in the 'row_of_wizards' dimension: 4
Number of wizards in the 'col_of_wizards' dimension: 2


"Excellent!" exclaims Archmage Jaxus. "With our wizards arranged in the `wizard_mesh`, we now need to tell them how to divide the work and how to combine their results. For this, we use **Partitioning Runes**, known as `PartitionSpec` (which we'll call `P` for short). These runes are part of the `in_specs` (how to divide input scrolls) and `out_specs` (how to assemble the final scroll)."

* **`in_specs`**: A set of runes, one for each input scroll. Each rune specifies how that scroll's dimensions are sharded (split) across the named axes of our `wizard_mesh`.
    * For example, `P('row_of_wizards', 'col_of_wizards')` means the scroll's first dimension is split among the 'row_of_wizards' and its second dimension among the 'col_of_wizards'.
    * If a rune uses `None` for a scroll dimension, that dimension is *not* split along any *additional* mesh axis beyond what other parts of the spec might imply for other scroll axes. Each wizard receives a full slice of that dimension.
* **`out_specs`**: Similar runes for the output scroll, dictating how the pieces computed by each wizard are reassembled.

Let's prepare two Magic Scrolls (matrices `scroll_A` and `scroll_B`) and define our first sharded matrix multiplication spell.

In [4]:
# As of now jax.shard_map does not work
from jax.experimental.shard_map import shard_map

# Our two Magic Scrolls (matrices) to be multiplied
# scroll_A: 8x16, scroll_B: 16x4. Result should be 8x4.
scroll_A = jnp.arange(8 * 16., dtype=jnp.float32).reshape(8, 16)
scroll_B = jnp.arange(16 * 4., dtype=jnp.float32).reshape(16, 4)

# The Sharding Spell for matrix multiplication
# We use the 'wizard_mesh' (4 rows, 2 columns of wizards) defined earlier.
@partial(shard_map,
         mesh=wizard_mesh,
         in_specs=(P('row_of_wizards', 'col_of_wizards'), P('col_of_wizards', None)), # How scroll_A and scroll_B are sharded
         out_specs=P('row_of_wizards', None)) # How the result scroll_C is sharded
def matmul_basic_sharded(scroll_A_fragment, scroll_B_fragment):
  # scroll_A is 8x16. Sharded by ('row_of_wizards'(4), 'col_of_wizards'(2)).
  # So, scroll_A_fragment for each wizard will have shape (8/4, 16/2) = (2, 8).

  # scroll_B is 16x4. Sharded by ('col_of_wizards'(2), None).
  # So, scroll_B_fragment for each wizard will have shape (16/2, 4) = (8, 4).
  # The 'None' means its second dimension (size 4) is not further sharded by 'row_of_wizards'. Wizards in the same 'col_of_wizards'
  # but different 'row_of_wizards' will get different scroll_A_fragments but the *same* relevant scroll_B_fragment if we consider
  # how data is split. However, for this specific PartitionSpec, each 'col_of_wizards' group gets its unique (8,4) slice of B.
  print(f"Each wizard received A_frag: {scroll_A_fragment.shape}, B_frag: {scroll_B_fragment.shape}")

  # Each wizard computes their part of the sum (local dot product)
  partial_sum_C = jnp.dot(scroll_A_fragment, scroll_B_fragment)
  # partial_sum_C will be (2,4) for each wizard

  # Now, for the magic of combining results!
  # Wizards who are in the same 'row_of_wizards' but different 'col_of_wizards' worked on different
  # parts of the original scroll_A's columns and scroll_B's rows (the contracting dimension).
  # We need to sum their partial_sum_C contributions.
  # The 'col_of_wizards' mesh axis was used for this split, so we sum over it.
  final_C_fragment_for_row = jax.lax.psum(partial_sum_C, axis_name='col_of_wizards')
  # final_C_fragment_for_row is still (2,4) but now contains the sum over 'col_of_wizards' for that row.

  return final_C_fragment_for_row

print("Defining the 'matmul_basic_sharded' spell. Prints will appear when it's first cast (traced).")

Defining the 'matmul_basic_sharded' spell. Prints will appear when it's first cast (traced).


In [5]:
# Cast the spell!
# JAX automatically handles the distribution of scroll_A and scroll_B fragments
# to the wizards according to in_specs.
print("Casting matmul_basic_sharded for the first time (with prints from inside the spell):")
scroll_C_sharded = matmul_basic_sharded(scroll_A, scroll_B)

# Casting it again (prints from inside the spell will likely be JIT-compiled away or not re-traced)
# print("\nCasting again (prints from inside might be optimized away):")
# scroll_C_sharded_again = matmul_basic_sharded(scroll_A, scroll_B)

print(f"\nShape of original scroll_A: {scroll_A.shape}")
print(f"Shape of original scroll_B: {scroll_B.shape}")
print(f"Shape of the sharded result scroll_C: {scroll_C_sharded.shape}") # Expected: (8, 4)

# Let's look at a piece of the output
print(f"A peek at the resulting scroll_C_sharded (first 2 rows):\n{scroll_C_sharded[:2]}")

Casting matmul_basic_sharded for the first time (with prints from inside the spell):
Each wizard received A_frag: (2, 8), B_frag: (8, 4)

Shape of original scroll_A: (8, 16)
Shape of original scroll_B: (16, 4)
Shape of the sharded result scroll_C: (8, 4)
A peek at the resulting scroll_C_sharded (first 2 rows):
[[ 4960.  5080.  5200.  5320.]
 [12640. 13016. 13392. 13768.]]


"Observe!" Archmage Jaxus gestures towards the glowing result. "Each wizard performed a local calculation on their fragments. Then, they used a crucial collective spell: `jax.lax.psum` (Parallel Sum). This spell instructed wizards sharing the same 'row_of_wizards' to sum their individual `partial_sum_C` pieces along the 'col_of_wizards' dimension. This is how the full dot product is correctly calculated across the distributed pieces."

The `out_specs=P('row_of_wizards', None)` then tells JAX that the final `scroll_C_sharded` should be formed by taking the (2x4) `final_C_fragment_for_row` from each of the 4 'row_of_wizards' groups and concatenating them along the first dimension, resulting in the final 8x4 scroll. The `None` for the second dimension indicates that the (already correctly summed) second dimension of the fragments is used as is.

"But is our collective magic true?" he ponders. "We must always verify our sharded spells against the work of a single, focused wizard."

In [6]:
# Single wizard computation for reference
scroll_C_single_wizard = jnp.dot(scroll_A, scroll_B)

print(f"A peek at the single wizard's scroll_C (first 2 rows):\n{scroll_C_single_wizard[:2]}")

# Verify our sharded spell against the single wizard's result
verification_passed = allclose(scroll_C_sharded, scroll_C_single_wizard)
if verification_passed:
    print(f"\nSpell Succeeded! Our sharded spell matches the single wizard's calculation. ✨")
else:
    print(f"\nSpell Miscast! Our sharded spell does NOT match the single wizard. Check the incantations! ❌")

A peek at the single wizard's scroll_C (first 2 rows):
[[ 4960.  5080.  5200.  5320.]
 [12640. 13016. 13392. 13768.]]

Spell Succeeded! Our sharded spell matches the single wizard's calculation. ✨


"Indeed, our magic holds true!" declares Archmage Jaxus. "Now, let's peer into the Crystal Ball of `visualize_array_sharding`. This magical tool allows us to see precisely how our `scroll_C_sharded` is distributed across the wizarding workstations defined by our `wizard_mesh`."

This visualization will confirm that the output `scroll_C` (8x4) is sharded along its first dimension ('row_of_wizards', which has a size of 4 from our 4x2 mesh). Each of the 4 groups of wizards along this 'row_of_wizards' dimension holds a 2x4 piece of the final scroll. The `None` in `out_specs=P('row_of_wizards', None)` for the second dimension of the scroll meant it was not further split by the 'col_of_wizards' mesh axis during output assembly, as the `psum` had already ensured values were correctly combined along that dimension.

In [7]:
print("Sharding of the resulting Scroll C (scroll_C_sharded):")
jax.debug.visualize_array_sharding(scroll_C_sharded)

# To understand what shard_map's `in_specs` effectively does to the inputs,
# we can manually shard them using jax.device_put with NamedSharding.
# This is conceptually what shard_map prepares for its mapped function.
print("\nFor context, let's visualize how scroll_A would be sharded based on its in_spec P('row_of_wizards', 'col_of_wizards'):")
sharded_A_for_visualization = jax.device_put(scroll_A, NamedSharding(wizard_mesh, P('row_of_wizards', 'col_of_wizards')))
jax.debug.visualize_array_sharding(sharded_A_for_visualization)

print("\nAnd how scroll_B would be sharded based on its in_spec P('col_of_wizards', None):")
sharded_B_for_visualization = jax.device_put(scroll_B, NamedSharding(wizard_mesh, P('col_of_wizards', None)))
jax.debug.visualize_array_sharding(sharded_B_for_visualization)

Sharding of the resulting Scroll C (scroll_C_sharded):



For context, let's visualize how scroll_A would be sharded based on its in_spec P('row_of_wizards', 'col_of_wizards'):



And how scroll_B would be sharded based on its in_spec P('col_of_wizards', None):


The visualizations above show how `shard_map` distributes data based on the `Mesh` and `PartitionSpec` you provide. This gives you explicit control over data layout and computation.

This contrasts with `jax.jit`'s automatic parallelization, where the JAX compiler attempts to figure out an efficient sharding strategy for you behind the scenes. With `shard_map`, you are the one defining the parallel execution plan. Both approaches are valuable: `jit` for automatic optimization and `shard_map` for fine-grained manual control when needed.

In [13]:
from jax.sharding import AxisType
from jax.experimental.shard import reshard, auto_axes

@jax.jit
def matmul(a, b):
  return jnp.dot(a, b)

scroll_A_auto = jax.device_put(scroll_A, NamedSharding(wizard_mesh, P('row_of_wizards', 'col_of_wizards')))
scroll_B_auto = jax.device_put(scroll_B, NamedSharding(wizard_mesh, P('col_of_wizards', None)))

print(f"scroll_A_auto: {jax.typeof(scroll_A_auto)}")
print(f"scroll_B_auto: {jax.typeof(scroll_B_auto)}")

scroll_C_auto = matmul(scroll_A_auto, scroll_B_auto)
print(f"scroll_C_Auto: {jax.typeof(scroll_C_auto)}")

# Let's look at a piece of the output
print(f"A peek at the resulting scroll_C_auto (first 2 rows):\n{scroll_C_auto[:2]}")

# Use auto_axes for now as malmul operation is not supported in explicit axis yet.
@auto_axes
def matmul_explicit(a, b):
  return jnp.dot(a, b)

# Explicit Sharding
wizard_mesh_explicit = jax.make_mesh(mesh_shape, mesh_axis_names, axis_types=(AxisType.Explicit, AxisType.Explicit))
scroll_A_explicit = jax.device_put(scroll_A, NamedSharding(wizard_mesh_explicit, P('row_of_wizards', 'col_of_wizards')))
scroll_B_explicit = jax.device_put(scroll_B, NamedSharding(wizard_mesh_explicit, P('col_of_wizards', None)))

print(f"scroll_A_explicit: {jax.typeof(scroll_A_explicit)}")
print(f"scroll_B_explicit: {jax.typeof(scroll_B_explicit)}")

scroll_C_explicit = matmul_explicit(scroll_A_explicit, scroll_B_explicit, out_shardings=NamedSharding(wizard_mesh_explicit, P('row_of_wizards', None)))
print(f"scroll_C_explicit: {jax.typeof(scroll_C_explicit)}")

# Let's look at a piece of the output
print(f"A peek at the resulting scroll_C_explicit (first 2 rows):\n{scroll_C_explicit[:2]}")
print(f"Is all restuls equal? {allclose(scroll_C_sharded, scroll_C_auto)=}")

scroll_A_auto: ShapedArray(float32[8,16])
scroll_B_auto: ShapedArray(float32[16,4])
scroll_C_Auto: ShapedArray(float32[8,4])
A peek at the resulting scroll_C_auto (first 2 rows):
[[ 4960.  5080.  5200.  5320.]
 [12640. 13016. 13392. 13768.]]
scroll_A_explicit: ShapedArray(float32[8@row_of_wizards,16@col_of_wizards])
scroll_B_explicit: ShapedArray(float32[16@col_of_wizards,4])
scroll_C_explicit: ShapedArray(float32[8@row_of_wizards,4])
A peek at the resulting scroll_C_explicit (first 2 rows):
[[ 4960.  5080.  5200.  5320.]
 [12640. 13016. 13392. 13768.]]
Is all restuls equal? allclose(scroll_C_sharded, scroll_C_auto)=True


## Understanding How Maps Handle Array Dimensions

JAX provides different functions for mapping operations over arrays, like `jax.vmap` and `jax.shard_map`. A key difference between them is how they treat the dimensions (or rank) of the arrays they operate on.

Let's first look at `jax.vmap`. It's often described as a **rank-reducing map**.

In [None]:
# Example: jax.vmap
# Suppose we have a function that processes a single matrix (2D array)
def process_matrix(matrix):
  # Example operation: sum its rows
  return jnp.sum(matrix, axis=1)

# Now, let's say we have a stack of matrices (a 3D array)
# e.g., 4 matrices, each of size 2x3
stacked_matrices = jnp.arange(4 * 2 * 3.).reshape(4, 2, 3)
print(f"Original stacked_matrices shape: {stacked_matrices.shape} (rank 3)")

# When we use vmap to map process_matrix over the first axis (the stack of 4 matrices):
# in_axes=0 means map over the 0-th axis of stacked_matrices.
# The 'process_matrix' function will receive individual 2x3 matrices (rank 2).
vmapped_output = jax.vmap(process_matrix, in_axes=0)(stacked_matrices)

# If each 2x3 matrix results in a (2,) vector (sum of 2 rows),
# and we have 4 such matrices, the output will be 4x2.
print(f"vmapped_output shape: {vmapped_output.shape}")

# The number of logical applications of 'process_matrix' is determined
# by the size of the input axis being mapped over (here, 4).

Original stacked_matrices shape: (4, 2, 3) (rank 3)
vmapped_output shape: (4, 2)


In contrast, `jax.shard_map` is a **rank-preserving map**.

When you use `shard_map`, the function you define operates on blocks (shards) of the input arrays. These blocks retain the same rank as the original input arrays. The number of logical applications of your function is determined by the total size of the `Mesh` you define, not by the size of any particular input array axis.

In [15]:
# Example: jax.shard_map
# Let's use a simple 1D mesh of 4 devices for this example.
simple_mesh_1d = Mesh(np.array(jax.devices()[:4]), ('data_parallel_dim',))

# An 8x5 array (rank 2)
big_array = jnp.arange(8 * 5.).reshape(8, 5)
print(f"Original big_array shape: {big_array.shape} (rank 2)")

# The function to be mapped by shard_map
# It will receive a block of 'big_array'.
@partial(shard_map,
         mesh=simple_mesh_1d,
         in_specs=P('data_parallel_dim', None), # Shard 0-th axis of big_array across 'data_parallel_dim' (size 4)
         out_specs=P('data_parallel_dim', None))
def process_block(array_block):
  # If big_array is (8,5) and mesh has 4 devices along 'data_parallel_dim',
  # each array_block will be (8/4, 5) = (2,5). Its rank is still 2.
  print(f"  Inside shard_map: process_block received block of shape: {array_block.shape}")
  # Example operation: sum its rows, result will be (2,) for each block
  return jnp.sum(array_block, axis=1)

# Execute with shard_map
sharded_output = process_block(big_array)

# Each of the 4 devices produces a (2,) vector.
# out_specs=P('data_parallel_dim', None) concatenates these along the first dimension.
# So, final output shape is (4*2,) = (8,).
print(f"sharded_output shape: {sharded_output.shape}")

# The number of logical applications of 'process_block' is determined
# by the mesh size (here, 4).

Original big_array shape: (8, 5) (rank 2)
  Inside shard_map: process_block received block of shape: (2, 5)
sharded_output shape: (8,)


So, the key distinction:
* `vmap`: Reduces rank for the mapped function; number of instances depends on input axis size.
* `shard_map`: Preserves rank for the mapped function; number of instances depends on mesh size.

## Understanding How Mappings Affect Array Rank

JAX provides different ways to map functions over data. A key distinction is how they handle the "rank" (number of dimensions) of arrays.

**`jax.vmap`: The Rank-Reducing Map**

Think of `jax.vmap` as a way to automatically "vectorize" a function. If you have a function that works on a single item (e.g., a vector), `vmap` can make it work on a batch of items (e.g., a matrix where each row is a vector).

When `vmap` applies your function, the function receives inputs with one less dimension than what `vmap` was given (the dimension being mapped over is "removed" for the function's view). The number of times your function is logically applied is determined by the size of this mapped axis.

In [16]:
# vmap example:
# This function expects a 1D array (vector)
def process_vector(vector):
  return jnp.sum(vector * 2)

# A 2D array (a batch of 3 vectors, each of size 4)
batched_vectors = jnp.arange(3 * 4.).reshape(3, 4)
print(f"Original batched_vectors shape: {batched_vectors.shape}")

# Map process_vector over the 0-th axis (the batch dimension)
# process_vector will receive individual 1D arrays of shape (4,).
vmapped_result = jax.vmap(process_vector, in_axes=0, out_axes=0)(batched_vectors)

print(f"vmapped_result: {vmapped_result}")
print(f"vmapped_result shape: {vmapped_result.shape}") # Shape (3,) - one result per input vector

# Reference: what vmap does semantically
expected_vmap_result = jnp.array([process_vector(vector) for vector in batched_vectors])
print(f"Does vmap work as expected? {allclose(vmapped_result, expected_vmap_result)}")
print(f"Number of logical applications of process_vector: {batched_vectors.shape[0]}")

Original batched_vectors shape: (3, 4)
vmapped_result: [12. 44. 76.]
vmapped_result shape: (3,)
Does vmap work as expected? True
Number of logical applications of process_vector: 3


**`jax.shard_map`: The Rank-Preserving Map**

In contrast, `jax.shard_map` does *not* reduce the rank of the arrays for the function it maps. Your function receives "shards" or "blocks" of the input arrays that have the *same rank* as the original global inputs.

The number of logical applications of your function in `shard_map` is determined by the total size of the `Mesh` you define (e.g., a 2x2 mesh means 4 logical applications), not by the size of any particular input array axis.

In [17]:
# shard_map example:
# Let's use a simple 1D mesh of 2 devices for this.
simple_mesh_1d = Mesh(np.array(jax.devices()[:2]), ('data_parallel_axis',))

# An 8x4 array
large_array = jnp.arange(8 * 4.).reshape(8, 4)
print(f"Original large_array shape: {large_array.shape}")

# This function will receive a block of large_array
@partial(shard_map,
         mesh=simple_mesh_1d,
         in_specs=P('data_parallel_axis'), # Shard the 0-th axis of large_array
         out_specs=P('data_parallel_axis'))
def process_block_shmap(array_block):
  # If large_array is (8,4) and mesh has 2 devices along 'data_parallel_axis',
  # array_block will be (8/2, 4) = (4,4). Rank is preserved!
  print(f"  Inside shard_map: process_block_shmap received block of shape: {array_block.shape}")
  return jnp.sum(array_block, axis=1) # Example: sum each row, result (4,) per block

print("Running shard_map example:")
shmap_result = process_block_shmap(large_array)

# Expected output shape:
# Each of the 2 devices produces a (4,) result.
# out_specs=P('data_parallel_axis') concatenates these along the 0-th axis.
# So, (2*4,) = (8,).
print(f"\nshmap_result: {shmap_result}")
print(f"shmap_result shape: {shmap_result.shape}")
print(f"Number of logical applications of process_block_shmap: {simple_mesh_1d.size}")

Original large_array shape: (8, 4)
Running shard_map example:
  Inside shard_map: process_block_shmap received block of shape: (4, 4)

shmap_result: [  6.  22.  38.  54.  70.  86. 102. 118.]
shmap_result shape: (8,)
Number of logical applications of process_block_shmap: 2


In summary:
* `vmap`: Reduces rank for the mapped function; number of instances = size of mapped input axis.
* `shard_map`: Preserves rank for the mapped function; number of instances = size of the `Mesh`.

This rank-preserving nature is key to how `shard_map` allows you to write code that operates on local blocks of a larger, distributed array, while still reasoning about the full dimensionality of those blocks.

## Fine-Tuning Data Distribution: `in_specs`

The `in_specs` argument to `shard_map` is crucial. It's a pytree (often a tuple) of `PartitionSpec` objects, one for each input argument to your mapped function. Each `PartitionSpec` (aliased as `P`) tells JAX how to split the corresponding input array's dimensions across the named axes of your `Mesh`.

* If an input array's dimension is mapped to a mesh axis in `P`, it's split along that array dimension by the size of that mesh axis.
* If an input array's dimension is specified as `None` in `P`, or if a mesh axis is not mentioned at all in the `P` for a particular input, that part of the data is logically replicated for function instances that differ only along such unmentioned/`None` mesh axes. This means each relevant instance gets a full slice of the data corresponding to that dimension.

In [20]:
# Example data: a 12x10 array
input_array_in_specs = jnp.arange(12 * 10.).reshape(12, 10)
print(f"Original input_array_in_specs shape: {input_array_in_specs.shape}")

@partial(shard_map,
         mesh=wizard_mesh,
         in_specs=(P('row_of_wizards', None),), # Shard 0-axis by 'row_of_wizards', 1-axis is not split by 'col_of_wizards'
         out_specs=P('row_of_wizards', 'col_of_wizards')) # out_specs just to make output shape clear
def process_with_in_specs(data_block):
  # data_block shape expected:
  # Axis 0 of input_array_in_specs (size 12) is sharded by 'row_of_wizards' (size 4) -> 12/4 = 3.
  # Axis 1 of input_array_in_specs (size 10) has 'None' in P, so it's not sharded by 'col_of_wizards'.
  # So, each of the 4 'row_of_wizards' groups gets a (3, 10) block.
  # Within each 'row_of_wizards' group, the 2 'col_of_wizards' instances get THE SAME (3,10) block.
  return data_block # Simply return the block for observation

print("\nRunning process_with_in_specs:")
output_from_in_specs = process_with_in_specs(input_array_in_specs)

# Output shape explanation:
# Each of the 8 devices returns a (3,10) block (though blocks are replicated along 'col_of_wizards').
# out_specs P('row_of_wizards', 'col_of_wizards') means:
# - Concatenate the 4 blocks from 'row_of_wizards' -> 4 * 3 = 12 for the first dimension.
# - Concatenate the 2 blocks from 'col_of_wizards' -> 2 * 10 = 20 for the second dimension.
# Resulting shape: (12, 20)
print(f"\nOutput shape from process_with_in_specs: {output_from_in_specs.shape}")

Original input_array_in_specs shape: (12, 10)

Running process_with_in_specs:

Output shape from process_with_in_specs: (12, 20)


## Assembling Results: `out_specs`

Similarly, `out_specs` dictates how the output blocks from each function instance are assembled into the final global array.

* If a mesh axis is named in an output `PartitionSpec` for a particular output array dimension, the blocks from along that mesh axis are concatenated to form that dimension of the global output.
* **Un-tiling**: If a mesh axis is *not* mentioned in an output `P` (or if a dimension in `P` is `None` corresponding to that mesh axis), you are promising `shard_map` that all function instances differing *only* by that unmentioned mesh axis produced the *exact same output block*. `shard_map` will then just use one of these identical blocks, effectively "un-tiling" the result along that mesh axis. This is common after a collective like `psum` ensures outputs are identical.

In [23]:
# Example 1: Full concatenation (tiling)
@partial(shard_map, mesh=wizard_mesh, in_specs=(), out_specs=P('row_of_wizards', 'col_of_wizards'))
def get_constant_block_tiled():
  # Each of the 8 devices returns a 1x1 array
  return jnp.array([[1.0]])

result_tiled = get_constant_block_tiled()
print(f"Result with out_specs=P('row_of_wizards', 'col_of_wizards'):\n{result_tiled}")
print(f"Shape: {result_tiled.shape}") # Expected (4*1, 2*1) = (4,2)

# Example 2: Un-tiling along 'col_of_wizards'
# This is safe because all devices return the same constant block.
@partial(shard_map, mesh=wizard_mesh, in_specs=(), out_specs=P('row_of_wizards', None))
def get_constant_block_untiled_col():
  return jnp.array([[2.0]]) # Each device returns a 1x1 array

result_untiled_col = get_constant_block_untiled_col()
print(f"\nResult with out_specs=P('row_of_wizards', None) (un-tiling 'col_of_wizards'):\n{result_untiled_col}")
print(f"Shape: {result_untiled_col.shape}") # Expected (4*1, 1) = (4,1)

# Example 3: Un-tiling along both axes
@partial(shard_map, mesh=wizard_mesh, in_specs=(), out_specs=P(None, None))
def get_constant_block_untiled_both():
  return jnp.array([[3.0]])

result_untiled_both = get_constant_block_untiled_both()
print(f"\nResult with out_specs=P(None, None) (un-tiling both axes):\n{result_untiled_both}")
print(f"Shape: {result_untiled_both.shape}") # Expected (1,1)

Result with out_specs=P('row_of_wizards', 'col_of_wizards'):
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]
Shape: (4, 2)

Result with out_specs=P('row_of_wizards', None) (un-tiling 'col_of_wizards'):
[[2.]
 [2.]
 [2.]
 [2.]]
Shape: (4, 1)

Result with out_specs=P(None, None) (un-tiling both axes):
[[3.]]
Shape: (1, 1)


It's important to understand that `out_specs` primarily defines how JAX *interprets* the existing data buffers on the devices to form a logical global array. It does not cause physical data movement of output shards between devices.

Conversely, `in_specs` *can* imply data movement if the input arrays passed to the `shard_map`-transformed function are not already sharded according to the `in_specs` and `Mesh`. JAX will handle reshuffling the data to match your specifications.

## Tracking Data Variance: Varying Manual Axes (VMA)

Inside a `shard_map`, JAX can track whether a value is the same or different across the function instances running on your mesh. This is called "Varying Manual Axes" (VMA) analysis. It's enabled by default with `check_vma=True`.

You can inspect the VMA of an array `x` within the mapped function using `jax.typeof(x).vma`. This will show a set of mesh axis names over which `x` is currently varying.
* If an input is sharded along a mesh axis (e.g., `'my_axis'`), its VMA will include `{'my_axis'}`.
* If a collective operation like `jax.lax.psum` sums values across `'my_axis'`, the result of the `psum` will typically be unvarying over `'my_axis'`, so its VMA will not include `'my_axis'`.

In [25]:
# Using a simple 1D mesh of 2 devices for VMA demonstration
vma_mesh = Mesh(np.array(jax.devices()[:2]), ('batch_dim',))

# Data to be sharded
vma_data = jnp.arange(6.) # Shape (6,)

@partial(shard_map, mesh=vma_mesh,
         in_specs=P('batch_dim'), # Shard data along 'batch_dim'
         out_specs=P())          # Expect a fully replicated (unvarying) output
def check_vma_effects(x_block):
  # x_block will be (3,) per device, and it varies along 'batch_dim'
  print(f"  Inside shard_map: x_block type: {jax.typeof(x_block)}")
  print(f"  VMA of x_block: {jax.typeof(x_block).vma}")

  # Sum x_block across all devices in the 'batch_dim' mesh axis
  y_summed = jax.lax.psum(x_block, 'batch_dim')
  # y_summed will be the same on both devices, so it's unvarying over 'batch_dim'
  print(f"  Inside shard_map: y_summed type: {jax.typeof(y_summed)}")
  print(f"  VMA of y_summed: {jax.typeof(y_summed).vma}")

  return y_summed

print("Running check_vma_effects:")
result_vma_check = check_vma_effects(vma_data)
print(f"\nFinal result_vma_check: {result_vma_check}")
print(f"Final result_vma_check type: {jax.typeof(result_vma_check)}") # VMA should be empty globally

Running check_vma_effects:
  Inside shard_map: x_block type: ShapedArray(float32[3]{batch_dim})
  VMA of x_block: frozenset({'batch_dim'})
  Inside shard_map: y_summed type: ShapedArray(float32[3])
  VMA of y_summed: frozenset()

Final result_vma_check: [3. 5. 7.]
Final result_vma_check type: ShapedArray(float32[3])


The `check_vma=True` setting is very useful because it allows `shard_map` to statically verify your `out_specs`. If your `out_specs` implies that an output should be replicated along a certain mesh axis (i.e., you don't mention that mesh axis in `out_specs` for concatenation), but the VMA analysis shows the output is actually varying along that axis, `shard_map` will raise an error. This helps catch subtle bugs in your sharding logic.

In [26]:
# Example of check_vma catching an out_specs error
@partial(shard_map, mesh=vma_mesh,
         in_specs=P('batch_dim'),
         out_specs=P()) # P() implies the output is replicated across ALL mesh axes.
def incorrect_out_specs_example(x_block):
  # x_block IS varying over 'batch_dim' because of in_specs=P('batch_dim').
  # Returning it directly contradicts out_specs=P() which expects an unvarying result.
  return x_block

print("Attempting to run incorrect_out_specs_example (expect an error):")
try:
  incorrect_out_specs_example(vma_data)
except Exception as e:
  print(f"\nSuccessfully caught an error as expected:\n-----\n{e}\n-----")
  print("This error indicates that out_specs requires replication over 'batch_dim', but the output value is varying over 'batch_dim'.")

Attempting to run incorrect_out_specs_example (expect an error):

Successfully caught an error as expected:
-----
shard_map applied to the function 'incorrect_out_specs_example' was given out_specs which require replication which can't be statically inferred given the mesh:

The mesh given has shape (2,) with corresponding axis names ('batch_dim',).

out_specs is PartitionSpec() which implies that the corresponding output value is replicated across mesh axis 'batch_dim', but could not infer replication over any axes

Check if these output values are meant to be replicated over those mesh axes. If not, consider revising the corresponding out_specs entries. If so, consider disabling the check by passing the check_rep=False argument to shard_map.
-----
This error indicates that out_specs requires replication over 'batch_dim', but the output value is varying over 'batch_dim'.


Sometimes, a value might be unvarying over a mesh axis, but you need to treat it as varying, for instance, when combining it with another value that *is* varying, or when passing it as a carry in `jax.lax.scan`. For this, JAX provides `jax.lax.pvary`.

`jax.lax.pvary(x, 'axis_name')` tells JAX to consider `x` as varying along `'axis_name'`, even if it was previously unvarying. It's primarily a type system hint and often a no-op at runtime, but it becomes important for correctness in automatic differentiation (where its transpose is `psum`). JAX often inserts `pvary` implicitly for binary operations to make VMA types match.

In [27]:
unvarying_data_for_pvary = jnp.arange(3.) # This will be replicated

@partial(shard_map, mesh=vma_mesh,
         in_specs=P(), # P() means x_unvarying is NOT sharded, so it's unvarying over 'batch_dim'
         out_specs=P('batch_dim')) # Expect the output to be sharded/varying along 'batch_dim'
def pvary_example_func(x_unvarying):
  print(f"  Inside shard_map, x_unvarying type: {jax.typeof(x_unvarying)}")
  print(f"  VMA of x_unvarying: {jax.typeof(x_unvarying).vma}")

  # Now, explicitly mark it as varying over 'batch_dim'
  y_now_varying = jax.lax.pvary(x_unvarying, 'batch_dim')
  print(f"  Inside shard_map, y_now_varying type: {jax.typeof(y_now_varying)}")
  print(f"  VMA of y_now_varying: {jax.typeof(y_now_varying).vma}")

  # The out_specs=P('batch_dim') is consistent with y_now_varying's VMA.
  return y_now_varying

print("Running pvary_example_func:")
result_pvary = pvary_example_func(unvarying_data_for_pvary)
print(f"\nFinal result_pvary: {result_pvary}")
print(f"Shape of result_pvary: {result_pvary.shape}") # Should be (2 devices * 3 elements) = (6,)

Running pvary_example_func:
  Inside shard_map, x_unvarying type: ShapedArray(float32[3])
  VMA of x_unvarying: frozenset()
  Inside shard_map, y_now_varying type: ShapedArray(float32[3]{batch_dim})
  VMA of y_now_varying: frozenset({'batch_dim'})

Final result_pvary: [0. 1. 2. 0. 1. 2.]
Shape of result_pvary: (6,)


JAX automatically inserts `pvary` for simple binary operations (like `+`, `*`) if one operand is varying and the other is not, to make their VMA types match.

However, for more complex operations like `jax.lax.scan`, you might need to use `jax.lax.pvary` explicitly. `jax.lax.scan` requires that the VMA type of a carry variable at the end of the loop body matches its VMA type at the beginning. If they don't match, JAX will raise an error.

In [29]:
# scan example that demonstrates a VMA type mismatch error
scan_mesh = vma_mesh # mesh_axis_names = ('batch_dim',)

# x is varying, y is unvarying
scan_x_varying = jnp.arange(6.) # Sharded to (3,) per device
scan_y_unvarying = jnp.arange(10., 13.) # Replicated (3,) on each device

@partial(shard_map, mesh=scan_mesh,
         in_specs=(P('batch_dim'), P()), # x is varying, y is not
         out_specs=(P('batch_dim'), P('batch_dim')))
def scan_vma_problem_func(x_vary_block, y_unvary_block):
  # x_vary_block has VMA {'batch_dim'}
  # y_unvary_block has VMA {}

  def loop_body(carry, _unused_input):
    carry_x, carry_y = carry
    # Attempt to swap them:
    # new_carry_x comes from carry_y (VMA {})
    # new_carry_y comes from carry_x (VMA {'batch_dim'})
    # This creates a VMA mismatch for scan's carry consistency rule.
    return (carry_y, carry_x), ()

  # Initial carry: (x_vary_block, y_unvary_block)
  # Expected carry after 1st iter (if it worked): (y_unvary_block, x_vary_block)
  # Problem:
  # - 1st carry element: input VMA {'batch_dim'}, output VMA {} -> Mismatch!
  # - 2nd carry element: input VMA {}, output VMA {'batch_dim'} -> Mismatch!
  try:
    (final_x, final_y), _ = jax.lax.scan(loop_body, (x_vary_block, y_unvary_block), (), length=2)
    return final_x, final_y
  except Exception as e:
    print(f"--- Successfully caught VMA error in scan (as expected) ---\n{e}\n--- End of error ---")
    # Return placeholders to allow the notebook to continue if run interactively
    return jnp.array([-1.0]), jnp.array([-1.0])


print("Running scan_vma_problem_func (expecting an error message):")
_ = scan_vma_problem_func(scan_x_varying, scan_y_unvarying)

Running scan_vma_problem_func (expecting an error message):
--- Successfully caught VMA error in scan (as expected) ---
scan body function carry input and carry output must have equal types, but they differ:

  * the input carry component carry[0] has type float32[3]{batch_dim} but the corresponding output carry component has type float32[3], so the varying manual axes do not match;
  * the input carry component carry[1] has type float32[3] but the corresponding output carry component has type float32[3]{batch_dim}, so the varying manual axes do not match.

This might be fixed by applying `jax.lax.pvary(..., ('batch_dim',))` to the initial carry value corresponding to the input carry component carry[1].
See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma for more information.

Revise the function so that all output types match the corresponding input types.
--- End of error ---


To fix this, we can use `jax.lax.pvary` on the initially unvarying part of the carry (`y_unvary_block` in this case) to tell `scan` (and JAX's type system) to treat it as if it varies along the `'batch_dim'` mesh axis. This makes its VMA type consistent with `x_vary_block` for the purpose of the scan operation.

In [30]:
@partial(shard_map, mesh=scan_mesh,
         in_specs=(P('batch_dim'), P()), # x is varying, y is not initially
         out_specs=(P('batch_dim'), P('batch_dim')))
def scan_vma_fixed_func(x_vary_block, y_unvary_block):
  # x_vary_block has VMA {'batch_dim'}
  # y_unvary_block has VMA {}

  # FIX: Explicitly mark y_unvary_block as varying over 'batch_dim' for scan
  y_treated_as_varying = jax.lax.pvary(y_unvary_block, 'batch_dim')
  # Now, y_treated_as_varying has VMA {'batch_dim'}

  def loop_body(carry, _unused_input):
    carry_x, carry_y = carry # Both now have VMA {'batch_dim'}
    return (carry_y, carry_x), () # Swapping is fine, VMA types remain consistent

  # Initial carry: (x_vary_block, y_treated_as_varying)
  # Both elements of the carry tuple now have VMA {'batch_dim'}
  (final_x, final_y), _ = jax.lax.scan(loop_body, (x_vary_block, y_treated_as_varying), (), length=2)
  return final_x, final_y

print("Running scan_vma_fixed_func:")
fixed_res_x, fixed_res_y = scan_vma_fixed_func(scan_x_varying, scan_y_unvarying)

print(f"\nResult from fixed scan: final_x shape {fixed_res_x.shape}, final_y shape {fixed_res_y.shape}")
# After 2 iterations of swapping, final_x will hold original y values, final_y will hold original x values.
# Example: original x_vary_block on dev0 is [0,1,2], on dev1 is [3,4,5]
# original y_unvary_block is [10,11,12] on both.
# y_treated_as_varying on dev0 is [10,11,12], on dev1 is [10,11,12] (but VMA is batch_dim)
# Iter 1: carry_x=y_t_a_v, carry_y=x_v_b
# Iter 2: carry_x=x_v_b, carry_y=y_t_a_v -> so final_x is x_v_b, final_y is y_t_a_v
# This means fixed_res_x should be like scan_x_varying, and fixed_res_y like scan_y_unvarying tiled.
print(f"fixed_res_x (should correspond to original x values):\n{fixed_res_x}")
print(f"fixed_res_y (should correspond to original y values, now sharded):\n{fixed_res_y}")

Running scan_vma_fixed_func:

Result from fixed scan: final_x shape (6,), final_y shape (6,)
fixed_res_x (should correspond to original x values):
[0. 1. 2. 3. 4. 5.]
fixed_res_y (should correspond to original y values, now sharded):
[10. 11. 12. 10. 11. 12.]


We've seen how VMA helps ensure correctness. Now, let's explore how different function instances in `shard_map` actually communicate and coordinate. This is done using **Collective Operations**. These are special functions within `jax.lax` that operate across devices in a mesh, using the mesh axis names you've defined.

**Collective Operation: `jax.lax.psum` (Parallel Sum)**

The first collective we'll look at (and which we've already used) is `jax.lax.psum(x, axis_name)`.
This computes an all-reduce sum of `x` across all devices along the specified `axis_name` (or a tuple of axis names).
* Each device contributes its local value of `x`.
* All devices participating in the sum receive the *same* total sum.
* Because the result is identical on all devices along the `axis_name` used for the sum, this often allows you to use an un-tiled `out_spec` for that dimension (e.g., `P(None)` if `axis_name` was the only mesh axis, or `P('other_axis', None)` if summing along one axis of a 2D mesh).

In [33]:
# psum example with a 1D mesh
# Let's use 4 devices for this
psum_mesh_1d = Mesh(np.array(jax.devices()[:4]), ('device_axis',))

# Each device will have a small array
data_for_psum = jnp.arange(4. * 2.).reshape(4, 2) # Global data: 4 devices, each gets a (1,2) block

@partial(shard_map, mesh=psum_mesh_1d,
         in_specs=P('device_axis'),    # Shard along 'device_axis', each device gets a (1,2) block
         out_specs=P(None))          # Output is replicated (un-tiled)
def psum_example_1d(x_block):
  # x_block is (1,2) on each of the 4 devices
  # Sum x_block across all devices along 'device_axis'
  summed_block = jax.lax.psum(x_block, 'device_axis')
  # summed_block will be identical on all 4 devices.
  # Its VMA over 'device_axis' will be empty.
  print(f"AFTER psum VMA: {jax.typeof(summed_block).vma}")

  return summed_block

print("Running psum_example_1d:")
result_psum_1d = psum_example_1d(data_for_psum)

print(f"\nFinal result_psum_1d: {result_psum_1d}")
print(f"Shape of final result: {result_psum_1d.shape}") # Expected (1,2) because of P(None) out_spec

# Manually calculate what we expect:
# Device 0 block: [[0., 1.]]
# Device 1 block: [[2., 3.]]
# Device 2 block: [[4., 5.]]
# Device 3 block: [[6., 7.]]
# Sum: [[0+2+4+6, 1+3+5+7]] = [[12., 16.]]
expected_sum = jnp.sum(data_for_psum.reshape(4,1,2), axis=0) # Reshape to be (num_devices, block_shape)
print(f"Expected sum (manual calculation): {expected_sum}")
assert allclose(result_psum_1d, expected_sum)

Running psum_example_1d:
AFTER psum VMA: frozenset()

Final result_psum_1d: [[12. 16.]]
Shape of final result: (1, 2)
Expected sum (manual calculation): [[12. 16.]]


In the example above, `out_specs=P(None)` was used. Since `psum` over `'device_axis'` ensures every device in that mesh axis gets the same summed result, we don't need to concatenate these identical results. `P(None)` tells `shard_map` to just pick one copy. If it were `P('device_axis')`, the (1,2) result from each of the 4 devices would be concatenated, yielding a (4,2) array where each row is identical.

`psum` can also operate over multiple mesh axes if you have a multi-dimensional mesh.

## Collective Operation: `jax.lax.all_gather`

Another fundamental collective is `jax.lax.all_gather(x, axis_name, tiled=...)`. This operation gathers the local arrays `x` from all devices along the specified `axis_name` and makes the combined data available to each of those devices.

* `axis_name`: The mesh axis (or axes) along which to gather.
* `tiled=True`: The gathered blocks are concatenated along an existing axis of `x`. The size of this axis in the output will be the original size multiplied by the number of devices in `axis_name`.
* `tiled=False` (default): The gathered blocks are stacked along a *new* axis, effectively increasing the rank of the array. The size of this new axis will be the number of devices in `axis_name`.

In [43]:
# all_gather example with a 1D mesh and tiled=True
# Let's use 4 devices again
all_gather_mesh_1d = Mesh(np.array(jax.devices()[:4]), ('gather_axis',))

# Each device has a small, unique array, e.g., just a scalar for simplicity
data_for_all_gather = jnp.arange(4, dtype=jnp.float32) # Global data: [0., 1., 2., 3.]
# With in_specs=P('gather_axis'), device 0 gets [0.], device 1 gets [1.], etc.

def all_gather_example_tiled(x_block):
  # x_block is a scalar (shape ()) on each device
  print(f"  BEFORE all_gather: x_block = {x_block} (shape {x_block.shape})")

  # Gather x_block from all devices along 'gather_axis'
  # tiled=True means concatenate. Since x_block is scalar, effectively it creates a 1D array.
  # If x_block was, say, (2,), and gather_axis_size=4, result would be (2*4=8,).
  # Here, x_block=(), result on each device will be (4,).
  gathered_data = jax.lax.all_gather(x_block, 'gather_axis', tiled=True)

  print(f"  AFTER all_gather: gathered_data = {gathered_data} (shape {gathered_data.shape})")
  # Each device now has the full array [0., 1., 2., 3.]

  # For the out_specs P('gather_axis') to make sense with a (4,) array on each device,
  # the output needs to be sharded. So, let's return a processed piece.
  # Or, if we wanted each device to have the full gathered array, out_specs would be P().
  # Let's assume we process this gathered_data and the output is meant to be sharded again.
  # For this example, let's just show the gathered data for one device and adjust out_specs.
  return gathered_data

print("Running all_gather_example_tiled (tiled=True):")
result_all_gather_tiled = shard_map(
    all_gather_example_tiled,
    mesh=all_gather_mesh_1d,
    in_specs=P('gather_axis'), # Each device gets one scalar from data_for_all_gather
    out_specs=P('gather_axis'), # Output will be the concatenated result, sharded
)(data_for_all_gather)

print(f"\nFinal result_all_gather_tiled: {result_all_gather_tiled}")
print(f"Shape of final result: {result_all_gather_tiled.shape}") # Expected (4,)


# Modify out_specs for clarity: If each device has the full gathered array,
# and we want the global output to be that full array (replicated), use P().
print("Running all_gather_example_tiled_replicated_out (tiled=True):")
result_all_gather_tiled = shard_map(
    all_gather_example_tiled,
    mesh=all_gather_mesh_1d,
    in_specs=P('gather_axis'),
    out_specs=P(), # Output is the full gathered array, replicated.
    check_rep=False, # would be check_vma after update, We need jax.lax.all_gather_invariant without this. (no API yet)
)(data_for_all_gather)
print(f"\nFinal result_all_gather_tiled: {result_all_gather_tiled}")
print(f"Shape of final result: {result_all_gather_tiled.shape}") # Expected (4,)

# All devices should have [0., 1., 2., 3.] after all_gather.
# With out_specs=P(), the global result is this replicated array.
assert allclose(result_all_gather_tiled, data_for_all_gather)

Running all_gather_example_tiled (tiled=True):
  BEFORE all_gather: x_block = On TFRT_CPU_0 at mesh coordinates (gather_axis,) = (0,):
[0.]

On TFRT_CPU_1 at mesh coordinates (gather_axis,) = (1,):
[1.]

On TFRT_CPU_2 at mesh coordinates (gather_axis,) = (2,):
[2.]

On TFRT_CPU_3 at mesh coordinates (gather_axis,) = (3,):
[3.]
 (shape (1,))
  AFTER all_gather: gathered_data = On TFRT_CPU_0 at mesh coordinates (gather_axis,) = (0,):
[0. 1. 2. 3.]

On TFRT_CPU_1 at mesh coordinates (gather_axis,) = (1,):
[0. 1. 2. 3.]

On TFRT_CPU_2 at mesh coordinates (gather_axis,) = (2,):
[0. 1. 2. 3.]

On TFRT_CPU_3 at mesh coordinates (gather_axis,) = (3,):
[0. 1. 2. 3.]
 (shape (4,))

Final result_all_gather_tiled: [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
Shape of final result: (16,)
Running all_gather_example_tiled_replicated_out (tiled=True):
  BEFORE all_gather: x_block = On TFRT_CPU_0 at mesh coordinates (gather_axis,) = (0,):
[0.]

On TFRT_CPU_1 at mesh coordinates (gather_axis,) = (1

**VMA and `out_specs` for `all_gather`:**

The output of `jax.lax.all_gather` is generally considered **varying** over the `axis_name` it gathered along. This is due to how its gradient (which involves `psum_scatter`) is handled in JAX's automatic differentiation system.

Therefore, if you use `check_vma=True` (the default):
* If you return the direct result of `all_gather` and your `out_specs` implies replication along the gathered axis (e.g., `P()` when gathering along the only mesh axis), you might encounter a VMA error.
* A common pattern is for `out_specs` to be `P('gather_axis')`, meaning the final global array is formed by concatenating these (now identical, but still typed as varying) gathered arrays from each device. This results in a larger array where the gathered data is repeated.
* If you truly need an output that is typed as invariant, and you are sure the `all_gather` produces identical results (which it does by definition), you might consider `jax.lax.all_gather_invariant` (though its user API is not present yet) or simply structure your computation so the `all_gather` is an intermediate step and a subsequent operation (like a `psum` or using the data in a replicated way) makes the final returned value invariant.

For simplicity in examples, we often use `out_specs=P()` after an `all_gather` if the goal is just to demonstrate that each device now possesses the full data. In more complex scenarios, careful consideration of VMA and `out_specs` is needed.

In [50]:
@partial(shard_map, mesh=all_gather_mesh_1d,
         in_specs=P('gather_axis'),   # Each device gets one scalar
         out_specs=P(),
         check_rep=False)               # Output is the full gathered array, replicated
def all_gather_example_stacked(x_block):
  # x_block is a scalar on each device

  # Gather x_block from all devices along 'gather_axis'
  # tiled=False means stack along a new axis (default axis=0 for the new dim).
  # Input x_block is shape (). Mesh size along 'gather_axis' is 4.
  # Result on each device will be shape (4,1).
  gathered_data_stacked = jax.lax.all_gather(x_block, 'gather_axis', tiled=False)
  print(f"  AFTER all_gather: gathered_data = {gathered_data_stacked} (shape {gathered_data_stacked.shape})")
  # print(f"  Device {device_idx} (stacked) AFTER all_gather: gathered_data_stacked = {gathered_data_stacked}")
  # Each device now has [[0.], [1.], [2.], [3.]] if x_block was scalar,
  # or more generally, if x_block was shape (S,), result is (4, S).
  # Since x_block is scalar here, shape is (4,1).

  return gathered_data_stacked

print("Running all_gather_example_stacked (tiled=False):")
result_all_gather_stacked = all_gather_example_stacked(data_for_all_gather)

print(f"\nFinal result_all_gather_stacked: {result_all_gather_stacked}")
print(f"Shape of final result: {result_all_gather_stacked.shape}") # Expected (4,) because scalars stack into a 1D array.

print(f"Shape {data_for_all_gather.shape=}")

Running all_gather_example_stacked (tiled=False):
  AFTER all_gather: gathered_data = On TFRT_CPU_0 at mesh coordinates (gather_axis,) = (0,):
[[0.]
 [1.]
 [2.]
 [3.]]

On TFRT_CPU_1 at mesh coordinates (gather_axis,) = (1,):
[[0.]
 [1.]
 [2.]
 [3.]]

On TFRT_CPU_2 at mesh coordinates (gather_axis,) = (2,):
[[0.]
 [1.]
 [2.]
 [3.]]

On TFRT_CPU_3 at mesh coordinates (gather_axis,) = (3,):
[[0.]
 [1.]
 [2.]
 [3.]]
 (shape (4, 1))

Final result_all_gather_stacked: [[0.]
 [1.]
 [2.]
 [3.]]
Shape of final result: (4, 1)
Shape data_for_all_gather.shape=(4,)


`jax.lax.all_gather` is particularly useful in scenarios like:
* **Fully Sharded Data Parallelism (FSDP):** When model parameters are sharded across devices, but a full parameter tensor is needed for a local computation (e.g., a matrix multiplication). The sharded parameter can be all-gathered before the operation.
* When an intermediate activation is sharded, but the next layer requires access to the full activation on each device.
It allows each device to obtain a complete copy of data that was previously distributed across a mesh axis.

## Collective Operation: `jax.lax.psum_scatter`

Next is `jax.lax.psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)`.
This collective is like `psum` in that it computes a sum of `x` across devices along `axis_name`. However, instead of every device receiving the *full* sum, each device receives only a *shard* of that total sum. The `scatter_dimension` argument specifies which axis of `x` (the local array) is used for scattering the summed result.

* `tiled=False` (default): The `scatter_dimension` of the input `x` must match the size of the `axis_name` (or the relevant subgroup of devices). This dimension is effectively consumed by the scatter, and the output on each device will have one less rank than the input `x`. Each device `i` along `axis_name` gets the `i`-th slice of the sum.
* `tiled=True`: The `scatter_dimension` of input `x` must be divisible by the size of `axis_name`. The output on each device will have the same rank as `x`, but the `scatter_dimension` will be divided by the size of `axis_name`. This is often more intuitive when you think of the input `x` itself as already being a "shard" of a larger conceptual array that you want to sum and then re-shard.

In [55]:
# psum_scatter example with a 1D mesh and tiled=True
# Let's use 4 devices
psum_scatter_mesh_1d = Mesh(np.array(jax.devices()[:4]), ('scatter_axis',))

# Global data: an array of shape (4, 4).
# Each of the 4 devices will receive a (1, 4) row if in_specs is P('scatter_axis').
# For psum_scatter to scatter the result, the input x_block itself is often the full piece that will be summed.
# Let's define an input that is already sharded. Each device has a vector of 4 elements.
# The sum of these vectors will be a vector of 4 elements.
# Each device will get 1 element of this summed vector.

# Input: Each device has a vector [d, d, d, d] where d is device index.
# Device 0: [0,0,0,0], Device 1: [1,1,1,1], Device 2: [2,2,2,2], Device 3: [3,3,3,3]
# Sum across devices: [0+1+2+3, 0+1+2+3, 0+1+2+3, 0+1+2+3] = [6,6,6,6]
# psum_scatter (tiled=True, scatter_dimension=0) will give:
# Device 0 gets [6] (the first element of the sum, because its input axis 0 was size 4, output axis 0 is 4/4=1)
# Device 1 gets [6]
# Device 2 gets [6]
# Device 3 gets [6]
# This example needs careful setup of input data for tiled=True to make sense for scattering.

# Each device has a (1, 4) block. The sum over devices results in a (1, 4) array.
# This (1, 4) sum is then scattered. If scatter_dimension=1, and tiled=True,
# the output on each device is (1, 4 / num_devices).
data_for_pscatter = jnp.arange(16, dtype=jnp.float32).reshape(4, 4)
# if in_specs=P('scatter_axis'), device 0 gets [[0,1,2,3]], device 1 gets [[4,5,6,7]] etc.

@partial(shard_map, mesh=psum_scatter_mesh_1d,
         in_specs=P('scatter_axis'),    # Each device gets a (1,4) row from data_for_pscatter
         out_specs=P('scatter_axis')) # Concatenate the scattered (1,1) results from each device
def psum_scatter_example_tiled(x_block):
  # x_block is (4,) on each of the 4 devices.
  print(f"  BEFORE psum_scatter: x_block = {x_block.shape}")

  # Sum x_block across 'scatter_axis'. The total sum will be a (1,4) vector.
  # E.g., for first element: 0+4+8+12 = 24. Total sum = [24, 28, 32, 36].
  # Scatter this sum. With tiled=True and scatter_dimension=1,
  # each device gets a slice of size 4/4=1 from this sum.
  # Device 0 gets [24], Device 1 gets [28], etc.
  scattered_sum_block = jax.lax.psum_scatter(x_block, 'scatter_axis', scatter_dimension=1, tiled=True)

  print(f"  AFTER psum_scatter: scattered_sum_block = {scattered_sum_block.shape}")
  return scattered_sum_block

print("Running psum_scatter_example_tiled (tiled=True):")
result_pscatter_tiled = psum_scatter_example_tiled(data_for_pscatter)

print(f"\nFinal result_pscatter_tiled: {result_pscatter_tiled}")
# Each device produces a (1,1) block. out_specs=P('scatter_axis') concatenates them.
# So, the global output is (4,1).
print(f"Shape of final result: {result_pscatter_tiled.shape}")

# Manually calculate expected result
# Device 0 input: [[0,1,2,3]]
# Device 1 input: [[4,5,6,7]]
# Device 2 input: [[8,9,10,11]]
# Device 3 input: [[12,13,14,15]]
# Element-wise sum: [[0+4+8+12], [1+5+9+13], [2+6+10+14], [3+7+11+15]] = [[24], [28], [32], [36]]
# scattered_sum_block on device 0: [[24]] (first element of total sum because scatter_dimension=0)
# scattered_sum_block on device 1: [[28]] (second element of total sum)
# ... this interpretation is subtle with tiled=True if scatter_dimension is not what's being sharded for output.
# The doc says: psum_scatter(x, 'i', tiled=True) -> output on dev_i is sum(x_j)[i] (if sum is reshaped)
# More precisely for tiled=True: output_shape[scatter_dim] = input_shape[scatter_dim] / num_devices_in_axis
# And device k gets the k-th slice of size input_shape[scatter_dim] / num_devices_in_axis of the total sum.
# So, if total sum is S (shape (4,)), output on device k is S[k:(k+1)].
expected_pscatter_result = jnp.array([24., 28., 32., 36.], dtype=jnp.float32)
assert allclose(result_pscatter_tiled.flatten(), expected_pscatter_result)

Running psum_scatter_example_tiled (tiled=True):
  BEFORE psum_scatter: x_block = (1, 4)
  AFTER psum_scatter: scattered_sum_block = (1,)

Final result_pscatter_tiled: [24. 28. 32. 36.]
Shape of final result: (4,)


**Relationship between `psum`, `psum_scatter`, and `all_gather`**

Interestingly, a full `psum` (where every device gets the total sum) can be thought of as performing a `psum_scatter` (where each device gets a piece of the sum) followed by an `all_gather` (where each device collects all the pieces to reconstruct the full sum).

`psum(x, axis) == all_gather(psum_scatter(x, axis, tiled=False, scatter_dimension=0), axis, tiled=True)`

(The `scatter_dimension` and `tiled` arguments need to be set appropriately for the shapes to align). This decomposition is often how `psum` is implemented efficiently in practice, especially on TPUs. `psum_scatter` does about half the communication of a full `psum`.

In [69]:
@partial(shard_map, mesh=psum_scatter_mesh_1d,
         in_specs=P('scatter_axis'),
         out_specs=P(),
         check_rep=False) # For psum, the output is replicated
def psum_via_scatter_gather(x_block):
  # x_block is (4,) on each device

  # 1. First half: Reduce-Scatter. Each device gets a part of the sum.
  # Total sum is [24, 28, 32, 36].
  # scattered_piece on dev 0 is [24.], dev 1 is [28.], etc. Shape (1,) on each device.
  scattered_piece = jax.lax.psum_scatter(x_block, 'scatter_axis', scatter_dimension=1, tiled=False)

  # 2. Second half: All-Gather. Each device gathers all scattered pieces.
  # Input to all_gather is (1,) on each device.
  # Output will be (4,) on each device, containing [24., 28., 32., 36.].
  full_sum_replicated = jax.lax.all_gather(scattered_piece, 'scatter_axis', tiled=True)

  return full_sum_replicated

print("Running psum_via_scatter_gather:")
result_composed_psum = psum_via_scatter_gather(data_for_pscatter)
print(f"Result from psum_via_scatter_gather: {result_composed_psum}")
print(f"Shape: {result_composed_psum.shape}")


# Compare with direct psum
@partial(shard_map, mesh=psum_scatter_mesh_1d,
         in_specs=P('scatter_axis'),
         out_specs=P())
def direct_psum_func(x_block):
    return jax.lax.psum(x_block, 'scatter_axis')

result_direct_psum = direct_psum_func(data_for_pscatter)
print(f"\nResult from direct_psum_func: {result_direct_psum}")
print(f"Shape: {result_direct_psum.shape}")

print("\nResults match: psum is equivalent to psum_scatter + all_gather.")

Running psum_via_scatter_gather:
Result from psum_via_scatter_gather: [24. 28. 32. 36.]
Shape: (4,)

Result from direct_psum_func: [[24. 28. 32. 36.]]
Shape: (1, 4)

Results match: psum is equivalent to psum_scatter + all_gather.


## Collective Operation: `jax.lax.ppermute` (Parallel Permute)

The `jax.lax.ppermute(x, axis_name, perm=...)` collective allows for direct data exchange between devices (function instances) along a specified `axis_name`.
* `x`: The local array on each device that will be sent.
* `axis_name`: The mesh axis along which the permutation occurs.
* `perm`: A list of `(source_device_index, destination_device_index)` pairs. This defines which device sends its `x` to which other device. The indices are relative to the `axis_name`.

For example, if device `s` is a source and device `d` is its destination in a pair `(s, d)`, then device `d` will receive the `x` value that was originally on device `s`.

In [71]:
# ppermute example: cyclic shift on a 1D mesh
# Using 4 devices
ppermute_mesh_1d = Mesh(np.array(jax.devices()[:4]), ('shift_axis',))
num_devices_in_shift_axis = ppermute_mesh_1d.shape['shift_axis']

# Each device has its index as its data (as a scalar array)
data_for_ppermute = jnp.arange(num_devices_in_shift_axis, dtype=jnp.float32)
# With in_specs=P('shift_axis'), device i gets jnp.array(float(i))

# Define the permutation: device i sends to (i+1) % num_devices
cyclic_perm = [(i, (i + 1) % num_devices_in_shift_axis) for i in range(num_devices_in_shift_axis)]
print(f"Cyclic permutation defined: {cyclic_perm}")

@partial(shard_map, mesh=ppermute_mesh_1d,
         in_specs=P('shift_axis'),
         out_specs=P('shift_axis'))
def ppermute_cyclic_shift(x_block):
  print(f"  BEFORE ppermute: x_block = {x_block}")

  permuted_block = jax.lax.ppermute(x_block, axis_name='shift_axis', perm=cyclic_perm)

  print(f"  AFTER ppermute:  permuted_block = {permuted_block}")
  return permuted_block

print("\nRunning ppermute_cyclic_shift:")
result_ppermute = ppermute_cyclic_shift(data_for_ppermute)

print(f"\nFinal result_ppermute: {result_ppermute}")
# Expected: Device 0 had 0, sent to 1. Device 1 had 1, sent to 2. ... Device 3 had 3, sent to 0.
# So, device 0 receives 3. Device 1 receives 0. Device 2 receives 1. Device 3 receives 2.
# Global result (concatenated by out_specs=P('shift_axis')): [3., 0., 1., 2.]
expected_ppermute_result = jnp.array([3., 0., 1., 2.], dtype=jnp.float32)
assert allclose(result_ppermute, expected_ppermute_result)

Cyclic permutation defined: [(0, 1), (1, 2), (2, 3), (3, 0)]

Running ppermute_cyclic_shift:
  BEFORE ppermute: x_block = On TFRT_CPU_0 at mesh coordinates (shift_axis,) = (0,):
[0.]

On TFRT_CPU_1 at mesh coordinates (shift_axis,) = (1,):
[1.]

On TFRT_CPU_2 at mesh coordinates (shift_axis,) = (2,):
[2.]

On TFRT_CPU_3 at mesh coordinates (shift_axis,) = (3,):
[3.]

  AFTER ppermute:  permuted_block = On TFRT_CPU_0 at mesh coordinates (shift_axis,) = (0,):
[3.]

On TFRT_CPU_1 at mesh coordinates (shift_axis,) = (1,):
[0.]

On TFRT_CPU_2 at mesh coordinates (shift_axis,) = (2,):
[1.]

On TFRT_CPU_3 at mesh coordinates (shift_axis,) = (3,):
[2.]


Final result_ppermute: [3. 0. 1. 2.]


**Notes on `ppermute`:**
* **Uniqueness:** In the `perm` list, each device index should appear at most once as a source and at most once as a destination.
* **Unspecified Destinations:** If a device index does not appear as a `destination_index` in any pair in `perm`, the corresponding device instance will receive an array of zeros of the appropriate shape and dtype.
* `ppermute` is a fundamental building block for many complex communication patterns.

**Implementing Other Collectives with `ppermute`**

More complex collectives can often be constructed using `ppermute`. For instance, `psum_scatter` (and by extension, `psum`) can be implemented using a sequence of `ppermute` operations that pass data between neighboring devices in a ring-like fashion, with local additions at each step.

Imagine a "bucket brigade" or a ring reduction:
1. Each device starts with its local value (or a portion of it if the data is further broken down).
2. In each step `k`:
   a. Devices send their current accumulated value to their neighbor (e.g., device `i` sends to `(i-1) % N`).
   b. Devices receive a value from their other neighbor and add it to their next piece of local data (or their ongoing accumulation).
3. After `N-1` steps (for `N` devices), each device `i` can end up with the `i`-th component of the total scattered sum.

This illustrates how point-to-point communication (`ppermute`) can build up global reductions.

## Collective Operation: `jax.lax.all_to_all`

The `jax.lax.all_to_all(x, axis_name, split_axis, concat_axis, tiled=...)` collective is used for more complex data redistributions. It essentially performs a "block transpose" of data across devices and within the local data arrays.

Here's how it works conceptually:
1. Each local array `x` on a device is split into blocks along its `split_axis` (a regular data dimension). The number of blocks is equal to the size of the mesh `axis_name`.
2. These blocks are then exchanged between devices: device `i` sends its `j`-th block to device `j`, and device `j` receives the `i`-th block from device `i` (this is a conceptual simplification; the actual exchange pattern ensures all data is redistributed).
3. On each receiving device, the collected blocks are concatenated along the `concat_axis` (another regular data dimension) to form the output array for that device.

In [85]:
# all_to_all example with a 1D mesh and tiled=True
all_to_all_mesh = Mesh(np.array(jax.devices()[:4]), ('device_group',))
mesh_size = all_to_all_mesh.size

data_for_all_to_all = jnp.arange(4. * 4.).astype(jnp.float32)

@partial(shard_map, mesh=all_to_all_mesh,
         in_specs=P('device_group'),
         out_specs=P('device_group'))
def all_to_all_example(x_block):
  print(f"  BEFORE all_to_all: x_block = \n{x_block}")
  y_block = jax.lax.all_to_all(x_block, 'device_group', split_axis=0, concat_axis=0, tiled=True)
  print(f"  AFTER all_to_all:  y_block = \n{y_block}")
  return y_block

print("Running all_to_all_example (tiled=True):")
result_all_to_all = all_to_all_example(data_for_all_to_all)

print(f"\nFinal global result_all_to_all:\n{result_all_to_all}")

Running all_to_all_example (tiled=True):
  BEFORE all_to_all: x_block = 
On TFRT_CPU_0 at mesh coordinates (device_group,) = (0,):
[0. 1. 2. 3.]

On TFRT_CPU_1 at mesh coordinates (device_group,) = (1,):
[4. 5. 6. 7.]

On TFRT_CPU_2 at mesh coordinates (device_group,) = (2,):
[ 8.  9. 10. 11.]

On TFRT_CPU_3 at mesh coordinates (device_group,) = (3,):
[12. 13. 14. 15.]

  AFTER all_to_all:  y_block = 
On TFRT_CPU_0 at mesh coordinates (device_group,) = (0,):
[ 0.  4.  8. 12.]

On TFRT_CPU_1 at mesh coordinates (device_group,) = (1,):
[ 1.  5.  9. 13.]

On TFRT_CPU_2 at mesh coordinates (device_group,) = (2,):
[ 2.  6. 10. 14.]

On TFRT_CPU_3 at mesh coordinates (device_group,) = (3,):
[ 3.  7. 11. 15.]


Final global result_all_to_all:
[ 0.  4.  8. 12.  1.  5.  9. 13.  2.  6. 10. 14.  3.  7. 11. 15.]


**The `tiled` argument in `all_to_all`:**

* **`tiled=True`**: This is often the more intuitive mode for block transposes.
    * The size of the `split_axis` in the local input `x` must be evenly divisible by the size of the mesh `axis_name`.
    * The output local array on each device will have the same rank and shape as the input local array `x`, but with data permuted. The `split_axis` is divided into chunks, distributed, and then these chunks are concatenated back along `concat_axis`. If `split_axis == concat_axis`, the dimension size effectively remains the same but is now composed of data from different original devices.

* **`tiled=False`** (default):
    * The size of the `split_axis` in the local input `x` must be equal to the size of the mesh `axis_name`.
    * A *new* axis (whose size is the mesh `axis_name` size) is created at the position specified by `concat_axis` in the output. The original `split_axis` is removed. This changes the rank of the local array.

For many common use cases like transposing blocks of data, `tiled=True` is used.

# Conclusion: The Power of `shard_map`

This tutorial has introduced you to `jax.shard_map`, a powerful tool for explicit, manual control over multi-device parallelism in JAX. We've covered:
* The concept of a `Mesh` and `PartitionSpec` for defining data and computation layout.
* How `in_specs` and `out_specs` control data sharding and assembly.
* The rank-preserving nature of `shard_map` compared to `vmap`.
* Varying Manual Axes (VMA) for type checking and `jax.lax.pvary` for managing variance.
* Fundamental collective operations: `psum`, `all_gather`, `psum_scatter`, `ppermute`, and `all_to_all`.

`shard_map` allows you to implement sophisticated parallel algorithms by giving you direct control over how data is split, where computation happens, and how devices communicate. It composes seamlessly with `jax.jit` for compilation and `jax.grad` for automatic differentiation.

While strategies like tensor parallelism and pipeline parallelism for neural networks are more complex (as seen in the full JAX documentation), the building blocks you've learned here are fundamental to them.

Continue exploring the JAX documentation to discover even more advanced patterns and unlock the full potential of your multi-device hardware! Happy sharding!