Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pip install xarray-tensorstore

## Usage

Open your Zarr files into an `xarray.Dataset` using `open_zarr()`, and then use
Open a Zarr file into an `xarray.Dataset` using `open_zarr()`, and then use
`read()` to start reading data in the background:

```python
Expand All @@ -39,6 +39,19 @@ read_example = xarray_tensorstore.read(example)
numpy_example = read_example.compute()
```

Open a list of Zarr files and concatenate them along a single dimension using
`open_concatenated_zarrs()`. The returned `xarray.Dataset` behaves exactly as above.
This function requires the Dask package to be installed.

```python
import xarray_tensorstore

ds = xarray_tensorstore.open_concatenated_zarrs(
paths=[path1, path2],
concat_dim="time",
)
```

## Limitations

- Xarray-TensorStore still uses Zarr-Python under the covers to open Zarr
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
license='Apache-2.0',
author='Google LLC',
author_email='noreply@google.com',
install_requires=['numpy', 'xarray', 'zarr', 'tensorstore'],
install_requires=['numpy', 'xarray', 'zarr', 'tensorstore',],
extras_require={
'tests': ['absl-py', 'dask', 'pandas', 'pytest'],
'tests': ['absl-py', 'pandas', 'pytest', 'dask'],
},
url={'source': 'https://github.com/google/xarray-tensorstore'},
py_modules=['xarray_tensorstore'],
Expand Down
80 changes: 80 additions & 0 deletions xarray_tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,83 @@ def open_zarr(
new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}

return ds.copy(data=new_data)


def _tensorstore_open_concatenated_zarrs(
paths: list[str],
data_vars: list[str],
concat_axes: list[int],
context: tensorstore.Context,
) -> dict[str, tensorstore.TensorStore]:
"""Open multiple zarrs with TensorStore.

Args:
paths: List of paths to zarr stores.
data_vars: List of data variable names to open.
concat_axes: List of axes along which to concatenate the data variables.
context: TensorStore context.
"""
# Open all arrays in all datasets using tensorstore
arrays_list = []
for path in paths:
zarr_format = _get_zarr_format(path)
specs = {k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in data_vars}
array_futures = {
k: tensorstore.open(spec, read=True, write=False, context=context)
for k, spec in specs.items()
}
arrays_list.append(array_futures)

# Concatenate the tensorstore arrays
arrays = {}
for k, axis in zip(data_vars, concat_axes, strict=True):
datasets = [array_futures[k].result() for array_futures in arrays_list]
arrays[k] = tensorstore.concat(datasets, axis=axis)

return arrays


def open_concatenated_zarrs(
paths: list[str],
concat_dim: str,
*,
context: tensorstore.Context | None = None,
mask_and_scale: bool = True,
) -> xarray.Dataset:
"""Open an xarray.Dataset whilst concatenating multiple Zarr using TensorStore.

Notes:
This function depends on the Dask package.

Args:
paths: List of paths to zarr stores.
concat_dim: Dimension along which to concatenate the data variables.
context: TensorStore context.
mask_and_scale: Whether to mask and scale the data.

Returns:
Concatentated Dataset with all data variables opened via TensorStore.
"""
if context is None:
context = tensorstore.Context()

ds = xarray.open_mfdataset(
paths,
concat_dim=concat_dim,
combine="nested",
mask_and_scale=mask_and_scale,
engine="zarr"
)

if mask_and_scale:
# Data variables get replaced below with _TensorStoreAdapter arrays, which
# don't get masked or scaled. Raising an error avoids surprising users with
# incorrect data values.
_raise_if_mask_and_scale_used_for_data_vars(ds)

data_vars = list(ds.data_vars)
concat_axes = [ds[v].dims.index(concat_dim) for v in data_vars]
arrays = _tensorstore_open_concatenated_zarrs(paths, data_vars, concat_axes, context)
new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}

return ds.copy(data=new_data)
44 changes: 37 additions & 7 deletions xarray_tensorstore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@

_USING_ZARR_PYTHON_3 = packaging.version.parse(zarr.__version__).major >= 3


class XarrayTensorstoreTest(parameterized.TestCase):

@parameterized.named_parameters(
# TODO(shoyer): consider using hypothesis to convert these into
# property-based tests
test_cases = [
{
'testcase_name': 'base',
'transform': lambda ds: ds,
Expand Down Expand Up @@ -88,7 +83,14 @@ class XarrayTensorstoreTest(parameterized.TestCase):
'testcase_name': 'select_a_variable',
'transform': lambda ds: ds['foo'],
},
)
]


class XarrayTensorstoreTest(parameterized.TestCase):

# TODO(shoyer): consider using hypothesis to convert these into
# property-based tests
@parameterized.named_parameters(test_cases)
def test_open_zarr(self, transform):
source = xarray.Dataset(
{
Expand All @@ -110,6 +112,34 @@ def test_open_zarr(self, transform):
actual = transform(xarray_tensorstore.open_zarr(path)).compute()
xarray.testing.assert_identical(actual, expected)

@parameterized.named_parameters(test_cases)
def test_open_concatenated_zarrs(self, transform):
sources = [
xarray.Dataset(
{
'foo': (('x',), x, {'local': 'local metadata'}),
'bar': (('x', 'y'), np.arange(6).reshape(2, 3)),
'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4)),
},
coords={
'x': [1, 2],
'y': pd.to_datetime(['2000-01-01', '2000-01-02', '2000-01-03']),
'z': ['a', 'b', 'c', 'd'],
},
attrs={'global': 'global metadata'},
)
for x in [range(0,2), range(3, 5)]
]

zarr_dir = self.create_tempdir().full_path
paths = [f"{zarr_dir}/{i}" for i in range(len(sources))]
for source, path in zip(sources, paths, strict=True):
source.chunk().to_zarr(path)

expected = transform(xarray.concat(sources, dim="x"))
actual = transform(xarray_tensorstore.open_concatenated_zarrs(paths, concat_dim="x")).compute()
xarray.testing.assert_identical(actual, expected)

@parameterized.parameters(
{'deep': True},
{'deep': False},
Expand Down
Loading