<a href="https://colab.research.google.com/github/kaixih/JAX101/blob/master/pjit_reduce_op.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [32]:
import os
from typing import Optional
import jax
import jax.numpy as jnp

In [33]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding

P = PartitionSpec

In [34]:
# Create an array of random values:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))

In [35]:
# Create a Sharding object to distribute a value across devices:
devices = mesh_utils.create_device_mesh((4, 2))

In [36]:
# Assign names to the axes of the device mesh
mesh = Mesh(devices, axis_names=('a', 'b'))

In [37]:
a = jax.device_put(x, NamedSharding(mesh, P('a', None)))
b = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
print('a sharding:')
jax.debug.visualize_array_sharding(a)
print('b sharding:')
jax.debug.visualize_array_sharding(b)

a sharding:


b sharding:


In [38]:
# visualize_array_sharding only works with arrays with 1 or 2 dims. So we keep dims here.
a_max = jnp.max(a, axis=(0, 1), keepdims=True)
b_max = jnp.max(b, axis=(0, 1), keepdims=True)
d = jnp.dot(a, b)
print('d sharding:')
jax.debug.visualize_array_sharding(d)
print('a_max sharding:')
jax.debug.visualize_array_sharding(a_max)
print('b_max sharding:')
jax.debug.visualize_array_sharding(b_max)

d sharding:


a_max sharding:


b_max sharding:


In [39]:
@jax.jit
def f(a, b):
  a_max = jnp.max(a, axis=(0, 1), keepdims=True)
  b_max = jnp.max(b, axis=(0, 1), keepdims=True)
  d = jnp.dot(a, b)
  return d, a_max, b_max
d, a_max, b_max = f(a, b)
print('d sharding:')
jax.debug.visualize_array_sharding(d)
print('a_max sharding:')
jax.debug.visualize_array_sharding(a_max)
print('b_max sharding:')
jax.debug.visualize_array_sharding(b_max)

d sharding:


a_max sharding:


b_max sharding:


In [40]:
lowered = f.lower(a, b)
compiled = lowered.compile().compiler_ir()

In [41]:
for module in compiled:
  print(module.to_string())

HloModule jit_f, entry_computation_layout={(f32[2048,8192]{1,0},f32[8192,4096]{1,0})->(f32[2048,4096]{1,0}, f32[1,1]{1,0}, f32[1,1]{1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true}

%region_0.4 (Arg_0.5: f32[], Arg_1.6: f32[]) -> f32[] {
  %Arg_0.5 = f32[] parameter(0)
  %Arg_1.6 = f32[] parameter(1)
  ROOT %maximum.7 = f32[] maximum(f32[] %Arg_0.5, f32[] %Arg_1.6), metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0, 1)]" source_file="<ipython-input-16-f595938591a5>" source_line=1}
}

%region_1.10 (Arg_0.11: f32[], Arg_1.12: f32[]) -> f32[] {
  %Arg_0.11 = f32[] parameter(0)
  %Arg_1.12 = f32[] parameter(1)
  ROOT %maximum.13 = f32[] maximum(f32[] %Arg_0.11, f32[] %Arg_1.12), metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0, 1)]" source_file="<ipython-input-16-f595938591a5>" source_line=1}
}

ENTRY %main.18_spmd (param: f32[2048,8192], param.1: f32[8192,4096]) -> (f32[2048,4096], f32[1,1], f32[1,1]) {
  %param = f32[2048,8192]{1,0} parameter(0), sharding={de