From 91f965f3a4c2be790b99fa089095c3df40e16867 Mon Sep 17 00:00:00 2001 From: Kunal Agarwal <32151899+westernguy2@users.noreply.github.com> Date: Tue, 8 Dec 2020 19:58:45 -0800 Subject: [PATCH] Fix bug caused by `groupby.agg` on column with many unique values (#174) * update export tutorial to add explanation for standalone argument * minor fixes and remove cell output in notebooks * added contributing doc * fix bugs and uncomment some tests * remove raise warning * remove unnecessary import * split up rename test into two parts * fix setting warning, fix data_type bugs and add relevant tests * remove ordinal data type * add test for small dataframe resetting index * add loc and iloc tests * fix attribute access directly to dataframe * add small changes to code * added test for qcut and cut * add check if dtype is Interval * added qcut test * fix Record KeyError * add tests * take care of reset_index case * small edits * add data_model to column_group Clause * small edits for row_group * fixes to row group Co-authored-by: Doris Lee --- lux/action/column_group.py | 9 +++++++-- lux/action/row_group.py | 2 +- lux/core/series.py | 2 +- lux/executor/PandasExecutor.py | 2 +- tests/test_pandas_coverage.py | 22 ++++++++++++++++++++++ tests/test_parser.py | 3 ++- 6 files changed, 34 insertions(+), 6 deletions(-) diff --git a/lux/action/column_group.py b/lux/action/column_group.py index 29d33d92..880cd422 100644 --- a/lux/action/column_group.py +++ b/lux/action/column_group.py @@ -46,12 +46,17 @@ def column_group(ldf): vis = Vis( [ lux.Clause( - index_column_name, + attribute=index_column_name, data_type="nominal", data_model="dimension", aggregation=None, ), - lux.Clause(str(attribute), data_type="quantitative", aggregation=None), + lux.Clause( + attribute=str(attribute), + data_type="quantitative", + data_model="measure", + aggregation=None, + ), ] ) collection.append(vis) diff --git a/lux/action/row_group.py b/lux/action/row_group.py index 3fca5428..01ab2a32 100644 --- a/lux/action/row_group.py +++ b/lux/action/row_group.py @@ -45,7 +45,7 @@ def row_group(ldf): # rowdf.cardinality["index"]=len(rowdf) # if isinstance(ldf.columns,pd.DatetimeIndex): # rowdf.data_type_lookup[dim_name]="temporal" - vis = Vis([dim_name, lux.Clause(row.name, aggregation=None)], rowdf) + vis = Vis([dim_name, lux.Clause(row.name, data_model="measure", aggregation=None)], rowdf) collection.append(vis) vlst = VisList(collection) # Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data diff --git a/lux/core/series.py b/lux/core/series.py index 77675473..44c05bf7 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -102,7 +102,7 @@ def __repr__(self): ldf._widget.observe(ldf.remove_deleted_recs, names="deletedIndices") ldf._widget.observe(ldf.set_intent_on_click, names="selectedIntentIndex") - if len(ldf._recommendation) > 0: + if len(ldf.recommendation) > 0: # box = widgets.Box(layout=widgets.Layout(display='inline')) button = widgets.Button( description="Toggle Pandas/Lux", diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py index 721a36e9..e0c10a90 100644 --- a/lux/executor/PandasExecutor.py +++ b/lux/executor/PandasExecutor.py @@ -152,11 +152,11 @@ def execute_aggregate(vis: Vis, isFiltered=True): has_color = True else: color_cardinality = 1 - if measure_attr != "": if measure_attr.attribute == "Record": vis._vis_data = vis.data.reset_index() # if color is specified, need to group by groupby_attr and color_attr + if has_color: vis._vis_data = ( vis.data.groupby([groupby_attr.attribute, color_attr.attribute]) diff --git a/tests/test_pandas_coverage.py b/tests/test_pandas_coverage.py index bd639197..f5977da5 100644 --- a/tests/test_pandas_coverage.py +++ b/tests/test_pandas_coverage.py @@ -145,6 +145,28 @@ def test_groupby_agg(global_var): assert len(new_df.cardinality) == 7 +def test_groupby_agg_big(global_var): + df = pd.read_csv("lux/data/car.csv") + new_df = df.groupby("Brand").agg(sum) + new_df._repr_html_() + assert list(new_df.recommendation.keys()) == ["Column Groups"] + assert len(new_df.cardinality) == 8 + year_vis = list( + filter( + lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Column Groups"] + ) + )[0] + assert year_vis.mark == "bar" + assert year_vis.get_attr_by_channel("x")[0].attribute == "Year" + new_df = new_df.T + new_df._repr_html_() + year_vis = list( + filter(lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Row Groups"]) + )[0] + assert year_vis.mark == "bar" + assert year_vis.get_attr_by_channel("x")[0].attribute == "Year" + + def test_qcut(global_var): df = pd.read_csv("lux/data/car.csv") df["Year"] = pd.to_datetime(df["Year"], format="%Y") diff --git a/tests/test_parser.py b/tests/test_parser.py index a59be1ec..333977aa 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -121,6 +121,7 @@ def test_validator_invalid_attribute(global_var): df = pytest.college_df with pytest.raises(KeyError, match="'blah'"): with pytest.warns( - UserWarning, match="The input attribute 'blah' does not exist in the DataFrame." + UserWarning, + match="The input attribute 'blah' does not exist in the DataFrame.", ): df.intent = ["blah"]