Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 3D Volume regridding #946

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 227 additions & 47 deletions datashader/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
from xarray import DataArray, Dataset
from collections import OrderedDict

from .utils import Dispatcher, ngjit, calc_res, calc_bbox, orient_array, \
compute_coords, dshape_from_xarray_dataset
from .utils import (
Dispatcher, ngjit, calc_res, calc_bbox, orient_array,
compute_coords, calc_res3d, calc_bbox3d, compute_coords3d,
dshape_from_xarray_dataset, orient_array3d
)
from .utils import get_indices, dshape_from_pandas, dshape_from_dask
from .utils import Expr # noqa (API import)
from .resampling import resample_2d, resample_2d_distributed
from .resampling import resample_2d, resample_2d_distributed, resample_3d
from . import reductions as rd

try:
Expand Down Expand Up @@ -149,24 +152,27 @@ class Canvas(object):

Parameters
----------
plot_width, plot_height : int, optional
Width and height of the output aggregate in pixels.
x_range, y_range : tuple, optional
plot_width, plot_height, plot_depth : int, optional
Width and height (and depth) of the output aggregate in pixels.
x_range, y_range, z_range : tuple, optional
A tuple representing the bounds inclusive space ``[min, max]`` along
the axis.
x_axis_type, y_axis_type : str, optional
x_axis_type, y_axis_type, z_axis_type : str, optional
The type of the axis. Valid options are ``'linear'`` [default], and
``'log'``.
"""
def __init__(self, plot_width=600, plot_height=600,
x_range=None, y_range=None,
x_axis_type='linear', y_axis_type='linear'):
def __init__(self, plot_width=600, plot_height=600, plot_depth=None,
x_range=None, y_range=None, z_range=None,
x_axis_type='linear', y_axis_type='linear', z_axis_type='linear'):
self.plot_width = plot_width
self.plot_height = plot_height
self.plot_depth = plot_depth
self.x_range = None if x_range is None else tuple(x_range)
self.y_range = None if y_range is None else tuple(y_range)
self.z_range = None if z_range is None else tuple(z_range)
self.x_axis = _axis_lookup[x_axis_type]
self.y_axis = _axis_lookup[y_axis_type]
self.z_axis = _axis_lookup[z_axis_type]

def points(self, source, x=None, y=None, agg=None, geometry=None):
"""Compute a reduction by pixel, mapping data to pixels as points.
Expand Down Expand Up @@ -888,6 +894,46 @@ def trimesh(self, vertices, simplices, mesh=None, agg=None, interp=True, interpo

return bypixel(source, self, Triangles(x, y, weights, weight_type=verts_have_weights, interp=interp), agg)


def _validate_regrid(self, source, agg, interpolate, upsample_methods, downsample_methods):
if interpolate not in upsample_methods:
raise ValueError('Invalid interpolate method: options include {}'.format(upsample_methods))

if not isinstance(source, (DataArray, Dataset)):
raise ValueError('Expected xarray DataArray or Dataset as '
'the data source, found %s.'
% type(source).__name__)

column = None
if isinstance(agg, rd.Reduction):
agg, column = type(agg), agg.column
if (isinstance(source, DataArray) and column is not None
and source.name != column):
agg_repr = '%s(%r)' % (agg.__name__, column)
raise ValueError('DataArray name %r does not match '
'supplied reduction %s.' %
(source.name, agg_repr))

if isinstance(source, Dataset):
data_vars = list(source.data_vars)
if column is None:
raise ValueError('When supplying a Dataset the agg reduction '
'must specify the variable to aggregate. '
'Available data_vars include: %r.' % data_vars)
elif column not in source.data_vars:
raise KeyError('Supplied reduction column %r not found '
'in Dataset, expected one of the following '
'data variables: %r.' % (column, data_vars))
source = source[column]

if agg not in downsample_methods.keys():
raise ValueError('Invalid aggregation method: options include {}'.format(list(downsample_methods.keys())))

if source.ndim not in [2, 3]:
raise ValueError('Raster aggregation expects a 2D or 3D '
'DataArray, found %s dimensions' % source.ndim)
return source

def raster(self,
source,
layer=None,
Expand Down Expand Up @@ -964,44 +1010,11 @@ def raster(self,
'min':'min', rd.min:'min',
'max':'max', rd.max:'max'}

if interpolate not in upsample_methods:
raise ValueError('Invalid interpolate method: options include {}'.format(upsample_methods))

if not isinstance(source, (DataArray, Dataset)):
raise ValueError('Expected xarray DataArray or Dataset as '
'the data source, found %s.'
% type(source).__name__)

column = None
if isinstance(agg, rd.Reduction):
agg, column = type(agg), agg.column
if (isinstance(source, DataArray) and column is not None
and source.name != column):
agg_repr = '%s(%r)' % (agg.__name__, column)
raise ValueError('DataArray name %r does not match '
'supplied reduction %s.' %
(source.name, agg_repr))

if isinstance(source, Dataset):
data_vars = list(source.data_vars)
if column is None:
raise ValueError('When supplying a Dataset the agg reduction '
'must specify the variable to aggregate. '
'Available data_vars include: %r.' % data_vars)
elif column not in source.data_vars:
raise KeyError('Supplied reduction column %r not found '
'in Dataset, expected one of the following '
'data variables: %r.' % (column, data_vars))
source = source[column]

if agg not in downsample_methods.keys():
raise ValueError('Invalid aggregation method: options include {}'.format(list(downsample_methods.keys())))
source = self._validate_regrid(
source, agg, interpolate, upsample_methods, downsample_methods
)
ds_method = downsample_methods[agg]

if source.ndim not in [2, 3]:
raise ValueError('Raster aggregation expects a 2D or 3D '
'DataArray, found %s dimensions' % source.ndim)

res = calc_res(source)
ydim, xdim = source.dims[-2:]
xvals, yvals = source[xdim].values, source[ydim].values
Expand Down Expand Up @@ -1031,7 +1044,11 @@ def raster(self,
height_ratio = min((ymax - ymin) / (self.y_range[1] - self.y_range[0]), 1)

if np.isclose(width_ratio, 0) or np.isclose(height_ratio, 0):
raise ValueError('Canvas x_range or y_range values do not match closely enough with the data source to be able to accurately rasterize. Please provide ranges that are more accurate.')
raise ValueError(
'Canvas x_range or y_range values do not match closely '
'enough with the data source to be able to accurately '
'rasterize. Please provide ranges that are more accurate.'
)

w = max(int(round(self.plot_width * width_ratio)), 1)
h = max(int(round(self.plot_height * height_ratio)), 1)
Expand Down Expand Up @@ -1131,10 +1148,173 @@ def raster(self,
dims = [layer_dim]+dims
return DataArray(data, coords=coords, dims=dims, attrs=attrs)

def volume(self, source, nan_value=None, agg='mean', interpolate='nearest'):
"""Sample a raster dataset by canvas size and bounds.

Handles 3D xarray DataArrays.

Missing values (those having the value indicated by the
"nodata" attribute of the raster) are replaced with `NaN` if
floats, and 0 if int.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
floats, and 0 if int.
floats, and 0 if int. Thus if 0 is not a suitable missing-value indicator,
for your integer data, convert it to a float type before calling this method.


Parameters
----------
source : xarray.DataArray or xr.Dataset
2D or 3D labelled array (if Dataset, the agg reduction must
define the data variable).
nan_value : int or float, optional
Optional nan_value which will be masked out when applying
the resampling.
Comment on lines +1165 to +1167
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nan_value : int or float, optional
Optional nan_value which will be masked out when applying
the resampling.
nan_value : int or float, optional
Optional nan_value that will be masked out when applying
the resampling.

How does this interact with the "nodata" attribute of the raster? Does it override it? Should clarify that here.

agg : Reduction, optional default=mean()
Resampling mode when downsampling raster.
options include: first, last, mean, mode, var, std, min, max
Accepts an executable function, function object, or string name.
interpolate : str, optional default=linear
Resampling mode when upsampling raster.
options include: nearest, linear.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
options include: nearest, linear.
Options include: nearest, linear.


Returns
-------
data : xarray.Dataset
"""
upsample_methods = ['nearest']

downsample_methods = {'mean': 'mean', rd.mean:'mean'}

source = self._validate_regrid(
source, agg, interpolate, upsample_methods, downsample_methods
)

ds_method = downsample_methods[agg]

if self.plot_depth is None:
raise ValueError("Supply plot_depth to Canvas to aggregate in 3D.")

res = calc_res3d(source)
zdim, ydim, xdim = source.dims[-3:]
xvals, yvals, zvals = (source[dim].values for dim in (xdim, ydim, zdim))
left, bottom, back, right, top, front = calc_bbox3d(xvals, yvals, zvals, res)
array = orient_array3d(source, res)
dtype = array.dtype

if nan_value is not None:
mask = array==nan_value
array = np.ma.masked_array(array, mask=mask, fill_value=nan_value)
fill_value = nan_value
else:
fill_value = np.NaN

if self.x_range is None: self.x_range = (left,right)
if self.y_range is None: self.y_range = (bottom,top)
if self.z_range is None: self.z_range = (back,front)

# window coordinates
xmin = max(self.x_range[0], left)
ymin = max(self.y_range[0], bottom)
zmin = max(self.z_range[0], back)
xmax = min(self.x_range[1], right)
ymax = min(self.y_range[1], top)
zmax = min(self.z_range[1], front)

width_ratio = min((xmax - xmin) / (self.x_range[1] - self.x_range[0]), 1)
height_ratio = min((ymax - ymin) / (self.y_range[1] - self.y_range[0]), 1)
depth_ratio = min((zmax - zmin) / (self.z_range[1] - self.z_range[0]), 1)

if np.isclose(width_ratio, 0) or np.isclose(height_ratio, 0) or np.isclose(depth_ratio, 0):
raise ValueError('Canvas x_range, y_range or z_range values '
'do not match closely enough with the data '
'source to be able to accurately rasterize. '
'Please provide ranges that are more accurate.')

w = max(int(round(self.plot_width * width_ratio)), 1)
h = max(int(round(self.plot_height * height_ratio)), 1)
d = max(int(round(self.plot_depth * depth_ratio)), 1)
cmin, cmax = get_indices(xmin, xmax, xvals, res[0])
rmin, rmax = get_indices(ymin, ymax, yvals, res[1])
zmin, zmax = get_indices(zmin, zmax, zvals, res[2])

kwargs = dict(w=w, h=h, d=d, fill_value=fill_value)
source_window = array[zmin: zmax+1, rmin:rmax+1, cmin:cmax+1]
data = resample_3d(source_window, **kwargs)

if w != self.plot_width or h != self.plot_height:
num_height = self.plot_height - h
num_width = self.plot_width - w
num_depth = self.plot_depth - d

lpad = xmin - self.x_range[0]
rpad = self.x_range[1] - xmax
lpct = lpad / (lpad + rpad) if lpad + rpad > 0 else 0
left = max(int(np.ceil(num_width * lpct)), 0)
right = max(num_width - left, 0)
lshape, rshape = (self.plot_depth, self.plot_height, left), (self.plot_depth, self.plot_height, right)
left_pad = np.full(lshape, fill_value, source_window.dtype)
right_pad = np.full(rshape, fill_value, source_window.dtype)

tpad = ymin - self.y_range[0]
bpad = self.y_range[1] - ymax
tpct = tpad / (tpad + bpad) if tpad + bpad > 0 else 0
top = max(int(np.ceil(num_height * tpct)), 0)
bottom = max(num_height - top, 0)
tshape, bshape = (self.plot_depth, top, w), (self.plot_depth, bottom, w)
top_pad = np.full(tshape, fill_value, source_window.dtype)
bottom_pad = np.full(bshape, fill_value, source_window.dtype)

bkpad = zmin - self.z_range[0]
frpad = self.z_range[1] - zmax
frpct = bkpad / (frpad + bkpad) if (frpad + bkpad) > 0 else 0
front = max(int(np.ceil(num_depth * frpct)), 0)
back = max(num_depth - front, 0)
bkshape, frshape = (back, h, w), (front, h, w)
back_pad = np.full(bkshape, fill_value, source_window.dtype)
front_pad = np.full(frshape, fill_value, source_window.dtype)

arrays = (back_pad, data) if back_pad.shape[1] > 0 else (data,)
if front_pad.shape[0] > 0:
arrays += (front_pad,)
data = np.concat(arrays, axis=0) if len(arrays) > 1 else arrays[0]

arrays = (top_pad, data) if top_pad.shape[1] > 0 else (data,)
if bottom_pad.shape[1] > 0:
arrays += (bottom_pad,)
data = np.concat(arrays, axis=1) if len(arrays) > 1 else arrays[0]

arrays = (left_pad, data) if left_pad.shape[2] > 0 else (data,)
if right_pad.shape[2] > 0:
arrays += (right_pad,)
data = concat(arrays, axis=2) if len(arrays) > 1 else arrays[0]

# Reorient array to original orientation
if res[2] < 0: data = data[::-1]
if res[1] < 0: data = data[:, ::-1]
if res[0] < 0: data = data[:, :, ::-1]

# Restore nan_value from masked array
if nan_value is not None:
data = data.filled()

# Restore original dtype
if dtype != data.dtype:
data = data.astype(dtype)

# Compute DataArray metadata
xs, ys, zs = compute_coords3d(
self.plot_width, self.plot_height, self.plot_depth,
self.x_range, self.y_range, self.z_range, res
)
coords = {xdim: xs, ydim: ys, zdim: zs}
dims = [zdim, ydim, xdim]
attrs = dict(res=res[0])
if source._file_obj is not None and hasattr(source._file_obj, 'nodata'):
attrs['nodata'] = source._file_obj.nodata
return DataArray(data, coords=coords, dims=dims, attrs=attrs)

def validate(self):
"""Check that parameter settings are valid for this object"""
self.x_axis.validate(self.x_range)
self.y_axis.validate(self.y_range)
if self.plot_depth is not None:
self.z_axis.validate(self.z_range)


def bypixel(source, canvas, glyph, agg):
Expand Down
Loading