Skip to content

Commit

Permalink
Refactor LuxGroupBy using multiple inheritance (#309)
Browse files Browse the repository at this point in the history
* add series equality and value counts test

* black formatting

* fix old value counts test instead

* add pandas tests

* remove str from column group

* save work on groupby bugs so far

* fix merge conflict again

* add new tests and add groupby bug fixes

* remove tests for staging

* update series tests

* add back getitem

* fix merge conflicts for staging

* remove print statements

* revert Makefile

* add test for name column case

* add test to ensure column is not all None

* re-add tests and commit new groupby implementation

* make parent LuxGroupby

* reformat LuxGroupBy and fix SeriesGroupby issue

* remove pandas tests

* run black

* finish seriesgroupby metadata propagation and add test

* run black
  • Loading branch information
westernguy2 committed Mar 16, 2021
1 parent 1c93057 commit 0748949
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 11 deletions.
5 changes: 3 additions & 2 deletions lux/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pandas as pd
from .frame import LuxDataFrame
from .groupby import LuxDataFrameGroupBy
from .groupby import LuxDataFrameGroupBy, LuxSeriesGroupBy
from .series import LuxSeries

global originalDF
Expand Down Expand Up @@ -58,8 +58,9 @@ def setOption(overridePandas=True):
) = (
pd.io.spss.DataFrame
) = pd.io.stata.DataFrame = pd.io.api.DataFrame = pd.core.frame.DataFrame = LuxDataFrame
pd.Series = pd.core.series.Series = pd.core.groupby.ops.Series = LuxSeries
pd.Series = pd.core.series.Series = pd.core.groupby.ops.Series = pd._testing.Series = LuxSeries
pd.core.groupby.generic.DataFrameGroupBy = LuxDataFrameGroupBy
pd.core.groupby.generic.SeriesGroupBy = LuxSeriesGroupBy
else:
pd.DataFrame = pd.io.parsers.DataFrame = pd.core.frame.DataFrame = originalDF
pd.Series = originalSeries
Expand Down
35 changes: 26 additions & 9 deletions lux/core/groupby.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd


class LuxDataFrameGroupBy(pd.core.groupby.generic.DataFrameGroupBy):
class LuxGroupBy(pd.core.groupby.groupby.GroupBy):

_metadata = [
"_intent",
Expand All @@ -26,35 +26,42 @@ class LuxDataFrameGroupBy(pd.core.groupby.generic.DataFrameGroupBy):
]

def __init__(self, *args, **kwargs):
super(LuxDataFrameGroupBy, self).__init__(*args, **kwargs)
super(LuxGroupBy, self).__init__(*args, **kwargs)

def aggregate(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).aggregate(*args, **kwargs)
ret_val = super(LuxGroupBy, self).aggregate(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def _agg_general(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self)._agg_general(*args, **kwargs)
ret_val = super(LuxGroupBy, self)._agg_general(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def _cython_agg_general(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self)._cython_agg_general(*args, **kwargs)
ret_val = super(LuxGroupBy, self)._cython_agg_general(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def get_group(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).get_group(*args, **kwargs)
ret_val = super(LuxGroupBy, self).get_group(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated
return ret_val

def filter(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).filter(*args, **kwargs)
ret_val = super(LuxGroupBy, self).filter(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated
return ret_val

def apply(self, *args, **kwargs):
ret_val = super(LuxGroupBy, self).apply(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated
Expand All @@ -68,15 +75,25 @@ def apply(self, *args, **kwargs):
return ret_val

def size(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).size(*args, **kwargs)
ret_val = super(LuxGroupBy, self).size(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def __getitem__(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).__getitem__(*args, **kwargs)
ret_val = super(LuxGroupBy, self).__getitem__(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

agg = aggregate


class LuxDataFrameGroupBy(LuxGroupBy, pd.core.groupby.generic.DataFrameGroupBy):
def __init__(self, *args, **kwargs):
super(LuxDataFrameGroupBy, self).__init__(*args, **kwargs)


class LuxSeriesGroupBy(LuxGroupBy, pd.core.groupby.generic.SeriesGroupBy):
def __init__(self, *args, **kwargs):
super(LuxSeriesGroupBy, self).__init__(*args, **kwargs)
15 changes: 15 additions & 0 deletions lux/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,18 @@ def exported(self) -> Union[Dict[str, VisList], VisList]:
When the exported vis is from the different tabs, return a dictionary with the action name as key and selected visualizations in the VisList. -> {"Enhance": VisList(v1, v2...), "Filter": VisList(v5, v7...), ..}
"""
return self._ldf.exported

def groupby(self, *args, **kwargs):
history_flag = False
if "history" not in kwargs or ("history" in kwargs and kwargs["history"]):
history_flag = True
if "history" in kwargs:
del kwargs["history"]
groupby_obj = super(LuxSeries, self).groupby(*args, **kwargs)
for attr in self._metadata:
groupby_obj.__dict__[attr] = getattr(self, attr, None)
if history_flag:
groupby_obj._history = groupby_obj._history.copy()
groupby_obj._history.append_event("groupby", *args, **kwargs)
groupby_obj.pre_aggregated = True
return groupby_obj
8 changes: 8 additions & 0 deletions tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ def test_get_group(global_var):
new_df._repr_html_()
assert new_df.history[0].name == "groupby"
assert not new_df.pre_aggregated


def test_series_groupby(global_var):
df = pytest.car_df
df._repr_html_()
new_ser = df.set_index("Brand")["Displacement"].groupby(level=0).agg("mean")
assert new_ser._history[0].name == "groupby"
assert new_ser.pre_aggregated

0 comments on commit 0748949

Please sign in to comment.