/
xarray.py
281 lines (237 loc) · 10.2 KB
/
xarray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
from __future__ import absolute_import
import sys
import types
import numpy as np
import xarray as xr
try:
import dask
except ImportError:
dask = None
from .. import util
from ..dimension import Dimension
from ..ndmapping import NdMapping, item_check, sorted_context
from ..element import Element
from .grid import GridInterface
from .interface import Interface
class XArrayInterface(GridInterface):
types = (xr.Dataset, xr.DataArray)
datatype = 'xarray'
@classmethod
def dimension_type(cls, dataset, dim):
name = dataset.get_dimension(dim, strict=True).name
return dataset.data[name].dtype.type
@classmethod
def dtype(cls, dataset, dim):
name = dataset.get_dimension(dim, strict=True).name
return dataset.data[name].dtype
@classmethod
def init(cls, eltype, data, kdims, vdims):
element_params = eltype.params()
kdim_param = element_params['kdims']
vdim_param = element_params['vdims']
if isinstance (data, xr.DataArray):
if data.name:
vdim = Dimension(data.name)
elif vdims:
vdim = vdims[0]
elif len(vdim_param.default) == 1:
vdim = vdim_param.default[0]
vdims = [vdim]
kdims = [Dimension(d) for d in data.dims[::-1]]
data = data.to_dataset(name=vdim.name)
elif not isinstance(data, xr.Dataset):
if kdims is None:
kdims = kdim_param.default
if vdims is None:
vdims = vdim_param.default
kdims = [kd if isinstance(kd, Dimension) else Dimension(kd)
for kd in kdims]
vdims = [vd if isinstance(vd, Dimension) else Dimension(vd)
for vd in vdims]
if isinstance(data, tuple):
data = {d.name: vals for d, vals in zip(kdims + vdims, data)}
elif isinstance(data, list) and data == []:
ndims = len(kdims)
dimensions = [d.name if isinstance(d, Dimension) else
d for d in kdims + vdims]
data = {d: np.array([]) for d in dimensions[:ndims]}
data.update({d: np.empty((0,) * ndims) for d in dimensions[ndims:]})
if not isinstance(data, dict):
raise TypeError('XArrayInterface could not interpret data type')
coords = [(kd.name, data[kd.name]) for kd in kdims][::-1]
arrays = {}
for vdim in vdims:
arr = data[vdim.name]
if not isinstance(arr, xr.DataArray):
arr = xr.DataArray(arr, coords=coords)
arrays[vdim.name] = arr
data = xr.Dataset(arrays)
else:
if vdims is None:
vdims = list(data.data_vars.keys())
if kdims is None:
kdims = [name for name in data.indexes.keys()
if isinstance(data[name].data, np.ndarray)]
if not isinstance(data, xr.Dataset):
raise TypeError('Data must be be an xarray Dataset type.')
return data, {'kdims': kdims, 'vdims': vdims}, {}
@classmethod
def range(cls, dataset, dimension):
dim = dataset.get_dimension(dimension, strict=True).name
if dim in dataset.data:
data = dataset.data[dim]
dmin, dmax = data.min().data, data.max().data
dmin = dmin if np.isscalar(dmin) else dmin.item()
dmax = dmax if np.isscalar(dmax) else dmax.item()
return dmin, dmax
else:
return np.NaN, np.NaN
@classmethod
def groupby(cls, dataset, dimensions, container_type, group_type, **kwargs):
index_dims = [dataset.get_dimension(d, strict=True) for d in dimensions]
element_dims = [kdim for kdim in dataset.kdims
if kdim not in index_dims]
group_kwargs = {}
if group_type != 'raw' and issubclass(group_type, Element):
group_kwargs = dict(util.get_param_values(dataset),
kdims=element_dims)
group_kwargs.update(kwargs)
drop_dim = any(d not in group_kwargs['kdims'] for d in element_dims)
# XArray 0.7.2 does not support multi-dimensional groupby
# Replace custom implementation when
# https://github.com/pydata/xarray/pull/818 is merged.
group_by = [d.name for d in index_dims]
data = []
if len(dimensions) == 1:
for k, v in dataset.data.groupby(index_dims[0].name):
if drop_dim:
v = v.to_dataframe().reset_index()
data.append((k, group_type(v, **group_kwargs)))
else:
unique_iters = [cls.values(dataset, d, False) for d in group_by]
indexes = zip(*util.cartesian_product(unique_iters))
for k in indexes:
sel = dataset.data.sel(**dict(zip(group_by, k)))
if drop_dim:
sel = sel.to_dataframe().reset_index()
data.append((k, group_type(sel, **group_kwargs)))
if issubclass(container_type, NdMapping):
with item_check(False), sorted_context(False):
return container_type(data, kdims=index_dims)
else:
return container_type(data)
@classmethod
def coords(cls, dataset, dim, ordered=False, expanded=False):
dim = dataset.get_dimension(dim, strict=True).name
if expanded:
return util.expand_grid_coords(dataset, dim)
data = np.atleast_1d(dataset.data[dim].data)
if ordered and data.shape and np.all(data[1:] < data[:-1]):
data = data[::-1]
return data
@classmethod
def values(cls, dataset, dim, expanded=True, flat=True):
dim = dataset.get_dimension(dim, strict=True)
data = dataset.data[dim.name].data
if dim in dataset.vdims:
coord_dims = dataset.data[dim.name].dims
if dask and isinstance(data, dask.array.Array):
data = data.compute()
data = cls.canonicalize(dataset, data, coord_dims=coord_dims)
return data.T.flatten() if flat else data
elif expanded:
data = cls.coords(dataset, dim.name, expanded=True)
return data.flatten() if flat else data
else:
return cls.coords(dataset, dim.name, ordered=True)
@classmethod
def aggregate(cls, dataset, dimensions, function, **kwargs):
reduce_dims = [d.name for d in dataset.kdims if d not in dimensions]
return dataset.data.reduce(function, dim=reduce_dims)
@classmethod
def unpack_scalar(cls, dataset, data):
"""
Given a dataset object and data in the appropriate format for
the interface, return a simple scalar.
"""
if (len(data.data_vars) == 1 and
len(data[dataset.vdims[0].name].shape) == 0):
return data[dataset.vdims[0].name].item()
return data
@classmethod
def concat(cls, dataset_objs):
#cast_objs = cls.cast(dataset_objs)
# Reimplement concat to automatically add dimensions
# once multi-dimensional concat has been added to xarray.
return xr.concat([col.data for col in dataset_objs], dim='concat_dim')
@classmethod
def redim(cls, dataset, dimensions):
renames = {k: v.name for k, v in dimensions.items()}
return dataset.data.rename(renames)
@classmethod
def reindex(cls, dataset, kdims=None, vdims=None):
dropped_kdims = [kd for kd in dataset.kdims if kd not in kdims]
constant = {}
for kd in dropped_kdims:
vals = cls.values(dataset, kd.name, expanded=False)
if len(vals) == 1:
constant[kd.name] = vals[0]
if len(constant) == len(dropped_kdims):
return dataset.data.sel(**constant)
elif dropped_kdims:
return tuple(dataset.columns(kdims+vdims).values())
return dataset.data
@classmethod
def sort(cls, dataset, by=[]):
return dataset
@classmethod
def select(cls, dataset, selection_mask=None, **selection):
validated = {}
for k, v in selection.items():
dim = dataset.get_dimension(k, strict=True).name
if isinstance(v, slice):
v = (v.start, v.stop)
if isinstance(v, set):
validated[dim] = list(v)
elif isinstance(v, tuple):
upper = None if v[1] is None else v[1]-sys.float_info.epsilon*10
validated[dim] = slice(v[0], upper)
elif isinstance(v, types.FunctionType):
validated[dim] = v(dataset[k])
else:
validated[dim] = v
data = dataset.data.sel(**validated)
# Restore constant dimensions
indexed = cls.indexed(dataset, selection)
dropped = {d.name: np.atleast_1d(data[d.name])
for d in dataset.kdims
if not data[d.name].data.shape}
if dropped and not indexed:
data = data.assign_coords(**dropped)
if (indexed and len(data.data_vars) == 1 and
len(data[dataset.vdims[0].name].shape) == 0):
return data[dataset.vdims[0].name].item()
elif indexed:
return np.array([data[vd.name].item() for vd in dataset.vdims])
return data
@classmethod
def length(cls, dataset):
return np.product([len(dataset.data[d.name]) for d in dataset.kdims])
@classmethod
def dframe(cls, dataset, dimensions):
if dimensions:
return dataset.reindex(columns=dimensions).data.to_dataframe().reset_index(dimensions)
else:
return dataset.data.to_dataframe().reset_index(dimensions)
@classmethod
def sample(cls, columns, samples=[]):
raise NotImplementedError
@classmethod
def add_dimension(cls, dataset, dimension, dim_pos, values, vdim):
if not vdim:
raise Exception("Cannot add key dimension to a dense representation.")
dim = dimension.name if isinstance(dimension, Dimension) else dimension
arr = xr.DataArray(values, coords=dataset.data.coords, name=dim,
dims=dataset.data.indexes)
return dataset.data.assign(**{dim: arr})
Interface.register(XArrayInterface)