Skip to content

nnx.eval_shape should propagate 'sharding' and 'format' to ShapeDtypeStructs #5110

@qGentry

Description

@qGentry

Currently, nnx.eval_shape doesn't propagate 'sharding' and 'format' fields to output ShapeDtypeStructs. This leads to checkpoint restore via orbax to ignore required sharding.

Consider the following snippet, it basically does the following:

  1. Prerequisites - save somehow sharded (in current example - unsharded, but in general it can be saved with sharding with different number of devices) to checkpoint - this is usually done after some kind of checkpoint conversion or another training run.
  2. Target training run - we create abstract state with nnx.eval_shape with target sharding. Then we try to restore checkpoint using given abstract state. The issue is that resulting tensors is sharded according to sharding they was saved with instead of the target one. It obviously would lead to OOM in real-world scenarion.
import jax
import flax.nnx as nnx
import orbax.checkpoint as ocp

import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, AxisType
import flax.nnx as nnx


mesh1 = jax.make_mesh((2, 4), ("a", "b"))
rules1 = (("A", "a"), ("B", "b"))
mesh2 = jax.make_mesh((2, 2, 2), ("x", "y", "z"))
rules2 = (("X", "x"), ("Y", "y"), ("Z", "z"))
mesh3 = jax.make_mesh((8,), ("c",))
rules3 = (("C", "c"),)

mesh_data = jax.make_mesh((4,2), ("data", "context"))


class Model(nnx.Module):
    def __init__(self, enable_sharding: bool = True):
        if enable_sharding:
            self.small_linear1 = nnx.Param(
                jnp.ones((16, 16)), 
                sharding=("A", "B"), 
                mesh=mesh1,
                sharding_rules=rules1,
            )
            self.small_linear2 = nnx.Param(
                jnp.ones((16, 16, 16)), 
                sharding=("X", "Y", "Z"), 
                mesh=mesh2,
                sharding_rules=rules2,
            )
            self.small_linear3 = nnx.Param(
                jnp.ones((16, 16)),
                sharding=("C",), 
                mesh=mesh3,
                sharding_rules=rules3,
            )
        else:
            self.small_linear1 = nnx.Param(
                jnp.ones((16, 16))
            )
            self.small_linear2 = nnx.Param(
                jnp.ones((16, 16, 16))
            )
            self.small_linear3 = nnx.Param(
                jnp.ones((16, 16))
            )




path = '/ckpt/'

def save():
    # Model was saved by some other sharding config
    model = Model(enable_sharding=False)
    model_state = nnx.state(model)

    options = ocp.CheckpointManagerOptions()
    with ocp.CheckpointManager(
        ocp.test_utils.erase_and_create_empty(path),
        options=options,
    ) as mngr:
        mngr.save(0, args=ocp.args.PyTreeSave(model_state))
    


def partial_load():
    params_to_load = ["small_linear1", "small_linear2", "small_linear3"]
    # We're using new sharding now
    abs_model = nnx.eval_shape(lambda: Model(enable_sharding=True))
    abs_state = nnx.state(abs_model, nnx.Any(*[nnx.PathContains(path) for path in params_to_load]))

    options = ocp.CheckpointManagerOptions()
    with ocp.CheckpointManager(
        path, options=options,
    ) as mngr:
        restored_partial_state = mngr.restore(
            0,
            args=ocp.args.PyTreeRestore(abs_state, partial_restore=True)
        )

    nnx.update(abs_model, restored_partial_state)
    return abs_model




# First save the checkpoint
with mesh_data:
    save()
    restored_model = partial_load()
    original_model = Model(enable_sharding=True)
    print("ORIGINAL MODEL:")
    print(jax.tree.map(lambda p: (p.shape, p.__class__.__name__, p.sharding), nnx.state(original_model)))
    print("RESTORED MODEL:")
    print(jax.tree.map(lambda p: (p.shape, p.__class__.__name__, p.sharding), nnx.state(restored_model)))

Output:

ORIGINAL MODEL:
State({
  'small_linear1': Param(
    value=((16, 16), 'ArrayImpl', NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device)),
    mesh=Mesh(axis_sizes=(2, 4), axis_names=('a', 'b'), axis_types=(Auto, Auto)),
    sharding_names=('A', 'B'),
    sharding_rules=(('A', 'a'), ('B', 'b'))
  ),
  'small_linear2': Param(
    value=((16, 16, 16), 'ArrayImpl', NamedSharding(mesh=Mesh('x': 2, 'y': 2, 'z': 2, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('x', 'y', 'z'), memory_kind=device)),
    mesh=Mesh(axis_sizes=(2, 2, 2), axis_names=('x', 'y', 'z'), axis_types=(Auto, Auto, Auto)),
    sharding_names=('X', 'Y', 'Z'),
    sharding_rules=(('X', 'x'), ('Y', 'y'), ('Z', 'z'))
  ),
  'small_linear3': Param(
    value=((16, 16), 'ArrayImpl', NamedSharding(mesh=Mesh('c': 8, axis_types=(Auto,)), spec=PartitionSpec('c',), memory_kind=device)),
    mesh=Mesh(axis_sizes=(8,), axis_names=('c',), axis_types=(Auto,)),
    sharding_names=('C',),
    sharding_rules=(('C', 'c'),)
  )
})
RESTORED MODEL:
State({
  'small_linear1': Param(
    value=((16, 16), 'ArrayImpl', SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)),
    mesh=Mesh(axis_sizes=(2, 4), axis_names=('a', 'b'), axis_types=(Auto, Auto)),
    sharding_names=('A', 'B'),
    sharding_rules=(('A', 'a'), ('B', 'b'))
  ),
  'small_linear2': Param(
    value=((16, 16, 16), 'ArrayImpl', SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)),
    mesh=Mesh(axis_sizes=(2, 2, 2), axis_names=('x', 'y', 'z'), axis_types=(Auto, Auto, Auto)),
    sharding_names=('X', 'Y', 'Z'),
    sharding_rules=(('X', 'x'), ('Y', 'y'), ('Z', 'z'))
  ),
  'small_linear3': Param(
    value=((16, 16), 'ArrayImpl', SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)),
    mesh=Mesh(axis_sizes=(8,), axis_names=('c',), axis_types=(Auto,)),
    sharding_names=('C',),
    sharding_rules=(('C', 'c'),)
  )
})

Orbax also outputs the following warning.

/usr/local/lib/python3.11/dist-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:701: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions