Skip to content

Commit

Permalink
Fixed issues with irregular xarray shape (#4188)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Jan 16, 2020
1 parent b78f84a commit e61bf82
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
29 changes: 6 additions & 23 deletions holoviews/core/data/xarray.py
Expand Up @@ -52,29 +52,13 @@ def packed(cls, dataset):
@classmethod
def shape(cls, dataset, gridded=False):
if cls.packed(dataset):
shape = dataset.data.shape[:-1]
if gridded:
return shape
else:
return (np.product(shape, dtype=np.intp), len(dataset.dimensions()))
array = dataset.data[..., 0]
else:
array = dataset.data[dataset.vdims[0].name]
if not any(cls.irregular(dataset, kd) for kd in dataset.kdims):
names = [kd.name for kd in dataset.kdims
if kd.name in array.dims][::-1]
if not all(d in names for d in array.dims):
array = np.squeeze(array)
if len(names) > 1:
try:
array = array.transpose(*names, transpose_coords=False)
except:
array = array.transpose(*names) # Handle old xarray
shape = array.shape
if gridded:
return shape
else:
return (np.product(shape, dtype=np.intp), len(dataset.dimensions()))

if not gridded:
return (np.product(array.shape, dtype=np.intp), len(dataset.dimensions()))
shape_map = dict(zip(array.dims, array.shape))
return tuple(shape_map.get(kd.name, np.nan) for kd in dataset.kdims[::-1])

@classmethod
def init(cls, eltype, data, kdims, vdims):
Expand Down Expand Up @@ -342,8 +326,7 @@ def coords(cls, dataset, dimension, ordered=False, expanded=False, edges=False):

if dim in dataset.kdims:
idx = dataset.get_dimension_index(dim)
isedges = (dim in dataset.kdims and len(shape) == dataset.ndims
and len(data) == (shape[dataset.ndims-idx-1]+1))
isedges = (len(shape) == dataset.ndims and len(data) == (shape[dataset.ndims-idx-1]+1))
else:
isedges = False
if edges and not isedges:
Expand Down
24 changes: 24 additions & 0 deletions holoviews/tests/core/data/testxarrayinterface.py
Expand Up @@ -5,6 +5,7 @@
import numpy as np

try:
import pandas as pd
import xarray as xr
except:
raise SkipTest("Could not import xarray, skipping XArrayInterface tests.")
Expand Down Expand Up @@ -45,6 +46,29 @@ def get_irregular_dataarray(self, invert_y=True):
return da.assign_coords(**{'xc': xr.DataArray(xs, dims=('y','x')),
'yc': xr.DataArray(ys, dims=('y','x')),})

def get_multi_dim_irregular_dataset(self):
temp = 15 + 8 * np.random.randn(2, 2, 4, 3)
precip = 10 * np.random.rand(2, 2, 4, 3)
lon = [[-99.83, -99.32], [-99.79, -99.23]]
lat = [[42.25, 42.21], [42.63, 42.59]]
return xr.Dataset({'temperature': (['x', 'y', 'z', 'time'], temp),
'precipitation': (['x', 'y', 'z', 'time'], precip)},
coords={'lon': (['x', 'y'], lon),
'lat': (['x', 'y'], lat),
'z': np.arange(4),
'time': pd.date_range('2014-09-06', periods=3),
'reference_time': pd.Timestamp('2014-09-05')})

def test_xarray_dataset_irregular_shape(self):
ds = Dataset(self.get_multi_dim_irregular_dataset())
shape = ds.interface.shape(ds, gridded=True)
self.assertEqual(shape, (np.nan, np.nan, 3, 4))

def test_xarray_irregular_dataset_values(self):
ds = Dataset(self.get_multi_dim_irregular_dataset())
values = ds.dimension_values('z', expanded=False)
self.assertEqual(values, np.array([0, 1, 2, 3]))

def test_xarray_dataset_with_scalar_dim_canonicalize(self):
xs = [0, 1]
ys = [0.1, 0.2, 0.3]
Expand Down

0 comments on commit e61bf82

Please sign in to comment.