Skip to content

Commit

Permalink
Fixed missing handling of nodata for count aggregator with column (#4951
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jlstevens committed May 27, 2021
1 parent 2b1b9b8 commit e02402b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
17 changes: 12 additions & 5 deletions holoviews/operation/datashader.py
Expand Up @@ -229,7 +229,7 @@ class AggregationOperation(ResamplingOperation):
no column is defined the first value dimension of the element
will be used. May also be defined as a string.""")

vdim_prefix = param.String(default='{kdims} ', doc="""
vdim_prefix = param.String(default='{kdims} ', allow_None=True, doc="""
Prefix to prepend to value dimension name where {kdims}
templates in the names of the input element key dimensions.""")

Expand Down Expand Up @@ -297,8 +297,11 @@ def _get_agg_params(self, element, x, y, agg_fn, bounds):
params = dict(get_param_values(element), kdims=[x, y],
datatype=['xarray'], bounds=bounds)

kdim_list = '_'.join(str(kd) for kd in params['kdims'])
vdim_prefix = self.vdim_prefix.format(kdims=kdim_list)
if self.vdim_prefix:
kdim_list = '_'.join(str(kd) for kd in params['kdims'])
vdim_prefix = self.vdim_prefix.format(kdims=kdim_list)
else:
vdim_prefix = ''

category = None
if hasattr(agg_fn, 'reduction'):
Expand All @@ -311,8 +314,12 @@ def _get_agg_params(self, element, x, y, agg_fn, bounds):
raise ValueError("Aggregation column '%s' not found on '%s' element. "
"Ensure the aggregator references an existing "
"dimension." % (column,element))
if isinstance(agg_fn, ds.count_cat):
vdims = dims[0].clone('%s %s Count' % (vdim_prefix, column), nodata=0)
if isinstance(agg_fn, (ds.count, ds.count_cat)):
if vdim_prefix:
vdim_name = '%s%s Count' % (vdim_prefix, column)
else:
vdim_name = '%s Count' % column
vdims = dims[0].clone(vdim_name, nodata=0)
else:
vdims = dims[0].clone(vdim_prefix + column)
elif category:
Expand Down
8 changes: 8 additions & 0 deletions holoviews/tests/operation/test_datashader.py
Expand Up @@ -61,6 +61,14 @@ def test_aggregate_points(self):
vdims=[Dimension('Count', nodata=0)])
self.assertEqual(img, expected)

def test_aggregate_points_count_column(self):
points = Points([(0.2, 0.3, np.NaN), (0.4, 0.7, 22), (0, 0.99,np.NaN)], vdims='z')
img = aggregate(points, dynamic=False, x_range=(0, 1), y_range=(0, 1),
width=2, height=2, aggregator=ds.count('z'))
expected = Image(([0.25, 0.75], [0.25, 0.75], [[0, 0], [1, 0]]),
vdims=[Dimension('z Count', nodata=0)])
self.assertEqual(img, expected)

@cudf_skip
def test_aggregate_points_cudf(self):
points = Points([(0.2, 0.3), (0.4, 0.7), (0, 0.99)], datatype=['cuDF'])
Expand Down

0 comments on commit e02402b

Please sign in to comment.