<a href="https://colab.research.google.com/github/kaixih/JAX101/blob/master/pjit_flax_logical_axis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax

from jax import lax, random, numpy as jnp

import flax
from flax import struct, traverse_util, linen as nn
from flax.linen import spmd # Flax Linen SPMD.

In [2]:
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import Mesh, PartitionSpec
from jax.experimental import mesh_utils

# Start a device mesh.
device_mesh = mesh_utils.create_device_mesh((4, 2))

# Annotate each axis with a name.
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
mesh



Mesh(device_ids=array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]]), axis_names=('data', 'model'))

In [3]:
class SuperDot(nn.Module):
  depth: int
  max_history_length: int
  @nn.compact
  def __call__(self, x):
    W1 = self.param(
        'W1', 
        spmd.with_logical_partitioning(nn.initializers.xavier_normal(), ('embed', 'hidden')),
        (x.shape[-1], self.depth))
    x_max_history = self.param(
        'x_max_history', 
        spmd.with_logical_partitioning(nn.initializers.zeros_init(), ('history_length',)),
        (self.max_history_length,))
    w_max_history = self.param(
        'w_max_history', 
        spmd.with_logical_partitioning(nn.initializers.zeros_init(), ('history_length',)),
        (self.max_history_length,))

    y = jnp.dot(x, W1)
    x_max = jnp.max(x, axis=(0, 1), keepdims=True)
    w_max = jnp.max(W1, axis=(0, 1), keepdims=True)

    return y, x_max, w_max

In [4]:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
k = random.PRNGKey(0)

model = SuperDot(8192, 16)

In [5]:
# A functional way of model initialization.
def init_fn(k, x):
  variables = model.init(k, x) # Initialize the model.
  return variables

abstract_variables = jax.eval_shape(init_fn, k, x)
# This `state_spec` has the same pytree structure as the output
# of the `init_fn`.
logical_output_spec = nn.get_partition_spec(abstract_variables)
logical_output_spec

FrozenDict({
    params: {
        W1: PartitionSpec('embed', 'hidden'),
        w_max_history: PartitionSpec('history_length',),
        x_max_history: PartitionSpec('history_length',),
    },
})

In [6]:
# Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`.
rules = (('batch', 'data'),
         ('hidden', 'model'))

logical_state_spec = spmd.logical_to_mesh(logical_output_spec, rules)
logical_state_spec

FrozenDict({
    params: {
        W1: PartitionSpec(None, 'model'),
        w_max_history: PartitionSpec(None,),
        x_max_history: PartitionSpec(None,),
    },
})

In [7]:
pjit_init_fn = pjit(init_fn,
                    in_axis_resources=(PartitionSpec(None), PartitionSpec('data', None)),  # PRNG key and x
                    out_axis_resources=logical_state_spec,  # params
                    )
# if in_axis_resources, we need mesh context
with mesh:
  initialized_state = pjit_init_fn(k, x)
jax.tree_map(jnp.shape, initialized_state)

FrozenDict({
    params: {
        W1: LogicallyPartitioned(value=(8192, 8192), names=('embed', 'hidden'), mesh=None, rules=None),
        w_max_history: LogicallyPartitioned(value=(16,), names=('history_length',), mesh=None, rules=None),
        x_max_history: LogicallyPartitioned(value=(16,), names=('history_length',), mesh=None, rules=None),
    },
})

In [8]:
print(initialized_state['params']['W1'].value.sharding)
print(initialized_state['params']['x_max_history'].value.sharding)
print(initialized_state['params']['w_max_history'].value.sharding)

GSPMDSharding({devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate})
GSPMDSharding({replicated})
GSPMDSharding({replicated})


In [9]:
def infer_step(state, x):
  y, x_max, w_max = model.apply({'params': state['params']}, x)
  # Update the initialized_state['params']['x_max_history'] with x_max.
  # Update the initialized_state['params']['w_max_history'] with w_max.
  return y, x_max, w_max

pjit_step_fn = pjit(infer_step,
                    in_axis_resources=(logical_state_spec, PartitionSpec('data', None)),  # params and x
                    out_axis_resources=(PartitionSpec('data', 'model'), PartitionSpec(None), PartitionSpec(None))  # y and max
                    )
with mesh:
  y, x_max, w_max = pjit_step_fn(initialized_state, x)
print('y sharding:')
jax.debug.visualize_array_sharding(y)
print('x_max sharding:')
jax.debug.visualize_array_sharding(x_max)
print('w_max sharding:')
jax.debug.visualize_array_sharding(w_max)

y sharding:


x_max sharding:


w_max sharding:


In [10]:
with mesh:
  lowered = pjit_step_fn.lower(initialized_state, x)
compiled = lowered.compile().compiler_ir()

In [11]:
for module in compiled:
  print(module.to_string())

HloModule pjit_infer_step, entry_computation_layout={(f32[8192,4096]{1,0},f32[2048,8192]{1,0})->(f32[2048,4096]{1,0}, f32[1,1]{1,0}, f32[1,1]{1,0})}, allow_spmd_sharding_propagation_to_output={false,false,false}

%region_0.5 (Arg_0.6: f32[], Arg_1.7: f32[]) -> f32[] {
  %Arg_0.6 = f32[] parameter(0)
  %Arg_1.7 = f32[] parameter(1)
  ROOT %maximum.8 = f32[] maximum(f32[] %Arg_0.6, f32[] %Arg_1.7), metadata={op_name="pjit(infer_step)/jit(main)/SuperDot/reduce_max[axes=(0, 1)]" source_file="<ipython-input-3-7f4f0c7fea61>" source_line=20}
}

%region_1.11 (Arg_0.12: f32[], Arg_1.13: f32[]) -> f32[] {
  %Arg_0.12 = f32[] parameter(0)
  %Arg_1.13 = f32[] parameter(1)
  ROOT %maximum.14 = f32[] maximum(f32[] %Arg_0.12, f32[] %Arg_1.13), metadata={op_name="pjit(infer_step)/jit(main)/SuperDot/reduce_max[axes=(0, 1)]" source_file="<ipython-input-3-7f4f0c7fea61>" source_line=20}
}

ENTRY %main.21_spmd (param.1: f32[8192,4096], param: f32[2048,8192]) -> (f32[2048,4096], f32[1,1], f32[1,1]) {
  %par