diff --git a/lux/core/__init__.py b/lux/core/__init__.py index f1f0acf3..41943bff 100644 --- a/lux/core/__init__.py +++ b/lux/core/__init__.py @@ -14,7 +14,7 @@ import pandas as pd from .frame import LuxDataFrame -from .groupby import LuxDataFrameGroupBy +from .groupby import LuxDataFrameGroupBy, LuxSeriesGroupBy from .series import LuxSeries global originalDF @@ -58,8 +58,9 @@ def setOption(overridePandas=True): ) = ( pd.io.spss.DataFrame ) = pd.io.stata.DataFrame = pd.io.api.DataFrame = pd.core.frame.DataFrame = LuxDataFrame - pd.Series = pd.core.series.Series = pd.core.groupby.ops.Series = LuxSeries + pd.Series = pd.core.series.Series = pd.core.groupby.ops.Series = pd._testing.Series = LuxSeries pd.core.groupby.generic.DataFrameGroupBy = LuxDataFrameGroupBy + pd.core.groupby.generic.SeriesGroupBy = LuxSeriesGroupBy else: pd.DataFrame = pd.io.parsers.DataFrame = pd.core.frame.DataFrame = originalDF pd.Series = originalSeries diff --git a/lux/core/groupby.py b/lux/core/groupby.py index b6a18d74..7956b867 100644 --- a/lux/core/groupby.py +++ b/lux/core/groupby.py @@ -1,7 +1,7 @@ import pandas as pd -class LuxDataFrameGroupBy(pd.core.groupby.generic.DataFrameGroupBy): +class LuxGroupBy(pd.core.groupby.groupby.GroupBy): _metadata = [ "_intent", @@ -26,35 +26,42 @@ class LuxDataFrameGroupBy(pd.core.groupby.generic.DataFrameGroupBy): ] def __init__(self, *args, **kwargs): - super(LuxDataFrameGroupBy, self).__init__(*args, **kwargs) + super(LuxGroupBy, self).__init__(*args, **kwargs) def aggregate(self, *args, **kwargs): - ret_val = super(LuxDataFrameGroupBy, self).aggregate(*args, **kwargs) + ret_val = super(LuxGroupBy, self).aggregate(*args, **kwargs) for attr in self._metadata: ret_val.__dict__[attr] = getattr(self, attr, None) return ret_val def _agg_general(self, *args, **kwargs): - ret_val = super(LuxDataFrameGroupBy, self)._agg_general(*args, **kwargs) + ret_val = super(LuxGroupBy, self)._agg_general(*args, **kwargs) for attr in self._metadata: ret_val.__dict__[attr] = getattr(self, attr, None) return ret_val def _cython_agg_general(self, *args, **kwargs): - ret_val = super(LuxDataFrameGroupBy, self)._cython_agg_general(*args, **kwargs) + ret_val = super(LuxGroupBy, self)._cython_agg_general(*args, **kwargs) for attr in self._metadata: ret_val.__dict__[attr] = getattr(self, attr, None) return ret_val def get_group(self, *args, **kwargs): - ret_val = super(LuxDataFrameGroupBy, self).get_group(*args, **kwargs) + ret_val = super(LuxGroupBy, self).get_group(*args, **kwargs) for attr in self._metadata: ret_val.__dict__[attr] = getattr(self, attr, None) ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated return ret_val def filter(self, *args, **kwargs): - ret_val = super(LuxDataFrameGroupBy, self).filter(*args, **kwargs) + ret_val = super(LuxGroupBy, self).filter(*args, **kwargs) + for attr in self._metadata: + ret_val.__dict__[attr] = getattr(self, attr, None) + ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated + return ret_val + + def apply(self, *args, **kwargs): + ret_val = super(LuxGroupBy, self).apply(*args, **kwargs) for attr in self._metadata: ret_val.__dict__[attr] = getattr(self, attr, None) ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated @@ -68,15 +75,25 @@ def apply(self, *args, **kwargs): return ret_val def size(self, *args, **kwargs): - ret_val = super(LuxDataFrameGroupBy, self).size(*args, **kwargs) + ret_val = super(LuxGroupBy, self).size(*args, **kwargs) for attr in self._metadata: ret_val.__dict__[attr] = getattr(self, attr, None) return ret_val def __getitem__(self, *args, **kwargs): - ret_val = super(LuxDataFrameGroupBy, self).__getitem__(*args, **kwargs) + ret_val = super(LuxGroupBy, self).__getitem__(*args, **kwargs) for attr in self._metadata: ret_val.__dict__[attr] = getattr(self, attr, None) return ret_val agg = aggregate + + +class LuxDataFrameGroupBy(LuxGroupBy, pd.core.groupby.generic.DataFrameGroupBy): + def __init__(self, *args, **kwargs): + super(LuxDataFrameGroupBy, self).__init__(*args, **kwargs) + + +class LuxSeriesGroupBy(LuxGroupBy, pd.core.groupby.generic.SeriesGroupBy): + def __init__(self, *args, **kwargs): + super(LuxSeriesGroupBy, self).__init__(*args, **kwargs) diff --git a/lux/core/series.py b/lux/core/series.py index ea2a4a3b..40cb8b5e 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -226,3 +226,18 @@ def exported(self) -> Union[Dict[str, VisList], VisList]: When the exported vis is from the different tabs, return a dictionary with the action name as key and selected visualizations in the VisList. -> {"Enhance": VisList(v1, v2...), "Filter": VisList(v5, v7...), ..} """ return self._ldf.exported + + def groupby(self, *args, **kwargs): + history_flag = False + if "history" not in kwargs or ("history" in kwargs and kwargs["history"]): + history_flag = True + if "history" in kwargs: + del kwargs["history"] + groupby_obj = super(LuxSeries, self).groupby(*args, **kwargs) + for attr in self._metadata: + groupby_obj.__dict__[attr] = getattr(self, attr, None) + if history_flag: + groupby_obj._history = groupby_obj._history.copy() + groupby_obj._history.append_event("groupby", *args, **kwargs) + groupby_obj.pre_aggregated = True + return groupby_obj diff --git a/tests/test_groupby.py b/tests/test_groupby.py index 0c1c5929..b5841cc8 100644 --- a/tests/test_groupby.py +++ b/tests/test_groupby.py @@ -69,3 +69,11 @@ def test_get_group(global_var): new_df._repr_html_() assert new_df.history[0].name == "groupby" assert not new_df.pre_aggregated + + +def test_series_groupby(global_var): + df = pytest.car_df + df._repr_html_() + new_ser = df.set_index("Brand")["Displacement"].groupby(level=0).agg("mean") + assert new_ser._history[0].name == "groupby" + assert new_ser.pre_aggregated