-
Notifications
You must be signed in to change notification settings - Fork 74
Description
Hi Orbax team,
I was looking at Orbax code at the latest version==0.7.0 and found that pieces of code with quite heavy logic around tensorstore_spec creation seem to contain duplicates.
I'd like to know if this code duplication intended by design or I am welcome to submit a PR.
Here get_tensorstore_spec is a part of public API, and I can't find any usage of get_tensorstore_spec by orbax-checkpoint itself
https://github.com/google/orbax/blob/8b4e90d573082a5c7caa5f99c51db376f62a6995/checkpoint/orbax/checkpoint/serialization.py#L97C5-L124
And here is a very similar piece of code in build_kvstore_tspec in _internal package, and build_kvstore_tspec is used heavily by type_handlers.py
orbax/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py
Lines 62 to 135 in 8b4e90d
| def build_kvstore_tspec( | |
| directory: str, | |
| name: str | None = None, | |
| *, | |
| use_ocdbt: bool = True, | |
| process_id: int | str | None = None, | |
| ) -> JsonSpec: | |
| """Constructs a spec for a Tensorstore KvStore. | |
| Args: | |
| directory: Base path (key prefix) of the KvStore, used by the underlying | |
| file driver. | |
| name: Name (filename) of the parameter. | |
| use_ocdbt: Whether to use OCDBT driver. | |
| process_id: [only used with OCDBT driver] If provided, | |
| `{directory}/ocdbt.process_{process_id}` path is used as the base path. | |
| If a string, must conform to [A-Za-z0-9]+ pattern. | |
| Returns: | |
| A Tensorstore KvStore spec in dictionary form. | |
| """ | |
| default_driver = DEFAULT_DRIVER | |
| # Normalize path to exclude trailing '/'. In GCS path case, we will need to | |
| # fix the path prefix to add back the stripped '/'. | |
| directory = os.path.normpath(directory).replace('gs:/', 'gs://') | |
| is_gcs_path = directory.startswith('gs://') | |
| kv_spec = {} | |
| if use_ocdbt: | |
| if not is_gcs_path and not os.path.isabs(directory): | |
| raise ValueError(f'Checkpoint path should be absolute. Got {directory}') | |
| if process_id is not None: | |
| process_id = str(process_id) | |
| if re.fullmatch(_OCDBT_PROCESS_ID_RE, process_id) is None: | |
| raise ValueError( | |
| f'process_id must conform to {_OCDBT_PROCESS_ID_RE} pattern' | |
| f', got {process_id}' | |
| ) | |
| directory = os.path.join( | |
| directory, f'{PROCESS_SUBDIR_PREFIX}{process_id}' | |
| ) | |
| base_driver_spec = ( | |
| directory | |
| if is_gcs_path | |
| else {'driver': default_driver, 'path': str(directory)} | |
| ) | |
| kv_spec.update({ | |
| 'driver': 'ocdbt', | |
| 'base': base_driver_spec, | |
| }) | |
| if name is not None: | |
| kv_spec['path'] = name | |
| kv_spec.update({ # pytype: disable=attribute-error | |
| # Enable read coalescing. This feature merges adjacent read_ops into | |
| # one, which could reduce I/O ops by a factor of 10. This is especially | |
| # beneficial for unstacked models. | |
| 'experimental_read_coalescing_threshold_bytes': 1000000, | |
| 'experimental_read_coalescing_merged_bytes': 500000000000, | |
| 'experimental_read_coalescing_interval': '1ms', | |
| # References the cache specified in ts.Context. | |
| 'cache_pool': 'cache_pool#ocdbt', | |
| }) | |
| else: | |
| if name is None: | |
| path = directory | |
| else: | |
| path = os.path.join(directory, name) | |
| if is_gcs_path: | |
| kv_spec = _get_kvstore_for_gcs(path) | |
| else: | |
| kv_spec = {'driver': default_driver, 'path': path} | |
| return kv_spec |
Would you consider get_tensorstore_spec to reuse build_kvstore_tspec under the hood?
Also, there seems to be a bit of obscurity with default ts_context value.
- In
orbax/checkpoint/serialization.py, there is TS_CONTEXT in publicserialization.pythat is used as a default value ofcontextin async_serialize (actually,orbaxdoes not useasync_serializeanywhere and recommends usingasync_serialize_shards) , async_serialize_shards, async_deserialize and by StringHandler. - At the same time, in
type_handlers.py, there is get_ts_context() (it references_DEFAULT_OCDBT_TS_CONTEXT), andget_ts_contextis used by all other handler implementations.
So, TS_CONTEXT from serialization.py seems to be never used by common checkpoint IO code.
Should we somehow leave only 1 source of truth for default ts_context values?