Skip to content

Commit

Permalink
Improved rasterize API
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Mar 21, 2018
1 parent dde5367 commit 84ce627
Showing 1 changed file with 107 additions and 56 deletions.
163 changes: 107 additions & 56 deletions holoviews/operation/datashader.py
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
import xarray as xr
import datashader as ds
import datashader.reductions as rd
import datashader.transfer_functions as tf
import dask.dataframe as dd
from param.parameterized import bothmethod
Expand Down Expand Up @@ -175,7 +176,58 @@ def _get_sampling(self, element, x, y):



class aggregate(ResamplingOperation):
class AggregationOperation(ResamplingOperation):
"""
AggregationOperation extends the ResamplingOperation defining an
aggregator parameter used to define a datashader Reduction.
"""

aggregator = param.ClassSelector(class_=(ds.reductions.Reduction, basestring),
default=ds.count(), doc="""
Datashader reduction function used for aggregating the data.
The aggregator may also define a column to aggregate, if
no column is defined the first value dimension of the element
will be used. May also be defined as a string.""")

_agg_methods = {'first': rd.first,
'last': rd.last,
'mode': rd.mode,
'mean': rd.mean,
'var': rd.var,
'std': rd.std,
'min': rd.min,
'max': rd.max}

def _get_aggregator(self, element):
agg = self.p.aggregator
if isinstance(agg, basestring):
if agg not in self._agg_methods:
agg_methods = sorted(agg)
raise ValueError('Aggregation method %r is not know, '
'aggregator must be one of: %r' %
(agg, agg_methods))
agg = self._agg_methods[agg]()

elements = element.traverse(lambda x: x, [Element])
if agg.column is None and not isinstance(agg, rd.count):
if not elements:
raise ValueError('Could not find any elements to apply '
'%s operation to.' % type(self).__name__)
inner_element = elements[0]
if inner_element.vdims:
field = inner_element.vdims[0].name
elif isinstance(element, NdOverlay):
field = element.kdims[0].name
else:
raise ValueError('Could not determine dimension to apply '
'%s operation to. Declare the dimension '
'to aggregate as part of the datashader '
'aggregator.' % type(self).__name__)
agg = type(agg)(field)
return agg


class aggregate(AggregationOperation):
"""
aggregate implements 2D binning for any valid HoloViews Element
type using datashader. I.e., this operation turns a HoloViews
Expand All @@ -200,8 +252,6 @@ class aggregate(ResamplingOperation):
the linked plot.
"""

aggregator = param.ClassSelector(class_=ds.reductions.Reduction,
default=ds.count())

@classmethod
def get_agg_data(cls, obj, category=None):
Expand Down Expand Up @@ -341,7 +391,7 @@ def _aggregate_ndoverlay(self, element, agg_fn):


def _process(self, element, key=None):
agg_fn = self.p.aggregator
agg_fn = self._get_aggregator(element)
category = agg_fn.column if isinstance(agg_fn, ds.count_cat) else None

if (isinstance(element, NdOverlay) and
Expand Down Expand Up @@ -384,7 +434,7 @@ def _process(self, element, key=None):
datatype=['xarray'], vdims=vdims)

dfdata = PandasInterface.as_dframe(data)
agg = getattr(cvs, glyph)(dfdata, x.name, y.name, self.p.aggregator)
agg = getattr(cvs, glyph)(dfdata, x.name, y.name, agg_fn)
if 'x_axis' in agg.coords and 'y_axis' in agg.coords:
agg = agg.rename({'x_axis': x, 'y_axis': y})
if xtype == 'datetime':
Expand All @@ -406,7 +456,7 @@ def _process(self, element, key=None):



class regrid(ResamplingOperation):
class regrid(AggregationOperation):
"""
regrid allows resampling a HoloViews Image type using specified
up- and downsampling functions defined using the aggregator and
Expand All @@ -417,10 +467,13 @@ class regrid(ResamplingOperation):
with nan values.
"""

aggregator = param.ObjectSelector(default='mean',
objects=['first', 'last', 'mean', 'mode', 'std', 'var', 'min', 'max'], doc="""
Aggregation method.
""")
aggregator = param.ClassSelector(default=ds.mean(),
class_=(ds.reductions.Reduction, basestring),
doc="""
Datashader reduction function used for aggregating the data.
The aggregator may also define a column to aggregate, if
no column is defined the first value dimension of the element
will be used. May also be defined as a string.""")

expand = param.Boolean(default=False, doc="""
Whether the x_range and y_range should be allowed to expand
Expand Down Expand Up @@ -482,14 +535,13 @@ def _process(self, element, key=None):
if ds_version <= '0.5.0':
raise RuntimeError('regrid operation requires datashader>=0.6.0')

# Compute coords, anges and size
x, y = element.kdims
coords = tuple(element.dimension_values(d, expanded=False)
for d in [x, y])
coords = tuple(element.dimension_values(d, expanded=False) for d in [x, y])
info = self._get_sampling(element, x, y)
(x_range, y_range), _, (width, height), (xtype, ytype) = info
arrays = self._get_xarrays(element, coords, xtype, ytype)

# Disable upsampling if requested
# Disable upsampling by clipping size and ranges
(xstart, xend), (ystart, yend) = (x_range, y_range)
xspan, yspan = (xend-xstart), (yend-ystart)
if not self.p.upsample and self.p.target is None:
Expand All @@ -503,20 +555,26 @@ def _process(self, element, key=None):
height = min([int((yspan/eyspan) * len(coords[1])), height])
width, height = max([width, 1]), max([height, 1])

# Get expanded or bounded ranges
# Instantiate Canvas
cvs = ds.Canvas(plot_width=width, plot_height=height,
x_range=x_range, y_range=y_range)

# Apply regridding to each value dimension
regridded = {}
arrays = self._get_xarrays(element, coords, xtype, ytype)
for vd, xarr in arrays.items():
rarray = cvs.raster(xarr, upsample_method=self.p.interpolation,
downsample_method=self.p.aggregator)
downsample_method=self._get_aggregator(element))

# Convert datetime coordinates
if xtype == "datetime":
rarray[x.name] = (rarray[x.name]/10e5).astype('datetime64[us]')
if ytype == "datetime":
rarray[y.name] = (rarray[y.name]/10e5).astype('datetime64[us]')
regridded[vd] = rarray

regridded = xr.Dataset(regridded)

# Compute bounds (converting datetimes)
if xtype == 'datetime':
xstart, xend = (np.array([xstart, xend])/10e5).astype('datetime64[us]')
if ytype == 'datetime':
Expand All @@ -534,8 +592,13 @@ class trimesh_rasterize(aggregate):
data.
"""

aggregator = param.ClassSelector(class_=ds.reductions.Reduction,
default=None)
aggregator = param.ClassSelector(default=ds.mean(),
class_=(ds.reductions.Reduction, basestring),
doc="""
Datashader reduction function used for aggregating the data.
The aggregator may also define a column to aggregate, if
no column is defined the first value dimension of the element
will be used. May also be defined as a string.""")

interpolation = param.ObjectSelector(default='bilinear',
objects=['bilinear', None], doc="""
Expand Down Expand Up @@ -580,7 +643,7 @@ def _process(self, element, key=None):

vdim = element.vdims[0] if element.vdims else element.nodes.vdims[0]
interpolate = bool(self.p.interpolation)
agg = cvs.trimesh(pts, simplices, agg=self.p.aggregator,
agg = cvs.trimesh(pts, simplices, agg=self._get_aggregator(element),
interp=interpolate, mesh=mesh)
params = dict(get_param_values(element), kdims=[x, y],
datatype=['xarray'], vdims=[vdim])
Expand All @@ -600,7 +663,7 @@ def _precompute(self, element):



class rasterize(ResamplingOperation):
class rasterize(AggregationOperation):
"""
Rasterize is a high-level operation which will rasterize any
Element or combination of Elements aggregating it with the supplied
Expand All @@ -624,48 +687,36 @@ class rasterize(ResamplingOperation):
"""

aggregator = param.ClassSelector(class_=ds.reductions.Reduction,
default=None)
default=None, doc="""
Datashader reduction function used for aggregating the data.
The aggregator may also define a column to aggregate, if
no column is defined the first value dimension of the element
will be used. May also be defined as a string.""")

interpolation = param.ObjectSelector(default='bilinear',
objects=['bilinear', None], doc="""
The interpolation method to apply during rasterization.""")

def _process(self, element, key=None):
# Get input Images to avoid multiple rasterization
imgs = element.traverse(lambda x: x, [Image])

# Rasterize TriMeshes
tri_params = dict({k: v for k, v in self.p.items()
if k in aggregate.params()}, dynamic=False)
trirasterize = trimesh_rasterize.instance(**tri_params)
trirasterize._precomputed = self._precomputed
element = element.map(trirasterize, TriMesh)
self._precomputed = trirasterize._precomputed

# Rasterize QuadMesh
quad_params = dict({k: v for k, v in self.p.items()
if k in aggregate.params()}, dynamic=False)
quadrasterize = quadmesh_rasterize.instance(**quad_params)
quadrasterize._precomputed = self._precomputed
element = element.map(quadrasterize, QuadMesh)
self._precomputed = quadrasterize._precomputed

# Rasterize NdOverlay of objects
agg_params = dict({k: v for k, v in self.p.items()
if k in aggregate.params()}, dynamic=False)
dsrasterize = aggregate.instance(**agg_params)
dsrasterize._precomputed = self._precomputed
predicate = lambda x: (isinstance(x, NdOverlay) and
_transforms = [(Image, regrid),
(TriMesh, trimesh_rasterize),
(QuadMesh, quadmesh_rasterize),
(lambda x: (isinstance(x, NdOverlay) and
issubclass(x.type, Dataset)
and not issubclass(x.type, Image))
element = element.map(dsrasterize, predicate)

# Rasterize other Dataset types
predicate = lambda x: (isinstance(x, Dataset) and
(not isinstance(x, Image) or x in imgs))
element = element.map(dsrasterize, predicate)
self._precomputed = dsrasterize._precomputed
and not issubclass(x.type, Image)),
aggregate),
(lambda x: (isinstance(x, Dataset) and
(not isinstance(x, Image))),
aggregate)]

def _process(self, element, key=None):
for predicate, transform in self._transforms:
op_params = dict({k: v for k, v in self.p.items()
if k in transform.params() and v is not None},
dynamic=False)
op = transform.instance(**op_params)
op._precomputed = self._precomputed
element = element.map(op, predicate)
self._precomputed = op._precomputed
return element


Expand Down

0 comments on commit 84ce627

Please sign in to comment.