Skip to content

Commit

Permalink
Merge pull request #713 from ioam/xarray_interface
Browse files Browse the repository at this point in the history
Adding a grid based xarray data interface
  • Loading branch information
jlstevens committed Jul 14, 2016
2 parents 3e0d726 + d7aabb8 commit d4250e6
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ install:
- conda update -q conda
# Useful for debugging any issues with conda
- conda info -a
- conda create -q -c scitools -n test-environment python=$TRAVIS_PYTHON_VERSION scipy numpy freetype nose matplotlib bokeh pandas jupyter ipython param iris
- conda create -q -c scitools -n test-environment python=$TRAVIS_PYTHON_VERSION scipy numpy freetype nose matplotlib bokeh pandas jupyter ipython=4.2.0 param iris xarray
- source activate test-environment
- if [[ "$TRAVIS_PYTHON_VERSION" == "3.4" ]]; then
conda install python=3.4.3;
Expand Down
10 changes: 9 additions & 1 deletion holoviews/core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import param

from ..dimension import replace_dimensions
from .interface import Interface
from .array import ArrayInterface
from .dictionary import DictInterface
Expand Down Expand Up @@ -38,7 +39,14 @@
param.main.warning('Iris interface failed to import with '
'following error: %s' % e)

from ..dimension import Dimension, replace_dimensions
try:
import xarray # noqa (Availability import)
from .xarray import XArrayInterface # noqa (Conditional API import)
datatypes.append('xarray')
except ImportError:
pass

from ..dimension import Dimension
from ..element import Element
from ..spaces import HoloMap
from .. import util
Expand Down
2 changes: 1 addition & 1 deletion holoviews/core/data/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def values(cls, dataset, dim, expanded=True, flat=True):
dim_inds = [coord_names.index(d.name) for d in dataset.kdims]
dim_inds += [i for i in range(len(dataset.data.dim_coords))
if i not in dim_inds]
data = data.transpose(dim_inds)
data = data.transpose(dim_inds[::-1])
elif expanded:
idx = dataset.get_dimension_index(dim)
data = util.cartesian_product([dataset.data.coords(d.name)[0].points
Expand Down
223 changes: 223 additions & 0 deletions holoviews/core/data/xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from __future__ import absolute_import
import sys
import types

import numpy as np
import xarray as xr

from .. import util
from ..dimension import Dimension
from ..ndmapping import NdMapping, item_check, sorted_context
from ..element import Element
from .grid import GridInterface
from .interface import Interface


class XArrayInterface(GridInterface):

types = (xr.Dataset if xr else None,)

datatype = 'xarray'

@classmethod
def dimension_type(cls, dataset, dim):
name = dataset.get_dimension(dim).name
idx = list(dataset.data.keys()).index(name)
return dataset.data[name].dtype.type


@classmethod
def dtype(cls, dataset, dim):
name = dataset.get_dimension(dim).name
idx = list(dataset.data.keys()).index(name)
return dataset.data[name].dtype


@classmethod
def init(cls, eltype, data, kdims, vdims):
element_params = eltype.params()
kdim_param = element_params['kdims']
vdim_param = element_params['vdims']

if kdims:
kdim_names = [kd.name if isinstance(kd, Dimension) else kd for kd in kdims]
else:
kdim_names = [kd.name for kd in eltype.kdims]

if not isinstance(data, xr.Dataset):
ndims = len(kdim_names)
kdims = [kd if isinstance(kd, Dimension) else Dimension(kd)
for kd in kdims]
vdim = vdims[0].name if isinstance(vdims[0], Dimension) else vdims[0]
if isinstance(data, tuple):
value_array = np.array(data[-1])
data = {d: vals for d, vals in zip(kdim_names + [vdim], data)}
elif isinstance(data, dict):
value_array = np.array(data[vdim])
if value_array.ndim > 1:
value_array = value_array.T
dims, coords = zip(*[(kd.name, data[kd.name])
for kd in kdims])
try:
arr = xr.DataArray(value_array, coords=coords, dims=dims)
data = xr.Dataset({vdim: arr})
except:
pass
if not isinstance(data, xr.Dataset):
raise TypeError('Data must be be an xarray Dataset type.')

if isinstance(data, xr.Dataset):
if vdims is None:
vdims = list(data.data_vars.keys())
if kdims is None:
kdims = list(data.dims.keys())
return data, {'kdims': kdims, 'vdims': vdims}, {}


@classmethod
def range(cls, dataset, dimension):
dim = dataset.get_dimension(dimension).name
if dim in dataset.data:
data = dataset.data[dim]
return data.min().item(), data.max().item()
else:
return np.NaN, np.NaN


@classmethod
def groupby(cls, dataset, dimensions, container_type, group_type, **kwargs):
index_dims = [dataset.get_dimension(d) for d in dimensions]
element_dims = [kdim for kdim in dataset.kdims
if kdim not in index_dims]

group_kwargs = {}
if group_type != 'raw' and issubclass(group_type, Element):
group_kwargs = dict(util.get_param_values(dataset),
kdims=element_dims)
group_kwargs.update(kwargs)

# XArray 0.7.2 does not support multi-dimensional groupby
# Replace custom implementation when
# https://github.com/pydata/xarray/pull/818 is merged.
if len(dimensions) == 1:
data = [(k, group_type(v, **group_kwargs)) for k, v in
dataset.data.groupby(dimensions[0])]
else:
unique_iters = [cls.values(dataset, d, False) for d in dimensions]
indexes = zip(*[vals.flat for vals in util.cartesian_product(unique_iters)])
data = [(k, group_type(dataset.data.sel(**dict(zip(dimensions, k))),
**group_kwargs))
for k in indexes]

if issubclass(container_type, NdMapping):
with item_check(False), sorted_context(False):
return container_type(data, kdims=index_dims)
else:
return container_type(data)


@classmethod
def values(cls, dataset, dim, expanded=True, flat=True):
data = dataset.data[dim].data
if dim in dataset.vdims:
if data.ndim == 1:
return np.array(data)
else:
data = data.T
return data.flatten() if flat else data
elif not expanded:
return data
else:
arrays = [dataset.data[d.name].data for d in dataset.kdims]
product = util.cartesian_product(arrays)[dataset.get_dimension_index(dim)]
return product.flatten() if flat else product


@classmethod
def aggregate(cls, dataset, dimensions, function, **kwargs):
if len(dimensions) > 1:
raise NotImplementedError('Multi-dimensional aggregation not '
'supported as of xarray <=0.7.2.')
elif not dimensions:
return dataset.data.apply(function)
else:
return dataset.data.groupby(dimensions[0]).apply(function)


@classmethod
def unpack_scalar(cls, dataset, data):
"""
Given a dataset object and data in the appropriate format for
the interface, return a simple scalar.
"""
if (len(data.data_vars) == 1 and
len(data[dataset.vdims[0].name].shape) == 0):
return data[dataset.vdims[0].name].item()
return data


@classmethod
def concat(cls, dataset_objs):
#cast_objs = cls.cast(dataset_objs)
# Reimplement concat to automatically add dimensions
# once multi-dimensional concat has been added to xarray.
return xr.concat([col.data for col in dataset_objs], dim='concat_dim')

@classmethod
def redim(cls, dataset, dimensions):
renames = {k: v.name for k, v in dimensions.items()}
return dataset.data.rename(renames)

@classmethod
def reindex(cls, dataset, kdims=None, vdims=None):
return dataset.data

@classmethod
def sort(cls, dataset, by=[]):
return dataset

@classmethod
def select(cls, dataset, selection_mask=None, **selection):
validated = {}
for k, v in selection.items():
if isinstance(v, slice):
v = (v.start, v.stop)
if isinstance(v, set):
validated[k] = list(v)
elif isinstance(v, tuple):
validated[k] = slice(v[0], v[1]-sys.float_info.epsilon*10)
elif isinstance(v, types.FunctionType):
validated[k] = v(dataset[k])
else:
validated[k] = v
data = dataset.data.sel(**validated)
indexed = cls.indexed(dataset, selection)
if (indexed and len(data.data_vars) == 1 and
len(data[dataset.vdims[0].name].shape) == 0):
return data[dataset.vdims[0].name].item()
return data

@classmethod
def length(cls, dataset):
return np.product(dataset[dataset.vdims[0].name].shape)

@classmethod
def dframe(cls, dataset, dimensions):
if dimensions:
return dataset.reindex(columns=dimensions)
else:
return dataset.data.to_dataframe().reset_index(dimensions)

@classmethod
def sample(cls, columns, samples=[]):
raise NotImplementedError

@classmethod
def add_dimension(cls, columns, dimension, dim_pos, values, vdim):
if not vdim:
raise Exception("Cannot add key dimension to a dense representation.")
dim = dimension.name if isinstance(dimension, Dimension) else dimension
return dataset.assign(**{dim: values})


Interface.register(XArrayInterface)
23 changes: 23 additions & 0 deletions tests/testdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,26 @@ def test_dataset_2D_aggregate_partial_hm(self):
def test_dataset_sample_hm(self):
pass



class XArrayDatasetTest(GridDatasetTest):
"""
Tests for Iris interface
"""

def setUp(self):
import xarray
self.restore_datatype = Dataset.datatype
Dataset.datatype = ['xarray']
self.data_instance_type = xarray.Dataset
self.init_data()

# Disabled tests for NotImplemented methods
def test_dataset_add_dimensions_values_hm(self):
pass

def test_dataset_sort_vdim_hm(self):
pass

def test_dataset_sample_hm(self):
pass
2 changes: 1 addition & 1 deletion tests/testirisinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_dimension_values_vdim(self):
np.array([[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]], dtype=np.int32))
[ 3, 7, 11]], dtype=np.int32).T)

def test_range_kdim(self):
cube = Dataset(self.cube, kdims=['longitude', 'latitude'])
Expand Down

0 comments on commit d4250e6

Please sign in to comment.