# Model Sharding Experiment

In [1]:
from typing import *

import os
import numpy as np
import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import jax.numpy as jnp
from flax import nnx
import optax

In [None]:
# Emulate multiple devices
# os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" 
print(jax.devices())


[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)]


In [3]:
devices = np.array(jax.devices()).reshape(2, 4)
print(devices)
mesh = Mesh(devices=devices, axis_names=('data', 'model'))
print(mesh)

[[CudaDevice(id=0) CudaDevice(id=1) CudaDevice(id=2) CudaDevice(id=3)]
 [CudaDevice(id=4) CudaDevice(id=5) CudaDevice(id=6) CudaDevice(id=7)]]
Mesh('data': 2, 'model': 4)


In [4]:
class DotReluDot(nnx.Module):
    def __init__(self, depth: int, rngs: nnx.Rngs):
        init_fn = nnx.initializers.lecun_normal()

        self.dot1 = nnx.Linear(
            depth, depth,
            kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
            use_bias=False,
            rngs=rngs
        )

        self.w2 = nnx.Param(
            init_fn(rngs.Params(), (depth, depth)),
            sharding=("model", None),
        )

    def __call__(self, x: jax.Array):
        y = self.dot1(x)
        y = jax.nn.relu(y)
        z = jnp.dot(y, self.w2.value)
        return z


In [5]:
unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))
print(unsharded_model.dot1.kernel.sharding)
print(unsharded_model.w2.sharding)

print(unsharded_model.dot1.kernel.value.sharding)
print(unsharded_model.w2.value.sharding)

(None, 'model')
('model', None)
SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)
SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)


In [6]:
@nnx.jit
def create_sharded_model():
    model = DotReluDot(1024, rngs=nnx.Rngs(0))
    state = nnx.state(model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)
    return model

with mesh:
    sharded_model = create_sharded_model()

    print(sharded_model.dot1.kernel.value.sharding)
    print(sharded_model.w2.value.sharding)

NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=device)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model',), memory_kind=device)


In [7]:
jax.debug.visualize_array_sharding(sharded_model.dot1.kernel.value)

In [8]:
jax.debug.visualize_array_sharding(sharded_model.w2.value)

In [9]:
data_sharding = NamedSharding(mesh, PartitionSpec('data', None))
input = jax.device_put(jnp.ones((8, 1024)), data_sharding)
jax.debug.visualize_array_sharding(input)

with mesh:
    output = sharded_model(input)
    print(output.shape)
    print(output.sharding)
    jax.debug.visualize_array_sharding(output)

(8, 1024)
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('data',), memory_kind=device)


In [10]:

@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model, x, y):
        y_pred = model(x)
        return jnp.mean((y_pred - y)**2)
    
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(grads)

    return loss

In [12]:
inputs = jax.device_put(jax.random.normal(jax.random.key(1), (8, 1024)), data_sharding)
labels = jax.device_put(jax.random.normal(jax.random.key(2), (8, 1024)), data_sharding)

optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3))

with mesh:
    for i in range(100):
        loss = train_step(sharded_model, optimizer, inputs, labels)
        print(loss)

0.20979019
0.13867
0.06545616
0.048815664
0.04634791
0.036297187
0.032754853
0.03459078
0.03334333
0.028454082
0.02301151
0.018247345
0.014421618
0.01179364
0.010173613
0.009169059
0.008489858
0.007960102
0.0076076966
0.007357572
0.006922085
0.0062350994
0.0054418044
0.0046596257
0.003994516
0.0034763152
0.0030522803
0.0026858547
0.0023597206
0.002083327
0.0019132653
0.0018321353
0.0017673839
0.001657441
0.0014845124
0.0013212756
0.0011983521
0.001073835
0.0009347295
0.00080625794
0.00071486074
0.0006466589
0.0005860225
0.00053393084
0.0004858976
0.0004500544
0.00041767934
0.00038260067
0.0003457353
0.00031027163
0.00027989544
0.00025062694
0.00022279198
0.00019747886
0.0001774224
0.00016117783
0.00014645421
0.00013169038
0.000118317184
0.000108385255
0.000100392834
9.1668284e-05
8.208952e-05
7.422385e-05
6.7065e-05
6.0272832e-05
5.381253e-05
4.822365e-05
4.275275e-05
3.841902e-05
3.5163204e-05
3.2181717e-05
2.904359e-05
2.6534422e-05
2.4554938e-05
2.2363045e-05
1.9770552e-05
1.7589611