<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/jax/parallel101.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install --upgrade flax orbax jax treescope

In [4]:
import os
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.experimental.shard_map import shard_map

from functools import partial
import numpy as np

In [5]:
USE_CPU_ONLY = True

flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
    flags += " --xla_force_host_platform_device_count=8"  # Simulate 8 devices
    # Enforce CPU-only execution
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
    # GPU flags
    flags += (
        "--xla_gpu_enable_triton_softmax_fusion=true "
        "--xla_gpu_triton_gemm_any=false "
        "--xla_gpu_enable_async_collectives=true "
        "--xla_gpu_enable_latency_hiding_scheduler=true "
        "--xla_gpu_enable_highest_priority_async_stream=true "
    )
os.environ["XLA_FLAGS"] = flags

print(f'You have 8 “fake” JAX devices now: {jax.devices()}')

You have 8 “fake” JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


In [25]:
arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()
arr.sharding

SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

In [9]:
arr

Array([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11., 12., 13., 14., 15.],
       [16., 17., 18., 19., 20., 21., 22., 23.],
       [24., 25., 26., 27., 28., 29., 30., 31.]], dtype=float32)

In [8]:
jax.debug.visualize_array_sharding(arr)

In [26]:
mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)

NamedSharding(mesh=Mesh('x': 2, 'y': 4), spec=PartitionSpec('x', 'y'), memory_kind=unpinned_host)


In [27]:
arr_sharded = jax.device_put(arr, sharding)

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)

[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]


In [17]:
@jax.jit
def f_elementwise(x):
  return 2 * x + 1

result = f_elementwise(arr_sharded)
jax.debug.visualize_array_sharding(result)

In [29]:
@jax.jit
def f_contract(x):
  return x.mean(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
result

Array([12., 13., 14., 15., 16., 17., 18., 19.], dtype=float32)

In [33]:
@jax.jit
def f_contract_2(x):
  out = x.sum(axis=0)
  mesh = jax.make_mesh((8,), ('x',))
  sharding = jax.sharding.NamedSharding(mesh, P('x'))
  return jax.lax.with_sharding_constraint(out, sharding)

result = f_contract_2(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

[48. 52. 56. 60. 64. 68. 72. 76.]


In [35]:
from jax.experimental.shard_map import shard_map
mesh = jax.make_mesh((8,), ('x',))

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(48)
f_elementwise_sharded(arr)

Array([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33,
       35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67,
       69, 71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95],      dtype=int32)

In [38]:
x = jnp.arange(48)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)

global shape: x.shape=(48,)
device local shape: x.shape=(6,)


In [39]:
# Create a Sharding object to distribute a value across devices:
mesh = jax.make_mesh((4, 2), ('x', 'y'))

In [41]:
from jax.sharding import PartitionSpec as P, NamedSharding
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)

In [42]:
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)

In [43]:
# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()

639 ms ± 101 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
635 ms ± 102 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
