Skip to content

Commit

Permalink
LuxGroupby Implementation (#260)
Browse files Browse the repository at this point in the history
* fix Record KeyError

* add tests

* take care of reset_index case

* small edits

* first implementation of groupby extended

* add flag for groupby

* update metadata lists

* pre-agg impl + first pass on tests

* 5 tests failing

* 4 failing

* fix failing tests with pre_aggregate

* extend get_group and filter

* fix final bug and add tests for groupby

* fix get_axis_number bug and added default metadata values

* remove unecessary computation

* move history out of original df

* add comments and consolidate metadata tests

* add back cached datasets for tests

* add clear_intent to tests

Co-authored-by: Ujjaini Mukhopadhyay <ujjaini@berkeley.edu>
  • Loading branch information
westernguy2 and jinimukh committed Feb 15, 2021
1 parent 144e8f7 commit e2ece28
Show file tree
Hide file tree
Showing 18 changed files with 283 additions and 87 deletions.
4 changes: 2 additions & 2 deletions lux/action/column_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def column_group(ldf):
attribute=index_column_name,
data_type="nominal",
data_model="dimension",
aggregation=None,
aggregation="",
),
lux.Clause(
attribute=str(attribute),
attribute=attribute,
data_type="quantitative",
data_model="measure",
aggregation=None,
Expand Down
4 changes: 3 additions & 1 deletion lux/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pandas as pd
from .frame import LuxDataFrame
from .groupby import LuxDataFrameGroupBy
from .series import LuxSeries

global originalDF
Expand Down Expand Up @@ -57,7 +58,8 @@ def setOption(overridePandas=True):
) = (
pd.io.spss.DataFrame
) = pd.io.stata.DataFrame = pd.io.api.DataFrame = pd.core.frame.DataFrame = LuxDataFrame
pd.Series = pd.core.series.Series = LuxSeries
pd.Series = pd.core.series.Series = pd.core.groupby.ops.Series = LuxSeries
pd.core.groupby.generic.DataFrameGroupBy = LuxDataFrameGroupBy
else:
pd.DataFrame = pd.io.parsers.DataFrame = pd.core.frame.DataFrame = originalDF
pd.Series = originalSeries
Expand Down
29 changes: 23 additions & 6 deletions lux/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,14 @@ def _infer_structure(self):
is_multi_index_flag = self.index.nlevels != 1
not_int_index_flag = not pd.api.types.is_integer_dtype(self.index)
small_df_flag = len(self) < 100
self.pre_aggregated = (is_multi_index_flag or not_int_index_flag) and small_df_flag
if "Number of Records" in self.columns:
self.pre_aggregated = True
very_small_df_flag = len(self) <= 10
if very_small_df_flag:
self.pre_aggregated = True
if self.pre_aggregated == None:
self.pre_aggregated = (is_multi_index_flag or not_int_index_flag) and small_df_flag
if "Number of Records" in self.columns:
self.pre_aggregated = True
very_small_df_flag = len(self) <= 10
self.pre_aggregated = "groupby" in [event.name for event in self.history]
# if very_small_df_flag:
# self.pre_aggregated = True

@property
def intent(self):
Expand Down Expand Up @@ -920,3 +922,18 @@ def describe(self, *args, **kwargs):
self._pandas_only = True
self._history.append_event("describe", *args, **kwargs)
return super(LuxDataFrame, self).describe(*args, **kwargs)

def groupby(self, *args, **kwargs):
history_flag = False
if "history" not in kwargs or ("history" in kwargs and kwargs["history"]):
history_flag = True
if "history" in kwargs:
del kwargs["history"]
groupby_obj = super(LuxDataFrame, self).groupby(*args, **kwargs)
for attr in self._metadata:
groupby_obj.__dict__[attr] = getattr(self, attr, None)
if history_flag:
groupby_obj._history = groupby_obj._history.copy()
groupby_obj._history.append_event("groupby", *args, **kwargs)
groupby_obj.pre_aggregated = True
return groupby_obj
75 changes: 75 additions & 0 deletions lux/core/groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pandas as pd


class LuxDataFrameGroupBy(pd.core.groupby.generic.DataFrameGroupBy):

_metadata = [
"_intent",
"_inferred_intent",
"_data_type",
"unique_values",
"cardinality",
"_rec_info",
"_min_max",
"_current_vis",
"_widget",
"_recommendation",
"_prev",
"_history",
"_saved_export",
"_sampled",
"_toggle_pandas_display",
"_message",
"_pandas_only",
"pre_aggregated",
"_type_override",
]

def __init__(self, *args, **kwargs):
super(LuxDataFrameGroupBy, self).__init__(*args, **kwargs)

def aggregate(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).aggregate(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def _agg_general(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self)._agg_general(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def _cython_agg_general(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self)._cython_agg_general(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def get_group(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).get_group(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated
return ret_val

def filter(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).filter(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
ret_val.pre_aggregated = False # Returned LuxDataFrame isn't pre_aggregated
return ret_val

def size(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).size(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

def __getitem__(self, *args, **kwargs):
ret_val = super(LuxDataFrameGroupBy, self).__getitem__(*args, **kwargs)
for attr in self._metadata:
ret_val.__dict__[attr] = getattr(self, attr, None)
return ret_val

agg = aggregate
53 changes: 42 additions & 11 deletions lux/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import warnings
import traceback
import numpy as np
from lux.history.history import History
from lux.utils.message import Message


class LuxSeries(pd.Series):
Expand All @@ -26,11 +28,11 @@ class LuxSeries(pd.Series):

_metadata = [
"_intent",
"data_type",
"_inferred_intent",
"_data_type",
"unique_values",
"cardinality",
"_rec_info",
"_pandas_only",
"_min_max",
"plotting_style",
"_current_vis",
Expand All @@ -39,9 +41,34 @@ class LuxSeries(pd.Series):
"_prev",
"_history",
"_saved_export",
"name",
"_sampled",
"_toggle_pandas_display",
"_message",
"_pandas_only",
"pre_aggregated",
"_type_override",
]

_default_metadata = {
"_intent": list,
"_inferred_intent": list,
"_current_vis": list,
"_recommendation": list,
"_toggle_pandas_display": lambda: True,
"_pandas_only": lambda: False,
"_type_override": dict,
"_history": History,
"_message": Message,
}

def __init__(self, *args, **kw):
super(LuxSeries, self).__init__(*args, **kw)
for attr in self._metadata:
if attr in self._default_metadata:
self.__dict__[attr] = self._default_metadata[attr]()
else:
self.__dict__[attr] = None

@property
def _constructor(self):
return LuxSeries
Expand All @@ -50,14 +77,18 @@ def _constructor(self):
def _constructor_expanddim(self):
from lux.core.frame import LuxDataFrame

# def f(*args, **kwargs):
# df = LuxDataFrame(*args, **kwargs)
# for attr in self._metadata:
# df.__dict__[attr] = getattr(self, attr, None)
# return df

# f._get_axis_number = super(LuxSeries, self)._get_axis_number
return LuxDataFrame
def f(*args, **kwargs):
df = LuxDataFrame(*args, **kwargs)
for attr in self._metadata:
# if attr in self._default_metadata:
# default = self._default_metadata[attr]
# else:
# default = None
df.__dict__[attr] = getattr(self, attr, None)
return df

f._get_axis_number = LuxDataFrame._get_axis_number
return f

def to_pandas(self) -> pd.Series:
"""
Expand Down
17 changes: 11 additions & 6 deletions lux/executor/PandasExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,20 @@ def execute_aggregate(vis: Vis, isFiltered=True):

vis._vis_data = vis.data.reset_index()
# if color is specified, need to group by groupby_attr and color_attr

if has_color:
vis._vis_data = (
vis.data.groupby([groupby_attr.attribute, color_attr.attribute], dropna=False)
vis.data.groupby(
[groupby_attr.attribute, color_attr.attribute], dropna=False, history=False
)
.count()
.reset_index()
.rename(columns={index_name: "Record"})
)
vis._vis_data = vis.data[[groupby_attr.attribute, color_attr.attribute, "Record"]]
else:
vis._vis_data = (
vis.data.groupby(groupby_attr.attribute, dropna=False)
vis.data.groupby(groupby_attr.attribute, dropna=False, history=False)
.count()
.reset_index()
.rename(columns={index_name: "Record"})
Expand All @@ -183,10 +186,12 @@ def execute_aggregate(vis: Vis, isFiltered=True):
# if color is specified, need to group by groupby_attr and color_attr
if has_color:
groupby_result = vis.data.groupby(
[groupby_attr.attribute, color_attr.attribute], dropna=False
[groupby_attr.attribute, color_attr.attribute], dropna=False, history=False
)
else:
groupby_result = vis.data.groupby(groupby_attr.attribute, dropna=False)
groupby_result = vis.data.groupby(
groupby_attr.attribute, dropna=False, history=False
)
groupby_result = groupby_result.agg(agg_func)
intermediate = groupby_result.reset_index()
vis._vis_data = intermediate.__finalize__(vis.data)
Expand Down Expand Up @@ -358,7 +363,7 @@ def execute_2D_binning(vis: Vis):
color_attr = vis.get_attr_by_channel("color")
if len(color_attr) > 0:
color_attr = color_attr[0]
groups = vis._vis_data.groupby(["xBin", "yBin"])[color_attr.attribute]
groups = vis._vis_data.groupby(["xBin", "yBin"], history=False)[color_attr.attribute]
if color_attr.data_type == "nominal":
# Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0])
result = groups.agg(
Expand All @@ -374,7 +379,7 @@ def execute_2D_binning(vis: Vis):
).reset_index()
result = result.dropna()
else:
groups = vis._vis_data.groupby(["xBin", "yBin"])[x_attr]
groups = vis._vis_data.groupby(["xBin", "yBin"], history=False)[x_attr]
result = groups.count().reset_index(name=x_attr)
result = result.rename(columns={x_attr: "count"})
result = result[result["count"] != 0]
Expand Down
5 changes: 5 additions & 0 deletions lux/history/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ def __repr__(self):
def append_event(self, name, *args, **kwargs):
event = Event(name, *args, **kwargs)
self._events.append(event)

def copy(self):
history_copy = History()
history_copy._events.extend(self._events)
return history_copy
6 changes: 5 additions & 1 deletion lux/interestingness/interestingness.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,11 @@ def unevenness(vis: Vis, ldf: LuxDataFrame, measure_lst: list, dimension_lst: li
v = vis.data[measure_lst[0].attribute]
v = v / v.sum() # normalize by total to get ratio
v = v.fillna(0) # Some bar values may be NaN
C = ldf.cardinality[dimension_lst[0].attribute]
attr = dimension_lst[0].attribute
if isinstance(attr, pd._libs.tslibs.timestamps.Timestamp):
# If timestamp, use the _repr_ (e.g., TimeStamp('2020-04-05 00.000')--> '2020-04-05')
attr = str(attr._date_repr)
C = ldf.cardinality[attr]
D = (0.9) ** C # cardinality-based discounting factor
v_flat = pd.Series([1 / C] * len(v))
if is_datetime(v):
Expand Down
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,25 @@ def global_var():
pytest.olympic = pd.read_csv(url)
pytest.car_df = pd.read_csv("lux/data/car.csv")
pytest.college_df = pd.read_csv("lux/data/college.csv")
pytest.metadata = [
"_intent",
"_inferred_intent",
"_data_type",
"unique_values",
"cardinality",
"_rec_info",
"_min_max",
"plotting_style",
"_current_vis",
"_widget",
"_recommendation",
"_prev",
"_history",
"_saved_export",
"_sampled",
"_toggle_pandas_display",
"_message",
"_pandas_only",
"pre_aggregated",
"_type_override",
]
3 changes: 2 additions & 1 deletion tests/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_row_column_group(global_var):
tseries[tseries.columns.max()] = tseries[tseries.columns.max()].fillna(tseries.max(axis=1))
tseries = tseries.interpolate("zero", axis=1)
tseries._repr_html_()
assert list(tseries.recommendation.keys()) == ["Row Groups", "Column Groups"]
assert list(tseries.recommendation.keys()) == ["Temporal"]


def test_groupby(global_var):
Expand Down Expand Up @@ -171,6 +171,7 @@ def test_custom_aggregation(global_var):
df.set_intent(["HighestDegree", lux.Clause("AverageCost", aggregation=np.ptp)])
df._repr_html_()
assert list(df.recommendation.keys()) == ["Enhance", "Filter", "Generalize"]
df.clear_intent()


def test_year_filter_value(global_var):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_underspecified_no_vis(global_var, test_recs):
df.set_intent([lux.Clause(attribute="Origin", filter_op="=", value="USA")])
test_recs(df, no_vis_actions)
assert len(df.current_vis) == 0
df.clear_intent()


def test_underspecified_single_vis(global_var, test_recs):
Expand Down Expand Up @@ -233,6 +234,7 @@ def test_autoencoding_scatter(global_var):
lux.Clause(attribute="Weight", channel="x"),
]
)
df.clear_intent()


def test_autoencoding_histogram(global_var):
Expand Down Expand Up @@ -286,6 +288,7 @@ def test_autoencoding_line_chart(global_var):
lux.Clause(attribute="Acceleration", channel="x"),
]
)
df.clear_intent()


def test_autoencoding_color_line_chart(global_var):
Expand Down Expand Up @@ -354,6 +357,7 @@ def test_populate_options(global_var):
list(col_set),
["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"],
)
df.clear_intent()


def test_remove_all_invalid(global_var):
Expand All @@ -368,6 +372,7 @@ def test_remove_all_invalid(global_var):
)
df._repr_html_()
assert len(df.current_vis) == 0
df.clear_intent()


def list_equal(l1, l2):
Expand Down

0 comments on commit e2ece28

Please sign in to comment.