diff --git a/conftest.py b/conftest.py index c8ed6d3..af7de76 100644 --- a/conftest.py +++ b/conftest.py @@ -17,4 +17,4 @@ try: app.run(lambda argv: None) except SystemExit: - pass \ No newline at end of file + pass diff --git a/xarray_tensorstore.py b/xarray_tensorstore.py index d1d0803..9743d38 100644 --- a/xarray_tensorstore.py +++ b/xarray_tensorstore.py @@ -99,6 +99,18 @@ def __getitem__(self, key: indexing.ExplicitIndexer) -> _TensorStoreAdapter: translated = indexed[tensorstore.d[:].translate_to[0]] return type(self)(translated) + def __setitem__(self, key: indexing.ExplicitIndexer, value) -> None: + index_tuple = tuple(map(_numpy_to_tensorstore_index, key.tuple, self.shape)) + if isinstance(key, indexing.OuterIndexer): + self.array.oindex[index_tuple] = value + elif isinstance(key, indexing.VectorizedIndexer): + self.array.vindex[index_tuple] = value + else: + assert isinstance(key, indexing.BasicIndexer) + self.array[index_tuple] = value + # Invalidate the future so that the next read will pick up the new value + object.__setattr__(self, 'future', None) + # xarray>2024.02.0 uses oindex and vindex properties, which are expected to # return objects whose __getitem__ method supports the appropriate form of # indexing. @@ -200,6 +212,7 @@ def open_zarr( *, context: tensorstore.Context | None = None, mask_and_scale: bool = True, + write: bool = False, ) -> xarray.Dataset: """Open an xarray.Dataset from Zarr using TensorStore. @@ -228,6 +241,7 @@ def open_zarr( mask_and_scale: if True (default), attempt to apply masking and scaling like xarray.open_zarr(). This is only supported for coordinate variables and otherwise will raise an error. + write: Allow write access. Defaults to False. Returns: Dataset with all data variables opened via TensorStore. @@ -259,7 +273,7 @@ def open_zarr( specs = {k: _zarr_spec_from_path(os.path.join(path, k)) for k in ds} array_futures = { - k: tensorstore.open(spec, read=True, write=False, context=context) + k: tensorstore.open(spec, read=True, write=write, context=context) for k, spec in specs.items() } arrays = {k: v.result() for k, v in array_futures.items()} diff --git a/xarray_tensorstore_test.py b/xarray_tensorstore_test.py index d6f33dc..c40ab9b 100644 --- a/xarray_tensorstore_test.py +++ b/xarray_tensorstore_test.py @@ -1,13 +1,13 @@ # Copyright 2023 Google LLC # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, +# distributed under the License is distributed on an 'AS IS' BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. @@ -15,8 +15,10 @@ from absl.testing import parameterized import numpy as np import pandas as pd +import pytest import tensorstore import xarray +from xarray.core import indexing import xarray_tensorstore @@ -136,9 +138,7 @@ def test_compute(self): self.assertNotIsInstance(computed_data, tensorstore.TensorStore) def test_open_zarr_from_uri(self): - source = xarray.Dataset( - {'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))} - ) + source = xarray.Dataset({'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))}) path = self.create_tempdir().full_path source.chunk().to_zarr(path) @@ -221,6 +221,61 @@ def test_mask_and_scale(self): xarray.testing.assert_identical(actual, source) self.assertEqual(actual.coords['x'].encoding['add_offset'], -1) + @parameterized.named_parameters( + { + 'testcase_name': 'basic_indexing', + 'key': (slice(1, None), slice(None), slice(None)), + 'value': np.full((1, 2, 3), -1), + }, + { + 'testcase_name': 'outer_indexing', + 'key': (np.array([0]), np.array([1]), slice(None)), + 'value': np.full((1, 1, 3), -2), + }, + { + 'testcase_name': 'vectorized_indexing', + 'key': (np.array([0]), np.array([0, 1]), slice(None)), + 'value': np.full((2, 3), -3), + }, + ) + def test_setitem(self, key, value): + source_data = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) + source = xarray.DataArray( + source_data, + dims=('x', 'y', 'z'), + name='baz', + ) + path = self.create_tempdir().full_path + source.to_dataset().chunk().to_zarr(path) + + opened = xarray_tensorstore.open_zarr(path, write=True)['baz'] + + opened[key] = value + read = xarray_tensorstore.read(opened) + + expected_data = source_data.copy() + expected_data[key] = value + expected = xarray.DataArray( + expected_data, + dims=('x', 'y', 'z'), + name='baz', + ) + + xarray.testing.assert_equal(read, expected) + + def test_setitem_readonly(self): + source = xarray.DataArray( + np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), + dims=('x', 'y', 'z'), + name='baz', + ) + path = self.create_tempdir().full_path + source.to_dataset().chunk().to_zarr(path) + + opened = xarray_tensorstore.open_zarr(path)['baz'] + with pytest.raises(ValueError): + opened[1:, ...] = np.full((1, 2, 3), -1) + if __name__ == '__main__': absltest.main()