<a href="https://colab.research.google.com/github/mridul-sahu/jax-sharding-tutorials/blob/main/Mastering_JAX_Shardings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip3 install --upgrade jax



# [Series 1, Module: Mastering JAX Shardings] - The Aurora Project 🌌

## I. Module Introduction: From Blueprints to Advanced Distributed Systems

Welcome back, Aurora Architect! In our previous briefings, we've laid the groundwork: understanding JAX data primitives (Chapter 1.1), exploring initial parallelism with `jax.pmap` (Chapter 1.2), and crucially, learning to define structured device topologies with `jax.sharding.Mesh` and `AxisType` (Chapter 1.3). We've particularly highlighted the importance of `AxisType.Explicit` for enabling JAX's modern "sharding in types" paradigm, even though `AxisType.Auto` is the default for mesh axes if unspecified.

This comprehensive module takes you deep into the world of **explicit sharding**. We will cover:
* **Crafting the Blueprints:** Defining precisely how arrays should be distributed using `jax.sharding.PartitionSpec`.
* **Bringing Blueprints to Life:** Using `jax.sharding.NamedSharding` and `jax.device_put` to create `jax.Array`s that are physically sharded according to your specifications, and understanding how "sharding in types" makes this distribution a queryable part of the array's JAX-level type (when using `Explicit` mesh axes).
* **Understanding the Flow:** How these explicit shardings propagate through JAX operations and how the GSPMD compiler interacts with them.
* **Alternative Control:** Exploring `jax.shard_map` for per-device SPMD programming.
* **Advanced Techniques:** Delving into mixed sharding modes, and designing layouts for patterns like FSDP and tensor parallelism.

By the end of this module, you will have a robust toolkit to design, implement, and analyze sophisticated distributed systems for Project Aurora's most demanding AI models.

## II. Overall Module Objectives & Outline

### Overall Learning Objectives for this Module:

* Master the syntax, semantics, and application of `jax.sharding.PartitionSpec` (`P`) to define logical sharding blueprints.
* Create and use `jax.sharding.NamedSharding` objects with `jax.device_put` (and `jax.experimental.shard.reshard`) to produce explicitly sharded `jax.Array`s.
* Clearly distinguish between concrete array sharding (`array.sharding`) and JAX type-level sharding (`jax.typeof(array).sharding`), understanding the role of `Mesh` `AxisType.Explicit` in the latter.
* Inspect and visualize array sharding distributions.
* Describe how type-specified shardings propagate through JAX operations and how `jax.jit`/ XLA/ GSPMD utilize this.
* Utilize `out_sharding` to resolve propagation ambiguities and `jax.lax.with_sharding_constraint` to guide compiler choices.
* Implement SPMD-style parallelism using `jax.shard_map`.
* Employ `jax.experimental.shard.auto_axes` and `jax.experimental.shard.explicit_axes` to mix sharding modes.
* Design `PartitionSpec` configurations for advanced strategies like FSDP and tensor parallelism.
* Troubleshoot common errors related to sharding definitions.

### Module Outline:

**Part 1: Defining and Applying Sharding Blueprints**
* 1.A: `jax.sharding.PartitionSpec` (`P`) - The Logical Blueprint
* 1.B: `jax.sharding.NamedSharding` - Binding Blueprints to Meshes
* 1.C: Creating Explicitly Sharded Arrays - The "Sharding in Types" Core
* 1.D: Introspecting Sharded Arrays
* 1.E: Visualizing Sharding

**Part 2: Sharding Propagation and Compiler Interaction**
* 2.A: Explicit Sharding Propagation at JAX-Level (Trace Time)
* 2.B: `out_sharding` for Ambiguity Resolution
* 2.C: Interaction with JAX Transformations
* 2.D: GSPMD Compiler Integration & Automatic Resharding
* 2.E: `jax.lax.with_sharding_constraint`

**Part 3: `shard_map` - Explicit Per-Device Programming**
* 3.A: `shard_map` vs. `jit` with Global View
* 3.B: SPMD Programming with `shard_map`
* 3.C: `in_specs` and `out_specs` with `PartitionSpec`
* 3.D: Manual Collective Invocation in `shard_map`

**Part 4: Advanced Sharding Techniques & Mixed Modes**
* 4.A: Mixing Sharding Modes (`Mesh` with mixed `AxisType`s)
* 4.B: `jax.experimental.shard.auto_axes`
* 4.C: `jax.experimental.shard.explicit_axes`
* 4.D: Concrete Array Sharding (`x.sharding`) vs. Type-Specified Sharding (`jax.typeof(x).sharding`) - Revisited
* 4.E: Applied Advanced Patterns (FSDP, Tensor Parallelism with `NamedSharding`)

## III. Prerequisites & Setup

This module assumes familiarity with:
* JAX fundamentals (Chapter 1.1).
* `jax.pmap` (Chapter 1.2, for context on earlier sharding).
* **Crucially, `jax.sharding.Mesh` and `AxisType` (Chapter 1.3).** We will be building directly upon the concept of a device mesh with named axes.

Let's set up our environment.

In [3]:
import jax
import jax.numpy as jnp
import numpy as np
import os
from functools import partial # For shard_map examples later

# JAX sharding APIs
from jax.sharding import Mesh, PartitionSpec, NamedSharding, AxisType
from jax.experimental.shard import reshard, auto_axes, explicit_axes # For later sections

# Visualization
from jax.debug import visualize_array_sharding

# For consistent pretty printing of JAX types (optional)
# from jax.interpreters.pretty_printing import stringify_dtype
# jax.config.update("jax_pprint_dtype_in_shape", False) # Example preference

# --- JAX Device Setup ---
# Aim for a consistent number of devices (e.g., 8 for 2x4 mesh examples).
# Users on Colab CPU need to set XLA_FLAGS and restart session if they haven't already.
desired_num_devices = 8
os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={desired_num_devices}'

physical_devices_list = jax.devices()
num_physical_devices_available = len(physical_devices_list)

print(f"JAX version: {jax.__version__}")
print(f"JAX default backend: {jax.default_backend()}")
print(f"Number of local JAX devices available: {num_physical_devices_available}")

# --- Create a Reusable Mesh for Examples ---
# We will create a 2D mesh. For "sharding in types" to be fully visible via jax.typeof(),
# the mesh axes used in PartitionSpec should be AxisType.Explicit.
# If axis_types is omitted, it defaults to Auto for all axes.

mesh_axis_names = ('dp', 'mp') # For Data Parallel and Model Parallel dimensions

mp_dim_size = 2 # Try to have a model parallel dimension of size 2
dp_dim_size = num_physical_devices_available // mp_dim_size

mesh_shape = (dp_dim_size, mp_dim_size)
devices_for_mesh = np.array(physical_devices_list[:np.prod(mesh_shape)]).reshape(mesh_shape)

# Mesh with EXPLICIT axes for "sharding-in-types" demonstrations
aurora_explicit_mesh = Mesh(devices_for_mesh,
                            mesh_axis_names,
                            axis_types=(AxisType.Explicit, AxisType.Explicit))

# Mesh with AUTO axes (default if axis_types is omitted) for demonstrating concrete sharding
aurora_auto_mesh = Mesh(devices_for_mesh, mesh_axis_names)

print(f"\nCreated Aurora Explicit Mesh ('aurora_explicit_mesh'):")
print(f"  Shape: {aurora_explicit_mesh.shape}, Names: {aurora_explicit_mesh.axis_names}, Types: {aurora_explicit_mesh.axis_types}")
print(f"  Device IDs:\n{aurora_explicit_mesh.device_ids}")

print(f"\nCreated Aurora Auto Mesh ('aurora_auto_mesh'):")
print(f"  Shape: {aurora_auto_mesh.shape}, Names: {aurora_auto_mesh.axis_names}, Types: {aurora_auto_mesh.axis_types}")

# Alias PartitionSpec for convenience
P = PartitionSpec



JAX version: 0.6.0
JAX default backend: cpu
Number of local JAX devices available: 8

Created Aurora Explicit Mesh ('aurora_explicit_mesh'):
  Shape: OrderedDict([('dp', 4), ('mp', 2)]), Names: ('dp', 'mp'), Types: (Explicit, Explicit)
  Device IDs:
[[0 1]
 [2 3]
 [4 5]
 [6 7]]

Created Aurora Auto Mesh ('aurora_auto_mesh'):
  Shape: OrderedDict([('dp', 4), ('mp', 2)]), Names: ('dp', 'mp'), Types: (Auto, Auto)


## Part 1: Defining and Applying Sharding Blueprints

This part focuses on how to describe your desired sharding layout using `PartitionSpec`, how to combine it with a `Mesh` into a `NamedSharding` object, and then how to use this to create `jax.Array`s that are physically distributed. We'll pay close attention to the "sharding in types" concept.

### 1.A: `jax.sharding.PartitionSpec` (`P`) - The Logical Blueprint

`jax.sharding.PartitionSpec` (aliased as `P`) is the primary tool for specifying the logical layout of an N-dimensional array across an M-dimensional `Mesh`.

**Structure:**
A `PartitionSpec` is a tuple, where `len(P_tuple)` must equal the rank (`ndim`) of the array it describes. Each element of the tuple corresponds to one dimension of the array, from axis 0 to axis N-1.

**Possible values for each element in the `PartitionSpec` tuple:**

1.  **`None`**:
    * Indicates that the corresponding array dimension is **replicated** across all devices mapped by other parts of the `PartitionSpec`. It is not sharded along any mesh axis.
    * Example: `P(None, 'dp')` for a 2D array means axis 0 is replicated, axis 1 is sharded along mesh axis `'dp'`.

2.  **`'mesh_axis_name'` (string)**:
    * Indicates that the corresponding array dimension is **sharded** (split) across the devices along the named axis of the `Mesh`. The size of this array dimension must be divisible by the size of the mesh axis it's mapped to.
    * Example: `P('dp', 'mp')` for a 2D array means array axis 0 is sharded over mesh axis `'dp'`, and array axis 1 is sharded over mesh axis `'mp'`.

3.  **Tuple of mesh axis names `('axisA', 'axisB', ...)`**:
    * Indicates that the corresponding array dimension is sharded across **multiple mesh axes simultaneously**. The array dimension is effectively "flattened" onto these mesh axes. The size of the array dimension must be divisible by the product of the sizes of the specified mesh axes.
    * Example: `P(('dp', 'mp'),)` for a 1D array on a 2D mesh (`'dp'`, `'mp'`) shards the array's single dimension across all devices in the mesh.
    * This is a more advanced use case, allowing, for instance, a 2D array to be fully sharded over a 3D or 4D mesh by mapping some array axes to tuples of mesh axes.

**Constraints:**
* A single mesh axis name cannot be used to shard more than one *different* tensor dimension within the same `PartitionSpec`. For example, `P('dp', 'dp')` is invalid if `'dp'` is a mesh axis of size > 1.
* The total number of devices implied by the sharding of an array dimension (i.e., the product of sizes of mesh axes it's sharded over) must evenly divide the size of that array dimension.

Let's see some examples using our `aurora_explicit_mesh` (or `aurora_auto_mesh`). The `PartitionSpec` itself is just a declaration; its effect is seen when used with `NamedSharding` and `jax.device_put`.

In [4]:
print(f"Using mesh: {aurora_explicit_mesh.axis_names} with shape {aurora_explicit_mesh.shape}")

# Examples for a 2D array, e.g., shape (dim0, dim1)
# 1. Fully Replicated on the mesh
spec_repl_2d = P(None, None)
print(f"\n1. Fully Replicated 2D: P{spec_repl_2d}")

# 2. Shard dim0 by 'dp', replicate dim1
spec_shard_dim0_on_dp = P('dp', None)
print(f"2. Shard dim0 on 'dp', replicate dim1: P{spec_shard_dim0_on_dp}")

# 3. Replicate dim0, shard dim1 by 'mp'
spec_shard_dim1_on_mp = P(None, 'mp')
print(f"3. Replicate dim0, shard dim1 on 'mp': P{spec_shard_dim1_on_mp}")

# 4. Shard dim0 by 'dp', shard dim1 by 'mp' (fully sharded over 2D mesh)
spec_fully_sharded_2d = P('dp', 'mp')
print(f"4. Shard dim0 on 'dp', dim1 on 'mp': P{spec_fully_sharded_2d}")

# Examples for a 1D array, e.g., shape (dim0)
# 5. Replicated 1D array
spec_repl_1d = P(None,) # Trailing comma for single element tuple
print(f"\n5. Fully Replicated 1D: P{spec_repl_1d}")

# 6. Shard 1D array along 'dp'
spec_shard_1d_on_dp = P('dp',)
print(f"6. Shard 1D on 'dp': P{spec_shard_1d_on_dp}")

# 7. Shard 1D array along all axes of the 2D mesh (flatten mesh)
spec_shard_1d_on_dp_mp = P(('dp','mp'),)
print(f"7. Shard 1D on ('dp', 'mp') combined: P{spec_shard_1d_on_dp_mp}")

# Example for a 4D array (e.g., CNN weights: [output_channels, input_channels, height, width])
# Shard output_channels by 'mp', input_channels by 'dp', replicate spatial dims
spec_cnn_weights = P('mp', 'dp', None, None)
print(f"\n8. CNN Weights [O,I,H,W] as P('mp', 'dp', None, None): P{spec_cnn_weights}")

Using mesh: ('dp', 'mp') with shape OrderedDict([('dp', 4), ('mp', 2)])

1. Fully Replicated 2D: PPartitionSpec(None, None)
2. Shard dim0 on 'dp', replicate dim1: PPartitionSpec('dp', None)
3. Replicate dim0, shard dim1 on 'mp': PPartitionSpec(None, 'mp')
4. Shard dim0 on 'dp', dim1 on 'mp': PPartitionSpec('dp', 'mp')

5. Fully Replicated 1D: PPartitionSpec(None,)
6. Shard 1D on 'dp': PPartitionSpec('dp',)
7. Shard 1D on ('dp', 'mp') combined: PPartitionSpec(('dp', 'mp'),)

8. CNN Weights [O,I,H,W] as P('mp', 'dp', None, None): PPartitionSpec('mp', 'dp', None, None)


### 1.B: `jax.sharding.NamedSharding` - Binding Blueprints to Meshes

A `PartitionSpec` is a logical blueprint. To make it concrete and usable for sharding an array, it must be combined with a specific `jax.sharding.Mesh`. This combination is encapsulated by the `jax.sharding.NamedSharding` object.

`NamedSharding(mesh, partition_spec)`

* `mesh`: An instance of `jax.sharding.Mesh`.
* `partition_spec`: An instance of `jax.sharding.PartitionSpec`.

The `NamedSharding` object holds the complete instruction for how an array of a compatible rank should be distributed across the given mesh.

In [5]:
# Create a NamedSharding object for a 2D array, fully sharded on aurora_explicit_mesh
named_sharding_2d_explicit = NamedSharding(aurora_explicit_mesh, spec_fully_sharded_2d)
print(f"--- NamedSharding with Explicit Mesh ---")
print(f"NamedSharding object: {named_sharding_2d_explicit}")
print(f"  Mesh used: {named_sharding_2d_explicit.mesh.axis_names}, Types: {named_sharding_2d_explicit.mesh.axis_types}")
print(f"  PartitionSpec used: {named_sharding_2d_explicit.spec}")

# Example with aurora_auto_mesh (defaulting to Auto axes)
named_sharding_2d_auto = NamedSharding(aurora_auto_mesh, spec_fully_sharded_2d)
print(f"\n--- NamedSharding with Auto Mesh ---")
print(f"NamedSharding object: {named_sharding_2d_auto}")
print(f"  Mesh used: {named_sharding_2d_auto.mesh.axis_names}, Types: {named_sharding_2d_auto.mesh.axis_types}")
print(f"  PartitionSpec used: {named_sharding_2d_auto.spec}")


--- NamedSharding with Explicit Mesh ---
NamedSharding object: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec('dp', 'mp'), memory_kind=unpinned_host)
  Mesh used: ('dp', 'mp'), Types: (Explicit, Explicit)
  PartitionSpec used: PartitionSpec('dp', 'mp')

--- NamedSharding with Auto Mesh ---
NamedSharding object: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('dp', 'mp'), memory_kind=unpinned_host)
  Mesh used: ('dp', 'mp'), Types: (Auto, Auto)
  PartitionSpec used: PartitionSpec('dp', 'mp')


### 1.C: Creating Explicitly Sharded Arrays - The "Sharding in Types" Core

With a `NamedSharding` object, we can now create a `jax.Array` that is physically distributed according to our blueprint. The primary way to do this is with `jax.device_put()`:

`sharded_array = jax.device_put(numpy_array_or_python_list_or_jax_array, named_sharding_object)`

This operation takes the input data, transfers it to the devices specified by the `NamedSharding` object's mesh, and arranges it according to the `PartitionSpec`. The result is a `jax.Array` whose `sharding` attribute *is* that `NamedSharding` object.

**The "Sharding in Types" Connection:**

* **Concrete Sharding (`array.sharding`):** When you use `jax.device_put(data, named_sharding)`, the resulting `jax.Array` *is* concretely sharded according to your `named_sharding`. You can inspect this via `array.sharding`. This works whether the mesh axes in `named_sharding.mesh` are `AxisType.Auto` or `AxisType.Explicit`.
* **Type-Level Sharding (`jax.typeof(array).sharding`):** For the sharding to become part of the array's **JAX-level type** (visible as annotations like `@dp` in the output of `jax.typeof()`), the mesh axes referenced in the `PartitionSpec` (and thus in the `NamedSharding`) **must be `AxisType.Explicit`**. If they are `Auto` (the default if `axis_types` is not specified when creating the mesh), `jax.typeof(array).sharding` will not show sharding along those `Auto` axes, even if the array is concretely sharded.

This distinction is vital. "Sharding in types" enables JAX's trace-time sharding propagation and more predictable distributed behavior.

**Alternative: `jax.experimental.shard.reshard`**
`reshard(array, partition_spec)` can be used to change the sharding of an existing array or to shard a NumPy array. It typically requires an active mesh context (set via `with mesh:` or `jax.sharding.set_mesh(mesh)`), and the `partition_spec` refers to this context mesh. The mesh should have `Explicit` axes for type-level sharding.

In [6]:
# Create some data
data_np = np.arange(32, dtype=np.float32).reshape(8, 4) # e.g., (batch, features)
# For a 2D mesh (dp_dim_size, mp_dim_size), let's assume our example mesh is (4,2) for an 8-device setup
# So, data_np could be (dp_dim_size * 2, mp_dim_size * 2) = (8,4) for demonstration
# If your mesh is, say, (8,1), then data_np (8,4) can be sharded as P('dp', None)


# For data_np (8,4) and mesh (4,2):
# dim0 sharded by 'dp' (size 4) -> 8/4 = 2 per device-slice
# dim1 sharded by 'mp' (size 2) -> 4/2 = 2 per device-slice
# Each device gets a (2,2) shard.

print(f"--- Creating Sharded Arrays ---")
print(f"Original NumPy data shape: {data_np.shape}")
print(f"Using Explicit Mesh: {aurora_explicit_mesh.axis_names} with shape {aurora_explicit_mesh.shape}")
print(f"Using Auto Mesh: {aurora_auto_mesh.axis_names} with shape {aurora_auto_mesh.shape}")

# Define a PartitionSpec - let's fully shard this 2D array
# Adjust P if data_np or mesh shapes are different from (8,4) and (4,2) respectively
# This P assumes data_np.shape[0] is sharded by 'dp' and data_np.shape[1] by 'mp'
spec_for_data = P(mesh_axis_names[0], mesh_axis_names[1]) # P('dp', 'mp')

# 1. Using Mesh with EXPLICIT axes
named_sharding_explicit = NamedSharding(aurora_explicit_mesh, spec_for_data)
arr_sharded_explicit = jax.device_put(data_np, named_sharding_explicit)

print(f"\n1. Array sharded with Explicit Mesh:")
print(f"   arr_sharded_explicit.sharding: {arr_sharded_explicit.sharding}")
# This will show sharding annotations like @dp, @mp because axes are Explicit
print(f"   jax.typeof(arr_sharded_explicit): {jax.typeof(arr_sharded_explicit)}")
print(f"   Device buffers exist on devices: {arr_sharded_explicit.devices()}")


# 2. Using Mesh with AUTO axes
named_sharding_auto = NamedSharding(aurora_auto_mesh, spec_for_data)
arr_sharded_auto = jax.device_put(data_np, named_sharding_auto)
print(f"\n2. Array sharded with Auto Mesh:")
print(f"   arr_sharded_auto.sharding: {arr_sharded_auto.sharding}")
# This will NOT show @dp, @mp in the type, as axes are Auto, even though concretely sharded.
print(f"   jax.typeof(arr_sharded_auto): {jax.typeof(arr_sharded_auto)}")
print(f"   Device buffers exist on devices: {arr_sharded_auto.devices()}")

# 3. Example with reshard (requires an active mesh context)
print(f"\n3. Array sharded with reshard (using Explicit Mesh):")
# Activate mesh context for reshard
with jax.sharding.use_mesh(aurora_explicit_mesh):
  arr_resharded = reshard(data_np, spec_for_data)
print(f"   arr_resharded.sharding: {arr_resharded.sharding}")
print(f"   jax.typeof(arr_resharded): {jax.typeof(arr_resharded)}")

--- Creating Sharded Arrays ---
Original NumPy data shape: (8, 4)
Using Explicit Mesh: ('dp', 'mp') with shape OrderedDict([('dp', 4), ('mp', 2)])
Using Auto Mesh: ('dp', 'mp') with shape OrderedDict([('dp', 4), ('mp', 2)])

1. Array sharded with Explicit Mesh:
   arr_sharded_explicit.sharding: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec('dp', 'mp'), memory_kind=unpinned_host)
   jax.typeof(arr_sharded_explicit): ShapedArray(float32[8@dp,4@mp])
   Device buffers exist on devices: {CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=1), CpuDevice(id=4), CpuDevice(id=2), CpuDevice(id=7), CpuDevice(id=3), CpuDevice(id=0)}

2. Array sharded with Auto Mesh:
   arr_sharded_auto.sharding: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('dp', 'mp'), memory_kind=unpinned_host)
   jax.typeof(arr_sharded_auto): ShapedArray(float32[8,4])
   Device buffers exist on devices: {CpuDevice(id=5), CpuDevice(id=6), CpuDev

### 1.D: Introspecting Sharded Arrays

Once you have a `jax.Array`, whether created via `jax.device_put` with a `NamedSharding` or from a JIT-compiled function, you can inspect its sharding properties:

* **`array.sharding`**: This attribute directly gives you the `Sharding` object (e.g., `NamedSharding`) associated with the array. This describes the *concrete, physical layout* of the array across devices. It includes the mesh and the `PartitionSpec`. This information is always available for sharded arrays.
    * `array.sharding.mesh`: The `Mesh` object used.
    * `array.sharding.spec`: The `PartitionSpec` used.
* **`jax.typeof(array)`**: This function returns the JAX-level type of the array.
    * If the array was sharded using `NamedSharding` where the referenced mesh axes are `AxisType.Explicit`, the returned type string will include sharding annotations like `float32[dim0@mesh_axisA, dim1@mesh_axisB]`. This is "sharding in types."
    * If the array was sharded using `NamedSharding` but the referenced mesh axes are `AxisType.Auto` (the default if `axis_types` is omitted), `jax.typeof(array)` will show the shape and dtype but will *not* include the `@mesh_axisA` annotations for those `Auto` axes. The sharding is concrete but not part of the JAX-level static type for those axes.
* **`array.devices()`**: Returns the set of devices on which this array has data.

Let's see this with the arrays we created.

In [7]:
print(f"--- Introspection of 'arr_sharded_explicit' (Mesh with Explicit Axes) ---")
print(f"1. Concrete Sharding object: arr_sharded_explicit.sharding")
print(f"   Type: {type(arr_sharded_explicit.sharding)}")
print(f"   Mesh used: {arr_sharded_explicit.sharding.mesh.axis_names}")
print(f"   PartitionSpec used: {arr_sharded_explicit.sharding.spec}")

print(f"\n2. JAX-level type (sharding in type): jax.typeof(arr_sharded_explicit)")
type_explicit = jax.typeof(arr_sharded_explicit)
print(f"   Type string: {type_explicit}")
# For Explicit axes, type_explicit.sharding should also reflect the PartitionSpec
if hasattr(type_explicit, 'sharding'):
    print(f"   jax.typeof().sharding.spec: {type_explicit.sharding.spec}")
else:
    print(f"   jax.typeof().sharding attribute not found (unexpected for explicit sharding).")


print(f"\n3. Devices hosting data: arr_sharded_explicit.devices()")
print(f"   {arr_sharded_explicit.devices()}")


print(f"\n--- Introspection of 'arr_sharded_auto' (Mesh with Auto Axes) ---")
print(f"1. Concrete Sharding object: arr_sharded_auto.sharding")
print(f"   Type: {type(arr_sharded_auto.sharding)}")
print(f"   Mesh used: {arr_sharded_auto.sharding.mesh.axis_names}")
print(f"   PartitionSpec used: {arr_sharded_auto.sharding.spec}")

print(f"\n2. JAX-level type: jax.typeof(arr_sharded_auto)")
type_auto = jax.typeof(arr_sharded_auto)
print(f"   Type string: {type_auto}") # Expect no @dp, @mp annotations here
# For Auto axes, type_auto.sharding might be trivial or None for those axes
if hasattr(type_auto, 'sharding'):
    print(f"   jax.typeof().sharding.spec: {type_auto.sharding.spec}") # Likely P(None,None) or similar
else:
    print(f"   jax.typeof().sharding attribute not found.")


print(f"\n3. Devices hosting data: arr_sharded_auto.devices()")
print(f"   {arr_sharded_auto.devices()}")

--- Introspection of 'arr_sharded_explicit' (Mesh with Explicit Axes) ---
1. Concrete Sharding object: arr_sharded_explicit.sharding
   Type: <class 'jaxlib.xla_extension.NamedSharding'>
   Mesh used: ('dp', 'mp')
   PartitionSpec used: PartitionSpec('dp', 'mp')

2. JAX-level type (sharding in type): jax.typeof(arr_sharded_explicit)
   Type string: ShapedArray(float32[8@dp,4@mp])
   jax.typeof().sharding.spec: PartitionSpec('dp', 'mp')

3. Devices hosting data: arr_sharded_explicit.devices()
   {CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=1), CpuDevice(id=4), CpuDevice(id=2), CpuDevice(id=7), CpuDevice(id=3), CpuDevice(id=0)}

--- Introspection of 'arr_sharded_auto' (Mesh with Auto Axes) ---
1. Concrete Sharding object: arr_sharded_auto.sharding
   Type: <class 'jaxlib.xla_extension.NamedSharding'>
   Mesh used: ('dp', 'mp')
   PartitionSpec used: PartitionSpec('dp', 'mp')

2. JAX-level type: jax.typeof(arr_sharded_auto)
   Type string: ShapedArray(float32[8,4])
   jax.typeof().shar

The difference in `jax.typeof()` output when using `Explicit` vs. `Auto` mesh axes is the core of "sharding in types." With `Explicit` axes, the sharding becomes a static, checkable part of the type system that JAX uses at trace time.

### 1.E: Visualizing Sharding

For a more intuitive understanding of how your data is physically laid out across the devices in your mesh, `jax.debug.visualize_array_sharding()` is an invaluable tool. It provides a textual representation.

In [8]:

print(f"--- Visualizing Sharding for 'arr_sharded_explicit' ---")
print(f"Data shape: {arr_sharded_explicit.shape}")
print(f"Mesh shape: {arr_sharded_explicit.sharding.mesh.shape}, Axis names: {arr_sharded_explicit.sharding.mesh.axis_names}")
print(f"PartitionSpec: {arr_sharded_explicit.sharding.spec}\n")
visualize_array_sharding(arr_sharded_explicit)

print(f"\n--- Visualizing Sharding for 'arr_sharded_auto' ---")
# Visualization will look identical to arr_sharded_explicit if sharding spec and mesh shape are the same,
# as it shows the CONCRETE layout.
print(f"Data shape: {arr_sharded_auto.shape}")
print(f"Mesh shape: {arr_sharded_auto.sharding.mesh.shape}, Axis names: {arr_sharded_auto.sharding.mesh.axis_names}")
print(f"PartitionSpec: {arr_sharded_auto.sharding.spec}\n")
visualize_array_sharding(arr_sharded_auto)

spec_replicated_on_explicit_mesh = P(None, None)
named_sharding_repl_explicit = NamedSharding(aurora_explicit_mesh, spec_replicated_on_explicit_mesh)
arr_replicated_explicit = jax.device_put(data_np, named_sharding_repl_explicit)

print(f"\n--- Visualizing Replicated Array on Explicit Mesh ---")
print(f"Data shape: {arr_replicated_explicit.shape}")
print(f"Mesh shape: {aurora_explicit_mesh.shape}, Axis names: {aurora_explicit_mesh.axis_names}")
print(f"PartitionSpec: {spec_replicated_on_explicit_mesh}\n")
visualize_array_sharding(arr_replicated_explicit)
print(f"jax.typeof(arr_replicated_explicit): {jax.typeof(arr_replicated_explicit)}") # Will show no @ annotations

--- Visualizing Sharding for 'arr_sharded_explicit' ---
Data shape: (8, 4)
Mesh shape: OrderedDict([('dp', 4), ('mp', 2)]), Axis names: ('dp', 'mp')
PartitionSpec: PartitionSpec('dp', 'mp')




--- Visualizing Sharding for 'arr_sharded_auto' ---
Data shape: (8, 4)
Mesh shape: OrderedDict([('dp', 4), ('mp', 2)]), Axis names: ('dp', 'mp')
PartitionSpec: PartitionSpec('dp', 'mp')




--- Visualizing Replicated Array on Explicit Mesh ---
Data shape: (8, 4)
Mesh shape: OrderedDict([('dp', 4), ('mp', 2)]), Axis names: ('dp', 'mp')
PartitionSpec: PartitionSpec(None, None)



jax.typeof(arr_replicated_explicit): ShapedArray(float32[8,4])


This visualization helps confirm that your `PartitionSpec` and `NamedSharding` are achieving the physical data layout you intended for Project Aurora's distributed arrays.

## Part 2: Sharding Propagation and Compiler Interaction

Now that we can create explicitly sharded arrays where sharding is part of the JAX-level type (when using `Mesh` with `AxisType.Explicit`), we need to understand how these shardings propagate through JAX operations and how the JAX Just-In-Time (JIT) compiler, specifically its GSPMD (General SPMD) partitioner, interacts with them.

### 2.A: Explicit Sharding Propagation at JAX-Level (Trace Time)

When using a `Mesh` with `AxisType.Explicit` axes, the sharding of `jax.Array`s (as seen by `jax.typeof()`) is considered during JAX's tracing phase (e.g., when a function is JIT-compiled for the first time). JAX operations have defined "sharding rules" that determine the sharding of their output(s) based on the type-level shardings of their input(s).

**Key Principles:**
* **Deterministic:** This propagation is designed to be deterministic.
* **Trace Time:** It happens at the JAX level during tracing, *before* XLA compilation.
* **Queryable:** Because it's part of the type, you can `print(jax.typeof(intermediate_array))` inside a JITted function to see how shardings evolve.

**General Sharding Rules for Common Operations:**

1.  **Nullary Ops (e.g., `jnp.zeros`, `jnp.arange`):**
    * These create new arrays without sharded inputs to propagate from.
    * By default, their output is **unsharded** (replicated across the entire mesh if used in a sharded context, or placed on the default device).
    * You can override this using an `out_sharding` argument specific to the operation (if available/implemented) or by applying `jax.device_put` with a `NamedSharding` to the result.

2.  **Unary Elementwise Ops (e.g., `jnp.sin`, `jnp.exp`, `jnp.log`):**
    * The output array typically **inherits the exact sharding** of the input array.
    * Example: If `x` is `f32[N@dp]`, then `jnp.sin(x)` will also be `f32[N@dp]`.

3.  **Binary Elementwise Ops (e.g., `jnp.add`, `jnp.multiply`):**
    * Input arrays must have compatible shapes for broadcasting.
    * Their shardings must also be compatible. For dimensions being "zipped" together by the elementwise operation:
        * If one input's dimension is sharded by a mesh axis (e.g., `'dp'`) and the other input's corresponding dimension is also sharded by the *same* mesh axis (`'dp'`), the output dimension will also be sharded by `'dp'`.
        * If one is sharded by `'dp'` and the other is replicated (`None`) for that dimension, the output dimension is sharded by `'dp'` (replication broadcasts).
        * If one is sharded by `'dp'` and the other by a *different* mesh axis (e.g., `'mp'`) for corresponding tensor dimensions that are being zipped, this is usually an error, as it implies an ambiguous or illegal output sharding (e.g., trying to shard a single output dimension by both `'dp'` and `'mp'` in a conflicting way).
        * If the operation results in an output sharding that uses the same mesh axis multiple times for different dimensions of the output tensor (e.g., `P('dp', 'dp')`), it's an error.
    * "Outer product" dimensions (those that only appear in one argument due to broadcasting) typically retain their input sharding.

4.  **`reshape`:**
    * `reshape` is complex. An output axis can map to multiple input axes (merging) or part of an input axis (splitting).
    * Simple cases (e.g., splitting/merging axes that are fully replicated) might propagate replication.
    * In many other cases, especially involving sharded axes being split or merged, JAX's sharding rule for `reshape` will raise a trace-time error, requiring the user to provide an explicit `out_sharding` argument to `reshape` to disambiguate.

Let's see some examples using our `aurora_explicit_mesh`.

In [9]:
print(f"--- Sharding Propagation Examples (using Explicit Mesh: {aurora_explicit_mesh.axis_names}) ---")

# Setup some sharded arrays on the explicit mesh
# Array A: shape (M, K), sharded (dp, mp)
# Array B: shape (K, N), sharded (mp, None) -> K sharded by 'mp', N replicated
# Array C: shape (M, N), sharded (dp, None) -> M sharded by 'dp', N replicated

# Let M=mesh.shape['dp']*2, K=mesh.shape['mp']*2, N=5 (arbitrary)
# Example: mesh=(4,2) -> M=8, K=4, N=5
M = aurora_explicit_mesh.shape['dp'] * 2
K = aurora_explicit_mesh.shape['mp'] * 2
N_val = 5

data_A_np = np.arange(M * K, dtype=jnp.float32).reshape(M, K)
data_B_np = np.arange(K * N_val, dtype=jnp.float32).reshape(K, N_val)
data_C_np = np.arange(M * N_val, dtype=jnp.float32).reshape(M, N_val)

sharding_A = NamedSharding(aurora_explicit_mesh, P('dp', 'mp'))
sharding_B = NamedSharding(aurora_explicit_mesh, P('mp', None)) # K sharded on 'mp', N replicated
sharding_C = NamedSharding(aurora_explicit_mesh, P('dp', None)) # M sharded on 'dp', N replicated

arr_A_explicit = jax.device_put(data_A_np, sharding_A)
arr_B_explicit = jax.device_put(data_B_np, sharding_B)
arr_C_explicit = jax.device_put(data_C_np, sharding_C) # For binary op example

print(f"jax.typeof(arr_A_explicit): {jax.typeof(arr_A_explicit)}")
print(f"jax.typeof(arr_B_explicit): {jax.typeof(arr_B_explicit)}")
print(f"jax.typeof(arr_C_explicit): {jax.typeof(arr_C_explicit)}")

@jax.jit
def demonstrate_propagation(a, b, c):
  print(f"\n--- Inside JIT ---")
  print(f"Input a: {jax.typeof(a)}") # (M@dp, K@mp)
  print(f"Input b: {jax.typeof(b)}") # (K@mp, N)
  print(f"Input c: {jax.typeof(c)}") # (M@dp, N)

  # 1. Unary op
  unary_out = jnp.sin(a)
  print(f"sin(a): {jax.typeof(unary_out)}") # Should be (M@dp, K@mp)

  # 2. Binary op - compatible sharding (broadcasting N for arr_A_explicit effectively)
  # arr_A_explicit (M@dp, K@mp) + arr_C_explicit (M@dp, N) -> This example is not a direct element-wise sum due to shapes.
  # Let's use a compatible binary op example:
  # If C was also (M@dp, K@mp), then A+C -> (M@dp, K@mp)
  # If C was (M@dp, K) where K is replicated: A (M@dp, K@mp) + C (M@dp, K_repl)
  # This requires careful thought on broadcasting rules with sharding.
  # A simpler compatible example: c + c
  binary_out_cc = c + c
  print(f"c + c: {jax.typeof(binary_out_cc)}") # Should be (M@dp, N)

  # 3. Nullary op
  zeros_out = jnp.zeros_like(a) # zeros_like inherits sharding from argument by default in some cases
  print(f"jnp.zeros_like(a): {jax.typeof(zeros_out)}")

  zeros_new = jnp.zeros((M,K), dtype=a.dtype) # Created without sharded input context
  print(f"jnp.zeros((M,K)): {jax.typeof(zeros_new)}") # Expected to be unsharded (replicated on all mesh devices)

  return unary_out, binary_out_cc, zeros_out, zeros_new

# Run outside JIT to see initial types, then JIT will print traced types
_ = demonstrate_propagation(arr_A_explicit, arr_B_explicit, arr_C_explicit)

--- Sharding Propagation Examples (using Explicit Mesh: ('dp', 'mp')) ---
jax.typeof(arr_A_explicit): ShapedArray(float32[8@dp,4@mp])
jax.typeof(arr_B_explicit): ShapedArray(float32[4@mp,5])
jax.typeof(arr_C_explicit): ShapedArray(float32[8@dp,5])

--- Inside JIT ---
Input a: ShapedArray(float32[8@dp,4@mp])
Input b: ShapedArray(float32[4@mp,5])
Input c: ShapedArray(float32[8@dp,5])
sin(a): ShapedArray(float32[8@dp,4@mp])
c + c: ShapedArray(float32[8@dp,5])
jnp.zeros_like(a): ShapedArray(float32[8@dp,4@mp])
jnp.zeros((M,K)): ShapedArray(float32[8,4])


### 2.B: `out_sharding` for Ambiguity Resolution

Sometimes, the sharding propagation rules for a JAX operation might be ambiguous, or the default propagated sharding might not be what you desire for subsequent operations. In such cases, many JAX operations (especially in `jax.lax`) accept an `out_sharding` argument (or a similarly named argument like `output_sharding` or specific sharding hints for axes).

* **Purpose**: `out_sharding` allows you to explicitly specify the desired `NamedSharding` (or sometimes just a `PartitionSpec` if a mesh is contextually clear) for the output of an operation.
* **When to Use**:
    * **Ambiguity**: If JAX cannot determine a unique valid output sharding (e.g., complex `reshape` operations).
    * **Override Defaults**: If the default propagated sharding is correct but suboptimal for the next steps in your computation, leading to later resharding. Specifying `out_sharding` can preemptively set the desired layout.
    * **Operations without Clear Input Propagation**: For operations like `jnp.einsum` or complex custom operations, specifying output sharding is often necessary.

If `out_sharding` is provided, JAX will attempt to produce the output with that sharding. If this requires data movement (resharding from what would have been the default propagated sharding), XLA/GSPMD will insert the necessary collective communication operations during compilation.

This feature is more commonly found and documented for `jax.lax` primitives. For `jax.numpy` ops, direct `out_sharding` arguments are less common, and you might achieve a similar effect by applying `jax.device_put(..., new_sharding)` or `reshard(..., new_spec)` immediately after the operation.

In [10]:
# Example: Reshape where output sharding might be ambiguous or need override
# Imagine data_R_np (4, 2, 3) sharded P('dp', 'mp', None) on a mesh (4,2)
# Reshaping to (4, 6) -> P('dp', 'mp') might be desired.

data_R_shape = (aurora_explicit_mesh.shape['dp'], aurora_explicit_mesh.shape['mp'], 3)
data_R_np = np.arange(np.prod(data_R_shape), dtype=jnp.float32).reshape(data_R_shape)

sharding_R_in = NamedSharding(aurora_explicit_mesh, P('dp', 'mp', None))
arr_R_in = jax.device_put(data_R_np, sharding_R_in)
print(f"--- out_sharding Example (Conceptual) ---")
print(f"Initial arr_R_in type: {jax.typeof(arr_R_in)}")

# Desired output sharding after reshape to (dim0, dim1*dim2) -> (mesh_dp_size, mesh_mp_size * 3)
# We want this new 2D array to be sharded P('dp', 'mp')
# New shape will be (aurora_explicit_mesh.shape[0], aurora_explicit_mesh.shape[1] * 3)
new_shape_R = (data_R_shape[0], data_R_shape[1] * data_R_shape[2])
desired_sharding_R_out = NamedSharding(aurora_explicit_mesh, P('dp', 'mp'))

# jax.numpy.reshape doesn't directly take out_sharding.
# We would typically use jax.lax.reshape which does, or reshard after.

@partial(jax.jit, static_argnums=(1, 2)) # new_shape needs to be static for lax.reshape
def reshape_with_explicit_out_sharding(x, new_shape_static, out_spec):
  # Note: lax.reshape's out_sharding capabilities are more nuanced,
  # and often you'd rely on GSPMD to figure it out or use with_sharding_constraint.
  # This is a conceptual illustration; direct out_sharding on reshape isn't always straightforward
  # without deeper XLA/GSPMD interaction knowledge.
  # A more common pattern is to use jax.lax.with_sharding_constraint or reshard.

  # If lax.reshape had a direct out_sharding parameter like some other lax ops:
  # y = jax.lax.reshape(x, new_dimensions=new_shape_static[1:], dimensions=list(range(len(new_shape_static)))[1:], out_sharding=out_spec)
  # However, reshape's sharding is complex.
  # Instead, let's demonstrate resharding to the desired spec *after* a standard reshape.

  y_reshaped_default = jnp.reshape(x, new_shape_static)
  print(f"  Inside JIT - y_reshaped_default type: {jax.typeof(y_reshaped_default)}") # Observe default propagation

  # Force the desired sharding using jax.device_put (or reshard)
  # For JIT, it's better to pass the sharding object itself
  y_resharded_explicitly = jax.device_put(y_reshaped_default, out_spec)
  print(f"  Inside JIT - y_resharded_explicitly type: {jax.typeof(y_resharded_explicitly)}")
  return y_resharded_explicitly

try:
  arr_R_out = reshape_with_explicit_out_sharding(arr_R_in, new_shape_R, desired_sharding_R_out)
  print(f"Output arr_R_out type: {jax.typeof(arr_R_out)}")
  print(f"Output arr_R_out.sharding: {arr_R_out.sharding}")
  visualize_array_sharding(arr_R_out)
except Exception as e:
  print(f"Error in reshape_with_explicit_out_sharding: {e}")
  print("  This highlights that complex reshapes with sharding often require careful handling.")

--- out_sharding Example (Conceptual) ---
Initial arr_R_in type: ShapedArray(float32[4@dp,2@mp,3])
  Inside JIT - y_reshaped_default type: ShapedArray(float32[4@dp,6@mp])
  Inside JIT - y_resharded_explicitly type: ShapedArray(float32[4@dp,6@mp])
Output arr_R_out type: ShapedArray(float32[4@dp,6@mp])
Output arr_R_out.sharding: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec('dp', 'mp'), memory_kind=unpinned_host)


### 2.C: Interaction with JAX Transformations

Explicit sharding (via "sharding in types" with `AxisType.Explicit` meshes) interacts predictably with JAX's core transformations:

* **`jax.jit`**: This is the primary context where explicit sharding shines. `jit` uses the type-level sharding information to guide the XLA/GSPMD compiler, as discussed in the next section.
* **`jax.grad` (and other auto-diff transformations)**:
    * The backward pass generally tries to mirror the sharding of the forward pass. If activations are sharded `P('dp', None)` (data parallel), gradients w.r.t. those activations will also tend to be `P('dp', None)`.
    * Gradients w.r.t. weights will have the same sharding as the weights themselves. If weights are replicated `P(None, None)`, their gradients will also be replicated (requiring a subsequent all-reduce for data parallelism). If weights are sharded (e.g., for model parallelism), their gradients will be similarly sharded.
* **`jax.vmap`**:
    * `vmap` primarily deals with adding a batch dimension. If the inputs to `vmap` are already sharded, `vmap` typically preserves the sharding of the non-batched dimensions and the new batch dimension is often replicated by default unless specific sharding is applied to the output of `vmap`.
    * If you `vmap` a function that itself contains sharded operations, the interaction becomes more complex. Generally, `vmap` happens "outside" the sharding concerns of the mapped function's body, but the sharding of the inputs *to* the `vmap`ped function matters.
* **`jax.lax.scan` (and other control flow)**:
    * The sharding of loop-carried state (`carry`) and per-iteration inputs/outputs (`xs`/`ys`) is maintained across iterations.
    * The body function of `scan` is traced and compiled with the (potentially sharded) types of the initial carry and the first slice of `xs`.
    * It's crucial that sharding specifications are consistent across loop iterations.

**General Guideline:** When "sharding in types" is active, JAX aims to propagate these shardings through transformations. If ambiguity arises, an error will usually be thrown at trace time, prompting you to be more explicit (e.g., using `out_sharding` if available, or `reshard`/`device_put`).

### 2.D: GSPMD Compiler Integration & Automatic Resharding

The real power of explicit JAX-level sharding (from `AxisType.Explicit` meshes) comes from its integration with the **GSPMD** (General and Scalable Parallelization for ML models via Sharding Propagation) partitioner in the XLA compiler, which `jax.jit` invokes.

**How it Works (Conceptual):**

1.  **JAX Tracing & HLO Generation**: When `jax.jit` traces your Python function with explicitly sharded input types, this sharding information is embedded into the JAX intermediate representation (jaxpr). This jaxpr is then converted into XLA's High-Level Optimizer (HLO) representation. The sharding annotations are passed to XLA.
2.  **XLA/GSPMD Partitioning**:
    * GSPMD takes the HLO computation and the sharding annotations for inputs (and potentially outputs, if `out_sharding` was used or propagated).
    * It attempts to honor these explicit shardings. For each operation in the HLO graph, GSPMD has rules to determine the sharding of the operation's output based on its inputs' shardings.
    * **Automatic Resharding (Collectives Insertion)**: If an operation requires its inputs to have a different sharding than what they currently possess (according to the propagated types), GSPMD will automatically insert collective communication operations (like `all-gather`, `all-to-all`, `reduce-scatter`, `all-reduce`) to reshard the data.
        * **Example**: If `op(x, y)` requires `x` to be replicated but `x` is currently sharded along `'dp'`, GSPMD might insert an `all-gather` to replicate `x` before feeding it to `op`.
    * The goal is to produce an XLA HLO program that is partitioned for SPMD execution, where each device runs the same program but on its shard of the data.
3.  **Performance Costs of Automatic Resharding**:
    * While incredibly convenient, automatic resharding by GSPMD is not free. Collective operations involve communication between devices, which can be a significant performance bottleneck.
    * Understanding *when and why* GSPMD inserts collectives is crucial for optimizing distributed performance. This often involves:
        * Carefully designing your `PartitionSpec`s to minimize sharding mismatches between operations.
        * Aligning sharding of contraction dimensions in matmuls.
        * Using tools like `jax.debug.visualize_array_sharding` or examining the XLA HLO dump (`JAX_DUMP_IR_TO=/tmp/my_jit_hlo jax.jit(f)(...)`) to see where collectives are inserted.

**Explicit is Preferred**: Relying on "sharding in types" gives you the most direct control and predictability over how GSPMD will partition your program. While GSPMD can also work with `AxisType.Auto` meshes (where JAX doesn't carry sharding in the type and GSPMD infers it), making shardings explicit at the JAX level is generally recommended for complex models.

### 2.E: `jax.lax.with_sharding_constraint`

`jax.lax.with_sharding_constraint(x, sharding)` is a utility that acts as an identity function at runtime but provides a hint to the XLA/GSPMD compiler.

* **Purpose**: It tells GSPMD that the array `x` *should* have the specified `sharding` (a `NamedSharding` object) at that particular point in the computation.
* **How it Works**:
    * If `x` already has the `sharding` specified, `with_sharding_constraint` is effectively a no-op.
    * If `x` has a different sharding, GSPMD will attempt to reshard `x` to match the `sharding` constraint by inserting necessary collectives.
* **Use Cases**:
    * **Fine-tuning Compiler Choices**: When you are not using fully "sharding in types" (e.g., using `AxisType.Auto` meshes, or parts of your code where sharding isn't explicit in the type), `with_sharding_constraint` can guide GSPMD's sharding decisions for intermediate tensors.
    * **Forcing a Layout**: You might use it to ensure a specific layout before an operation that is sensitive to input sharding, even if type propagation might have led to something else.
    * **Breaking Megamorphic Kernels**: Can be used to break undesirable kernel fusion by forcing a resharding step.
* **Not a Replacement for Explicit Types**: It's more of a compiler hint than a way to define the primary sharding of an array like `jax.device_put` does. For core array sharding, prefer `NamedSharding` with `AxisType.Explicit` meshes and `device_put`.
* **Can only be applied to AxisType.Auto axes**: `jax.lax.with_sharding_constraint` can only be applied to auto axes of the mesh.

It's an identity function to JAX's tracing and differentiation, so it doesn't change the numerical outcome, only potentially the compiled program and its performance.

In [11]:
# Create an array, initially replicated
data_wsc_np = np.arange(16, dtype=jnp.float32).reshape(4, 4)
# Shard it P('dp', None)
initial_sharding_wsc = NamedSharding(aurora_explicit_mesh, P('dp', None))
arr_wsc = jax.device_put(data_wsc_np, initial_sharding_wsc)

print(f"--- with_sharding_constraint Example ---")
print(f"Initial arr_wsc type: {jax.typeof(arr_wsc)}")
print(f"Initial arr_wsc.sharding: {arr_wsc.sharding}")

# Define a desired sharding constraint, e.g., P(None, 'mp')
# This would mean resharding from ('dp', None) to (None, 'mp')
constrained_sharding_spec = P(None, 'mp')
# jax.lax.with_sharding_constraint can only be applied to auto axes of the mesh
constraint = NamedSharding(aurora_auto_mesh, constrained_sharding_spec)

@partial(jax.jit, static_argnums=(1,))
def apply_constraint(x, target_sharding):
    # Apply the constraint
    y = jax.lax.with_sharding_constraint(x, target_sharding)
    # Note: jax.typeof(y) might not reflect the constraint if the original mesh axes were Auto.
    # However, y.sharding (concrete sharding) should reflect the constraint if XLA honored it.
    # And if x was on an Explicit mesh, jax.typeof(y) should reflect the new sharding.
    return y

# Run the JITted function
try:
    arr_constrained = apply_constraint(arr_wsc, constraint)
    # The .sharding attribute of the output will show the new sharding if XLA resharded.
    # jax.typeof() will reflect it if the mesh uses Explicit axes.
    print(f"\nOutput arr_constrained type: {jax.typeof(arr_constrained)}")
    print(f"Output arr_constrained.sharding: {arr_constrained.sharding}") # This shows the concrete sharding
    visualize_array_sharding(arr_constrained)
except Exception as e:
    print(f"Error during with_sharding_constraint example: {e}")
    print("  Ensure that the mesh shapes and array dimensions are compatible for the constraint.")
    print(f"  Array shape: {data_wsc_np.shape}, Mesh shape: {aurora_explicit_mesh.shape}")
    print(f"  Initial spec: {initial_sharding_wsc.spec}, Constraint spec: {constrained_sharding_spec}")


# Example: Constraining a replicated array to be sharded
data_repl_np = np.arange(16, dtype=jnp.float32).reshape(4,4)
repl_sharding = NamedSharding(aurora_explicit_mesh, P(None,None)) # Replicated
arr_repl = jax.device_put(data_repl_np, repl_sharding)

target_sharded_spec = P('dp','mp') # Target sharded P('dp','mp')
target_sharding = NamedSharding(aurora_auto_mesh, target_sharded_spec)

print(f"\n--- Constraining a Replicated array ---")
print(f"Initial arr_repl type: {jax.typeof(arr_repl)}")
print(f"Initial arr_repl.sharding: {arr_repl.sharding}")

try:
    arr_constrained_from_repl = apply_constraint(arr_repl, target_sharding)
    print(f"\nOutput arr_constrained_from_repl type: {jax.typeof(arr_constrained_from_repl)}")
    print(f"Output arr_constrained_from_repl.sharding: {arr_constrained_from_repl.sharding}")
    visualize_array_sharding(arr_constrained_from_repl)
except Exception as e:
    print(f"Error constraining replicated array: {e}")

--- with_sharding_constraint Example ---
Initial arr_wsc type: ShapedArray(float32[4@dp,4])
Initial arr_wsc.sharding: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec('dp', None), memory_kind=unpinned_host)

Output arr_constrained type: ShapedArray(float32[4@dp,4])
Output arr_constrained.sharding: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec('dp', None), memory_kind=unpinned_host)



--- Constraining a Replicated array ---
Initial arr_repl type: ShapedArray(float32[4,4])
Initial arr_repl.sharding: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec(None, None), memory_kind=unpinned_host)

Output arr_constrained_from_repl type: ShapedArray(float32[4,4])
Output arr_constrained_from_repl.sharding: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec(None, None), memory_kind=unpinned_host)


## Part 3: `shard_map` - Explicit Per-Device Programming (SPMD)

While `jax.jit` with global array views and explicit `NamedSharding` (or compiler-driven sharding with `AxisType.Auto` meshes) offers a powerful way to write distributed programs, JAX also provides `jax.shard_map`. This is a lower-level utility that allows you to write code from the perspective of a single device/shard operating on its local piece of data, following the SPMD (Single Program, Multiple Data) model more directly.

### 3.A: `shard_map` vs. `jit` with Global View

| Feature             | `jax.jit` with Global View (using `NamedSharding`) | `jax.shard_map`                                      |
|---------------------|----------------------------------------------------|------------------------------------------------------|
| **Programming Model** | Global program, logical arrays. JAX/XLA handles sharding. | Per-device program logic (SPMD). Programmer manages local shards. |
| **Array View** | Operates on global `jax.Array`s.                   | Function receives and returns local shards of data.  |
| **Sharding Control**| Via `NamedSharding` on inputs/outputs, `with_sharding_constraint`, or GSPMD inference. | Via `in_specs` and `out_specs` (`PartitionSpec`s) arguments to `shard_map`. |
| **Collectives** | Often inserted automatically by XLA/GSPMD.         | Must be explicitly invoked by the user within the mapped function using `jax.lax` collectives (e.g., `psum`, `all_gather`). |
| **Complexity** | Can be higher-level, abstracting some distribution details. | More explicit control, but can be more verbose for managing shards and collectives. |
| **Use Cases** | General purpose, large models, complex dataflows where global view is helpful. | Implementing custom communication patterns, fine-grained control over per-device computation and collectives, potentially easier for developers from MPI backgrounds. |
| **Mesh Requirement**| `jax.sharding.Mesh`                                | `jax.sharding.Mesh`                                  |

`shard_map` does *not* perform automatic differentiation across the shards; `grad` must be applied either entirely inside or entirely outside the `shard_map`ped function.

### 3.B: SPMD Programming with `shard_map`

`jax.shard_map(f, mesh, in_specs, out_specs)`

* `f`: The function to be applied per shard. This function receives arguments that are local shards of the global input arrays.
* `mesh`: The `jax.sharding.Mesh` over which the computation is distributed.
* `in_specs`: A Pytree (tuple/list/dict) of `PartitionSpec`s, matching the structure of the arguments to `f`. Each `PartitionSpec` describes how the corresponding global input array is sharded. `shard_map` uses this to provide the correct local shard to `f`.
* `out_specs`: A Pytree of `PartitionSpec`s, matching the structure of the return values from `f`. This describes how the local shards returned by `f` should be assembled into global output arrays.

The function `f` will be JIT-compiled and run on each device in the mesh, operating on its local data shard.

In [12]:
print(f"--- shard_map Basic Example ---")
print(f"Using mesh: {aurora_explicit_mesh.axis_names} with shape {aurora_explicit_mesh.shape}")

# Global data
# Array X: sharded P('dp', None)
# Array Y: sharded P(None, 'mp')
# Mesh (dp_size, mp_size) e.g. (4,2)

dp_axis_size = aurora_explicit_mesh.shape['dp']
mp_axis_size = aurora_explicit_mesh.shape['mp']

global_X_shape = (dp_axis_size * 2, mp_axis_size * 3) # e.g. (8, 6) for (4,2) mesh
global_Y_shape = (dp_axis_size * 2, mp_axis_size * 3) # e.g. (8, 6)

data_X_np = np.arange(np.prod(global_X_shape), dtype=jnp.float32).reshape(global_X_shape)
data_Y_np = np.ones(global_Y_shape, dtype=jnp.float32) * 2

# Define PartitionSpecs for global arrays
# X is sharded along 'dp' for its first dim, 'mp' for its second dim
# Y is sharded along 'dp' for its first dim, replicated for its second dim
spec_X = P('dp', 'mp')
spec_Y = P('dp', None) # Y's second dim is replicated across 'mp' devices

# Create global sharded arrays (not strictly needed for shard_map input if data is just NumPy,
# but good for consistency if data already exists as sharded JAX arrays)
sharding_X = NamedSharding(aurora_explicit_mesh, spec_X)
sharding_Y = NamedSharding(aurora_explicit_mesh, spec_Y)

arr_X_global = jax.device_put(data_X_np, sharding_X)
arr_Y_global = jax.device_put(data_Y_np, sharding_Y)

print(f"Global arr_X_global type: {jax.typeof(arr_X_global)}")
print(f"Global arr_Y_global type: {jax.typeof(arr_Y_global)}")

# The function 'f' for shard_map receives local shards
def spmd_function(local_x_shard, local_y_shard):
  # local_x_shard and local_y_shard are the data portions on *this specific device*
  print(f"  Inside spmd_function (on one device):")
  print(f"    local_x_shard shape: {local_x_shard.shape}, type: {jax.typeof(local_x_shard)}")
  print(f"    local_y_shard shape: {local_y_shard.shape}, type: {jax.typeof(local_y_shard)}")

  # Example: element-wise addition on local shards
  local_result = local_x_shard + local_y_shard
  return local_result

# `in_specs` define how global X and Y are sharded to produce local_x_shard, local_y_shard
# `out_specs` define how the local_result shards should combine into a global result
# Let the output have the same sharding as X

# The Pytrees for in_specs and out_specs must match args/return of spmd_function
in_specs_for_map = (spec_X, spec_Y)
out_specs_for_map = spec_X # Output sharding P('dp', 'mp')

# Apply shard_map
# Use functools.partial if spmd_function had non-sharded static arguments
# In this case, it does not.

# Note: JAX might print from spmd_function multiple times during tracing/compilation.
# The shapes printed are the *local shard shapes*.

# For now, use shard_map from experimental. As jax does not have shard_map exposed for now.
from jax.experimental.shard_map import shard_map

# For arr_X_global (8,6) with P('dp','mp') on mesh (4,2):
# local_x_shard shape = (8/4, 6/2) = (2,3)
# For arr_Y_global (8,6) with P('dp',None) on mesh (4,2):
# local_y_shard shape = (8/4, 6) = (2,6) (since dim 1 is replicated over 'mp')
# The spmd_function will complain if shapes are not broadcastable for '+'
# In this case (2,3) + (2,6) is NOT directly broadcastable.
# This highlights a common issue: ensuring local shards are compatible.

# Let's adjust spec_Y or the operation for a valid example:
# Option 1: Make Y also P('dp', 'mp')
# Option 2: Replicate X's second dim to match Y: spec_X_alt = P('dp', None)
# Option 3: Modify spmd_function to handle different shard shapes if logic allows

print("\n--- Rerunning shard_map with compatible sharding for '+' ---")
spec_Y_compatible = P('dp', 'mp') # Make Y also P('dp', 'mp')
sharding_Y_compatible = NamedSharding(aurora_explicit_mesh, spec_Y_compatible)
arr_Y_global_compatible = jax.device_put(data_Y_np, sharding_Y_compatible)

print(f"Global arr_X_global type: {jax.typeof(arr_X_global)}")
print(f"Global arr_Y_global_compatible type: {jax.typeof(arr_Y_global_compatible)}")

global_result_arr_compatible = shard_map(
    spmd_function,
    mesh=aurora_explicit_mesh,
    in_specs=(spec_X, spec_Y_compatible), # Use compatible spec for Y
    out_specs=spec_X
)(arr_X_global, arr_Y_global_compatible)

print(f"\nCompatible Output global_result_arr_compatible type: {jax.typeof(global_result_arr_compatible)}")
print(f"Compatible Output global_result_arr_compatible.sharding: {global_result_arr_compatible.sharding}")
visualize_array_sharding(global_result_arr_compatible)

--- shard_map Basic Example ---
Using mesh: ('dp', 'mp') with shape OrderedDict([('dp', 4), ('mp', 2)])
Global arr_X_global type: ShapedArray(float32[8@dp,6@mp])
Global arr_Y_global type: ShapedArray(float32[8@dp,6])

--- Rerunning shard_map with compatible sharding for '+' ---
Global arr_X_global type: ShapedArray(float32[8@dp,6@mp])
Global arr_Y_global_compatible type: ShapedArray(float32[8@dp,6@mp])
  Inside spmd_function (on one device):
    local_x_shard shape: (2, 3), type: ShapedArray(float32[2,3]{mp,dp})
    local_y_shard shape: (2, 3), type: ShapedArray(float32[2,3]{mp,dp})

Compatible Output global_result_arr_compatible type: ShapedArray(float32[8@dp,6@mp])
Compatible Output global_result_arr_compatible.sharding: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec('dp', 'mp'), memory_kind=unpinned_host)


The print statements inside `spmd_function` will typically execute during JAX's tracing phase on one device to determine shapes and types, and then the compiled code runs on all devices. The key is that `local_x_shard` and `local_y_shard` are `jax.Array`s representing only the portion of the data local to the device executing that instance of `spmd_function`.

### 3.C: `in_specs` and `out_specs` with `PartitionSpec`

As seen above, `in_specs` and `out_specs` are crucial. They are Pytrees of `PartitionSpec` objects that mirror the argument and return structure of the function `f` passed to `shard_map`.

* **`in_specs`**: Tells `shard_map` how each global input array is logically partitioned. `shard_map` then ensures that `f` receives the correct local shard corresponding to that `PartitionSpec` and the device's position in the `mesh`.
* **`out_specs`**: Tells `shard_map` how to interpret the local arrays returned by `f` from each device. It specifies how these local shards should be assembled to form the global output array(s) with the desired sharding.

**Important Considerations:**
* The `PartitionSpec`s in `in_specs` and `out_specs` refer to the partitioning of the *global* array.
* The rank of a `PartitionSpec` in `in_specs` must match the rank of the corresponding global input array.
* The rank of a `PartitionSpec` in `out_specs` must match the rank of the corresponding global output array (which is formed by assembling the local shards returned by `f`).
* The shapes of the local shards received by `f` are determined by the global array shape and the `PartitionSpec` in `in_specs`. For example, if a global array of shape `(G1, G2)` is sharded with `P('dp', 'mp')` on a mesh with shape `(D, M)`, the local shard received by `f` will have shape `(G1/D, G2/M)`.

Mismatches in Pytree structure or incompatible operations on local shard shapes are common sources of errors when using `shard_map`.

### 3.D: Manual Collective Invocation in `shard_map`

A key difference from `jax.jit` with global arrays is that inside a `shard_map`ped function, if you need communication between devices (e.g., a sum-reduction across the `'dp'` axis, or an all-gather), you must **explicitly invoke `jax.lax` collective operations.**

Common `jax.lax` collectives:
* `jax.lax.psum(x, axis_name)`: Sums `x` across all devices in the mesh axis `axis_name`. The `axis_name` must be one of the `mesh.axis_names`.
* `jax.lax.all_gather(x, axis_name, tiled=True/False)`: Gathers `x` from all devices along `axis_name`. If `tiled=True`, the output has a new leading dimension of size `mesh.shape[axis_name]`.
* `jax.lax.reduce_scatter(x, axis_name, reduction_op, scatter_dimension)`: Performs a reduction then scatters parts of the result.
* `jax.lax.all_to_all(x, split_axis, concat_axis, axis_name)`: Exchanges data between devices along `axis_name`.

When using collectives inside `shard_map`:
* The `axis_name` argument refers to a name in the `mesh.axis_names` that `shard_map` is using.
* The input `x` to the collective is the *local data shard* on the device.
* The output of the collective is also a *local data shard* on the device (which might be a shard of a now globally modified/aggregated array).

In [13]:
print(f"--- shard_map with Collectives Example ---")

# Global data: an array sharded P('dp', None)
# We want to sum its elements along the 'dp' axis of the mesh.
dp_axis_size = aurora_explicit_mesh.shape['dp']
other_dim_size = 4 # Let's say each shard has this many elements for the other dimension

global_data_shape_coll = (dp_axis_size * 2, other_dim_size) # e.g. (8,4) for mesh (4,2)
data_coll_np = np.arange(np.prod(global_data_shape_coll), dtype=jnp.float32).reshape(global_data_shape_coll)

in_specs_coll = P('dp', None) # Shard dim0 by 'dp', replicate dim1
sharding_coll_in = NamedSharding(aurora_explicit_mesh, in_specs_coll)
arr_coll_global_in = jax.device_put(data_coll_np, sharding_coll_in)

print(f"Input arr_coll_global_in type: {jax.typeof(arr_coll_global_in)}")

# SPMD function that performs a psum
def spmd_psum_function(local_shard):
  # local_shard is the data on this device.
  # For P('dp', None) on mesh (D,M), local_shard shape is (global_dim0/D, global_dim1)
  print(f"  Inside spmd_psum_function (on one device): local_shard shape {local_shard.shape}")

  # Sum local_shard contributions along the 'dp' mesh axis.
  # The result of psum will be replicated across devices in the 'dp' axis.
  # So, the sharding of sum_result_local effectively becomes P(None, None) *for the 'dp' part*.
  # If the input was P('dp', 'mp'), output of psum(..., 'dp') would be P(None, 'mp').
  sum_result_local = jax.lax.psum(local_shard, axis_name='dp')
  print(f"    After psum(local_shard, 'dp'): sum_result_local shape {sum_result_local.shape}")
  return sum_result_local

# Input spec: P('dp', None)
# Output spec: What should it be?
# If input is P('dp', None) and we psum over 'dp', the 'dp' dimension is reduced away conceptually
# and the result is replicated along 'dp'. So, output spec could be P(None, None).
# The global shape of the output will be (global_dim0/dp_axis_size, other_dim_size) if we take one such replicated result,
# or it could remain (global_dim0, other_dim_size) if psum replicates the sum back.
# psum broadcasts the sum to all devices in the reduction group.
# So the local shard shape remains, but its values are the sum.
# The *logical* global array corresponding to this is now replicated along 'dp'.

out_specs_coll = P(None, None) # Result is replicated along 'dp', still replicated along 'mp' (as original dim1 was)

global_sum_result = shard_map(
    spmd_psum_function,
    mesh=aurora_explicit_mesh,
    in_specs=in_specs_coll,
    out_specs=out_specs_coll
)(arr_coll_global_in)

print(f"\nOutput global_sum_result type: {jax.typeof(global_sum_result)}")
print(f"Output global_sum_result.sharding: {global_sum_result.sharding}")
print(f"Output global_sum_result (values from one device):\n{global_sum_result}") # Will be large if printed from all shards
visualize_array_sharding(global_sum_result) # Visualization might be less clear for sums if not careful with interpretation

# Let's verify the sum manually for the first column
expected_sum_col0 = np.sum(data_coll_np[:, 0].reshape(dp_axis_size, -1), axis=0)
print(f"\nExpected sum for first element of each 'dp' group (summed over 'dp' devices): {expected_sum_col0[0]}")

--- shard_map with Collectives Example ---
Input arr_coll_global_in type: ShapedArray(float32[8@dp,4])
  Inside spmd_psum_function (on one device): local_shard shape (2, 4)
    After psum(local_shard, 'dp'): sum_result_local shape (2, 4)

Output global_sum_result type: ShapedArray(float32[2,4])
Output global_sum_result.sharding: NamedSharding(mesh=Mesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)), spec=PartitionSpec(), memory_kind=unpinned_host)
Output global_sum_result (values from one device):
[[48. 52. 56. 60.]
 [64. 68. 72. 76.]]



Expected sum for first element of each 'dp' group (summed over 'dp' devices): 48.0


## Part 4: Advanced Sharding Techniques & Mixed Modes

Project Aurora will often require more sophisticated sharding strategies than simple data or model parallelism. This part explores mixing sharding modes, understanding nuances between concrete and type-specified sharding, and applying these concepts to advanced patterns.

### 4.A: Mixing Sharding Modes (`Mesh` with mixed `AxisType`s)

In Project Aurora, while the "sharding-in-types" paradigm with `AxisType.Explicit` offers maximum predictability and control, there are scenarios where a more nuanced approach involving `AxisType.Auto` is beneficial. JAX allows a single `jax.sharding.Mesh` to have axes with different `AxisType`s.

For example, you could define a mesh like this:
`mixed_mesh = Mesh(devices_array, ('dp_explicit', 'mp_auto'), axis_types=(AxisType.Explicit, AxisType.Auto))`

**Why Mix AxisTypes?**

1.  **Gradual Adoption/Migration:** If you're migrating a large codebase from an older, more compiler-driven sharding model, you might start by making critical data paths or parameters use `Explicit` axes while leaving others as `Auto`.
2.  **Unsupported Operations:** As the "sharding-in-types" system evolves, some JAX operations might not yet have well-defined sharding propagation rules for `Explicit` axes. For such operations, sharding along an `Auto` axis allows the GSPMD compiler to attempt to find a valid partitioning.
3.  **Compiler Freedom for Non-Critical Dimensions:** For certain array dimensions or intermediate tensors, you might not have a strong opinion on their sharding, or the optimal sharding might be complex and best left to GSPMD. Marking the corresponding mesh axis as `Auto` cedes this control.
4.  **Performance Exploration:** You might experiment by switching an axis between `Explicit` and `Auto` to see if the compiler can find a more performant solution than your explicit specification for certain sub-computations.

**Interaction with `NamedSharding` and `PartitionSpec`:**

* When you create a `NamedSharding` object using a mixed-type mesh and a `PartitionSpec`:
    * If a `PartitionSpec` entry refers to a mesh axis that is `AxisType.Explicit`, that part of the sharding will be reflected in the array's JAX-level type (via `jax.typeof()`) and will follow explicit propagation rules.
    * If a `PartitionSpec` entry refers to a mesh axis that is `AxisType.Auto`, that part of the sharding will define the array's *concrete* layout (`array.sharding`), but it will *not* appear in `jax.typeof(array).sharding` for that axis. GSPMD will have more freedom regarding this `Auto` axis.

The tools `jax.experimental.shard.auto_axes` and `jax.experimental.shard.explicit_axes` (which we'll cover next) provide dynamic ways to temporarily change the `AxisType` of mesh axes within a specific code scope, further enhancing this flexibility.

### 4.B: Temporarily Enabling Compiler Control with `jax.experimental.shard.auto_axes`

In our pursuit of optimal performance and flexibility for Project Aurora, there will be times when we want to temporarily cede sharding control for certain mesh axes to the GSPMD compiler, even if those axes are part of a `Mesh` that is otherwise predominantly `Explicit`. The `jax.experimental.shard.auto_axes` utility provides exactly this capability.

**What is `auto_axes`?**

`auto_axes` is typically used as a function decorator (often with `functools.partial`) or as a Python `with` statement context manager. It allows you to specify one or more axes of a given `mesh` that should behave as if they are `AxisType.Auto` *only within the scope of the decorated function or the `with` block*.

`jax.experimental.shard.auto_axes(*axis_names_to_make_auto)` or
`@functools.partial(auto_axes, axes=('axis1', 'axis2', ...))`

**Key Use Cases:**

1.  **Fallback for Unimplemented Sharding Rules:** This is a primary motivation. If you're using a `Mesh` with `Explicit` axes and encounter a JAX operation that doesn't yet have a well-defined sharding propagation rule for those `Explicit` axes (leading to a trace-time error), you can wrap the problematic operation within an `auto_axes` context for the relevant mesh axis. This allows GSPMD to attempt to find a valid sharding and execution plan for that operation.
2.  **Overriding "Sharding-in-Types" System for Ambiguity Resolution:** Sometimes, even with defined rules, the explicit propagation might lead to an ambiguous situation or an illegal sharding (e.g., an attempt to shard a tensor dimension along the same mesh axis twice). If you have a desired final sharding for the operation's output, `auto_axes` can let the compiler try to achieve it by temporarily relaxing the strict explicit rules on the problematic axes. We saw this as a solution for the ambiguous matrix multiplication in Part 2.A.
3.  **Intentional Compiler Optimization:** For certain sub-computations, you might hypothesize that allowing GSPMD more freedom (by making some axes `Auto`) could lead to better performance than your current explicit sharding strategy for that local part. `auto_axes` allows for such experimentation.

**How it Works:**

* Inside the `auto_axes` scope, the specified mesh axes behave as `AxisType.Auto`.
* Sharding information along these temporarily `Auto` axes will *not* be part of the JAX-level type (`jax.typeof()`) within this scope. For instance, an array that was `f32[N@dp, M@mp]` might appear as `f32[N@dp, M]` inside `with auto_axes(mesh, 'mp'):` if `'dp'` remains `Explicit`.
* **The `out_shardings` Argument (Crucial!):** When the decorated function returns, or when exiting an `auto_axes` `with` block, you **must** provide an `out_shardings` keyword argument. This argument takes a `PartitionSpec` (or a Pytree of `PartitionSpec`s matching the function's output structure) that specifies how the result, computed within the `Auto` context, should be sharded with respect to the *original mesh* (which likely has `Explicit` axes that you want the output to conform to). This re-establishes the JAX-level type sharding for the result as it "exits" the `auto_axes` context. JAX will ensure the output conforms to this `out_shardings` specification, potentially inserting resharding operations if the compiler's internal choice in the `Auto` context differed.

This mechanism provides a powerful bridge between the strict, predictable world of `AxisType.Explicit` and the more flexible, compiler-driven world of `AxisType.Auto`.

In [14]:
@partial(auto_axes, axes='dp')
def g(y):
  print(f'mesh inside g: {jax.sharding.get_abstract_mesh()}')
  print(f'y.sharding inside g: {jax.typeof(y) = }', end='\n\n')
  return y * 2

@jax.jit
def f(arr1):
  print(f'mesh inside f: {jax.sharding.get_abstract_mesh()}')
  x = jnp.sin(arr1)
  print(f'x.sharding: {jax.typeof(x)}', end='\n\n')

  z = g(x, out_shardings=P("dp", "mp"))

  print(f'z.sharding: {jax.typeof(z)}', end="\n\n")
  return z + 1
with jax.sharding.use_mesh(aurora_explicit_mesh):
  some_x = reshard(np.arange(16).reshape(4, 4), P("dp", "mp"))
  print(f"Final value: {jax.typeof(f(some_x))}")

mesh inside f: AbstractMesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit))
x.sharding: ShapedArray(float32[4@dp,4@mp])

mesh inside g: AbstractMesh('dp': 4, 'mp': 2, axis_types=(Auto, Explicit))
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@mp])

z.sharding: ShapedArray(float32[4@dp,4@mp])

Final value: ShapedArray(float32[4@dp,4@mp])


### 4.C: Temporarily Enforcing Explicit Sharding with `jax.experimental.shard.explicit_axes`

While `auto_axes` lets us give more control to the compiler for specific sections by treating axes as `Auto`, there are situations where we might start with a `Mesh` whose axes are predominantly `Auto` (perhaps created via `jax.make_mesh` without specifying `axis_types`, or by design for general compiler-driven sharding) but need to enforce the strict "sharding-in-types" behavior for a particular block of code. `jax.experimental.shard.explicit_axes` provides this capability.

**What is `explicit_axes`?**

`explicit_axes` is a context manager or decorator that temporarily changes specified axes of a given `mesh` to `AxisType.Explicit` *within that specific scope*.

`jax.experimental.shard.explicit_axes(mesh, *axis_names_to_make_explicit)` or
`@functools.partial(explicit_axes, mesh=your_mesh, axes=('axis1', 'axis2', ...))`

**Key Use Cases:**

1.  **Enforcing "Sharding-in-Types" Locally:** If you are working within a broader context where mesh axes are `Auto`, but for a critical function or code segment, you need the sharding to be part of the JAX-level type (`jax.typeof()`) and want to ensure JAX's explicit sharding propagation rules apply, `explicit_axes` allows this.
2.  **Precise Sharding Control for a Sub-computation:** You might use an `Auto` mesh for most of your program but switch to `Explicit` axes for a specific module or operation where you have strong opinions about its sharding and want to ensure it's not altered by compiler heuristics for those axes.
3.  **Interfacing Components:** When combining different code components, some of which expect `Explicit` sharding semantics.

**How it Works:**

* Inside the `explicit_axes` scope, the specified `axis_names_to_make_explicit` of `your_mesh` will behave as if their `AxisType` is `Explicit`.
* Sharding information along these temporarily `Explicit` axes *will* now be part of the JAX-level type (`jax.typeof()`) within this scope.
* **The `in_shardings` Argument (Often Crucial!):** When a function decorated by `explicit_axes` is called, or when entering an `explicit_axes` `with` block, you typically **must** provide an `in_shardings` keyword argument. This argument takes a `PartitionSpec` (or a Pytree of them) that specifies how the input data (which might be coming from an `Auto` context without explicit type-level sharding) should be sharded with respect to the now `Explicit` axes of the mesh *within* this context. This establishes the initial explicit sharding for the type system to work with inside the block.
* Similarly, an `out_shardings` argument can specify how results should be sharded when exiting the context, if different from what explicit propagation would yield.

This mechanism is powerful for selectively "opting-in" to the full "sharding-in-types" model for specific parts of your code when operating with meshes that might otherwise be `Auto`.

In [15]:
@partial(explicit_axes, axes=("dp", "mp"))
def explicit_g(y):
  print(f'mesh inside g: {jax.sharding.get_abstract_mesh()}')
  print(f'y.sharding inside g: {jax.typeof(y) = }')
  z = y * 2
  print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n')
  return z

@jax.jit
def f(arr1):
  print(f'mesh inside f: {jax.sharding.get_abstract_mesh()}', end='\n\n')
  x = jnp.sin(arr1)

  z = explicit_g(x, in_shardings=P("dp", "mp"))
  print(f'z.sharding inside f: {jax.typeof(z) = }', end='\n\n')
  return z + 1

with jax.sharding.use_mesh(aurora_auto_mesh):
  some_x = jax.device_put(np.arange(16).reshape(4, 4), P("dp", "mp"))
  print(f"Final value: {jax.typeof(f(some_x))}")

mesh inside f: AbstractMesh('dp': 4, 'mp': 2, axis_types=(Auto, Auto))

mesh inside g: AbstractMesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit))
y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4@dp,4@mp])
z.sharding inside g: jax.typeof(z) = ShapedArray(float32[4@dp,4@mp])

z.sharding inside f: jax.typeof(z) = ShapedArray(float32[4,4])

Final value: ShapedArray(float32[4,4])


### 4.D: Concrete Array Sharding (`x.sharding`) vs. Type-Specified Sharding (`jax.typeof(x).sharding`) - Revisited

Throughout our exploration of `Mesh`, `AxisType`, `NamedSharding`, and context managers like `auto_axes`, a recurring theme is the distinction between how an array is *actually* laid out across devices versus how its sharding is represented *in JAX's type system at trace time*. Understanding this is paramount for debugging and for predictable distributed programming in Project Aurora.

**1. Concrete Array Sharding (`x.sharding`)**

* **What it is:** This attribute of a `jax.Array` object (`x`) directly provides its `jax.sharding.Sharding` object (e.g., a `NamedSharding` instance). This `Sharding` object describes the array's **actual, physical layout** across the devices of its associated `Mesh`.
* **Visibility:** It reflects sharding along **all** mesh axes involved in its `PartitionSpec`, regardless of whether those mesh axes were `AxisType.Explicit` or `AxisType.Auto` at the time the array's sharding was defined or inferred.
* **Source:** This concrete sharding is determined when the array is created (e.g., via `jax.device_put(data, NamedSharding(mesh, spec))`) or as the result of a JIT-compiled computation where GSPMD decided on the layout.
* **The Ground Truth:** `x.sharding` tells you how the data is *really* distributed right now. `jax.debug.visualize_array_sharding(x)` visualizes this concrete sharding.

**2. Type-Specified Sharding (`jax.typeof(x).sharding`)**

* **What it is:** This refers to the sharding information that is embedded into the **JAX-level static type** of the array `x`. This is the information JAX's tracer sees and uses for its trace-time sharding propagation rules.
* **Visibility & Dependency on `AxisType.Explicit`:** Crucially, as stated in the JAX documentation ("Explicit sharding (a.k.a. “sharding in types”)" guide):
    * "*Shardings (on JAX-level types) can only mention `explicit` mesh axes.*"
    * "*the type-specified sharding, `jax.typeof(x).sharding`, only describes the sharding along `Explicit` mesh axes. The `Auto` axes are deliberately hidden from the type because they’re the purview of the compiler.*"
* **Effect:** If an array `x` is sharded using a `NamedSharding` whose `Mesh` has an axis `'my_axis'` defined as `AxisType.Explicit`, then `jax.typeof(x)` will show an annotation like `...dtype[..., dim_size@my_axis, ...]`. If `'my_axis'` was `AxisType.Auto` in the mesh at the point this type was established, this annotation will be absent for `'my_axis'` in `jax.typeof(x).sharding.spec` (it will typically appear as `None` for that dimension's mapping to mesh axes).
* **"Sharding in Types" Paradigm:** This type-level sharding is the core of the "sharding in types" model, enabling predictable, JAX-level reasoning about distributed layouts before XLA compilation.

**Why This Distinction is Critical for Project Aurora:**

* **Predictability:** Relying on `AxisType.Explicit` for mesh axes used in your `NamedSharding` specifications gives you predictable type-level sharding that JAX can reason about statically.
* **Debugging:**
    * If `jax.typeof(x).sharding` doesn't show the sharding you expect (e.g., missing an `@my_axis` annotation), it likely means `my_axis` was not `AxisType.Explicit` in the mesh when that type was determined, or the sharding was never explicitly set at the type level.
    * If `x.sharding` (concrete) is different from what you intended, it means the physical layout is not as planned, perhaps due to compiler decisions on `Auto` axes or an incorrect `PartitionSpec`.
* **Performance:** Unexpected concrete shardings, especially if they differ from what subsequent operations expect (based on type-level propagation), can lead to silent, automatic resharding by XLA/GSPMD, introducing costly communication. Making critical shardings `Explicit` at the type level helps GSPMD make better choices.
* **Mixed Contexts:** When using `auto_axes` or `explicit_axes`, the `AxisType` of mesh axes changes temporarily *within that context*. This directly affects what `jax.typeof()` will report inside versus outside that context. The `in_shardings` and `out_shardings` arguments for these context managers are vital for bridging these transitions and (re-)establishing clear type-level shardings.

In essence, for Project Aurora's most complex components where precise sharding control and predictability are paramount, we will strive to use `Mesh`es with `AxisType.Explicit` for the axes we intend to control, and use `NamedSharding` to define these layouts. This ensures that the concrete sharding (`x.sharding`) aligns with the JAX-level type sharding (`jax.typeof(x).sharding`), giving us the full benefits of the "sharding in types" model.

In [16]:
jax.sharding.set_mesh(aurora_explicit_mesh)

def compare_shardings(x):
  print(f"=== with mesh: {jax.sharding.get_abstract_mesh()} ===")
  print(f"Concrete value sharding: {x.sharding.spec}")
  print(f"Type-specified sharding: {jax.typeof(x).sharding.spec}")

my_array = jnp.sin(reshard(np.arange(8), P("dp")))
compare_shardings(my_array)

@auto_axes
def check_in_auto_context(x):
  compare_shardings(x)
  return x

check_in_auto_context(my_array, out_shardings=P("dp"))

=== with mesh: AbstractMesh('dp': 4, 'mp': 2, axis_types=(Explicit, Explicit)) ===
Concrete value sharding: PartitionSpec('dp',)
Type-specified sharding: PartitionSpec('dp',)
=== with mesh: AbstractMesh('dp': 4, 'mp': 2, axis_types=(Auto, Auto)) ===
Concrete value sharding: PartitionSpec('dp',)
Type-specified sharding: PartitionSpec(None,)


Array([ 0.        ,  0.84147096,  0.9092974 ,  0.14112   , -0.7568025 ,
       -0.9589243 , -0.2794155 ,  0.6569866 ], dtype=float32)

### 4.E: Conceptualizing Advanced Sharding Patterns with `NamedSharding`

The `Mesh`, `PartitionSpec`, and `NamedSharding` system, particularly when leveraging `AxisType.Explicit` for "sharding in types," provides the precise building blocks for describing the data layouts required by sophisticated parallelism techniques. Here, we'll conceptually outline how you'd approach designing `PartitionSpec`s for patterns like FSDP and basic Tensor Parallelism.

**1. Fully Sharded Data Parallelism (FSDP) - Conceptual Layout**

FSDP is designed to minimize memory usage per device by sharding not just the input data batch, but also the model's parameters, gradients, and optimizer states across data-parallel workers.

* **Mesh Setup (Conceptual for FSDP):**
    * Typically, a 1D mesh is sufficient for pure FSDP, representing the data-parallel dimension.
    * `fsdp_mesh = Mesh(physical_devices, ('dp',), axis_types=(AxisType.Explicit,))`
    * Let `N_dp` be the size of this `'dp'` axis (i.e., `len(physical_devices)`).

* **Data (Activations) `PartitionSpec`:**
    * Input batch (e.g., shape `[global_batch_size, seq_len, hidden_dim]`) is sharded along the batch dimension.
    * `P_activations_fsdp = P('dp', None, None)`
    * Each device processes `global_batch_size / N_dp` samples.

* **Model Parameters (Weights) `PartitionSpec`:**
    * **Key Idea:** Each parameter tensor is fully sharded across the `'dp'` axis. Each device holds only `1/N_dp` of each parameter.
    * For a 2D weight matrix `W` (shape `[dim_out, dim_in]`), you could shard it like:
        * `P_W_fsdp = P('dp', None)` (sharding `dim_out` across data-parallel workers)
        * Or `P_W_fsdp = P(None, 'dp')` (sharding `dim_in`)
        * Or even `P_W_fsdp = P(('dp',), None)` if `dim_out` is very large and you have a multi-dimensional mesh that you are conceptually "flattening" into the `'dp'` group for parameter sharding. For a 1D FSDP mesh, `P('dp', None)` or `P(None, 'dp')` are common.
    * Similarly for biases and other parameters.

* **Computational Flow (Conceptual, if JITted with these shardings):**
    * **Forward Pass:** Before a parameter is used in a layer, the full version of that parameter needed for the current local data shard must be reconstructed on the device. This typically involves an **`all-gather`** collective operation across the `'dp'` axis to gather all shards of that parameter. After the local computation, the full parameter can be discarded to save memory.
    * **Backward Pass:** Gradients with respect to parameters are computed locally (they will be partial gradients, corresponding to the sharded parameters). Before the optimizer step (or gradient update), these partial gradients need to be summed globally and then re-sharded. This often involves a **`reduce-scatter`** collective along the `'dp'` axis. Each device ends up with the shard of the summed gradient corresponding to its shard of the parameter.
    * **Optimizer Step:** The optimizer updates its local shard of parameters using its local shard of summed gradients. Optimizer states are also sharded identically to the parameters.

    *Using `jax.jit` on operations involving these explicitly sharded parameters and activations would ideally lead GSPMD to insert these collectives. Libraries implementing FSDP often manage these patterns more explicitly.*

**2. Tensor Parallelism (TP) for Transformer MLPs - Conceptual Layout**

Tensor parallelism involves splitting individual weight matrices and activations *within* a layer across a set of devices, often called a model-parallel group.

* **Mesh Setup (Conceptual for TP):**
    * Often a 2D mesh is used to combine data parallelism with tensor parallelism.
    * `hybrid_mesh = Mesh(devices_array.reshape(dp_groups, mp_size), ('dp', 'mp'), axis_types=(AxisType.Explicit, AxisType.Explicit))`

* **Example: MLP Block (`Y = GELU(X @ W1 + B1) @ W2 + B2`)**
    * **Activations `X`** (shape `[batch_per_dp_replica, seq, hidden_in]`):
        * Sharded by `'dp'` for data parallelism.
        * Replicated across `'mp'` for the start of the MLP.
        * `P_X = P('dp', None, None)`

    * **First Linear Layer (`W1`, `B1`): Column Parallelism**
        * `W1` (shape `[hidden_in, hidden_ff]`) is sharded on its *output* dimension (`hidden_ff`) along the `'mp'` axis.
            `P_W1 = P(None, 'mp')` (assuming `hidden_in` is replicated across `'mp'`)
        * `B1` (shape `[hidden_ff]`) is also sharded along `'mp'`.
            `P_B1 = P('mp',)`
        * Computation `X @ W1`: `X` is `P('dp', None, None)`. `W1` is `P(None, 'mp')`.
            The output `(X @ W1)` will have its last dimension (originally `hidden_ff`) sharded by `'mp'`.
            Result: `P('dp', None, 'mp')`.
        * An **`all-reduce`** (sum) is typically needed across the `'mp'` axis after `X @ W1 + B1` (before GELU or before the next layer if `W2` expects its input `hidden_ff` dimension to be replicated across `'mp'`).

    * **Second Linear Layer (`W2`, `B2`): Row Parallelism**
        * Input to this layer (output of GELU, let's call it `Y_intermediate`) might be `P('dp', None, None)` if an all-reduce was performed, or `P('dp', None, 'mp')` if not. Let's assume it's `P('dp', None, None)` (replicated across `'mp'` for its `hidden_ff` dimension).
        * `W2` (shape `[hidden_ff, hidden_out]`) is sharded on its *input* dimension (`hidden_ff`) along the `'mp'` axis.
            `P_W2 = P('mp', None)`
        * `B2` (shape `[hidden_out]`) is typically replicated.
            `P_B2 = P(None,)` (or fully replicated `P(None, None, ...)` if considered on the full mesh)
        * Computation `Y_intermediate @ W2`: Input `Y_intermediate` is effectively replicated for the `'mp'` sharded dimension of `W2`. The local matrix multiplications produce partial sums. The "sum" part of the matmul across the `hidden_ff` dimension (which is sharded by `'mp'` in `W2`) effectively performs the sum reduction. The output `Z` will be `P('dp', None, None)`. No `all-reduce` is needed after this specific matmul if `Y_intermediate` was replicated across `'mp'`. If `Y_intermediate` was `P('dp',None,'mp')`, then `Z = Y_intermediate @ W2` (where `W2` is `P('mp',None)`) would naturally sum over the 'mp' axis for the contraction and result in `P('dp',None,None)`.

Designing these `PartitionSpec`s with `NamedSharding` and then applying them with `jax.device_put` sets up the initial state. The behavior of `jax.jit` on operations with these sharded inputs will then determine the necessary data movements and collective operations. For very complex interactions, `shard_map` (Part 3) or explicit resharding might be used to control specific steps.

This conceptual understanding is key before diving into full implementations, which often involve careful management of these implied or explicit collectives.

## Overall Module Debrief & Next Steps

**Module Summary:**
Architect, this comprehensive module has equipped you with the core principles and advanced techniques for mastering explicit data distribution in JAX for Project Aurora. We began by learning to draft precise data layout blueprints using `jax.sharding.PartitionSpec` (`P`). We then saw how to bind these blueprints to a `jax.sharding.Mesh` using `jax.sharding.NamedSharding` and instantiate physically distributed `jax.Array`s with `jax.device_put`.

A critical theme was "sharding in types," understanding how, with `AxisType.Explicit` mesh axes, an array's sharding becomes an inspectable part of its JAX-level type (`jax.typeof()`), distinct from its always-present concrete sharding (`array.sharding`). We explored how these type-level shardings propagate through JAX operations and how the GSPMD compiler in `jax.jit` uses this information, including mechanisms like `out_sharding` for ambiguity and the nuanced role of `jax.lax.with_sharding_constraint` (primarily for guiding `Auto` axes).

We then contrasted this global-view, explicit-sharding model with the per-device programming paradigm of `jax.shard_map`, where you take full control of local computations and explicitly invoke collectives. Finally, we delved into advanced topics, including dynamically mixing sharding modes with `auto_axes` and `explicit_axes`, reinforcing the concrete vs. type-specified sharding distinction, and conceptually designing `PartitionSpec`s for sophisticated strategies like FSDP and Tensor Parallelism.

You now possess a powerful and nuanced understanding of how to command data layout and parallelism in JAX, moving far beyond basic device placement into the realm of true distributed system design.

**Key Takeaways from this Module:**

* **`PartitionSpec` (`P`)** is the language for defining logical sharding (how tensor axes map to mesh axes or are replicated).
* **`NamedSharding`** combines a `Mesh` with a `PartitionSpec` to create a concrete sharding instruction.
* **`jax.device_put(data, named_sharding)`** is the primary way to create a `jax.Array` with a specific physical sharding.
* **"Sharding in Types"** is achieved when using `NamedSharding` with a `Mesh` whose relevant axes are `AxisType.Explicit`. This makes sharding visible in `jax.typeof()` and enables predictable JAX-level propagation.
* **`array.sharding`** always shows the concrete, physical sharding, while `jax.typeof(array).sharding` only reflects type-level sharding along `Explicit` mesh axes.
* Explicit shardings (type-level) influence GSPMD; mismatches can lead to automatic resharding (collective insertions). `out_sharding` (on ops, if available) and `auto_axes` (with its `out_shardings` argument) are key tools for resolving ambiguities or overriding default propagation when an explicit output sharding is desired.
* **`jax.lax.with_sharding_constraint`** is primarily for providing sharding hints to GSPMD for `Auto` axes or within an `auto_axes` context.
* **`jax.shard_map`** offers an alternative SPMD model with a per-device view and fully explicit user-managed collectives.
* **`auto_axes` / `explicit_axes`** provide vital flexibility to temporarily change the `AxisType` behavior of mesh axes within specific code sections, bridging explicit control with compiler automation.
* These tools form the vocabulary for designing complex distributed layouts needed for patterns like FSDP and various forms of Tensor Parallelism.

**Transition & Next Steps for Project Aurora:**

Architect, your mastery of single-host sharding is now profound. You can define intricate data layouts, understand how they interact with JAX's type system and compiler, and choose the right paradigm (`jit` with `NamedSharding` or `shard_map`) for different tasks.

This concludes Series 1 of our "JAX Pro-Level Tutorial Series: From Fundamentals to Huge-Scale Systems." We have laid the definitive guide to data distribution on a single host.

The next grand challenge for Project Aurora is to transcend the boundaries of a single machine. Our models are growing too vast. We must learn to orchestrate computations across a *fleet* of interconnected hosts. Therefore, prepare for **Series 2: JAX Multi-Controller & Distributed Topologies - Mastering the Fleet**. In Series 2, we will explore:
* The multi-controller paradigm with `jax.distributed`.
* Defining and using global `Mesh`es that span multiple hosts.
* Strategies for distributed data loading.
* Understanding the impact of physical hardware interconnects (like InfiniBand, NVLink across nodes, TPU ICI) on multi-host performance and mesh design.
* Orchestrating and debugging multi-controller JAX applications.

The skills you've honed in Series 1 are the direct prerequisites for commanding this fleet. The principles of `Mesh`, `PartitionSpec`, and `NamedSharding` will extend directly into the multi-host world.