Skip to content

Commit

Permalink
Fix get_tensorstore_spec for GCS paths if ocdbt is enabled
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 538627415
  • Loading branch information
yashk2810 authored and jax authors committed Jun 7, 2023
1 parent 1a3ac88 commit b44f8b4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
8 changes: 5 additions & 3 deletions jax/experimental/array_serialization/serialization.py
Expand Up @@ -90,16 +90,18 @@ def _get_kvstore_for_gcs(ckpt_path: str):
return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket}

def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False):
is_gcs_path = ckpt_path.startswith('gs://')
# Normalize path to exclude trailing '/'. In GCS path case, we will need to
# fix the path prefix to add back the stripped '/'.
ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://')
is_gcs_path = ckpt_path.startswith('gs://')
spec = {'driver': 'zarr', 'kvstore': {}}
if ocdbt:
prefix = 'gs' if is_gcs_path else 'file'
if not is_gcs_path and not os.path.isabs(ckpt_path):
raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}')
base_path = os.path.dirname(ckpt_path)
spec['kvstore'] = {
'driver': 'ocdbt',
'base': f'{prefix}://{os.path.dirname(ckpt_path)}',
'base': base_path if is_gcs_path else f'file://{base_path}',
'path': os.path.basename(ckpt_path),
}
else:
Expand Down
22 changes: 22 additions & 0 deletions jax/experimental/array_serialization/serialization_test.py
Expand Up @@ -21,6 +21,7 @@
import tracemalloc as tm

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax import config
Expand Down Expand Up @@ -291,5 +292,26 @@ def test_empty_spec_has_no_metadata(self):
spec = {}
self.assertFalse(serialization._spec_has_metadata(spec))

@parameterized.named_parameters(
('gcs', 'gs://my/ckpt/dir/path'),
('file', '/my/ckpt/dir/path')
)
def test_get_tensorstore_spec_ocdbt(self, path):
spec = serialization.get_tensorstore_spec(path, ocdbt=True)
is_gcs_path = path.startswith('gs://')
if is_gcs_path:
self.assertEqual(spec['kvstore']['base'], os.path.dirname(path))
else:
self.assertEqual(spec['kvstore']['base'],
f'file://{os.path.dirname(path)}')
self.assertEqual(spec['kvstore']['path'], 'path')

def test_get_tensorstore_spec_not_absolute_path(self):
path = 'my/ckpt/path'
with self.assertRaisesRegex(ValueError,
"Checkpoint path should be absolute"):
serialization.get_tensorstore_spec(path, ocdbt=True)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b44f8b4

Please sign in to comment.