diff --git a/.gitignore b/.gitignore index f149ff4f..2a08cbfd 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,4 @@ venv/ ENV/ env.bak/ venv.bak/ +.mypy_cache diff --git a/eland/dataframe.py b/eland/dataframe.py index 875cdb1c..f7406876 100644 --- a/eland/dataframe.py +++ b/eland/dataframe.py @@ -1386,9 +1386,8 @@ def aggregate(self, func, axis=0, *args, **kwargs): # ['count', 'mad', 'max', 'mean', 'median', 'min', 'mode', 'quantile', # 'rank', 'sem', 'skew', 'sum', 'std', 'var', 'nunique'] if isinstance(func, str): - # wrap in list - func = [func] - return self._query_compiler.aggs(func) + # Wrap in list + return self._query_compiler.aggs([func]).squeeze().rename(None) elif is_list_like(func): # we have a list! return self._query_compiler.aggs(func) diff --git a/eland/tests/dataframe/test_aggs_pytest.py b/eland/tests/dataframe/test_aggs_pytest.py index 2f1546f3..330fb61a 100644 --- a/eland/tests/dataframe/test_aggs_pytest.py +++ b/eland/tests/dataframe/test_aggs_pytest.py @@ -18,8 +18,8 @@ # File called _pytest for PyCharm compatability import numpy as np -from pandas.testing import assert_frame_equal - +from pandas.testing import assert_frame_equal, assert_series_equal +import pytest from eland.tests.common import TestData @@ -94,3 +94,32 @@ def test_aggs_median_var(self): # TODO - investigate this more pd_aggs = pd_aggs.astype("float64") assert_frame_equal(pd_aggs, ed_aggs, check_exact=False, check_less_precise=2) + + # If Aggregate is given a string then series is returned. + @pytest.mark.parametrize("agg", ["mean", "min", "max"]) + def test_terms_aggs_series(self, agg): + pd_flights = self.pd_flights() + ed_flights = self.ed_flights() + + pd_sum_min_std = pd_flights.select_dtypes(include=[np.number]).agg(agg) + ed_sum_min_std = ed_flights.select_dtypes(include=[np.number]).agg(agg) + + assert_series_equal(pd_sum_min_std, ed_sum_min_std) + + def test_terms_aggs_series_with_single_list_agg(self): + # aggs list with single agg should return dataframe. + pd_flights = self.pd_flights() + ed_flights = self.ed_flights() + + pd_sum_min = pd_flights.select_dtypes(include=[np.number]).agg(["mean"]) + ed_sum_min = ed_flights.select_dtypes(include=[np.number]).agg(["mean"]) + + assert_frame_equal(pd_sum_min, ed_sum_min) + + # If Wrong Aggregate value is given. + def test_terms_wrongaggs(self): + ed_flights = self.ed_flights()[["FlightDelayMin"]] + + match = "('abc', ' not currently implemented')" + with pytest.raises(NotImplementedError, match=match): + ed_flights.select_dtypes(include=[np.number]).agg("abc")