Skip to content

Commit

Permalink
Fix checkpoint GCS bucket/path substitution to work with tensorstore>…
Browse files Browse the repository at this point in the history
…=0.1.14 Fixes #78

PiperOrigin-RevId: 415029234
  • Loading branch information
T5X Team authored and t5-copybara committed Dec 8, 2021
1 parent c4fb3fe commit a3510b1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
34 changes: 20 additions & 14 deletions t5x/checkpoints.py
Expand Up @@ -203,25 +203,31 @@ def maybe_cast(x):
def _update_ts_path_from_relative_to_absolute(
ckpt_dir: str, ts_spec_dict: MutableMapping[str, Any]):
"""Update (in-place) the path and gcs bucket (if applicable) in a TS Spec."""
# Update the path with `ckpt_dir`
if 'path' in ts_spec_dict:
# GCS driver
if 'gs://' not in ckpt_dir:
raise ValueError(
f'`ckpt_dir` should start with "gs://" prefix. Got {ckpt_dir}')

bucket, stripped_ckpt_dir = re.findall('gs://(.*?)/(.*)', ckpt_dir)[0]
ts_spec_dict['path'] = os.path.join(stripped_ckpt_dir, ts_spec_dict['path'])
# Dynamically update the dummy bucket to the bucket of `ckpt_dir`.

# Handle `gs://` paths.
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_dir, re.DOTALL)
if m is not None:
if ts_spec_dict['kvstore']['driver'] != 'gcs':
raise ValueError(f'Incorrect TensorStore Spec. '
f'Expects kvstore driver to be "gcs" for {ckpt_dir}. '
f'Got {ts_spec_dict}')
bucket = m.group(1)
ckpt_dir = m.group(2)
ts_spec_dict['kvstore']['bucket'] = bucket
elif 'path' in ts_spec_dict['kvstore']:
# Internal gfile driver

# Update the path with `ckpt_dir`

if 'path' in ts_spec_dict['kvstore']:
# tensorstore>=0.1.14 format
ts_spec_dict['kvstore']['path'] = os.path.join(
ckpt_dir, ts_spec_dict['kvstore']['path'])
elif 'path' in ts_spec_dict:
# tensorstore<0.1.14 format
ts_spec_dict['path'] = os.path.join(ckpt_dir, ts_spec_dict['path'])
else:
raise ValueError(
'Incorrect TensorStore Spec. Expects "path" to be a key of '
f'`spec["kvstore"]` or `spec` Got {ts_spec_dict}')
'Incorrect TensorStore Spec. Expects "path" to be a key of spec or '
f'`spec["kvstore"]`. Got {ts_spec_dict}')


def _maybe_update_ts_from_file_to_gcs(ckpt_contents):
Expand Down
30 changes: 24 additions & 6 deletions t5x/checkpoints_test.py
Expand Up @@ -1330,6 +1330,24 @@ def test_update_ts_from_gcs_to_file(self):
actual = checkpoints._maybe_update_ts_from_gcs_to_file(ckpt_contents)
jax.tree_multimap(np.testing.assert_array_equal, actual, expected)

def assert_update_ts_path_from_relative_to_absolute(self, ts_spec_dict,
expected, ckpt_dir):
"""Tests that `ts_spec_dict` gets updated with `ckpt_dir` to `expected`."""

# Test with normalization (corresponds to tensorstore>=0.1.14)
normalized_ts_spec_dict = ts.Spec(ts_spec_dict).to_json()
checkpoints._update_ts_path_from_relative_to_absolute(
ckpt_dir, normalized_ts_spec_dict)
normalized_ts_spec_dict = ts.Spec(normalized_ts_spec_dict).to_json()
normalized_expected = ts.Spec(expected).to_json()
jax.tree_multimap(np.testing.assert_array_equal, normalized_ts_spec_dict,
normalized_expected)

# Test without normalization (corresponds to tensorstore<0.1.14)
checkpoints._update_ts_path_from_relative_to_absolute(
ckpt_dir, ts_spec_dict)
jax.tree_multimap(np.testing.assert_array_equal, ts_spec_dict, expected)

def test_update_ts_path_from_relative_to_absolute_gfile(self):
ts_spec_dict = {
'driver': 'zarr',
Expand Down Expand Up @@ -1367,9 +1385,8 @@ def test_update_ts_path_from_relative_to_absolute_gfile(self):
}
ckpt_dir = '/dir1/dir2'

checkpoints._update_ts_path_from_relative_to_absolute(
ckpt_dir, ts_spec_dict)
jax.tree_multimap(np.testing.assert_array_equal, ts_spec_dict, expected)
self.assert_update_ts_path_from_relative_to_absolute(
ts_spec_dict, expected, ckpt_dir)

def test_update_ts_path_from_relative_to_absolute_gcs(self):
ts_spec_dict = {
Expand All @@ -1394,6 +1411,7 @@ def test_update_ts_path_from_relative_to_absolute_gcs(self):
'input_inclusive_min': [0, 0]
}
}

expected = {
'driver': 'zarr',
'dtype': 'float32',
Expand All @@ -1417,11 +1435,11 @@ def test_update_ts_path_from_relative_to_absolute_gcs(self):
'input_inclusive_min': [0, 0]
}
}

ckpt_dir = 'gs://test-bucket/dir1/dir2'

checkpoints._update_ts_path_from_relative_to_absolute(
ckpt_dir, ts_spec_dict)
jax.tree_multimap(np.testing.assert_array_equal, ts_spec_dict, expected)
self.assert_update_ts_path_from_relative_to_absolute(
ts_spec_dict, expected, ckpt_dir)

def test_restore_tf_checkpoint(self):
self.verify_restore_checkpoint_from_path(
Expand Down

0 comments on commit a3510b1

Please sign in to comment.