-
Notifications
You must be signed in to change notification settings - Fork 761
Open
Copy link
Description
Currently, nnx.eval_shape doesn't propagate 'sharding' and 'format' fields to output ShapeDtypeStructs. This leads to checkpoint restore via orbax to ignore required sharding.
Consider the following snippet, it basically does the following:
- Prerequisites - save somehow sharded (in current example - unsharded, but in general it can be saved with sharding with different number of devices) to checkpoint - this is usually done after some kind of checkpoint conversion or another training run.
- Target training run - we create abstract state with nnx.eval_shape with target sharding. Then we try to restore checkpoint using given abstract state. The issue is that resulting tensors is sharded according to sharding they was saved with instead of the target one. It obviously would lead to OOM in real-world scenarion.
import jax
import flax.nnx as nnx
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, AxisType
import flax.nnx as nnx
mesh1 = jax.make_mesh((2, 4), ("a", "b"))
rules1 = (("A", "a"), ("B", "b"))
mesh2 = jax.make_mesh((2, 2, 2), ("x", "y", "z"))
rules2 = (("X", "x"), ("Y", "y"), ("Z", "z"))
mesh3 = jax.make_mesh((8,), ("c",))
rules3 = (("C", "c"),)
mesh_data = jax.make_mesh((4,2), ("data", "context"))
class Model(nnx.Module):
def __init__(self, enable_sharding: bool = True):
if enable_sharding:
self.small_linear1 = nnx.Param(
jnp.ones((16, 16)),
sharding=("A", "B"),
mesh=mesh1,
sharding_rules=rules1,
)
self.small_linear2 = nnx.Param(
jnp.ones((16, 16, 16)),
sharding=("X", "Y", "Z"),
mesh=mesh2,
sharding_rules=rules2,
)
self.small_linear3 = nnx.Param(
jnp.ones((16, 16)),
sharding=("C",),
mesh=mesh3,
sharding_rules=rules3,
)
else:
self.small_linear1 = nnx.Param(
jnp.ones((16, 16))
)
self.small_linear2 = nnx.Param(
jnp.ones((16, 16, 16))
)
self.small_linear3 = nnx.Param(
jnp.ones((16, 16))
)
path = '/ckpt/'
def save():
# Model was saved by some other sharding config
model = Model(enable_sharding=False)
model_state = nnx.state(model)
options = ocp.CheckpointManagerOptions()
with ocp.CheckpointManager(
ocp.test_utils.erase_and_create_empty(path),
options=options,
) as mngr:
mngr.save(0, args=ocp.args.PyTreeSave(model_state))
def partial_load():
params_to_load = ["small_linear1", "small_linear2", "small_linear3"]
# We're using new sharding now
abs_model = nnx.eval_shape(lambda: Model(enable_sharding=True))
abs_state = nnx.state(abs_model, nnx.Any(*[nnx.PathContains(path) for path in params_to_load]))
options = ocp.CheckpointManagerOptions()
with ocp.CheckpointManager(
path, options=options,
) as mngr:
restored_partial_state = mngr.restore(
0,
args=ocp.args.PyTreeRestore(abs_state, partial_restore=True)
)
nnx.update(abs_model, restored_partial_state)
return abs_model
# First save the checkpoint
with mesh_data:
save()
restored_model = partial_load()
original_model = Model(enable_sharding=True)
print("ORIGINAL MODEL:")
print(jax.tree.map(lambda p: (p.shape, p.__class__.__name__, p.sharding), nnx.state(original_model)))
print("RESTORED MODEL:")
print(jax.tree.map(lambda p: (p.shape, p.__class__.__name__, p.sharding), nnx.state(restored_model)))
Output:
ORIGINAL MODEL:
State({
'small_linear1': Param(
value=((16, 16), 'ArrayImpl', NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device)),
mesh=Mesh(axis_sizes=(2, 4), axis_names=('a', 'b'), axis_types=(Auto, Auto)),
sharding_names=('A', 'B'),
sharding_rules=(('A', 'a'), ('B', 'b'))
),
'small_linear2': Param(
value=((16, 16, 16), 'ArrayImpl', NamedSharding(mesh=Mesh('x': 2, 'y': 2, 'z': 2, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('x', 'y', 'z'), memory_kind=device)),
mesh=Mesh(axis_sizes=(2, 2, 2), axis_names=('x', 'y', 'z'), axis_types=(Auto, Auto, Auto)),
sharding_names=('X', 'Y', 'Z'),
sharding_rules=(('X', 'x'), ('Y', 'y'), ('Z', 'z'))
),
'small_linear3': Param(
value=((16, 16), 'ArrayImpl', NamedSharding(mesh=Mesh('c': 8, axis_types=(Auto,)), spec=PartitionSpec('c',), memory_kind=device)),
mesh=Mesh(axis_sizes=(8,), axis_names=('c',), axis_types=(Auto,)),
sharding_names=('C',),
sharding_rules=(('C', 'c'),)
)
})
RESTORED MODEL:
State({
'small_linear1': Param(
value=((16, 16), 'ArrayImpl', SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)),
mesh=Mesh(axis_sizes=(2, 4), axis_names=('a', 'b'), axis_types=(Auto, Auto)),
sharding_names=('A', 'B'),
sharding_rules=(('A', 'a'), ('B', 'b'))
),
'small_linear2': Param(
value=((16, 16, 16), 'ArrayImpl', SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)),
mesh=Mesh(axis_sizes=(2, 2, 2), axis_names=('x', 'y', 'z'), axis_types=(Auto, Auto, Auto)),
sharding_names=('X', 'Y', 'Z'),
sharding_rules=(('X', 'x'), ('Y', 'y'), ('Z', 'z'))
),
'small_linear3': Param(
value=((16, 16), 'ArrayImpl', SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)),
mesh=Mesh(axis_sizes=(8,), axis_names=('c',), axis_types=(Auto,)),
sharding_names=('C',),
sharding_rules=(('C', 'c'),)
)
})
Orbax also outputs the following warning.
/usr/local/lib/python3.11/dist-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:701: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
Metadata
Metadata
Assignees
Labels
No labels