Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for fast QuadMesh rasterization #4020

Merged
merged 7 commits into from Oct 4, 2019
@@ -990,13 +990,17 @@ def to(self):
return self._conversion_interface(self)


def clone(self, data=None, shared_data=True, new_type=None, *args, **overrides):
def clone(self, data=None, shared_data=True, new_type=None, link=True,
*args, **overrides):
"""Clones the object, overriding data and parameters.
Args:
data: New data replacing the existing data
shared_data (bool, optional): Whether to use existing data
new_type (optional): Type to cast object to
link (bool, optional): Whether clone should be linked
Determines whether Streams and Links attached to
original object will be inherited.
*args: Additional arguments to pass to constructor
**overrides: New keyword arguments to pass to constructor
@@ -1010,6 +1014,12 @@ def clone(self, data=None, shared_data=True, new_type=None, *args, **overrides):
if data is None:
overrides['_validate_vdims'] = False

# Allows datatype conversions
if shared_data:
data = self
if link:
overrides['plot_id'] = self._plot_id

if 'dataset' not in overrides:
overrides['dataset'] = self.dataset

@@ -7,7 +7,7 @@

from .. import util
from ..element import Element
from ..ndmapping import OrderedDict, NdMapping
from ..ndmapping import NdMapping


def get_array_types():
@@ -225,14 +225,19 @@ def initialize(cls, eltype, data, kdims, vdims, datatype=None):
if not datatype:
datatype = eltype.datatype

if data.interface.datatype in datatype and data.interface.datatype in eltype.datatype:
interface = data.interface
if interface.datatype in datatype and interface.datatype in eltype.datatype:
data = data.data
elif data.interface.gridded and any(cls.interfaces[dt].gridded for dt in datatype):
gridded = OrderedDict([(kd.name, data.dimension_values(kd.name, expanded=False))
for kd in data.kdims])
elif interface.gridded and any(cls.interfaces[dt].gridded for dt in datatype):
new_data = []
for kd in data.kdims:
irregular = interface.irregular(data, kd)
coords = data.dimension_values(kd.name, expanded=irregular,
flat=not irregular)
new_data.append(coords)
for vd in data.vdims:
gridded[vd.name] = data.dimension_values(vd, flat=False)
data = tuple(gridded.values())
new_data.append(interface.values(data, vd, flat=False, compute=False))
data = tuple(new_data)
else:
data = tuple(data.columns().values())
elif isinstance(data, Element):
@@ -1017,7 +1017,48 @@ class quadmesh_rasterize(trimesh_rasterize):
"""

def _precompute(self, element, agg):
return super(quadmesh_rasterize, self)._precompute(element.trimesh(), agg)
if ds_version <= '0.7.0':
return super(quadmesh_rasterize, self)._precompute(element.trimesh(), agg)

def _process(self, element, key=None):
if ds_version <= '0.7.0':
return super(quadmesh_rasterize, self)._process(element, key)

if element.interface.datatype != 'xarray':
element = element.clone(datatype=['xarray'])
data = element.data

x, y = element.kdims
agg_fn = self._get_aggregator(element)
info = self._get_sampling(element, x, y)
(x_range, y_range), (xs, ys), (width, height), (xtype, ytype) = info
if xtype == 'datetime':
data[x.name] = data[x.name].astype('datetime64[us]').astype('int64')
if ytype == 'datetime':
data[y.name] = data[y.name].astype('datetime64[us]').astype('int64')

# Compute bounds (converting datetimes)
((x0, x1), (y0, y1)), (xs, ys) = self._dt_transform(
x_range, y_range, xs, ys, xtype, ytype
)
params = dict(get_param_values(element), datatype=['xarray'],
bounds=(x0, y0, x1, y1))

if width == 0 or height == 0:
return self._empty_agg(element, x, y, width, height, xs, ys, agg_fn, **params)

cvs = ds.Canvas(plot_width=width, plot_height=height,
x_range=x_range, y_range=y_range)

vdim = getattr(agg_fn, 'column', element.vdims[0].name)
agg = cvs.quadmesh(data[vdim], x.name, y.name, agg_fn)
xdim, ydim = list(agg.dims)[:2][::-1]
if xtype == "datetime":
agg[xdim] = (agg[xdim]/1e3).astype('datetime64[us]')
if ytype == "datetime":
agg[ydim] = (agg[ydim]/1e3).astype('datetime64[us]')

return Image(agg, **params)



@@ -626,14 +626,14 @@ def test_rasterize_trimesh_string_aggregator(self):
def test_rasterize_quadmesh(self):
qmesh = QuadMesh(([0, 1], [0, 1], np.array([[0, 1], [2, 3]])))
img = rasterize(qmesh, width=3, height=3, dynamic=False, aggregator=ds.mean('z'))
image = Image(np.array([[2., 3., np.NaN], [0, 1, np.NaN], [np.NaN, np.NaN, np.NaN]]),
image = Image(np.array([[2, 3, 3], [2, 3, 3], [0, 1, 1]]),
bounds=(-.5, -.5, 1.5, 1.5))
self.assertEqual(img, image)

def test_rasterize_quadmesh_string_aggregator(self):
qmesh = QuadMesh(([0, 1], [0, 1], np.array([[0, 1], [2, 3]])))
img = rasterize(qmesh, width=3, height=3, dynamic=False, aggregator='mean')
image = Image(np.array([[2., 3., np.NaN], [0, 1, np.NaN], [np.NaN, np.NaN, np.NaN]]),
image = Image(np.array([[2, 3, 3], [2, 3, 3], [0, 1, 1]]),
bounds=(-.5, -.5, 1.5, 1.5))
self.assertEqual(img, image)

ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.