diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 331c59ce95da..4cdb880ef373 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -90,16 +90,18 @@ def _get_kvstore_for_gcs(ckpt_path: str): return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket} def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): - is_gcs_path = ckpt_path.startswith('gs://') # Normalize path to exclude trailing '/'. In GCS path case, we will need to # fix the path prefix to add back the stripped '/'. ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://') + is_gcs_path = ckpt_path.startswith('gs://') spec = {'driver': 'zarr', 'kvstore': {}} if ocdbt: - prefix = 'gs' if is_gcs_path else 'file' + if not is_gcs_path and not os.path.isabs(ckpt_path): + raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') + base_path = os.path.dirname(ckpt_path) spec['kvstore'] = { 'driver': 'ocdbt', - 'base': f'{prefix}://{os.path.dirname(ckpt_path)}', + 'base': base_path if is_gcs_path else f'file://{base_path}', 'path': os.path.basename(ckpt_path), } else: diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index e260b57d05e4..e90e771659ec 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -21,6 +21,7 @@ import tracemalloc as tm from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu from jax import config @@ -291,5 +292,26 @@ def test_empty_spec_has_no_metadata(self): spec = {} self.assertFalse(serialization._spec_has_metadata(spec)) + @parameterized.named_parameters( + ('gcs', 'gs://my/ckpt/dir/path'), + ('file', '/my/ckpt/dir/path') + ) + def test_get_tensorstore_spec_ocdbt(self, path): + spec = serialization.get_tensorstore_spec(path, ocdbt=True) + is_gcs_path = path.startswith('gs://') + if is_gcs_path: + self.assertEqual(spec['kvstore']['base'], os.path.dirname(path)) + else: + self.assertEqual(spec['kvstore']['base'], + f'file://{os.path.dirname(path)}') + self.assertEqual(spec['kvstore']['path'], 'path') + + def test_get_tensorstore_spec_not_absolute_path(self): + path = 'my/ckpt/path' + with self.assertRaisesRegex(ValueError, + "Checkpoint path should be absolute"): + serialization.get_tensorstore_spec(path, ocdbt=True) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())