# Exploring PJIT today

In [2]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
import jax
jax.config.update('jax_num_cpu_devices', 8)

In [3]:
import jax
import numpy as np
import jax.numpy as jnp
from jax import sharding as shd
from jax._src.mesh import get_concrete_mesh

def P(*args):
    return shd.NamedSharding(get_concrete_mesh(), shd.PartitionSpec(*args))

def see_compiled(fn, *args):
    compiled = fn.lower(*args).compile()
    print(compiled.as_text())

# Replicated

In [44]:
from jax.experimental.shard import reshard

bs = 16
d = 32
m = 128

shd.set_mesh(jax.make_mesh(axis_shapes=(8,), axis_names=('X',), axis_types=(shd.AxisType.Explicit,)))

X = jnp.zeros((bs, d), device=P())
W1 = jnp.zeros((d, m), device=P())
W2 = jnp.zeros((m, d), device=P())

def f(X, W1, W2):
    return jax.nn.relu(X @ W1) @ W2
# see_compiled(jax.jit(f), X, W1, W2)

```python
ENTRY %main.18_spmd (param: f32[16,32], param.1: f32[32,128], param.2: f32[128,32]) -> f32[16,32] {
  # Parameters
  %param = f32[16,32]{1,0} parameter(0), sharding={replicated}, metadata={op_name="X"}
  %param.1 = f32[32,128]{1,0} parameter(1), sharding={replicated}, metadata={op_name="W1"}
  %param.2 = f32[128,32]{1,0} parameter(2), sharding={replicated}, metadata={op_name="W2"}
  # First matmul
  %dot = f32[16,128]{1,0} dot(f32[16,32]{1,0} %param, f32[32,128]{1,0} %param.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f)/jit(main)/dot_general" source_file="/tmp/ipykernel_234850/1341430039.py" source_line=14}
  # ReLU
  %broadcast_maximum_fusion = f32[16,128]{1,0} fusion(f32[16,128]{1,0} %dot), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(f)/jit(main)/jit(relu)/max" source_file="/tmp/ipykernel_234850/1341430039.py" source_line=14}
  # Second matmul
  ROOT %dot.1 = f32[16,32]{1,0} dot(f32[16,128]{1,0} %broadcast_maximum_fusion, f32[128,32]{1,0} %param.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f)/jit(main)/dot_general" source_file="/tmp/ipykernel_234850/1341430039.py" source_line=14}
}
```

# Data Parallel

In [11]:

X = jnp.zeros((bs, d), device=P('X'))
W1 = jnp.zeros((d, m), device=P())
W2 = jnp.zeros((m, d), device=P())

def f(X, W1, W2):
    return jax.nn.relu(X @ W1) @ W2
see_compiled(jax.jit(f), X, W1, W2)

HloModule jit_f, is_scheduled=true, entry_computation_layout={(f32[2,32]{1,0}, f32[32,128]{1,0}, f32[128,32]{1,0})->f32[2,32]{1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false}, num_partitions=8

%fused_computation (param_0: f32[2,128]) -> f32[2,128] {
  %param_0 = f32[2,128]{1,0} parameter(0)
  %constant.2 = f32[] constant(0)
  %broadcast.2 = f32[2,128]{1,0} broadcast(f32[] %constant.2), dimensions={}, metadata={op_name="jit(f)/jit(main)/jit(relu)/max" source_file="/tmp/ipykernel_234850/1296888503.py" source_line=6}
  ROOT %maximum.2 = f32[2,128]{1,0} maximum(f32[2,128]{1,0} %param_0, f32[2,128]{1,0} %broadcast.2), metadata={op_name="jit(f)/jit(main)/jit(relu)/max" source_file="/tmp/ipykernel_234850/1296888503.py" source_line=6}
}

ENTRY %main.18_spmd (param: f32[2,32], param.1: f32[32,128], param.2: f32[128,32]) -> f32[2,32] {
  %param = f32[2,32]{1,0} parameter(0), sharding={devices=[8,1]<=[8]}, metadata={op_name="X"}
  %param.1 = f32[32,128]{1,0} parameter(1),

# FSDP

In [28]:
X = jnp.zeros((bs, d), device=P('X'))
W1 = jnp.zeros((d, m), device=P(None, 'X'))
W2 = jnp.zeros((m, d), device=P('X'))

def f(X, W1, W2):
    return jax.nn.relu(X @ reshard(W1, P())) @ reshard(W2, P())

see_compiled(jax.jit(f), X, W1, W2)

HloModule jit_f, is_scheduled=true, entry_computation_layout={(f32[2,32]{1,0}, f32[32,16]{1,0}, f32[16,32]{1,0})->f32[2,32]{1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false}, num_partitions=8

%fused_computation (param_0: f32[2,128]) -> f32[2,128] {
  %param_0 = f32[2,128]{1,0} parameter(0)
  %constant.2 = f32[] constant(0)
  %broadcast.2 = f32[2,128]{1,0} broadcast(f32[] %constant.2), dimensions={}, metadata={op_name="jit(f)/jit(main)/jit(relu)/max" source_file="/tmp/ipykernel_234850/1296888503.py" source_line=6}
  ROOT %maximum.2 = f32[2,128]{1,0} maximum(f32[2,128]{1,0} %param_0, f32[2,128]{1,0} %broadcast.2), metadata={op_name="jit(f)/jit(main)/jit(relu)/max" source_file="/tmp/ipykernel_234850/1296888503.py" source_line=6}
}

ENTRY %main.20_spmd (param: f32[2,32], param.1: f32[32,16], param.2: f32[16,32]) -> f32[2,32] {
  %param = f32[2,32]{1,0} parameter(0), sharding={devices=[8,1]<=[8]}, metadata={op_name="X"}
  %param.1 = f32[32,16]{1,0} parameter(1), shar

# Tensor Parallel (Model Parallel)

In [43]:
X = jnp.zeros((bs, d), device=P())
W1 = jnp.zeros((d, m), device=P(None, 'X'))
W2 = jnp.zeros((m, d), device=P('X'))

def f(X, W1, W2):
    mid = jax.nn.relu(X @ W1)
    print(f"{jax.typeof(mid)=} {jax.typeof(W2)=}")
    out = jnp.einsum('bm,md->bd', mid, W2, out_sharding=P())
    print(f"{jax.typeof(out)=}")
    return out
see_compiled(jax.jit(f), X, W1, W2)

jax.typeof(mid)=ShapedArray(float32[16,128@X]) jax.typeof(W2)=ShapedArray(float32[128@X,32])
jax.typeof(out)=ShapedArray(float32[16,32])
HloModule jit_f, is_scheduled=true, entry_computation_layout={(f32[16,32]{1,0}, f32[32,16]{1,0}, f32[16,32]{1,0})->f32[16,32]{1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false}, num_partitions=8

%add.clone (x.1: f32[], y.1: f32[]) -> f32[] {
  %x.1 = f32[] parameter(0)
  %y.1 = f32[] parameter(1)
  ROOT %add.1 = f32[] add(f32[] %x.1, f32[] %y.1)
}

%fused_computation (param_0: f32[16,16]) -> f32[16,16] {
  %param_0 = f32[16,16]{1,0} parameter(0)
  %constant.4 = f32[] constant(0)
  %broadcast.2 = f32[16,16]{1,0} broadcast(f32[] %constant.4), dimensions={}, metadata={op_name="jit(f)/jit(main)/jit(relu)/max" source_file="/tmp/ipykernel_234850/3672491555.py" source_line=6}
  ROOT %maximum.2 = f32[16,16]{1,0} maximum(f32[16,16]{1,0} %param_0, f32[16,16]{1,0} %broadcast.2), metadata={op_name="jit(f)/jit(main)/jit(relu)/max" source_fil