<a href="https://colab.research.google.com/github/kaixih/JAX101/blob/master/pjit_sharding_positional_vs_named.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

*PositionalSharding and NamedSharding are two ways to express sharding. This colab demonstrates how they are used for the data sharding.*

In [78]:
import os

import functools
from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

In [79]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding

P = PartitionSpec

In [80]:
# Create an array of random values:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))

In [81]:
# Create a Sharding object to distribute a value across devices:
devices = mesh_utils.create_device_mesh((4, 2))

Positional Sharding for 4x2

In [82]:
sharding = PositionalSharding(devices)
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y)

Named Sharding for 4x2

In [83]:
# Assign names to the axes of the device mesh
mesh = Mesh(devices, axis_names=('a', 'b'))

In [84]:
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)

Positional Sharding for 2x4 (Transpose)

In [85]:
y = jax.device_put(x, sharding.T)
jax.debug.visualize_array_sharding(y)

Named Sharding for 2x4 (Transpose)

In [86]:
y = jax.device_put(x, NamedSharding(mesh, P('b', 'a')))
jax.debug.visualize_array_sharding(y)

Positional Sharding for 4x1 (Replicate second axis)

In [87]:
y = jax.device_put(x, sharding.replicate(axis=1))
jax.debug.visualize_array_sharding(y)

Named Sharding for 4x1 (Replicate second axis)

In [88]:
y = jax.device_put(x, NamedSharding(mesh, P('a', None)))
jax.debug.visualize_array_sharding(y)

Positional Sharding for 1x2 (Replicate first axis)

In [89]:
y = jax.device_put(x, sharding.replicate(axis=0))
jax.debug.visualize_array_sharding(y)

Named Sharding for 1x2 (Replicate first axis)

In [90]:
y = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
jax.debug.visualize_array_sharding(y)

Positional Sharding for 8x1 (Flatten all devices)

In [91]:
y = jax.device_put(x, sharding.reshape(8, 1))
jax.debug.visualize_array_sharding(y)

Named Sharding for 8x1 (Flatten all devices)

In [92]:
y = jax.device_put(x, NamedSharding(mesh, P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)

Positional Sharding for 1x4 (Transpose and replicate)

In [93]:
y = jax.device_put(x, sharding.T.replicate(0))
jax.debug.visualize_array_sharding(y)

Named Sharding for 1x4 (Transpose and replicate)

In [94]:
y = jax.device_put(x, NamedSharding(mesh, P(None, 'a')))
jax.debug.visualize_array_sharding(y)