<a href="https://colab.research.google.com/github/mridul-sahu/jax-sharding-tutorials/blob/main/Chapter_1_3_Global_Device_Mesh.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install --upgrade jax

# [Series 1, Chapter 1.3: Global Device Mesh - Abstracting Hardware Topology for Advanced Sharding] - The Aurora Project 🌌

## Introduction

Welcome back, Aurora Architect. In Chapter 1.2, we mastered `jax.pmap`, our first tool for single-host, multi-device data parallelism. While powerful, `pmap` treats devices as a simple list, and its sharding is largely implicit, relying on compiler heuristics (sometimes guided by `with_sharding_constraint`). This can lead to a "compiler tickling" development cycle when precise control is needed.

Project Aurora demands more sophisticated parallelism. We need to move towards JAX's modern "explicit sharding" or "sharding in types" paradigm. Here, how an array is distributed (its sharding) becomes an explicit part of its JAX-level type. This provides clarity, predictability, and fine-grained control. The cornerstone of this approach is the `jax.sharding.Mesh`.

**Chapter Goal:** This chapter introduces the `jax.sharding.Mesh` API. You will learn to define multi-dimensional logical grids of devices. This abstraction is crucial for implementing advanced sharding strategies where sharding decisions are made explicitly at the JAX level, during trace time, rather than being solely inferred by the compiler. We will focus on how `Mesh`, particularly with `AxisType.Explicit`, enables this powerful new programming model.

**Topic Introduction:** We will explore how to create 1D, 2D, and n-D device meshes, map physical devices to logical mesh axes, and understand the critical role of `AxisType.Explicit`. This `AxisType` signals that sharding along a mesh axis is determined by explicit user specifications, which is fundamental to the "sharding in types" model.

**Outcome Statement:** By the end of this chapter, you will be able to construct and interpret `jax.sharding.Mesh` instances. You will grasp the significance of `AxisType.Explicit` and be prepared to use `Mesh` as the foundation for defining explicit sharding patterns (using `NamedSharding` and `PartitionSpec` in Chapter 1.4), enabling Project Aurora to scale with precision and predictability.

### Learning Objectives for This Phase:

* Understand the motivation for a device mesh beyond `pmap` and as a foundation for "explicit sharding."
* Create and configure 1D, 2D, and 3D `jax.sharding.Mesh` instances with appropriate `axis_types`.
* Explain the significance of `AxisType.Explicit` for the "sharding in types" programming model.
* Relate logical mesh axes to physical device layouts and hardware interconnects.
* Inspect mesh properties and device assignments.

### Chapter Outline:

1.  **The Need for a Global, Explicit Device View: Foundation for "Sharding in Types"**
2.  **Defining Device Topologies: `jax.sharding.Mesh`**
    * Creating 1D Meshes
    * Creating 2D Meshes
    * Creating n-D Meshes
    * Mapping Physical Devices to Logical Meshes
    * A Note on `jax.make_mesh`
3.  **The Heart of Explicit Sharding: `jax.sharding.AxisType`**
    * `AxisType.Explicit`: The Key to "Sharding in Types"
    * Other Axis Types: `AxisType.Auto` and `AxisType.Manual` (Brief Overview)
4.  **Inspecting Your Mesh: Essential Utilities**
    * `mesh.devices`, `mesh.device_ids`, `mesh.axis_names`, `mesh.shape`, `mesh.size`
5.  **Conceptualizing Meshes: Aligning Logical Grids with Physical Hardware**
6.  **Mesh Context: `jax.sharding.use_mesh`**

## Core Concepts Refresher

* **`jax.devices()`**: Returns a list of available JAX devices.
* **Sharding**: The process of splitting arrays across multiple devices.
* **"Sharding in Types"**: The JAX paradigm where an array's distribution (sharding) across a `Mesh` is part of its static JAX-level type, queryable at trace time. This allows for deterministic propagation of sharding information through JAX operations.
* **Collective Operations**: Operations involving communication across devices.

`Mesh` allows us to define a *named, multi-dimensional grid* of devices, forming the basis for specifying how arrays are sharded in the "sharding in types" model.

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import os

# For `Mesh` and related types
from jax.sharding import Mesh, AxisType
# from jax.experimental import mesh_utils # Older utility, direct Mesh creation is preferred

# --- JAX Device Setup (Simulate Multiple CPUs if needed) ---
desired_num_devices = 8 # Let's aim for 8 devices for 2D/3D mesh examples
os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={desired_num_devices}'

print(f"JAX version: {jax.__version__}")
print(f"JAX default backend: {jax.default_backend()}")

# JAX context is implicitly initialized here (or after restart)
physical_devices = jax.devices()
num_physical_devices = len(physical_devices)

print(f"Number of local JAX devices available: {num_physical_devices}")
print(f"Available devices: {physical_devices}")

if num_physical_devices == 0:
  raise RuntimeError("No JAX devices found. Meshes require at least one device.")

### 1. The Need for a Global, Explicit Device View: Foundation for "Sharding in Types"

Traditional automatic sharding in JAX often leaves precise sharding decisions to the compiler. While `jax.lax.with_sharding_constraint` offers some guidance, achieving specific sharding configurations can be an iterative process of "compiler tickling."

The "explicit sharding" (or "sharding in types") paradigm fundamentally changes this. The core idea is that the sharding of a JAX array becomes part of its JAX-level type, making it queryable and propagable at trace time. This provides greater transparency and control. `jax.sharding.Mesh` is the bedrock of this approach. It defines a structured, logical view of your devices, against which these explicit sharding types are defined.

A `Mesh` provides:
* **Structured Topology:** Instead of a flat list, a `Mesh` is a multi-dimensional array of devices with named axes (e.g., 'data', 'model').
* **Basis for Explicit Sharding:** Arrays can be explicitly sharded across these named mesh axes using `NamedSharding` and `PartitionSpec` (Chapter 1.4). The sharding becomes part of the array's type (e.g., `int32[4@data, 2@model]`).
* **Predictable Sharding Propagation:** JAX operations have rules for how these typed shardings propagate from inputs to outputs, making the system more deterministic.

### 2. Defining Device Topologies: `jax.sharding.Mesh`

The primary way to define a mesh is `jax.sharding.Mesh(devices, axis_names, axis_types=None)`.

* `devices`: A NumPy array of JAX device objects, shaped according to the desired logical mesh topology.
* `axis_names`: A tuple or list of strings, naming each dimension of the mesh.
* `axis_types`: A tuple or list of `AxisType`. If `None`, defaults to `AxisType.Explicit` for all axes, which is crucial for the "sharding in types" model.

In [None]:
# --- 1D Mesh ---
# Often used for pure data parallelism, similar to pmap's device view.
# Uses all available devices in a single dimension.
try:
  mesh_1d_shape = (num_physical_devices,)
  # Reshape the flat list of devices into the desired mesh shape
  devices_for_1d_mesh = np.array(physical_devices).reshape(mesh_1d_shape)
  mesh_1d = Mesh(devices_for_1d_mesh, axis_names=('data_parallel_axis',))
  print(f"--- 1D Mesh ---")
  print(f"Mesh shape: {mesh_1d.shape}")
  print(f"Device IDs in mesh:\n{mesh_1d.device_ids}")
  print(f"Axis names: {mesh_1d.axis_names}")
  print(f"Default axis types: {mesh_1d.axis_types}") # Will be (AxisType.Explicit,) by default
except ValueError as e:
  print(f"Could not create 1D mesh. Error: {e}. Check num_physical_devices ({num_physical_devices}) and mesh shape {mesh_1d_shape}.")


# --- 2D Mesh ---
# Example: 2 devices for model parallelism, rest for data parallelism.
try:
  # Example: (num_data_parallel_groups, num_model_parallel_devices_per_group)
  # Let's try to make model_parallel_dim_size = 2, if possible, or 1.
  model_parallel_dim_size = 2 if num_physical_devices >= 2 else 1
  if num_physical_devices % model_parallel_dim_size != 0: # Fallback if not perfectly divisible by 2
        model_parallel_dim_size = 1 # Fallback to 1 if not cleanly divisible

  data_parallel_dim_size = num_physical_devices // model_parallel_dim_size
  mesh_2d_shape = (data_parallel_dim_size, model_parallel_dim_size)

  devices_for_2d_mesh = np.array(physical_devices[:data_parallel_dim_size * model_parallel_dim_size]).reshape(mesh_2d_shape)
  mesh_2d = Mesh(devices_for_2d_mesh, axis_names=('data', 'model')) # Default AxisType.Explicit

  print(f"\n--- 2D Mesh ---")
  print(f"Mesh shape: {mesh_2d.shape}")
  print(f"Device IDs in mesh:\n{mesh_2d.device_ids}") # Shows the 2D arrangement
  print(f"Axis names: {mesh_2d.axis_names}")
except ValueError as e:
  print(f"Could not create 2D mesh. Error: {e}. Check num_physical_devices ({num_physical_devices}) and mesh shape {mesh_2d_shape}.")



# --- 3D Mesh ---
# Example: (pipeline_stages, data_parallel_groups, model_parallel_devices)
# Assumes num_physical_devices >= 8 for a (2, 2, 2) mesh for instance.
try:
  mesh_3d_shape = (2, 2, num_physical_devices // 4) # Adjust as per your devices
  num_devices_for_3d_mesh = np.prod(mesh_3d_shape)
  devices_for_3d_mesh = np.array(physical_devices[:num_devices_for_3d_mesh]).reshape(mesh_3d_shape)
  mesh_3d = Mesh(devices_for_3d_mesh, axis_names=('pipeline', 'data', 'model'))

  print(f"\n--- 3D Mesh ---")
  print(f"Mesh shape: {mesh_3d.shape}")
  print(f"Device IDs in mesh:\n{mesh_3d.device_ids}")
  print(f"Axis names: {mesh_3d.axis_names}")
except ValueError as e:
  print(f"Could not create 3D mesh. Error: {e}. Check num_physical_devices ({num_physical_devices}) and mesh shape {mesh_3d_shape}.")

**Mapping Physical Devices to Logical Meshes:**
The critical step `np.array(physical_devices).reshape(...)` directly maps the ordered list of physical JAX devices to the logical grid of the mesh. This control over mapping is key for performance tuning, as the order of devices in `jax.devices()` determines their position in your logical mesh.

**A Note on `jax.make_mesh`:**
JAX also provides `jax.make_mesh(mesh_shape, axis_names, *, devices=None, axis_types=None)`. If `devices` is not provided, it implicitly uses `jax.devices()` and attempts to compute a good mapping, especially for TPU topologies. If `axis_types` is omitted, it also defaults to `AxisType.Auto`.
Example: `simple_mesh_auto = jax.make_mesh((2, 4), ("X", "Y"))` (uses 8 devices, 'X' and 'Y' axes are `Auto`).
To get `Explicit` axes with `make_mesh`, you must specify them:
`simple_mesh_explicit = jax.make_mesh((2, 4), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit))`
While `make_mesh` can be convenient, directly constructing `Mesh` with a reshaped `jax.devices()` array gives more explicit user control and visibility into the device mapping process, which is often beneficial for complex setups or when aligning with specific non-TPU hardware topologies.

### 3. The Heart of Explicit Sharding: `jax.sharding.AxisType`

The `axis_types` parameter in `jax.sharding.Mesh` dictates how JAX treats each named mesh axis. For the "sharding in types" paradigm, which is central to modern JAX sharding with `jax.Array` and `NamedSharding`, `AxisType.Explicit` is the most important.

* **`AxisType.Explicit` (Fundamental for "Sharding in Types"):**
    * When a mesh axis has this type, it means sharding behavior along this axis *must be explicitly defined by the user* through annotations like `NamedSharding` and `PartitionSpec` (covered in Chapter 1.4).
    * Critically, only axes marked as `Explicit` can be part of an array's JAX-level type information that includes sharding (e.g., `float32[N@axis_name]`). The JAX documentation states: *"Shardings (on JAX-level types) can only mention explicit mesh axes."*
    * The compiler relies on these explicit user specifications for these axes, rather than inferring sharding.
    * While not the default if `axis_types` is omitted (see `AxisType.Auto` below), **`AxisType.Explicit` is the strongly recommended and necessary `AxisType` for all mesh axes you intend to use with `NamedSharding` and `PartitionSpec` to achieve "sharding in types".**

* **`AxisType.Auto` (Default if `axis_types` is unspecified):**
    * If `axis_types` is not provided when creating a `Mesh` (either via `jax.sharding.Mesh` or `jax.make_mesh`), axes will default to `AxisType.Auto`.
    * This type gives more freedom to the XLA compiler to automatically manage sharding or replication along this axis for optimization.
    * Sharding along `Auto` axes is *not* part of the JAX-level type information visible to the user via `jax.typeof()` (e.g., you won't see `@axis_name` in the type if `axis_name` is `Auto`). The compiler makes these decisions, and they are not explicitly represented in the type.

* **`AxisType.Manual` (Legacy/Niche):**
    * Indicates the user will manage all data distribution and collective communication manually using lower-level primitives (often seen in `shard_map` with explicit collectives). This is less common with the `jax.Array` and `NamedSharding` workflow, which favors `AxisType.Explicit`.

In [None]:
# Determine a suitable 2D mesh shape, e.g., (2, num_devices // 2)
dim1_shape = 2
dim2_shape = num_physical_devices // dim1_shape
if dim1_shape * dim2_shape > num_physical_devices: # Adjust if not perfectly fitting
    dim2_shape = (num_physical_devices - (num_physical_devices % dim1_shape)) // dim1_shape # ensure it fits

mesh_shape_for_types_example = (dim1_shape, dim2_shape)
devices_needed_for_example_mesh = np.prod(mesh_shape_for_types_example)

if devices_needed_for_example_mesh == 0:
    raise ValueError("Calculated mesh shape results in zero devices.")

if num_physical_devices < devices_needed_for_example_mesh:
      raise ValueError(f"Not enough devices ({num_physical_devices}) for mesh shape {mesh_shape_for_types_example} requiring {devices_needed_for_example_mesh} devices.")

# Prepare device array for mesh constructor
devices_for_mesh_constructor = np.array(physical_devices[:devices_needed_for_example_mesh]).reshape(mesh_shape_for_types_example)

# Scenario 1: All axes Explicit (Necessary for "sharding in types")
# To use these axes with NamedSharding for type-level sharding, they MUST be Explicit.
mesh_all_explicit = Mesh(devices_for_mesh_constructor,
                          axis_names=('data_explicit', 'model_explicit'),
                          axis_types=(AxisType.Explicit, AxisType.Explicit)) # Explicitly set!
print(f"--- Mesh with all Explicit axes (Manually set, necessary for sharding-in-types) ---")
print(f"Mesh shape: {mesh_all_explicit.shape}")
print(f"Axis names: {mesh_all_explicit.axis_names}")
print(f"Axis types: {mesh_all_explicit.axis_types}") # Will show (AxisType.Explicit, AxisType.Explicit)

# Scenario 2: Illustrate default behavior (Auto) if axis_types is omitted
mesh_default_auto = Mesh(devices_for_mesh_constructor,
                          axis_names=('data_auto_default', 'model_auto_default'))
                          # axis_types is omitted, so it defaults to (AxisType.Auto, AxisType.Auto)
print(f"\n--- Mesh with default AxisTypes (Auto) ---")
print(f"Mesh shape: {mesh_default_auto.shape}")
print(f"Axis names: {mesh_default_auto.axis_names}")
print(f"Axis types: {mesh_default_auto.axis_types}") # Will show (AxisType.Auto, AxisType.Auto)

# Scenario 3: Mixing types explicitly
# 'data_auto_illustration' sharding would be compiler-managed and not in jax.typeof().
# 'model_explicit_illustration' can be used for sharding-in-types.
mesh_mixed_types = Mesh(devices_for_mesh_constructor,
                        axis_names=('data_auto_illustration', 'model_explicit_illustration'),
                        axis_types=(AxisType.Auto, AxisType.Explicit)) # Explicitly set mixed types
print(f"\n--- Mesh with mixed AxisTypes (Explicitly Set) ---")
print(f"Mesh shape: {mesh_mixed_types.shape}")
print(f"Axis names: {mesh_mixed_types.axis_names}")
print(f"Axis types: {mesh_mixed_types.axis_types}") # Shows (AxisType.Auto, AxisType.Explicit)

**Crucially:** For an array's sharding to be part of its JAX-level type (e.g., `int32[N@my_axis_name]`), `my_axis_name` must be an axis in the `Mesh` that was defined with `AxisType.Explicit`. If the axis is `AxisType.Auto` (which is the default if `axis_types` is not specified), its sharding behavior will be determined by the compiler and will *not* be annotated in the array's JAX type. This makes explicitly setting `AxisType.Explicit` fundamental for the "sharding in types" methodology.

### 4. Inspecting Your Mesh: Essential Utilities

Once a `Mesh` is created, you can inspect its properties to understand its structure and the devices it encompasses. This is vital for verifying your setup and for debugging sharding configurations in Project Aurora.

Key attributes to inspect:

* `mesh.size`: The total number of devices in the mesh.
* `mesh.shape`: A tuple representing the dimensions of the logical mesh (e.g., `(num_data_groups, num_model_devices)` for a 2D mesh).
* `mesh.axis_names`: The tuple of names assigned to each dimension of the mesh.
* `mesh.devices`: A NumPy array containing the actual JAX `Device` objects, arranged in the same N-dimensional layout as the logical mesh.
* `mesh.device_ids`: A NumPy array of the same shape as `mesh.devices`, but containing the integer ID of each device. This helps map logical mesh coordinates to physical device IDs.
* `mesh.axis_types`: A tuple indicating the `AxisType` for each mesh dimension.

These utilities are invaluable for debugging and verifying that your mesh is configured as intended for Project Aurora. Understanding the device layout within your mesh ensures that your sharding strategies align with your hardware capabilities.

In [None]:
mesh_to_inspect = mesh_all_explicit
print(f"\n--- Inspecting Mesh: '{mesh_to_inspect.axis_names}' ---")

print(f"\n1. Total number of devices in the mesh:")
print(f"   mesh.size: {mesh_to_inspect.size}")

print(f"\n2. Shape of the logical mesh (dimensions):")
print(f"   mesh.shape: {mesh_to_inspect.shape}") # Tuple of mesh dimension sizes

print(f"\n3. Names of the mesh axes:")
print(f"   mesh.axis_names: {mesh_to_inspect.axis_names}")

print(f"\n4. JAX Device objects arranged in the mesh's topology:")
# This is a NumPy array of device objects.
print(f"   mesh.devices:\n{mesh_to_inspect.devices}")
print(f"   Shape of mesh.devices: {mesh_to_inspect.devices.shape}")

print(f"\n5. Integer IDs of the devices in the mesh:")
# This has the same shape as mesh.devices.
print(f"   mesh.device_ids:\n{mesh_to_inspect.device_ids}")
print(f"   Shape of mesh.device_ids: {mesh_to_inspect.device_ids.shape}")

print(f"\n6. Types of the mesh axes:")
print(f"   mesh.axis_types: {mesh_to_inspect.axis_types}")

# Example: Accessing a specific device using logical mesh coordinates
# This is useful for understanding the mapping.
# Construct logical coordinates, e.g., (0,0) for a 2D mesh, (0,) for 1D.
logical_coords_example = tuple(0 for _ in mesh_to_inspect.shape)
try:
    device_at_example_coords = mesh_to_inspect.devices[logical_coords_example]
    print(f"\n7. Device at logical coordinates {logical_coords_example}:")
    print(f"   Device object: {device_at_example_coords}")
    print(f"   Its ID: {device_at_example_coords.id}")
    print(f"   Its platform: {device_at_example_coords.platform}")
except IndexError:
    print(f"\nCould not access device at {logical_coords_example} (mesh might be empty or coords out of bounds).")
except Exception as e:
    print(f"\nAn error occurred while accessing device by coordinates: {e}")

### 5. Conceptualizing Meshes: Aligning Logical Grids with Physical Hardware

While a `Mesh` is a *logical* abstraction, its performance implications are deeply tied to the *physical* layout and interconnects of your hardware. When you arrange your `jax.devices()` list into the multi-dimensional array for the `Mesh(devices_array, ...)` constructor, the order of devices in `devices_array` matters significantly.

**Considerations for Aurora's Architecture:**

* **Intra-Node vs. Inter-Node:** On a single Aurora node with multiple GPUs, these GPUs are often connected by high-speed interconnects like NVLink/NVSwitch. For multi-node setups (which we'll discuss in Series 2), communication between nodes is typically slower (e.g., over InfiniBand or Ethernet).
* **Mapping Mesh Axes to Interconnects:**
    * **Fast Axes:** Logical mesh axes that will involve frequent and high-bandwidth collective communication (e.g., a `'model_parallel'` axis for tensor parallelism, requiring `all_reduce` or `all_gather` for activations/gradients *within* a model layer) should ideally map to devices that are physically close and share the fastest interconnects. For example, on a multi-GPU server, map the `'model'` axis to GPUs within the same NVLink group.
    * **Slower Axes:** Axes used for forms of parallelism with less intensive communication, or communication that can tolerate higher latency (e.g., a `'data'` axis for data parallelism, where gradient synchronization happens once per step), might span across slower interconnects if necessary.
* **TPU Pods:** TPUs in a pod slice have a dedicated, high-speed Inter-Chip Interconnect (ICI), often a 2D or 3D torus. When creating a mesh on TPUs, `jax.devices()` usually returns devices in an order that naturally maps well to this ICI topology. You would still design your logical mesh axes (e.g., `'data'`, `'model'`) to align with the dimensions of this physical torus for optimal performance.

**Example Scenario (Conceptual for a Multi-GPU Node):**
Suppose you have an Aurora node with 8 GPUs. GPUs 0-3 might be on one CPU socket and tightly connected via NVLink, and GPUs 4-7 on another socket, similarly connected. The connection *between* these two groups of 4 might be standard PCIe, which is slower than NVLink.

If you define a `Mesh` as `Mesh(np.array(jax.devices()).reshape(2, 4), axis_names=('dp', 'mp'))`:
* You'd want the `'mp'` (model_parallel) axis of size 4 to map to devices `[0,1,2,3]` for the first `dp` replica and devices `[4,5,6,7]` for the second `dp` replica. This means the initial flat `physical_devices` array passed to `np.array()` must be ordered correctly before reshaping (e.g., `[dev0, dev1, dev2, dev3, dev4, dev5, dev6, dev7]`).
* Communication along the `'mp'` axis (within each group of 4) would then leverage the fast NVLink.
* Communication along the `'dp'` axis (between the two groups of 4) would go over the potentially slower link between sockets/groups.

Understanding your hardware is key to defining effective meshes. For very large scale systems or custom hardware, this mapping becomes even more critical.

### 6. Mesh Context: `jax.sharding.use_mesh`

JAX allows setting a global default mesh for a specific scope of code using the `jax.sharding.use_mesh(mesh_instance)` context manager. You can also set it globally (though less common in libraries) using `jax.sharding.set_mesh(mesh_instance)`. When a mesh is active in the context, some JAX APIs (especially older sharding APIs or those designed for implicit context) might use it by default if a mesh is not explicitly passed to them.

In [None]:
# This cell assumes a mesh like 'mesh_1d_explicit' or 'mesh_all_explicit' was created.
active_mesh_for_context = jax.make_mesh((2, 4), ("x", "y"))

with jax.sharding.use_mesh(active_mesh_for_context):
    print(f"Current global/abstract mesh INSIDE 'with use_mesh(...)' context: {jax.sharding.get_abstract_mesh()}")
    # Any JAX operations here that implicitly use a global mesh would pick up 'active_mesh_for_context'.
    # For example, jax.experimental.shard.reshard (an older API) would use this context.
    # However, for creating jax.Arrays with NamedSharding (Chapter 1.4),
    # we will explicitly pass the Mesh object to the NamedSharding constructor for clarity.
print(f"Current global/abstract mesh AFTER exiting context: {jax.sharding.get_abstract_mesh()}")

While `use_mesh` can be convenient for setting a default, for the "sharding in types" paradigm with `jax.Array` and `NamedSharding` (which we'll cover in the next chapter), it's often clearer and more robust to explicitly pass the `Mesh` object when creating `NamedSharding` instances. This avoids ambiguity about which mesh is being used.

## Structured Parallelism with Mesh

**Summary:**
Architect, you've now established `jax.sharding.Mesh` as a critical piece of Project Aurora's foundation for advanced, explicit parallelism. You've moved beyond `pmap`'s 1D device view to defining rich, multi-dimensional logical device topologies. You have learned to create 1D, 2D, and 3D meshes, map physical devices to them, and crucially, understand the different `AxisType`s. You now know that `AxisType.Auto` is the default if `axis_types` is unspecified, However, for the sharding to be represented in the JAX-level type (e.g., `int32[4@X,2]`, visible via `jax.typeof()`) and for the full "sharding-in-types" trace-time propagation mechanics to apply to those axes, those axes must be declared as `AxisType.Explicit` in the Mesh definition.

**Key Takeaways:**

* `jax.sharding.Mesh` provides a named, multi-dimensional abstraction for a collection of devices, essential for explicit sharding strategies.
* If `axis_types` is omitted during `Mesh` creation, axes default to `AxisType.Auto`.
* The physical-to-logical device mapping in `Mesh` construction (i.e., how `jax.devices()` is reshaped) is key for performance, especially aligning with hardware interconnects.
* Mesh utilities (`mesh.devices`, `mesh.shape`, etc.) allow inspection and verification of your mesh setup.
* A mesh context can be set using `jax.sharding.use_mesh`, but explicit passing of `Mesh` objects to sharding APIs like `NamedSharding` is often preferred for clarity.

**Transition:**
With `Mesh` established as our way of describing *how devices are organized*, the next vital step is to define *how specific JAX arrays (data, parameters) are distributed or sharded across these devices*. This is where `PartitionSpec` and `NamedSharding` come into play. Prepare for **Chapter 1.4: `PartitionSpec` - The Logical Blueprint for Array Sharding**, where you'll learn to draft the precise distribution plans for Aurora's vast datasets and complex models across your meticulously defined device meshes.

## Further Reading & Resources

* **JAX Distributed Arrays and Automatic Parallelization**: [https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) (This notebook covers `Mesh` and newer sharding APIs).