# Model Sharding Experiment

In [1]:
from typing import *

import os
import numpy as np
import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import jax.numpy as jnp
from flax import nnx
import optax

In [2]:
# Emulate multiple devices
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" 
print(jax.devices())


[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


In [10]:
devices = np.array(jax.devices()).reshape(2, 4)
print(devices)
mesh = Mesh(devices=devices, axis_names=('data', 'model'))
print(mesh)

[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
 [CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]]
Mesh('data': 2, 'model': 4)


In [11]:
class DotReluDot(nnx.Module):
    def __init__(self, depth: int, rngs: nnx.Rngs):
        init_fn = nnx.initializers.lecun_normal()

        self.dot1 = nnx.Linear(
            depth, depth,
            kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
            use_bias=False,
            rngs=rngs
        )

        self.w2 = nnx.Param(
            init_fn(rngs.Params(), (depth, depth)),
            sharding=("model", None),
        )

    def __call__(self, x: jax.Array):
        y = self.dot1(x)
        y = jax.nn.relu(y)
        z = jnp.dot(y, self.w2.value)
        return z


In [12]:
unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))
print(unsharded_model.dot1.kernel.sharding)
print(unsharded_model.w2.sharding)

print(unsharded_model.dot1.kernel.value.sharding)
print(unsharded_model.w2.value.sharding)

(None, 'model')
('model', None)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)


In [13]:
@nnx.jit
def create_sharded_model():
    model = DotReluDot(1024, rngs=nnx.Rngs(0))
    state = nnx.state(model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)
    return model

with mesh:
    sharded_model = create_sharded_model()

    print(sharded_model.dot1.kernel.value.sharding)
    print(sharded_model.w2.value.sharding)

NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model',), memory_kind=unpinned_host)


In [14]:
jax.debug.visualize_array_sharding(sharded_model.dot1.kernel.value)

In [15]:
jax.debug.visualize_array_sharding(sharded_model.w2.value)

In [16]:
data_sharding = NamedSharding(mesh, PartitionSpec('data', None))
input = jax.device_put(jnp.ones((8, 1024)), data_sharding)
jax.debug.visualize_array_sharding(input)

with mesh:
    output = sharded_model(input)
    print(output.shape)
    print(output.sharding)
    jax.debug.visualize_array_sharding(output)

(8, 1024)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('data',), memory_kind=unpinned_host)


In [19]:

@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model, x, y):
        y_pred = model(x)
        return jnp.mean((y_pred - y)**2)
    
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(grads)

    return loss

In [20]:
inputs = jax.device_put(jax.random.normal(jax.random.key(1), (8, 1024)), data_sharding)
labels = jax.device_put(jax.random.normal(jax.random.key(2), (8, 1024)), data_sharding)

optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3))

with mesh:
    for i in range(5):
        loss = train_step(sharded_model, optimizer, inputs, labels)
        print(loss)

1.4929407
0.82017606
0.55837417
0.41078538
0.2984159
