<a href="https://colab.research.google.com/github/kaixih/JAX101/blob/master/pjit_flax_named.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax

from jax import lax, random, numpy as jnp

import flax
from flax import struct, traverse_util, linen as nn
from flax.linen import spmd # Flax Linen SPMD.

In [2]:
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import Mesh, PartitionSpec
from jax.experimental import mesh_utils

# Start a device mesh.
device_mesh = mesh_utils.create_device_mesh((4, 2))

# Annotate each axis with a name.
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
mesh



Mesh(device_ids=array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]]), axis_names=('data', 'model'))

In [3]:
class SuperDot(nn.Module):
  depth: int
  max_history_length: int
  @nn.compact
  def __call__(self, x):
    W1 = self.param(
        'W1', 
        nn.with_partitioning(nn.initializers.xavier_normal(), (None, 'model')),
        (x.shape[-1], self.depth))
    x_max_history = self.variable(
        'fp8_params', 'x_max_history',
        nn.with_partitioning(nn.initializers.zeros_init(), (None,)),
        self.make_rng('fp8_params'), (self.max_history_length,))
    w_max_history = self.variable(
        'fp8_params', 'w_max_history', 
        nn.with_partitioning(nn.initializers.zeros_init(), (None,)),
        self.make_rng('fp8_params'), (self.max_history_length,))
    
    # The scales should also be defined in variables and be used in the dot.
    
    y = jnp.dot(x, W1)
    x_max = jnp.max(x, axis=(0, 1), keepdims=True)
    w_max = jnp.max(W1, axis=(0, 1), keepdims=True)

    # Fake max_history update. The new scales should also be computed.
    x_max_history.value = x_max_history.value + x_max
    w_max_history.value = w_max_history.value + w_max

    return y, x_max_history.value, w_max_history.value

In [4]:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
k = random.PRNGKey(0)

model = SuperDot(8192, 16)

In [5]:
# A functional way of model initialization.
def init_fn(k, x):
  rngs = {'params': k, 'fp8_params': k}
  variables = model.init(rngs, x) # Initialize the model.
  return variables

abstract_variables = jax.eval_shape(init_fn, k, x)
# This `state_spec` has the same pytree structure as the output
# of the `init_fn`.
state_spec = nn.get_partition_spec(abstract_variables)
state_spec

FrozenDict({
    fp8_params: {
        w_max_history: PartitionSpec(None,),
        x_max_history: PartitionSpec(None,),
    },
    params: {
        W1: PartitionSpec(None, 'model'),
    },
})

In [6]:
pjit_init_fn = pjit(init_fn,
                    in_axis_resources=(PartitionSpec(None), PartitionSpec('data', None)),  # PRNG key and x
                    out_axis_resources=state_spec,  # params
                    )
# if in_axis_resources, we need mesh context
with mesh:
  initialized_state = pjit_init_fn(k, x)
jax.tree_map(jnp.shape, initialized_state)

FrozenDict({
    fp8_params: {
        w_max_history: Partitioned(value=(1, 16), names=(None,), mesh=None),
        x_max_history: Partitioned(value=(1, 16), names=(None,), mesh=None),
    },
    params: {
        W1: Partitioned(value=(8192, 8192), names=(None, 'model'), mesh=None),
    },
})

In [7]:
print(initialized_state['params']['W1'].value.sharding)
print(initialized_state['fp8_params']['x_max_history'].value.sharding)
print(initialized_state['fp8_params']['w_max_history'].value.sharding)

GSPMDSharding({devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate})
GSPMDSharding({replicated})
GSPMDSharding({replicated})


In [8]:
def infer_step(state, x):
  rngs = {'fp8_params': jax.random.PRNGKey(0)}
  y, new_state = model.apply({'params': state['params'], 'fp8_params': state['fp8_params']}, x, rngs=rngs, mutable=['fp8_params'])
  return y

pjit_step_fn = pjit(infer_step,
                    in_axis_resources=(state_spec, PartitionSpec('data', None)),  # params and x
                    out_axis_resources=(PartitionSpec('data', 'model'), PartitionSpec(None), PartitionSpec(None))  # y and max
                    )
with mesh:
  y, x_max, w_max = pjit_step_fn(initialized_state, x)
print('y sharding:')
jax.debug.visualize_array_sharding(y)
print('x_max sharding:')
jax.debug.visualize_array_sharding(x_max)
print('w_max sharding:')
jax.debug.visualize_array_sharding(w_max)

y sharding:


x_max sharding:


w_max sharding:


In [9]:
with mesh:
  lowered = pjit_step_fn.lower(initialized_state, x)
compiled = lowered.compile().compiler_ir()

In [10]:
for module in compiled:
  print(module.to_string())

HloModule pjit_infer_step, entry_computation_layout={(f32[1,16]{1,0},f32[1,16]{1,0},f32[8192,4096]{1,0},f32[2048,8192]{1,0})->(f32[2048,4096]{1,0}, f32[1,16]{1,0}, f32[1,16]{1,0})}, allow_spmd_sharding_propagation_to_output={false,false,false}

%region_0.8 (Arg_0.9: f32[], Arg_1.10: f32[]) -> f32[] {
  %Arg_0.9 = f32[] parameter(0)
  %Arg_1.10 = f32[] parameter(1)
  ROOT %maximum.11 = f32[] maximum(f32[] %Arg_0.9, f32[] %Arg_1.10), metadata={op_name="pjit(infer_step)/jit(main)/SuperDot/reduce_max[axes=(0, 1)]" source_file="<ipython-input-3-db1fa558fa36>" source_line=22}
}

%region_1.14 (Arg_0.15: f32[], Arg_1.16: f32[]) -> f32[] {
  %Arg_0.15 = f32[] parameter(0)
  %Arg_1.16 = f32[] parameter(1)
  ROOT %maximum.17 = f32[] maximum(f32[] %Arg_0.15, f32[] %Arg_1.16), metadata={op_name="pjit(infer_step)/jit(main)/SuperDot/reduce_max[axes=(0, 1)]" source_file="<ipython-input-3-db1fa558fa36>" source_line=22}
}

%fused_computation (param_0: f32[1,16], param_1.1: f32[]) -> f32[1,16] {
  %param