In [64]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
import jax
jax.config.update("jax_debug_nans", False)



import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.shard_map import shard_map
from functools import partial

In [56]:
device_array = np.array(jax.devices())
mesh = Mesh(device_array, ('fsdp',))
mesh

Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('fsdp',), axis_types=(Auto,))

In [None]:
class TestNet(nn.Module):
  features: int = 8

  @nn.compact
  def __call__(self, x: Array):
    
    if not self.is_mutable_collection("params"):
      params = self.scope.get_variable("params", "Dense_0")
      params['kernel'] = jax.lax.all_gather(params['kernel'], "fsdp", axis=-1, tiled=True)
      out = x @ params['kernel'] + params['bias']
    else:
      out = nn.Dense(features=self.features)(x)
    return out
class TestNet2(nn.Module):
  features: int = 8

  @nn.compact
  def __call__(self, x: Array):
    
    out = nn.Dense(features=self.features)(x)
    return out 

In [165]:
dense = TestNet()
dense2 = TestNet2()
key = jax.random.PRNGKey(23)
x = jax.random.normal(key, (8,8))

key = jax.random.PRNGKey(0)
variables = dense.init(key, x)
variables2 = dense2.init(key, x)

def get_p_spec(x): 
    if x.ndim <= 1:
        return P()
    else: 
        specs = [None] * (x.ndim -1) + ['fsdp']  
        return P(*specs)
       

var_spec =  jax.tree.map(lambda x: get_p_spec(x), variables)
sharded_vars = jax.tree.map(lambda x, y: jax.device_put(x, jax.sharding.NamedSharding(mesh, y)), variables, var_spec)



@partial(shard_map, mesh=mesh, in_specs=(var_spec, P()), out_specs=var_spec)
def apply_fn_sharded(params, x):
    """
    This function now correctly handles the gradient calculation for a sharded model.
    """
    def step(params, x_batch):
        # The dense.apply now uses the mutable 'params' collection
        out = dense.apply(params, x_batch)
        loss = out.mean() # Using mean() is generally more stable than sum()
        
        # **THE FIX**: Synchronize the loss across all devices before returning.
        # This makes the loss a replicated value, which jax.grad expects for the backward pass.
        return jax.lax.pmean(loss, axis_name='fsdp')

    # We only want gradients with respect to the 'params'
    grads = jax.grad(step)(params, x)
    jax.tree.map(lambda x: print(x.shape), grads )
    
    # The pmean on the gradients is now correctly placed here to average the
    # final computed gradients across all devices before the update step.
    # Note: In a real FSDP implementation, the gradient reduction happens
    # layer-by-layer during the backward pass itself for efficiency. This
    # full pmean at the end is simpler but less performant.
    # grads = jax.tree.map(lambda g: jax.lax.pmean(g, 'fsdp'), grads)

    return grads# print(out.shape)
out = apply_fn_sharded(sharded_vars, x)

(8,)
(8, 1)


In [166]:
jax.tree.map(lambda x: jax.debug.visualize_array_sharding(x), out)

{'params': {'Dense_0': {'bias': None, 'kernel': None}}}

In [167]:
out

{'params': {'Dense_0': {'bias': Array([0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], dtype=float32),
   'kernel': Array([[-0.00487862, -0.00487862, -0.00487862, -0.00487862, -0.00487862,
           -0.00487862, -0.00487862, -0.00487862],
          [ 0.00338491,  0.00338491,  0.00338491,  0.00338491,  0.00338491,
            0.00338491,  0.00338491,  0.00338491],
          [-0.08750482, -0.08750482, -0.08750482, -0.08750482, -0.08750482,
           -0.08750482, -0.08750482, -0.08750482],
          [-0.0490033 , -0.0490033 , -0.0490033 , -0.0490033 , -0.0490033 ,
           -0.0490033 , -0.0490033 , -0.0490033 ],
          [-0.00665506, -0.00665506, -0.00665506, -0.00665506, -0.00665506,
           -0.00665506, -0.00665506, -0.00665506],
          [ 0.02282774,  0.02282774,  0.02282774,  0.02282774,  0.02282774,
            0.02282774,  0.02282774,  0.02282774],
          [-0.02082107, -0.02082107, -0.02082107, -0.02082107, -0.02082107,
           -0.02082107, -0.02082107, -0.

In [120]:
jax.tree.map(lambda x: jax.debug.visualize_array_sharding(x), sharded_vars)

{'params': {'Dense_0': {'bias': None, 'kernel': None}}}

In [None]:
def device_put(x): 
    if x.ndim <= 1:
        x = jax.device_put(x, jax.sharding.NamedSharding(mesh, P()))
    else: 
        specs = [None] * (x.ndim -1) + ['fsdp']  
        sharding = jax.sharding.NamedSharding(mesh, P(*specs))
        x = jax.device_put(x, sharding)
    return x

sharded_vars = jax.tree.map(lambda x: device_put(x), variables)

In [70]:
jax.tree.map(lambda x: jax.debug.visualize_array_sharding(x), sharded_vars)


{'params': {'Dense_0': {'bias': None, 'kernel': None}}}

In [61]:
jax.debug.visualize_array_sharding(out)

In [140]:
dense.apply(variables, x)

NameError: unbound axis name: fsdp

In [169]:
grad_b = jax.grad(lambda params, x: dense2.apply(params, x).mean())(variables2, x)

In [170]:
jax.tree.map(lambda x, y: x - y, grad_b, out)

{'params': {'Dense_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
   'kernel': Array([[ 0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00,
            0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00],
          [ 0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00,
            0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00],
          [-7.450581e-09, -7.450581e-09, -7.450581e-09, -7.450581e-09,
           -7.450581e-09, -7.450581e-09, -7.450581e-09, -7.450581e-09],
          [ 0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00,
            0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00],
          [-9.313226e-10, -9.313226e-10, -9.313226e-10, -9.313226e-10,
           -9.313226e-10, -9.313226e-10, -9.313226e-10, -9.313226e-10],
          [ 0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00,
            0.000000e+00,  0.000000e+00,  0.000000e+00,  0.000000e+00],
          [ 0.000000e+00,  0.000000e+00,  0.