Skip to content

Commit

Permalink
Add xarray Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jun 15, 2018
1 parent 0169cc2 commit 6474ab5
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -40,7 +40,7 @@ conda install -c conda-forge traittypes

## Usage

`traittypes` extends the `traitlets` library with an implementation of trait types for numpy arrays, pandas dataframes and pandas series.
`traittypes` extends the `traitlets` library with an implementation of trait types for numpy arrays, pandas dataframes, pandas series, and xarray datasets.
- `traittypes` works around some limitations with numpy array comparison to only trigger change events when necessary.
- `traittypes` also extends the traitlets API for adding custom validators to constained proposed values for the attribute.

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api_documentation.rst
Expand Up @@ -20,3 +20,7 @@ The ``DataFrame`` trait type holds a pandas DataFrame.
The ``Series`` trait type holds a pandas Series.

.. autoclass:: traittypes.traittypes.Series

The ``Dataset`` trait type holds an xarray Dataset.

.. autoclass:: traittypes.traittypes.Dataset
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -85,6 +85,7 @@
'test': [
'numpy',
'pandas',
'xarray',
'pytest', # traitlets[test] require this
]
}
Expand Down
42 changes: 40 additions & 2 deletions traittypes/tests/test_traittypes.py
Expand Up @@ -7,9 +7,10 @@
from unittest import TestCase
from traitlets import HasTraits, TraitError, observe, Undefined
from traitlets.tests.test_traitlets import TraitTestBase
from traittypes import Array, DataFrame, Series
from traittypes import Array, DataFrame, Series, Dataset
import numpy as np
import pandas as pd
import xarray as xr


# Good / Bad value trait test cases
Expand Down Expand Up @@ -178,4 +179,41 @@ class Foo(HasTraits):
foo = Foo()
with self.assertRaises(TraitError):
foo.bar = None
foo.baz = None
foo.baz = None


class TestDataset(TestCase):

def test_ds_equal(self):
notifications = []
class Foo(HasTraits):
bar = Dataset({'foo': xr.DataArray([[0, 1, 2], [3, 4, 5]], coords={'x': ['a', 'b']}, dims=('x', 'y')), 'bar': ('x', [1, 2]), 'baz': 3.14})
@observe('bar')
def _(self, change):
notifications.append(change)
foo = Foo()
foo.bar = {'foo': xr.DataArray([[0, 1, 2], [3, 4, 5]], coords={'x': ['a', 'b']}, dims=('x', 'y')), 'bar': ('x', [1, 2]), 'baz': 3.14}
self.assertEqual(notifications, [])
foo.bar = {'foo': xr.DataArray([[0, 1, 2], [3, 4, 5]], coords={'x': ['a', 'b']}, dims=('x', 'y')), 'bar': ('x', [1, 2]), 'baz': 3.15}
self.assertEqual(len(notifications), 1)

def test_initial_values(self):
class Foo(HasTraits):
a = Dataset()
b = Dataset(None, allow_none=True)
c = Dataset([])
d = Dataset(Undefined)
foo = Foo()
self.assertTrue(foo.a.equals(xr.Dataset()))
self.assertTrue(foo.b is None)
self.assertTrue(foo.c.equals(xr.Dataset([])))
self.assertTrue(foo.d is Undefined)

def test_allow_none(self):
class Foo(HasTraits):
bar = Dataset()
baz = Dataset(allow_none=True)
foo = Foo()
with self.assertRaises(TraitError):
foo.bar = None
foo.baz = None
70 changes: 69 additions & 1 deletion traittypes/traittypes.py
Expand Up @@ -19,6 +19,10 @@ def __getattribute__(self, name):
import pandas as pd
except ImportError:
pd = _DelayedImportError('pandas')
try:
import xarray as xr
except ImportError:
xr = _DelayedImportError('xarray')


Empty = Sentinel('Empty', 'traittypes',
Expand All @@ -30,7 +34,7 @@ def __getattribute__(self, name):

class SciType(TraitType):

"""A base trait type for numpy arrays, pandas dataframes and series."""
"""A base trait type for numpy arrays, pandas dataframes, pandas series and xarray datasets."""

def __init__(self, **kwargs):
super(SciType, self).__init__(**kwargs)
Expand Down Expand Up @@ -206,3 +210,67 @@ def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
kwargs['klass'] = pd.Series
super(Series, self).__init__(
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)


class XarrayType(SciType):

"""An xarray dataset trait type."""

info_text = 'an xarray dataset'

klass = None

def validate(self, obj, value):
if value is None and not self.allow_none:
self.error(obj, value)
if value is None or value is Undefined:
return super(XarrayType, self).validate(obj, value)
try:
value = self.klass(value)
except (ValueError, TypeError) as e:
raise TraitError(e)
return super(XarrayType, self).validate(obj, value)

def set(self, obj, value):
new_value = self._validate(obj, value)
old_value = obj._trait_values.get(self.name, self.default_value)
obj._trait_values[self.name] = new_value
if ((old_value is None and new_value is not None) or
(old_value is Undefined and new_value is not Undefined) or
not old_value.equals(new_value)):
obj._notify_trait(self.name, old_value, new_value)

def __init__(self, default_value=Empty, allow_none=False, dtype=None, klass=None, **kwargs):
if klass is None:
klass = self.klass
if (klass is not None) and inspect.isclass(klass):
self.klass = klass
else:
raise TraitError('The klass attribute must be a class'
' not: %r' % klass)
self.dtype = dtype
if default_value is Empty:
default_value = klass()
elif default_value is not None and default_value is not Undefined:
default_value = klass(default_value)
super(XarrayType, self).__init__(default_value=default_value, allow_none=allow_none, **kwargs)

def make_dynamic_default(self):
if self.default_value is None or self.default_value is Undefined:
return self.default_value
else:
return self.default_value.copy()


class Dataset(XarrayType):

"""An xarray dataset trait type."""

info_text = 'an xarray dataset'

def __init__(self, default_value=Empty, allow_none=False, dtype=None, **kwargs):
if 'klass' not in kwargs and self.klass is None:
import xarray as xr
kwargs['klass'] = xr.Dataset
super(Dataset, self).__init__(
default_value=default_value, allow_none=allow_none, dtype=dtype, **kwargs)

0 comments on commit 6474ab5

Please sign in to comment.