Skip to content

Commit

Permalink
Stop writing msgpack file for new checkpoints and update empty nodes …
Browse files Browse the repository at this point in the history
…handling so that it no longer depends on this file.

PiperOrigin-RevId: 650306765
  • Loading branch information
dubey authored and pax authors committed Jul 8, 2024
1 parent 5306149 commit 9863f27
Showing 1 changed file with 46 additions and 34 deletions.
80 changes: 46 additions & 34 deletions paxml/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from absl import logging
from etils import epath
from etils.epath import abstract_path
import flax.serialization
import jax
from jax.experimental import multihost_utils
Expand Down Expand Up @@ -442,26 +441,23 @@ class PaxCheckpointHandler(ocp.PyTreeCheckpointHandler):
from a state dict.
"""

_handler_impl: PaxCheckpointHandlerImpl

def __init__(
self,
*args,
enforce_restore_shape_check: bool = False,
use_ocdbt: bool = False,
**kwargs,
):
handler_impl = PaxCheckpointHandlerImpl(
*args, use_ocdbt=use_ocdbt, **kwargs
)

super().__init__(
*args,
use_ocdbt=use_ocdbt,
handler_impl=handler_impl,
handler_impl=PaxCheckpointHandlerImpl(
*args, use_ocdbt=use_ocdbt, **kwargs
),
**kwargs,
)
self._enforce_restore_shape_check = enforce_restore_shape_check
self._handler_impl = cast(PaxCheckpointHandlerImpl, self._handler_impl)

async def async_save(
self,
Expand Down Expand Up @@ -546,7 +542,7 @@ def _create_sharded_restore_args(shape_struct, pspec):
if reference_state_specs is None:
logging.warning(
'Found `None` for `state_specs` during restoration. If not restoring'
' using pmap or `pmap_use_tensorstore`, this may indicate an error.'
' using PMAP and `pmap_use_tensorstore`, this may indicate an error.'
)
restore_args = jax.tree_util.tree_map(
_create_restore_args, reference_train_state
Expand Down Expand Up @@ -577,23 +573,11 @@ def _create_sharded_restore_args(shape_struct, pspec):

return restored_train_state

def _read_aggregate_file(self, directory: epath.Path) -> PyTree:
if ocp.type_handlers.is_ocdbt_checkpoint(directory):
raise FileNotFoundError(
'OCDBT format checkpoint cannot depend on aggregate file as metadata.'
)
del directory
# Otherwise, rely on hacked structure.
return jax.tree_util.tree_map(
ocp.utils.leaf_placeholder,
self._handler_impl.get_param_names(None),
)


class PaxCheckpointHandlerImpl(ocp.BasePyTreeCheckpointHandler):
"""Implementation of PaxCheckpointHandler."""

_param_names: dict[str, str] = None
_param_names: PyTree = None

async def _write_metadata_file(
self,
Expand All @@ -614,28 +598,36 @@ def _read_metadata_file(
if self._use_ocdbt:
return super()._read_metadata_file(directory)
else:
raise FileNotFoundError(
'Metadata read not expected for non-OCDBT checkpoint. Ensure that if'
' your checkpoint is not in OCDBT format, there is no _METADATA file.'
)
# Explicitly ignore metadata with non-OCDBT, otherwise we will get the
# wrong tree structure.
raise FileNotFoundError('Metadata file is ignored for Pax.')

def set_param_names(self, param_names: PyTree):
self._param_names = param_names

def get_param_names(self, item: PyTree | None) -> PyTree:
def get_param_names(self, item: PyTree) -> PyTree:
if self._param_names is None:
if item is None:
raise AssertionError('Must provide item to get param names.')
return super().get_param_names(item)
return self._param_names

async def _maybe_deserialize(
async def _write_aggregate_file(
self,
directory: epath.Path,
item: PyTree,
metadata: PyTree,
param_infos: PyTree,
restore_args: PyTree,
save_args: PyTree,
):
"""Skip writing msgpack file for Pax since this file would be unused."""
if self._use_ocdbt:
return await super()._write_aggregate_file(
directory, item, param_infos, save_args
)
return ocp.future.NoopFuture()

async def _maybe_deserialize(
self, structure: PyTree, param_infos: PyTree, restore_args: PyTree
) -> PyTree:

def _replace_param_info_name(info, name):
return dataclasses.replace(info, name=name, path=info.path.parent / name)

Expand All @@ -649,7 +641,19 @@ def _replace_param_info_name(info, name):
self._param_names,
)
return await super()._maybe_deserialize(
item, metadata, param_infos, restore_args
structure, param_infos, restore_args
)

def _read_aggregate_file(self, directory: epath.Path) -> PyTree:
# Use msgpack file if it exists.
# Check for _use_ocdbt, since the msgpack file should only exist if the
# checkpoint was written with OCDBT.
if self._use_ocdbt and (directory / self._aggregate_filename).exists():
return super()._read_aggregate_file(directory)
# Otherwise, rely on hacked structure.
return jax.tree_util.tree_map(
ocp.utils.leaf_placeholder,
self._param_names,
)


Expand Down Expand Up @@ -714,7 +718,15 @@ async def async_save(
# mismatch caused by different versions of saver/restorer.
'str_pytree_state': str(pytree_state),
}
param_infos = self._handler_impl._get_param_infos(item, directory) # pylint: disable=protected-access
save_args = jax.tree_util.tree_map(
lambda _: ocp.SaveArgs(aggregate=True),
item,
is_leaf=ocp.utils.is_empty_or_leaf,
)
param_infos, all_params_aggregated = self._handler_impl._get_param_infos( # pylint: disable=protected-access
item, directory, save_args
)
assert all_params_aggregated
aggregate_file_write_start_time = time.time()
aggregate_commit_future = await self._aggregate_handler.serialize(
directory / self._aggregate_filename,
Expand Down

0 comments on commit 9863f27

Please sign in to comment.