Skip to content

Code duplication around tensorstore_spec logic in orbax-checkpoint #1241

@minotru

Description

@minotru

Hi Orbax team,

I was looking at Orbax code at the latest version==0.7.0 and found that pieces of code with quite heavy logic around tensorstore_spec creation seem to contain duplicates.

I'd like to know if this code duplication intended by design or I am welcome to submit a PR.

Here get_tensorstore_spec is a part of public API, and I can't find any usage of get_tensorstore_spec by orbax-checkpoint itself
https://github.com/google/orbax/blob/8b4e90d573082a5c7caa5f99c51db376f62a6995/checkpoint/orbax/checkpoint/serialization.py#L97C5-L124

And here is a very similar piece of code in build_kvstore_tspec in _internal package, and build_kvstore_tspec is used heavily by type_handlers.py

def build_kvstore_tspec(
directory: str,
name: str | None = None,
*,
use_ocdbt: bool = True,
process_id: int | str | None = None,
) -> JsonSpec:
"""Constructs a spec for a Tensorstore KvStore.
Args:
directory: Base path (key prefix) of the KvStore, used by the underlying
file driver.
name: Name (filename) of the parameter.
use_ocdbt: Whether to use OCDBT driver.
process_id: [only used with OCDBT driver] If provided,
`{directory}/ocdbt.process_{process_id}` path is used as the base path.
If a string, must conform to [A-Za-z0-9]+ pattern.
Returns:
A Tensorstore KvStore spec in dictionary form.
"""
default_driver = DEFAULT_DRIVER
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
# fix the path prefix to add back the stripped '/'.
directory = os.path.normpath(directory).replace('gs:/', 'gs://')
is_gcs_path = directory.startswith('gs://')
kv_spec = {}
if use_ocdbt:
if not is_gcs_path and not os.path.isabs(directory):
raise ValueError(f'Checkpoint path should be absolute. Got {directory}')
if process_id is not None:
process_id = str(process_id)
if re.fullmatch(_OCDBT_PROCESS_ID_RE, process_id) is None:
raise ValueError(
f'process_id must conform to {_OCDBT_PROCESS_ID_RE} pattern'
f', got {process_id}'
)
directory = os.path.join(
directory, f'{PROCESS_SUBDIR_PREFIX}{process_id}'
)
base_driver_spec = (
directory
if is_gcs_path
else {'driver': default_driver, 'path': str(directory)}
)
kv_spec.update({
'driver': 'ocdbt',
'base': base_driver_spec,
})
if name is not None:
kv_spec['path'] = name
kv_spec.update({ # pytype: disable=attribute-error
# Enable read coalescing. This feature merges adjacent read_ops into
# one, which could reduce I/O ops by a factor of 10. This is especially
# beneficial for unstacked models.
'experimental_read_coalescing_threshold_bytes': 1000000,
'experimental_read_coalescing_merged_bytes': 500000000000,
'experimental_read_coalescing_interval': '1ms',
# References the cache specified in ts.Context.
'cache_pool': 'cache_pool#ocdbt',
})
else:
if name is None:
path = directory
else:
path = os.path.join(directory, name)
if is_gcs_path:
kv_spec = _get_kvstore_for_gcs(path)
else:
kv_spec = {'driver': default_driver, 'path': path}
return kv_spec

Would you consider get_tensorstore_spec to reuse build_kvstore_tspec under the hood?


Also, there seems to be a bit of obscurity with default ts_context value.

  • In orbax/checkpoint/serialization.py, there is TS_CONTEXT in public serialization.py that is used as a default value of context in async_serialize (actually, orbax does not use async_serialize anywhere and recommends using async_serialize_shards) , async_serialize_shards, async_deserialize and by StringHandler.
  • At the same time, intype_handlers.py, there is get_ts_context() (it references _DEFAULT_OCDBT_TS_CONTEXT), and get_ts_context is used by all other handler implementations.

So, TS_CONTEXT from serialization.py seems to be never used by common checkpoint IO code.

Should we somehow leave only 1 source of truth for default ts_context values?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions