Skip to content

Commit

Permalink
Merge 5213d1b into e76fb6b
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Feb 14, 2019
2 parents e76fb6b + 5213d1b commit 2c7b727
Showing 1 changed file with 47 additions and 18 deletions.
65 changes: 47 additions & 18 deletions holoviews/operation/datashader.py
Expand Up @@ -537,7 +537,7 @@ class regrid(AggregationOperation):
being overlaid on a much larger background.""")

interpolation = param.ObjectSelector(default='nearest',
objects=['linear', 'nearest'], doc="""
objects=['linear', 'nearest', 'bilinear', None, False], doc="""
Interpolation method""")

upsample = param.Boolean(default=False, doc="""
Expand Down Expand Up @@ -595,7 +595,9 @@ def _process(self, element, key=None):
# Disable upsampling by clipping size and ranges
(xstart, xend), (ystart, yend) = (x_range, y_range)
xspan, yspan = (xend-xstart), (yend-ystart)
if not self.p.upsample and self.p.target is None:
interp = self.p.interpolation or None
if interp == 'bilinear': interp = 'linear'
if not (self.p.upsample or interp is None) and self.p.target is None:
(x0, x1), (y0, y1) = element.range(0), element.range(1)
if isinstance(x0, datetime_types):
x0, x1 = dt_to_int(x0, 'ns'), dt_to_int(x1, 'ns')
Expand Down Expand Up @@ -638,7 +640,7 @@ def _process(self, element, key=None):
arrays = self._get_xarrays(element, coords, xtype, ytype)
agg_fn = self._get_aggregator(element, add_field=False)
for vd, xarr in arrays.items():
rarray = cvs.raster(xarr, upsample_method=self.p.interpolation,
rarray = cvs.raster(xarr, upsample_method=interp,
downsample_method=agg_fn)

# Convert datetime coordinates
Expand Down Expand Up @@ -683,7 +685,7 @@ class trimesh_rasterize(aggregate):
class_=(ds.reductions.Reduction, basestring))

interpolation = param.ObjectSelector(default='bilinear',
objects=['bilinear', None], doc="""
objects=['bilinear', 'linear', None, False], doc="""
The interpolation method to apply during rasterization.""")

def _precompute(self, element, agg):
Expand All @@ -700,6 +702,16 @@ def _precompute(self, element, agg):
return {'mesh': mesh(verts, simplices), 'simplices': simplices,
'vertices': verts}

def _precompute_wireframe(self, element, agg):
if hasattr(element, '_wireframe'):
segments = element._wireframe.data
else:
simplexes = element.array([0, 1, 2, 0]).astype('int')
verts = element.nodes.array([0, 1])
segments = pd.DataFrame(verts[simplexes].reshape(len(simplexes), -1),
columns=['x0', 'y0', 'x1', 'y1', 'x2', 'y2', 'x3', 'y3'])
element._wireframe = Dataset(segments, datatype=['dataframe', 'dask'])
return {'segments': segments}

def _process(self, element, key=None):
if isinstance(element, TriMesh):
Expand All @@ -710,26 +722,35 @@ def _process(self, element, key=None):
(x_range, y_range), (xs, ys), (width, height), (xtype, ytype) = info

agg = self.p.aggregator
if getattr(agg, 'column', None):
interp = self.p.interpolation or None
precompute = self.p.precompute
if interp == 'linear': interp = 'bilinear'
wireframe = False
if (not interp and (isinstance(agg, (ds.any, ds.count)) or
agg in ['any', 'count'] or not (element.vdims or element.nodes.vdims))):
wireframe = True
precompute = False # TriMesh itself caches wireframe
agg = self._get_aggregator(element) if isinstance(agg, (ds.any, ds.count)) else ds.any()
vdim = 'Count' if isinstance(agg, ds.count) else 'Any'
elif getattr(agg, 'column', None):
if agg.column in element.vdims:
vdim = element.get_dimension(agg.column)
elif isinstance(element, TriMesh) and agg.column in element.nodes.vdims:
vdim = element.nodes.get_dimension(agg.column)
else:
raise ValueError("Aggregation column %s not found on TriMesh element."
% agg.column)
elif not (element.vdims or (isinstance(element, TriMesh) and element.nodes.vdims)):
self.p.aggregator = ds.count() if not isinstance(agg, ds.any) else agg
return aggregate._process(self, element, key)
else:
if isinstance(element, TriMesh) and element.nodes.vdims:
vdim = element.nodes.vdims[0]
else:
vdim = element.vdims[0]
agg = self._get_aggregator(element)

if element._plot_id in self._precomputed:
precomputed = self._precomputed[element._plot_id]
elif wireframe:
precomputed = self._precompute_wireframe(element, agg)
else:
precomputed = self._precompute(element, agg)

Expand All @@ -742,17 +763,25 @@ def _process(self, element, key=None):
bounds = (x_range[0], y_range[0], x_range[1], y_range[1])
return Image((xs, ys, np.zeros((height, width))), bounds=bounds, **params)

simplices = precomputed['simplices']
pts = precomputed['vertices']
mesh = precomputed['mesh']
if self.p.precompute:
if wireframe:
segments = precomputed['segments']
else:
simplices = precomputed['simplices']
pts = precomputed['vertices']
mesh = precomputed['mesh']
if precompute:
self._precomputed = {element._plot_id: precomputed}

cvs = ds.Canvas(plot_width=width, plot_height=height,
x_range=x_range, y_range=y_range)
interpolate = bool(self.p.interpolation)
agg = cvs.trimesh(pts, simplices, agg=agg,
interp=interpolate, mesh=mesh)
if wireframe:
agg = cvs.line(segments, x=['x0', 'x1', 'x2', 'x3'],
y=['y0', 'y1', 'y2', 'y3'], axis=1,
agg=agg)
else:
interpolate = bool(self.p.interpolation)
agg = cvs.trimesh(pts, simplices, agg=agg,
interp=interpolate, mesh=mesh)
return Image(agg, **params)


Expand Down Expand Up @@ -795,8 +824,8 @@ class rasterize(AggregationOperation):
aggregator = param.ClassSelector(class_=(ds.reductions.Reduction, basestring),
default=None)

interpolation = param.ObjectSelector(default='bilinear',
objects=['bilinear', None], doc="""
interpolation = param.ObjectSelector(
default='bilinear', objects=['linear', 'nearest', 'bilinear', None, False], doc="""
The interpolation method to apply during rasterization.""")

_transforms = [(Image, regrid),
Expand Down

0 comments on commit 2c7b727

Please sign in to comment.