Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def __call__(
class Loading:
"""Options for loading PyTrees.

partial_load: NOT IMPLEMENTED.
partial_load: If True, only restore the parameters that are specified
in the abstract PyTree.
"""

partial_load: bool = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _set_enable_padding_and_truncation(a):
return base_pytree_checkpoint_handler.BasePyTreeRestoreArgs(
item=abstract_checkpointable,
restore_args=restore_args,
partial_restore=context.pytree_options.loading.partial_load,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def handler_with_options(
ARRAY_METADATA_STORE
),
enable_write_sharding_file: bool = True,
partial_load: bool = False,
):
"""Registers handlers with OCDBT support and resets when done."""
type_handler_registry = copy.deepcopy(
Expand Down Expand Up @@ -233,7 +234,10 @@ def handler_with_options(
saving=options_lib.PyTreeOptions.Saving(
create_array_storage_options_fn=create_array_storage_options_fn,
pytree_metadata_options=pytree_metadata_options,
)
),
loading=options_lib.PyTreeOptions.Loading(
partial_load=partial_load,
),
),
)

Expand Down Expand Up @@ -2006,3 +2010,54 @@ def test_partial_restore_with_placeholder(self, use_ocdbt: bool):
ValueError, 'User-provided restore item and on-disk value'
):
restore_handler.load(self.directory, reference_item)

@parameterized.product(use_ocdbt=(True, False))
def test_partial_restore_with_omission(self, use_ocdbt: bool):
"""Basic save and restore test."""
directory = self.directory / 'partial_restore'

with handler_with_options(
use_ocdbt=use_ocdbt,
) as save_handler:
save_handler.save(directory, self.pytree)

with self.subTest('success'):
with handler_with_options(
use_ocdbt=use_ocdbt,
partial_load=True,
) as restore_handler:
# Create a new pytree structure with the same leaves.
# Leaves (ShapeDtypeStruct) are immutable and can be shared.
reference_item = jax.tree.map(lambda x: x, self.abstract_pytree)
# Omit 'b', 'c.e', and 'x' from the reference item.
del reference_item['b']
del reference_item['c']['e']
del reference_item['x']
expected = {
'a': self.pytree['a'],
'c': {
'a': self.pytree['c']['a'],
},
'y': self.pytree['y'],
}
restored = restore_handler.load(directory, reference_item)
test_utils.assert_tree_equal(self, expected, restored)

with self.subTest('extra_leaf'):
with handler_with_options(
use_ocdbt=use_ocdbt,
partial_load=True,
) as restore_handler:
# Create a new pytree structure with the same leaves.
# Leaves (ShapeDtypeStruct) are immutable and can be shared.
reference_item = jax.tree.map(lambda x: x, self.abstract_pytree)
del reference_item['b']
del reference_item['c']['e']
del reference_item['x']
# Add an extra leaf to the reference item.
reference_item['z'] = jax.ShapeDtypeStruct([0], np.int64)
with self.assertRaisesRegex(
ValueError,
r"Missing 1 keys in structure path \(\), including: \['z'\]",
):
restore_handler.load(directory, reference_item)