Skip to content

Commit

Permalink
Fixes cudf tests (#5166)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlstevens committed Feb 14, 2022
1 parent fa9498a commit 92bf472
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion holoviews/core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def sample(self, samples=[], bounds=None, closest=True, **kwargs):
# may be replaced with more general handling
# see https://github.com/ioam/holoviews/issues/1173
from ...element import Table, Curve
datatype = ['dataframe', 'dictionary', 'dask', 'ibis']
datatype = ['dataframe', 'dictionary', 'dask', 'ibis', 'cuDF']
if len(samples) == 1:
sel = {kd.name: s for kd, s in zip(self.kdims, samples[0])}
dims = [kd for kd, v in sel.items() if not np.isscalar(v)]
Expand Down
10 changes: 8 additions & 2 deletions holoviews/core/data/cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def select(cls, dataset, selection_mask=None, **selection):

indexed = cls.indexed(dataset, selection)
if selection_mask is not None:
df = df.loc[selection_mask]
df = df.iloc[selection_mask]
if indexed and len(df) == 1 and len(dataset.vdims) == 1:
return df[dataset.vdims[0].name].iloc[0]
return df
Expand Down Expand Up @@ -284,7 +284,13 @@ def aggregate(cls, dataset, dimensions, function, **kwargs):
if not hasattr(reindexed, agg):
raise ValueError('%s aggregation is not supported on cudf DataFrame.' % agg)
agg = getattr(reindexed, agg)()
data = dict(((col, [v]) for col, v in zip(agg.index.values_host, agg.to_array())))
try:
data = dict(((col, [v]) for col, v in zip(agg.index.values_host, agg.to_numpy())))
except Exception:
# Give FutureWarning: 'The to_array method will be removed in a future cuDF release.
# Consider using `to_numpy` instead.'
# Seen in cudf=21.12.01
data = dict(((col, [v]) for col, v in zip(agg.index.values_host, agg.to_array())))
df = util.pd.DataFrame(data, columns=list(agg.index.values_host))

dropped = []
Expand Down
9 changes: 7 additions & 2 deletions holoviews/tests/core/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,13 @@ def test_dataset_dataset_ht_dtypes(self):
# Test literal formats

def test_dataset_expanded_dimvals_ht(self):
self.assertEqual(self.table.dimension_values('Gender', expanded=False),
np.array(['M', 'F']))
# This will run unique(), which for pandas return
# in order of appearance, but can be sorted for other
# interfaces like cudf.
# pd.Series(["M", "M", "F"]).unique() -> ["M", "F"]
# cudf.Series(["M", "M", "F"]).unique() -> ["F", "M"]
data = self.table.dimension_values('Gender', expanded=False)
self.assertEqual(np.sort(data), np.array(['F', 'M']))

def test_dataset_implicit_indexing_init(self):
dataset = Scatter(self.ys, kdims=['x'], vdims=['y'])
Expand Down

0 comments on commit 92bf472

Please sign in to comment.