Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr plugin #7

Merged
merged 5 commits into from Apr 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -3,3 +3,5 @@
.coverage*
.idea/
__pycache__/
.cache/
*egg-info/
87 changes: 79 additions & 8 deletions intake_xarray/__init__.py
@@ -1,18 +1,12 @@
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# Copyright 2016 Continuum Analytics, Inc.
#
# May be copied and distributed freely only as part of an Anaconda or
# Miniconda installation.
# -----------------------------------------------------------------------------
from intake.source import base
import xarray as xr
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions


class NetCDFPlugin(base.Plugin):
"""Plugin for xarray reader"""
"""Plugin for netcdf->xarray reader"""

def __init__(self):
super(NetCDFPlugin, self).__init__(
Expand Down Expand Up @@ -41,3 +35,80 @@ def open(self, urlpath, chunks, **kwargs):
chunks=chunks,
xarray_kwargs=source_kwargs,
metadata=base_kwargs['metadata'])


class ZarrPlugin(base.Plugin):
"""zarr>xarray reader"""

def __init__(self):
super(ZarrPlugin, self).__init__(
name='zarr',
version=__version__,
container='xarray',
partition_access=True
)

def open(self, urlpath, storage_options=None, **kwargs):
"""
Parameters
----------
urlpath: str
Path to source. This can be a local directory or a remote data
service (i.e., with a protocol specifier like ``'s3://``).
storage_options: dict
Parameters passed to the backend file-system
kwargs:
Further parameters are passed to xr.open_zarr
"""
from intake_xarray.xzarr import ZarrSource
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this import to the beginning of the file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No! We don't want to import until necessary, because import intake would also import, and so take much longer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I figured that might be why.

base_kwargs, source_kwargs = self.separate_base_kwargs(kwargs)
return ZarrSource(urlpath, storage_options, base_kwargs['metadata'],
**source_kwargs)


class DataSourceMixin:
"""Common behaviours for plugins in this repo"""

def _get_schema(self):
"""Make schema object, which embeds xarray object and some details"""
if self._ds is None:
self._open_dataset()

metadata = {
'dims': dict(self._ds.dims),
'data_vars': tuple(self._ds.data_vars.keys()),
'coords': tuple(self._ds.coords.keys())
}
metadata.update(self._ds.attrs)
return base.Schema(
datashape=None,
dtype=xr.Dataset,
shape=None,
npartitions=None,
extra_metadata=metadata)

def read(self):
"""Return a version of the xarray with all the data in memory"""
self._load_metadata()
return self._ds.load()

def read_chunked(self):
"""Return xarray object (which will have chunks)"""
self._load_metadata()
return self._ds

def read_partition(self, i):
"""Fetch one chunk of data at tuple index i

(not yet implemented)
"""
raise NotImplementedError

def to_dask(self):
"""Return xarray object where variables are dask arrays"""
return self.read_chunked()

def close(self):
"""Delete open file from memory"""
self._ds.close()
self._ds = None
42 changes: 4 additions & 38 deletions intake_xarray/netcdf.py
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-
import xarray as xr
from intake.source import base
from . import DataSourceMixin


class NetCDFSource(base.DataSource):
class NetCDFSource(DataSourceMixin, base.DataSource):
"""Open a xarray file.

Parameters
Expand All @@ -27,41 +28,6 @@ def __init__(self, urlpath, chunks, xarray_kwargs=None, metadata=None):

def _open_dataset(self):
url = self.urlpath
if "*" in url:
return xr.open_mfdataset(url, chunks=self.chunks, **self._kwargs)
else:
return xr.open_dataset(url, chunks=self.chunks, **self._kwargs)
_open_dataset = xr.open_mfdataset if "*" in url else xr.open_dataset

def _get_schema(self):
if self._ds is None:
self._ds = self._open_dataset()

metadata = {
'dims': dict(self._ds.dims),
'data_vars': tuple(self._ds.data_vars.keys()),
'coords': tuple(self._ds.coords.keys())
}
metadata.update(self._ds.attrs)
return base.Schema(
datashape=None,
dtype=xr.Dataset,
shape=None,
npartitions=None,
extra_metadata=metadata)

def read(self):
self._load_metadata()
return self._ds.load()

def read_chunked(self):
self._load_metadata()
return self._ds

def read_partition(self, i):
raise NotImplementedError

def to_dask(self):
return self.read_chunked()

def close(self):
self._ds.close()
self._ds = _open_dataset(url, chunks=self.chunks, **self._kwargs)
54 changes: 54 additions & 0 deletions intake_xarray/xzarr.py
@@ -0,0 +1,54 @@
import xarray as xr
from intake.source import base
from dask.bytes.core import get_fs, infer_options, update_storage_options
from . import DataSourceMixin


class ZarrSource(DataSourceMixin, base.DataSource):
"""Open a xarray dataset.

Parameters
----------
urlpath: str
Path to source. This can be a local directory or a remote data
service (i.e., with a protocol specifier like ``'s3://``).
storage_options: dict
Parameters passed to the backend file-system
kwargs:
Further parameters are passed to xr.open_zarr
"""

def __init__(self, urlpath, storage_options=None, metadata=None, **kwargs):
super(ZarrSource, self).__init__(
container='xarray', metadata=metadata)
self.urlpath = urlpath
self.storage_options = storage_options
self.kwargs = kwargs
self._ds = None

def _open_dataset(self):
urlpath, protocol, options = infer_options(self.urlpath)
update_storage_options(options, self.storage_options)

self._fs, _ = get_fs(protocol, options)
if protocol != 'file':
self._mapper = get_mapper(protocol, self._fs, urlpath)
self._ds = xr.open_zarr(self._mapper, **self.kwargs)
else:
self._ds = xr.open_zarr(self.urlpath, **self.kwargs)

def close(self):
super(ZarrSource, self).close()
self._fs = None
self._mapper = None


def get_mapper(protocol, fs, path):
if protocol == 's3':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many more protocols do you think there will be?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hdfs3 has a mapper and I am not aware of any others. You may have noticed that https://github.com/martindurant/filesystem_spec contains a mapper, so with any luck, it should "just work" for any file-system meeting the spec - but that's a long-term goal.

from s3fs.mapping import S3Map
return S3Map(path, fs)
elif protocol == 'gcs':
from gcsfs.mapping import GCSMap
return GCSMap(path, fs)
else:
raise NotImplementedError
36 changes: 19 additions & 17 deletions tests/test_intake_xarray.py
Expand Up @@ -4,11 +4,12 @@
import pytest
import xarray as xr

from .util import TEST_URLPATH, source, dataset # noqa
from intake_xarray.netcdf import NetCDFSource
from .util import TEST_URLPATH, cdf_source, zarr_source, dataset # noqa


def test_discover(source, dataset):
@pytest.mark.parametrize('source', ['cdf', 'zarr'])
def test_discover(source, cdf_source, zarr_source, dataset):
source = {'cdf': cdf_source, 'zarr': zarr_source}[source]
r = source.discover()

assert r['datashape'] is None
Expand All @@ -17,32 +18,33 @@ def test_discover(source, dataset):

assert source.datashape is None
assert source.metadata['dims'] == dict(dataset.dims)
assert source.metadata['data_vars'] == tuple(dataset.data_vars.keys())
assert source.metadata['coords'] == tuple(dataset.coords.keys())
assert set(source.metadata['data_vars']) == set(dataset.data_vars.keys())
assert set(source.metadata['coords']) == set(dataset.coords.keys())


def test_read(source, dataset):
ds = source.read()
@pytest.mark.parametrize('source', ['cdf', 'zarr'])
def test_read(source, cdf_source, zarr_source, dataset):
source = {'cdf': cdf_source, 'zarr': zarr_source}[source]

ds = source.read_chunked()
assert ds.temp.chunks

ds = source.read()
assert ds.dims == dataset.dims
assert np.all(ds.temp == dataset.temp)
assert np.all(ds.rh == dataset.rh)


def test_read_chunked():
source = NetCDFSource(TEST_URLPATH, chunks={'lon': 2})
ds = source.read_chunked()
dataset = xr.open_dataset(TEST_URLPATH, chunks={'lon': 2})

assert ds.temp.chunks == dataset.temp.chunks


def test_read_partition(source):
@pytest.mark.parametrize('source', ['cdf', 'zarr'])
def test_read_partition(source, cdf_source, zarr_source):
source = {'cdf': cdf_source, 'zarr': zarr_source}[source]
with pytest.raises(NotImplementedError):
source.read_partition(None)


def test_to_dask(source, dataset):
@pytest.mark.parametrize('source', ['cdf', 'zarr'])
def test_to_dask(source, cdf_source, zarr_source, dataset):
source = {'cdf': cdf_source, 'zarr': zarr_source}[source]
ds = source.to_dask()

assert ds.dims == dataset.dims
Expand Down
17 changes: 16 additions & 1 deletion tests/util.py
Expand Up @@ -2,20 +2,35 @@

import os
import pytest
import shutil
import tempfile
import xarray as xr

from intake_xarray.netcdf import NetCDFSource
from intake_xarray.xzarr import ZarrSource

TEST_DATA_DIR = 'tests/data'
TEST_DATA = 'example_1.nc'
TEST_URLPATH = os.path.join(TEST_DATA_DIR, TEST_DATA)


@pytest.fixture
def source():
def cdf_source():
return NetCDFSource(TEST_URLPATH, {})


@pytest.fixture
def dataset():
return xr.open_dataset(TEST_URLPATH)


@pytest.fixture(scope='module')
def zarr_source():
pytest.importorskip('zarr')
try:
tdir = tempfile.mkdtemp()
data = xr.open_dataset(TEST_URLPATH)
data.to_zarr(tdir)
yield ZarrSource(tdir)
finally:
shutil.rmtree(tdir)