In [None]:
import jax
from jax import jit
from jax import numpy as jnp
from jax import random

jax.devices()

In [None]:
def square(x):
    return x @ x

prng_key = random.key(0)
n = 2**14
x = random.normal(prng_key, (n, n), dtype=jnp.float32)
print("Size of x:", 4 * n ** 2 / 1e9, "GB")
x.devices()

In [None]:
jax.debug.visualize_array_sharding(x)
%timeit square(x).block_until_ready()

In [None]:
jit_square = jit(square)
_ = jit_square(x) # compiles on the first call
%timeit jit_square(x)

In [None]:
from jax.sharding import NamedSharding, PartitionSpec, SingleDeviceSharding

mesh = jax.make_mesh((4, 2), ('x', 'y'))
y_xy = jax.device_put(x, NamedSharding(mesh, PartitionSpec('x', 'y')))
jax.debug.visualize_array_sharding(y_xy)
%timeit square(y_xy).block_until_ready()

In [None]:
y_xn = jax.device_put(x, NamedSharding(mesh, PartitionSpec('x', None)))
jax.debug.visualize_array_sharding(y_xn)
%timeit square(y_xn).block_until_ready()

In [None]:
y_yn = jax.device_put(x, NamedSharding(mesh, PartitionSpec('y', None)))
jax.debug.visualize_array_sharding(y_yn)
%timeit square(y_yn).block_until_ready()

In [None]:
y_nx = jax.device_put(x, NamedSharding(mesh, PartitionSpec(None, 'x')))
jax.debug.visualize_array_sharding(y_nx)
%timeit square(y_nx).block_until_ready()

In [None]:
y_ny = jax.device_put(x, NamedSharding(mesh, PartitionSpec(None, 'y')))
jax.debug.visualize_array_sharding(y_ny)
%timeit square(y_ny).block_until_ready()

In [None]:
z = jax.device_put(y_xy, SingleDeviceSharding(jax.devices()[0]))
jax.debug.visualize_array_sharding(z)
%timeit square(z).block_until_ready()