Skip to content

Commit

Permalink
Fix bug caused by groupby.agg on column with many unique values (#174)
Browse files Browse the repository at this point in the history
* update export tutorial to add explanation for standalone argument

* minor fixes and remove cell output in notebooks

* added contributing doc

* fix bugs and uncomment some tests

* remove raise warning

* remove unnecessary import

* split up rename test into two parts

* fix setting warning, fix data_type bugs and add relevant tests

* remove ordinal data type

* add test for small dataframe resetting index

* add loc and iloc tests

* fix attribute access directly to dataframe

* add small changes to code

* added test for qcut and cut

* add check if dtype is Interval

* added qcut test

* fix Record KeyError

* add tests

* take care of reset_index case

* small edits

* add data_model to column_group Clause

* small edits for row_group

* fixes to row group

Co-authored-by: Doris Lee <dorisjunglinlee@gmail.com>
  • Loading branch information
westernguy2 and dorisjlee committed Dec 9, 2020
1 parent 476f0ac commit 91f965f
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 6 deletions.
9 changes: 7 additions & 2 deletions lux/action/column_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,17 @@ def column_group(ldf):
vis = Vis(
[
lux.Clause(
index_column_name,
attribute=index_column_name,
data_type="nominal",
data_model="dimension",
aggregation=None,
),
lux.Clause(str(attribute), data_type="quantitative", aggregation=None),
lux.Clause(
attribute=str(attribute),
data_type="quantitative",
data_model="measure",
aggregation=None,
),
]
)
collection.append(vis)
Expand Down
2 changes: 1 addition & 1 deletion lux/action/row_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def row_group(ldf):
# rowdf.cardinality["index"]=len(rowdf)
# if isinstance(ldf.columns,pd.DatetimeIndex):
# rowdf.data_type_lookup[dim_name]="temporal"
vis = Vis([dim_name, lux.Clause(row.name, aggregation=None)], rowdf)
vis = Vis([dim_name, lux.Clause(row.name, data_model="measure", aggregation=None)], rowdf)
collection.append(vis)
vlst = VisList(collection)
# Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data
Expand Down
2 changes: 1 addition & 1 deletion lux/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __repr__(self):
ldf._widget.observe(ldf.remove_deleted_recs, names="deletedIndices")
ldf._widget.observe(ldf.set_intent_on_click, names="selectedIntentIndex")

if len(ldf._recommendation) > 0:
if len(ldf.recommendation) > 0:
# box = widgets.Box(layout=widgets.Layout(display='inline'))
button = widgets.Button(
description="Toggle Pandas/Lux",
Expand Down
2 changes: 1 addition & 1 deletion lux/executor/PandasExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def execute_aggregate(vis: Vis, isFiltered=True):
has_color = True
else:
color_cardinality = 1

if measure_attr != "":
if measure_attr.attribute == "Record":
vis._vis_data = vis.data.reset_index()
# if color is specified, need to group by groupby_attr and color_attr

if has_color:
vis._vis_data = (
vis.data.groupby([groupby_attr.attribute, color_attr.attribute])
Expand Down
22 changes: 22 additions & 0 deletions tests/test_pandas_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,28 @@ def test_groupby_agg(global_var):
assert len(new_df.cardinality) == 7


def test_groupby_agg_big(global_var):
df = pd.read_csv("lux/data/car.csv")
new_df = df.groupby("Brand").agg(sum)
new_df._repr_html_()
assert list(new_df.recommendation.keys()) == ["Column Groups"]
assert len(new_df.cardinality) == 8
year_vis = list(
filter(
lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Column Groups"]
)
)[0]
assert year_vis.mark == "bar"
assert year_vis.get_attr_by_channel("x")[0].attribute == "Year"
new_df = new_df.T
new_df._repr_html_()
year_vis = list(
filter(lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Row Groups"])
)[0]
assert year_vis.mark == "bar"
assert year_vis.get_attr_by_channel("x")[0].attribute == "Year"


def test_qcut(global_var):
df = pd.read_csv("lux/data/car.csv")
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
Expand Down
3 changes: 2 additions & 1 deletion tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_validator_invalid_attribute(global_var):
df = pytest.college_df
with pytest.raises(KeyError, match="'blah'"):
with pytest.warns(
UserWarning, match="The input attribute 'blah' does not exist in the DataFrame."
UserWarning,
match="The input attribute 'blah' does not exist in the DataFrame.",
):
df.intent = ["blah"]

0 comments on commit 91f965f

Please sign in to comment.