From b1b56d7a9f8b39db3f4bfe9958af07f3a3754eb6 Mon Sep 17 00:00:00 2001 From: Sophia Huang <6860749+sophiahhuang@users.noreply.github.com> Date: Wed, 17 Mar 2021 12:02:48 +0800 Subject: [PATCH 1/4] Test Commit of SQLTable --- lux/core/sqltable.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py index 3884f2cd..a6e31dab 100644 --- a/lux/core/sqltable.py +++ b/lux/core/sqltable.py @@ -109,7 +109,7 @@ def set_SQL_table(self, t_name): ) def _repr_html_(self): - from IPython.display import display + from IPython.display import HTML, display from IPython.display import clear_output import ipywidgets as widgets @@ -143,7 +143,19 @@ def _repr_html_(self): ) self.output = widgets.Output() lux.config.executor.execute_preview(self) - display(button, self.output) + notification = HTML(""" + Note: data displayed is just a preview of the database table. No Pandas functionality on this data is supported.""") + + display(button, notification, self.output) def on_button_clicked(b): with self.output: From 34b5c45885f82fa701f8e603d02dc3eeb81cdc28 Mon Sep 17 00:00:00 2001 From: Sophia Huang <6860749+sophiahhuang@users.noreply.github.com> Date: Sat, 20 Mar 2021 01:32:43 +0800 Subject: [PATCH 2/4] Updated LuxSQLTable notification - Text -> [Toggle Table/Lux] - Caption's appearance with preview table only - Caption has table name --- lux/_config/config.py | 21 ++--- lux/action/generalize.py | 7 +- lux/action/row_group.py | 8 +- lux/core/frame.py | 6 +- lux/core/series.py | 3 +- lux/core/sqltable.py | 26 ++---- lux/executor/PandasExecutor.py | 10 +-- lux/executor/SQLExecutor.py | 100 +++++++++-------------- lux/interestingness/interestingness.py | 6 +- lux/interestingness/similarity.py | 3 +- lux/processor/Compiler.py | 9 +-- lux/vislib/altair/AltairRenderer.py | 3 +- lux/vislib/altair/Choropleth.py | 4 +- lux/vislib/altair/Heatmap.py | 5 +- tests/test_action.py | 23 ++---- tests/test_compiler.py | 105 +++++-------------------- tests/test_dates.py | 5 +- tests/test_interestingness.py | 10 +-- tests/test_pandas_coverage.py | 8 +- tests/test_parser.py | 3 +- tests/test_vis.py | 5 +- 21 files changed, 98 insertions(+), 272 deletions(-) diff --git a/lux/_config/config.py b/lux/_config/config.py index 12db034a..65c778ef 100644 --- a/lux/_config/config.py +++ b/lux/_config/config.py @@ -55,8 +55,7 @@ def topk(self, k: Union[int, bool]): self._topk = k else: warnings.warn( - "Parameter to lux.config.topk must be an integer or a boolean.", - stacklevel=2, + "Parameter to lux.config.topk must be an integer or a boolean.", stacklevel=2, ) @property @@ -99,8 +98,7 @@ def pandas_fallback(self, fallback: bool) -> None: self._pandas_fallback = fallback else: warnings.warn( - "The flag for Pandas fallback must be a boolean.", - stacklevel=2, + "The flag for Pandas fallback must be a boolean.", stacklevel=2, ) @property @@ -120,8 +118,7 @@ def interestingness_fallback(self, fallback: bool) -> None: self._interestingness_fallback = fallback else: warnings.warn( - "The flag for interestingness fallback must be a boolean.", - stacklevel=2, + "The flag for interestingness fallback must be a boolean.", stacklevel=2, ) @property @@ -147,8 +144,7 @@ def sampling_cap(self, sample_number: int) -> None: self._sampling_cap = sample_number else: warnings.warn( - "The cap on the number samples must be an integer.", - stacklevel=2, + "The cap on the number samples must be an integer.", stacklevel=2, ) @property @@ -176,8 +172,7 @@ def sampling_start(self, sample_number: int) -> None: self._sampling_start = sample_number else: warnings.warn( - "The sampling starting point must be an integer.", - stacklevel=2, + "The sampling starting point must be an integer.", stacklevel=2, ) @property @@ -202,8 +197,7 @@ def sampling(self, sample_flag: bool) -> None: self._sampling_flag = sample_flag else: warnings.warn( - "The flag for sampling must be a boolean.", - stacklevel=2, + "The flag for sampling must be a boolean.", stacklevel=2, ) @property @@ -228,8 +222,7 @@ def heatmap(self, heatmap_flag: bool) -> None: self._heatmap_flag = heatmap_flag else: warnings.warn( - "The flag for enabling/disabling heatmaps must be a boolean.", - stacklevel=2, + "The flag for enabling/disabling heatmaps must be a boolean.", stacklevel=2, ) @property diff --git a/lux/action/generalize.py b/lux/action/generalize.py index 385891d4..84504132 100644 --- a/lux/action/generalize.py +++ b/lux/action/generalize.py @@ -78,12 +78,7 @@ def generalize(ldf): for clause in filters: # new_spec = ldf._intent.copy() # new_spec.remove_column_from_spec(new_spec.attribute) - temp_vis = Vis( - ldf.current_vis[0]._inferred_intent.copy(), - source=ldf, - title="Overall", - score=0, - ) + temp_vis = Vis(ldf.current_vis[0]._inferred_intent.copy(), source=ldf, title="Overall", score=0,) temp_vis.remove_filter_from_spec(clause.value) output.append(temp_vis) diff --git a/lux/action/row_group.py b/lux/action/row_group.py index 5555c01f..7407ab5b 100644 --- a/lux/action/row_group.py +++ b/lux/action/row_group.py @@ -50,13 +50,7 @@ def row_group(ldf): # rowdf.cardinality["index"]=len(rowdf) # if isinstance(ldf.columns,pd.DatetimeIndex): # rowdf.data_type_lookup[dim_name]="temporal" - vis = Vis( - [ - dim_name, - lux.Clause(row.name, data_model="measure", aggregation=None), - ], - rowdf, - ) + vis = Vis([dim_name, lux.Clause(row.name, data_model="measure", aggregation=None),], rowdf,) collection.append(vis) vlst = VisList(collection) # Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data diff --git a/lux/core/frame.py b/lux/core/frame.py index 5f584d4c..94e8f406 100644 --- a/lux/core/frame.py +++ b/lux/core/frame.py @@ -491,8 +491,7 @@ def exported(self) -> Union[Dict[str, VisList], VisList]: exported_vis = VisList( list( map( - self._recommendation[export_action].__getitem__, - exported_vis_lst[export_action], + self._recommendation[export_action].__getitem__, exported_vis_lst[export_action], ) ) ) @@ -562,8 +561,7 @@ def _repr_html_(self): self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex") button = widgets.Button( - description="Toggle Pandas/Lux", - layout=widgets.Layout(width="140px", top="5px"), + description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"), ) self.output = widgets.Output() display(button, self.output) diff --git a/lux/core/series.py b/lux/core/series.py index 83b2c42b..f5fd31ca 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -154,8 +154,7 @@ def __repr__(self): # box = widgets.Box(layout=widgets.Layout(display='inline')) button = widgets.Button( - description="Toggle Pandas/Lux", - layout=widgets.Layout(width="140px", top="5px"), + description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"), ) ldf.output = widgets.Output() # box.children = [button,output] diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py index a6e31dab..e2469675 100644 --- a/lux/core/sqltable.py +++ b/lux/core/sqltable.py @@ -104,8 +104,7 @@ def set_SQL_table(self, t_name): error_str = str(error) if f'relation "{t_name}" does not exist' in error_str: warnings.warn( - f"\nThe table '{t_name}' does not exist in your database./", - stacklevel=2, + f"\nThe table '{t_name}' does not exist in your database./", stacklevel=2, ) def _repr_html_(self): @@ -138,24 +137,12 @@ def _repr_html_(self): self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex") button = widgets.Button( - description="Toggle Data Preview/Lux", - layout=widgets.Layout(width="200px", top="5px"), + description="Toggle Table/Lux", + layout=widgets.Layout(width="200px", top="6px", bottom="6px"), ) self.output = widgets.Output() lux.config.executor.execute_preview(self) - notification = HTML(""" - Note: data displayed is just a preview of the database table. No Pandas functionality on this data is supported.""") - - display(button, notification, self.output) + display(button, self.output) def on_button_clicked(b): with self.output: @@ -163,7 +150,10 @@ def on_button_clicked(b): self._toggle_pandas_display = not self._toggle_pandas_display clear_output() if self._toggle_pandas_display: - display(self._sampled.display_pandas()) + notification = widgets.Label( + value="Preview of the database table: " + self.table_name + ) + display(notification, self._sampled.display_pandas()) else: # b.layout.display = "none" display(self._widget) diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py index e5534821..8983876a 100644 --- a/lux/executor/PandasExecutor.py +++ b/lux/executor/PandasExecutor.py @@ -216,10 +216,7 @@ def execute_aggregate(vis: Vis, isFiltered=True): } ) vis._vis_data = vis.data.merge( - df, - on=[columns[0], columns[1]], - how="right", - suffixes=["", "_right"], + df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"], ) for col in columns[2:]: vis.data[col] = vis.data[col].fillna(0) # Triggers __setitem__ @@ -367,10 +364,7 @@ def execute_2D_binning(vis: Vis): if color_attr.data_type == "nominal": # Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0]) result = groups.agg( - [ - ("count", "count"), - (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]), - ] + [("count", "count"), (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),] ).reset_index() elif color_attr.data_type == "quantitative" or color_attr.data_type == "temporal": # Compute the average of all values in the bin diff --git a/lux/executor/SQLExecutor.py b/lux/executor/SQLExecutor.py index fd063c0a..96112bd2 100644 --- a/lux/executor/SQLExecutor.py +++ b/lux/executor/SQLExecutor.py @@ -38,8 +38,7 @@ def execute_sampling(tbl: LuxSQLTable): SAMPLE_FRAC = 0.2 length_query = pandas.read_sql( - "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), - lux.config.SQLconnection, + "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), lux.config.SQLconnection, ) limit = int(list(length_query["length"])[0]) * SAMPLE_FRAC tbl._sampled = pandas.read_sql( @@ -123,8 +122,7 @@ def add_quotes(var_name): required_variables = ",".join(required_variables) row_count = list( pandas.read_sql( - f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", - lux.config.SQLconnection, + f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", lux.config.SQLconnection, )["count"] )[0] if row_count > lux.config.sampling_cap: @@ -229,48 +227,42 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True): # generates query for colored barchart case if has_color: if agg_func == "mean": - agg_query = ( - 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - tbl.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, - ) + agg_query = 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "sum": - agg_query = ( - 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - tbl.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, - ) + agg_query = 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "max": - agg_query = ( - 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - tbl.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, - ) + agg_query = 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) @@ -330,10 +322,7 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True): } ) view._vis_data = view._vis_data.merge( - df, - on=[columns[0], columns[1]], - how="right", - suffixes=["", "_right"], + df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"], ) for col in columns[2:]: view._vis_data[col] = view._vis_data[col].fillna(0) # Triggers __setitem__ @@ -402,10 +391,7 @@ def execute_binning(view: Vis, tbl: LuxSQLTable): upper_edges = ",".join(upper_edges) view_filter, filter_vars = SQLExecutor.execute_filter(view) bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format( - bin_attribute.attribute, - "{" + upper_edges + "}", - tbl.table_name, - where_clause, + bin_attribute.attribute, "{" + upper_edges + "}", tbl.table_name, where_clause, ) bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection) @@ -420,13 +406,11 @@ def execute_binning(view: Vis, tbl: LuxSQLTable): else: bin_centers = np.array([(attr_min + attr_min + bin_width) / 2]) bin_centers = np.append( - bin_centers, - np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0), + bin_centers, np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0), ) if attr_type == int: bin_centers = np.append( - bin_centers, - math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2), + bin_centers, math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2), ) else: bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2) @@ -567,10 +551,7 @@ def execute_filter(view: Vis): else: where_clause.append("AND") where_clause.extend( - [ - '"' + str(a.attribute) + '"', - "IS NOT NULL", - ] + ['"' + str(a.attribute) + '"', "IS NOT NULL",] ) if where_clause == []: @@ -649,8 +630,7 @@ def compute_stats(self, tbl: LuxSQLTable): tbl.unique_values = {} tbl._min_max = {} length_query = pandas.read_sql( - "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), - lux.config.SQLconnection, + "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), lux.config.SQLconnection, ) tbl.length = list(length_query["length"])[0] @@ -687,10 +667,7 @@ def get_cardinality(self, tbl: LuxSQLTable): card_query = 'SELECT Count(Distinct("{}")) FROM {} WHERE "{}" IS NOT NULL'.format( attr, tbl.table_name, attr ) - card_data = pandas.read_sql( - card_query, - lux.config.SQLconnection, - ) + card_data = pandas.read_sql(card_query, lux.config.SQLconnection,) cardinality[attr] = list(card_data["count"])[0] tbl.cardinality = cardinality @@ -713,10 +690,7 @@ def get_unique_values(self, tbl: LuxSQLTable): unique_query = 'SELECT Distinct("{}") FROM {} WHERE "{}" IS NOT NULL'.format( attr, tbl.table_name, attr ) - unique_data = pandas.read_sql( - unique_query, - lux.config.SQLconnection, - ) + unique_data = pandas.read_sql(unique_query, lux.config.SQLconnection,) unique_vals[attr] = list(unique_data[attr]) tbl.unique_values = unique_vals diff --git a/lux/interestingness/interestingness.py b/lux/interestingness/interestingness.py index 32b0f1db..a95efea9 100644 --- a/lux/interestingness/interestingness.py +++ b/lux/interestingness/interestingness.py @@ -198,11 +198,7 @@ def weighted_correlation(x, y, w): def deviation_from_overall( - vis: Vis, - ldf: LuxDataFrame, - filter_specs: list, - msr_attribute: str, - exclude_nan: bool = True, + vis: Vis, ldf: LuxDataFrame, filter_specs: list, msr_attribute: str, exclude_nan: bool = True, ) -> int: """ Difference in bar chart/histogram shape from overall chart diff --git a/lux/interestingness/similarity.py b/lux/interestingness/similarity.py index 8d810909..017b97b5 100644 --- a/lux/interestingness/similarity.py +++ b/lux/interestingness/similarity.py @@ -68,8 +68,7 @@ def interpolate(vis, length): yVals[count - 1] + (interpolated_x - xVals[count - 1]) / x_diff * yDiff ) vis.data = pd.DataFrame( - list(zip(interpolated_x_vals, interpolated_y_vals)), - columns=[xAxis, yAxis], + list(zip(interpolated_x_vals, interpolated_y_vals)), columns=[xAxis, yAxis], ) diff --git a/lux/processor/Compiler.py b/lux/processor/Compiler.py index 96d95d10..85c49cfb 100644 --- a/lux/processor/Compiler.py +++ b/lux/processor/Compiler.py @@ -281,10 +281,7 @@ def line_or_bar_or_geo(ldf, dimension: Clause, measure: Clause): # ShowMe logic + additional heuristics # count_col = Clause( attribute="count()", data_model="measure") count_col = Clause( - attribute="Record", - aggregation="count", - data_model="measure", - data_type="quantitative", + attribute="Record", aggregation="count", data_model="measure", data_type="quantitative", ) auto_channel = {} if ndim == 0 and nmsr == 1: @@ -475,9 +472,7 @@ def populate_wildcard_options(_inferred_intent: List[Clause], ldf: LuxDataFrame) options = ldf.unique_values[attr] specInd = _inferred_intent.index(clause) _inferred_intent[specInd] = Clause( - attribute=clause.attribute, - filter_op="=", - value=list(options), + attribute=clause.attribute, filter_op="=", value=list(options), ) else: options.extend(convert_to_list(clause.value)) diff --git a/lux/vislib/altair/AltairRenderer.py b/lux/vislib/altair/AltairRenderer.py index 2c1ab206..deba1592 100644 --- a/lux/vislib/altair/AltairRenderer.py +++ b/lux/vislib/altair/AltairRenderer.py @@ -128,8 +128,7 @@ def create_vis(self, vis, standalone=True): found_variable = "df" if standalone: chart.code = chart.code.replace( - "placeholder_variable", - f"pd.DataFrame({str(vis.data.to_dict())})", + "placeholder_variable", f"pd.DataFrame({str(vis.data.to_dict())})", ) else: # TODO: Placeholder (need to read dynamically via locals()) diff --git a/lux/vislib/altair/Choropleth.py b/lux/vislib/altair/Choropleth.py index bf71b010..c2675883 100644 --- a/lux/vislib/altair/Choropleth.py +++ b/lux/vislib/altair/Choropleth.py @@ -68,9 +68,7 @@ def initialize_chart(self): points = ( alt.Chart(geo_map) .mark_geoshape() - .encode( - color=f"{y_attr_abv}:Q", - ) + .encode(color=f"{y_attr_abv}:Q",) .transform_lookup(lookup="id", from_=alt.LookupData(self.data, x_attr_abv, [y_attr_abv])) .project(type=map_type) .properties( diff --git a/lux/vislib/altair/Heatmap.py b/lux/vislib/altair/Heatmap.py index f83a3bbb..2f743b32 100644 --- a/lux/vislib/altair/Heatmap.py +++ b/lux/vislib/altair/Heatmap.py @@ -71,10 +71,7 @@ def initialize_chart(self): ), y2=alt.Y2("yBinEnd"), opacity=alt.Opacity( - "count", - type="quantitative", - scale=alt.Scale(type="log"), - legend=None, + "count", type="quantitative", scale=alt.Scale(type="log"), legend=None, ), ) ) diff --git a/tests/test_action.py b/tests/test_action.py index 893c13b7..2cb49e86 100644 --- a/tests/test_action.py +++ b/tests/test_action.py @@ -184,8 +184,7 @@ def test_year_filter_value(global_var): lambda vis: len( list( filter( - lambda clause: clause.value != "" and clause.attribute == "Year", - vis._intent, + lambda clause: clause.value != "" and clause.attribute == "Year", vis._intent, ) ) ) @@ -215,16 +214,10 @@ def test_similarity(global_var): ranked_list = df.recommendation["Similarity"] japan_vis = list( - filter( - lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan", - ranked_list, - ) + filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan", ranked_list,) )[0] europe_vis = list( - filter( - lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe", - ranked_list, - ) + filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe", ranked_list,) )[0] assert japan_vis.score > europe_vis.score df.clear_intent() @@ -247,16 +240,10 @@ def test_similarity2(): ranked_list = df.recommendation["Similarity"] morrisville_vis = list( - filter( - lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville", - ranked_list, - ) + filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville", ranked_list,) )[0] watertown_vis = list( - filter( - lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown", - ranked_list, - ) + filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown", ranked_list,) )[0] assert morrisville_vis.score > watertown_vis.score diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 28e80f8a..d30b824f 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -168,10 +168,7 @@ def test_underspecified_vis_collection_zval(global_var): # check if the number of charts is correct df = pytest.car_df vlst = VisList( - [ - lux.Clause(attribute="Origin", filter_op="=", value="?"), - lux.Clause(attribute="MilesPerGal"), - ], + [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),], df, ) assert len(vlst) == 3 @@ -185,10 +182,7 @@ def test_underspecified_vis_collection_zval(global_var): lux.config.set_SQL_connection(connection) sql_df = lux.LuxSQLTable(table_name="cars") vlst = VisList( - [ - lux.Clause(attribute="origin", filter_op="=", value="?"), - lux.Clause(attribute="milespergal"), - ], + [lux.Clause(attribute="origin", filter_op="=", value="?"), lux.Clause(attribute="milespergal"),], sql_df, ) assert len(vlst) == 3 @@ -285,10 +279,7 @@ def test_specified_channel_enforced_vis_collection(global_var): df = pytest.car_df # change pandas dtype for the column "Year" to datetype df["Year"] = pd.to_datetime(df["Year"], format="%Y") - visList = VisList( - [lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], - df, - ) + visList = VisList([lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], df,) for vis in visList: check_attribute_on_channel(vis, "MilesPerGal", "x") @@ -304,22 +295,13 @@ def test_autoencoding_scatter(global_var): check_attribute_on_channel(vis, "Weight", "y") # Partial channel specified - vis = Vis( - [ - lux.Clause(attribute="MilesPerGal", channel="y"), - lux.Clause(attribute="Weight"), - ], - df, - ) + vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,) check_attribute_on_channel(vis, "MilesPerGal", "y") check_attribute_on_channel(vis, "Weight", "x") # Full channel specified vis = Vis( - [ - lux.Clause(attribute="MilesPerGal", channel="y"), - lux.Clause(attribute="Weight", channel="x"), - ], + [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),], df, ) check_attribute_on_channel(vis, "MilesPerGal", "y") @@ -339,8 +321,7 @@ def test_autoencoding_scatter(global_var): lux.config.set_SQL_connection(connection) sql_df = lux.LuxSQLTable(table_name="cars") visList = VisList( - [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")], - sql_df, + [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")], sql_df, ) for vis in visList: check_attribute_on_channel(vis, "milespergal", "x") @@ -358,22 +339,13 @@ def test_autoencoding_scatter(): check_attribute_on_channel(vis, "Weight", "y") # Partial channel specified - vis = Vis( - [ - lux.Clause(attribute="MilesPerGal", channel="y"), - lux.Clause(attribute="Weight"), - ], - df, - ) + vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,) check_attribute_on_channel(vis, "MilesPerGal", "y") check_attribute_on_channel(vis, "Weight", "x") # Full channel specified vis = Vis( - [ - lux.Clause(attribute="MilesPerGal", channel="y"), - lux.Clause(attribute="Weight", channel="x"), - ], + [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),], df, ) check_attribute_on_channel(vis, "MilesPerGal", "y") @@ -398,21 +370,14 @@ def test_autoencoding_scatter(): # Partial channel specified vis = Vis( - [ - lux.Clause(attribute="milespergal", channel="y"), - lux.Clause(attribute="weight"), - ], - sql_df, + [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight"),], sql_df, ) check_attribute_on_channel(vis, "milespergal", "y") check_attribute_on_channel(vis, "weight", "x") # Full channel specified vis = Vis( - [ - lux.Clause(attribute="milespergal", channel="y"), - lux.Clause(attribute="weight", channel="x"), - ], + [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight", channel="x"),], sql_df, ) check_attribute_on_channel(vis, "milespergal", "y") @@ -464,22 +429,13 @@ def test_autoencoding_line_chart(global_var): check_attribute_on_channel(vis, "Acceleration", "y") # Partial channel specified - vis = Vis( - [ - lux.Clause(attribute="Year", channel="y"), - lux.Clause(attribute="Acceleration"), - ], - df, - ) + vis = Vis([lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration"),], df,) check_attribute_on_channel(vis, "Year", "y") check_attribute_on_channel(vis, "Acceleration", "x") # Full channel specified vis = Vis( - [ - lux.Clause(attribute="Year", channel="y"), - lux.Clause(attribute="Acceleration", channel="x"), - ], + [lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration", channel="x"),], df, ) check_attribute_on_channel(vis, "Year", "y") @@ -505,21 +461,14 @@ def test_autoencoding_line_chart(global_var): # Partial channel specified vis = Vis( - [ - lux.Clause(attribute="year", channel="y"), - lux.Clause(attribute="acceleration"), - ], - sql_df, + [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration"),], sql_df, ) check_attribute_on_channel(vis, "year", "y") check_attribute_on_channel(vis, "acceleration", "x") # Full channel specified vis = Vis( - [ - lux.Clause(attribute="year", channel="y"), - lux.Clause(attribute="acceleration", channel="x"), - ], + [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration", channel="x"),], sql_df, ) check_attribute_on_channel(vis, "year", "y") @@ -628,10 +577,7 @@ def test_populate_options(global_var): assert list_equal(list(col_set), list(df.columns)) df.set_intent( - [ - lux.Clause(attribute="?", data_model="measure"), - lux.Clause(attribute="MilesPerGal"), - ] + [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="MilesPerGal"),] ) df._repr_html_() col_set = set() @@ -639,8 +585,7 @@ def test_populate_options(global_var): for clause in specOptions: col_set.add(clause.attribute) assert list_equal( - list(col_set), - ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"], + list(col_set), ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"], ) df.clear_intent() @@ -656,10 +601,7 @@ def test_populate_options(global_var): assert list_equal(list(col_set), list(sql_df.columns)) sql_df.set_intent( - [ - lux.Clause(attribute="?", data_model="measure"), - lux.Clause(attribute="milespergal"), - ] + [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="milespergal"),] ) sql_df._repr_html_() col_set = set() @@ -667,8 +609,7 @@ def test_populate_options(global_var): for clause in specOptions: col_set.add(clause.attribute) assert list_equal( - list(col_set), - ["acceleration", "weight", "horsepower", "milespergal", "displacement"], + list(col_set), ["acceleration", "weight", "horsepower", "milespergal", "displacement"], ) @@ -678,10 +619,7 @@ def test_remove_all_invalid(global_var): df["Year"] = pd.to_datetime(df["Year"], format="%Y") # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): df.set_intent( - [ - lux.Clause(attribute="Origin", filter_op="=", value="USA"), - lux.Clause(attribute="Origin"), - ] + [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Origin"),] ) df._repr_html_() assert len(df.current_vis) == 0 @@ -693,10 +631,7 @@ def test_remove_all_invalid(global_var): sql_df = lux.LuxSQLTable(table_name="cars") # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): sql_df.set_intent( - [ - lux.Clause(attribute="origin", filter_op="=", value="USA"), - lux.Clause(attribute="origin"), - ] + [lux.Clause(attribute="origin", filter_op="=", value="USA"), lux.Clause(attribute="origin"),] ) sql_df._repr_html_() assert len(sql_df.current_vis) == 0 diff --git a/tests/test_dates.py b/tests/test_dates.py index dc530fc7..48d514b9 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -44,10 +44,7 @@ def test_period_selection(global_var): ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq="A") ldf.set_intent( - [ - lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), - lux.Clause(attribute="Year"), - ] + [lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), lux.Clause(attribute="Year"),] ) lux.config.executor.execute(ldf.current_vis, ldf) diff --git a/tests/test_interestingness.py b/tests/test_interestingness.py index 7e6036f9..1ef929e3 100644 --- a/tests/test_interestingness.py +++ b/tests/test_interestingness.py @@ -64,10 +64,7 @@ def test_interestingness_1_0_1(global_var): df["Year"] = pd.to_datetime(df["Year"], format="%Y") df.set_intent( - [ - lux.Clause(attribute="Origin", filter_op="=", value="USA"), - lux.Clause(attribute="Cylinders"), - ] + [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Cylinders"),] ) df._repr_html_() assert df.current_vis[0].score == 0 @@ -124,10 +121,7 @@ def test_interestingness_0_1_1(global_var): df["Year"] = pd.to_datetime(df["Year"], format="%Y") df.set_intent( - [ - lux.Clause(attribute="Origin", filter_op="=", value="?"), - lux.Clause(attribute="MilesPerGal"), - ] + [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),] ) df._repr_html_() assert interestingness(df.recommendation["Current Vis"][0], df) != None diff --git a/tests/test_pandas_coverage.py b/tests/test_pandas_coverage.py index 21014f60..e30b0aac 100644 --- a/tests/test_pandas_coverage.py +++ b/tests/test_pandas_coverage.py @@ -173,8 +173,7 @@ def test_groupby_agg_big(global_var): assert len(new_df.cardinality) == 8 year_vis = list( filter( - lambda vis: vis.get_attr_by_attr_name("Year") != [], - new_df.recommendation["Column Groups"], + lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Column Groups"], ) )[0] assert year_vis.mark == "bar" @@ -182,10 +181,7 @@ def test_groupby_agg_big(global_var): new_df = new_df.T new_df._repr_html_() year_vis = list( - filter( - lambda vis: vis.get_attr_by_attr_name("Year") != [], - new_df.recommendation["Row Groups"], - ) + filter(lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Row Groups"],) )[0] assert year_vis.mark == "bar" assert year_vis.get_attr_by_channel("x")[0].attribute == "Year" diff --git a/tests/test_parser.py b/tests/test_parser.py index 333977aa..d274a1f4 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -121,7 +121,6 @@ def test_validator_invalid_attribute(global_var): df = pytest.college_df with pytest.raises(KeyError, match="'blah'"): with pytest.warns( - UserWarning, - match="The input attribute 'blah' does not exist in the DataFrame.", + UserWarning, match="The input attribute 'blah' does not exist in the DataFrame.", ): df.intent = ["blah"] diff --git a/tests/test_vis.py b/tests/test_vis.py index 55e10b39..0f210d82 100644 --- a/tests/test_vis.py +++ b/tests/test_vis.py @@ -154,10 +154,7 @@ def test_vis_list_custom_title_override(global_var): vcLst = [] for attribute in ["Sport", "Year", "Height", "HostRegion", "SportType"]: - vis = Vis( - [lux.Clause("Weight"), lux.Clause(attribute)], - title="overriding dummy title", - ) + vis = Vis([lux.Clause("Weight"), lux.Clause(attribute)], title="overriding dummy title",) vcLst.append(vis) vlist = VisList(vcLst, df) for v in vlist: From d68b16d712a8203157a5a00ef6b0f8fcd7c50eaf Mon Sep 17 00:00:00 2001 From: Sophia Huang <6860749+sophiahhuang@users.noreply.github.com> Date: Sat, 20 Mar 2021 02:10:55 +0800 Subject: [PATCH 3/4] Remove Out of Date LuxSQLTableNotice --- lux/core/sqltable.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py index 6e557a0a..842e0fb9 100644 --- a/lux/core/sqltable.py +++ b/lux/core/sqltable.py @@ -122,17 +122,7 @@ def _repr_html_(self): ) self.output = widgets.Output() lux.config.executor.execute_preview(self) - notice = HTML( - """ - Please note, the data shown here is just a preview of the database table. You will be unable to perform Pandas functionality on this data.""" - ) - - display(button, notice, self.output) + display(button, self.output) def on_button_clicked(b): with self.output: From f190a55d4e26bec8766ab7f178b0298bb8fca329 Mon Sep 17 00:00:00 2001 From: 19thyneb Date: Fri, 19 Mar 2021 11:29:55 -0700 Subject: [PATCH 4/4] Black Reformatting --- lux/_config/config.py | 21 +++-- lux/action/generalize.py | 7 +- lux/action/row_group.py | 8 +- lux/core/frame.py | 6 +- lux/core/series.py | 3 +- lux/core/sqltable.py | 3 +- lux/executor/PandasExecutor.py | 10 ++- lux/executor/SQLExecutor.py | 97 ++++++++++++++--------- lux/interestingness/interestingness.py | 6 +- lux/interestingness/similarity.py | 3 +- lux/processor/Compiler.py | 9 ++- lux/vislib/altair/AltairRenderer.py | 3 +- lux/vislib/altair/Choropleth.py | 4 +- lux/vislib/altair/Heatmap.py | 5 +- tests/test_action.py | 23 ++++-- tests/test_compiler.py | 105 ++++++++++++++++++++----- tests/test_dates.py | 5 +- tests/test_interestingness.py | 10 ++- tests/test_pandas_coverage.py | 8 +- tests/test_parser.py | 3 +- tests/test_vis.py | 5 +- 21 files changed, 254 insertions(+), 90 deletions(-) diff --git a/lux/_config/config.py b/lux/_config/config.py index 65c778ef..12db034a 100644 --- a/lux/_config/config.py +++ b/lux/_config/config.py @@ -55,7 +55,8 @@ def topk(self, k: Union[int, bool]): self._topk = k else: warnings.warn( - "Parameter to lux.config.topk must be an integer or a boolean.", stacklevel=2, + "Parameter to lux.config.topk must be an integer or a boolean.", + stacklevel=2, ) @property @@ -98,7 +99,8 @@ def pandas_fallback(self, fallback: bool) -> None: self._pandas_fallback = fallback else: warnings.warn( - "The flag for Pandas fallback must be a boolean.", stacklevel=2, + "The flag for Pandas fallback must be a boolean.", + stacklevel=2, ) @property @@ -118,7 +120,8 @@ def interestingness_fallback(self, fallback: bool) -> None: self._interestingness_fallback = fallback else: warnings.warn( - "The flag for interestingness fallback must be a boolean.", stacklevel=2, + "The flag for interestingness fallback must be a boolean.", + stacklevel=2, ) @property @@ -144,7 +147,8 @@ def sampling_cap(self, sample_number: int) -> None: self._sampling_cap = sample_number else: warnings.warn( - "The cap on the number samples must be an integer.", stacklevel=2, + "The cap on the number samples must be an integer.", + stacklevel=2, ) @property @@ -172,7 +176,8 @@ def sampling_start(self, sample_number: int) -> None: self._sampling_start = sample_number else: warnings.warn( - "The sampling starting point must be an integer.", stacklevel=2, + "The sampling starting point must be an integer.", + stacklevel=2, ) @property @@ -197,7 +202,8 @@ def sampling(self, sample_flag: bool) -> None: self._sampling_flag = sample_flag else: warnings.warn( - "The flag for sampling must be a boolean.", stacklevel=2, + "The flag for sampling must be a boolean.", + stacklevel=2, ) @property @@ -222,7 +228,8 @@ def heatmap(self, heatmap_flag: bool) -> None: self._heatmap_flag = heatmap_flag else: warnings.warn( - "The flag for enabling/disabling heatmaps must be a boolean.", stacklevel=2, + "The flag for enabling/disabling heatmaps must be a boolean.", + stacklevel=2, ) @property diff --git a/lux/action/generalize.py b/lux/action/generalize.py index 84504132..385891d4 100644 --- a/lux/action/generalize.py +++ b/lux/action/generalize.py @@ -78,7 +78,12 @@ def generalize(ldf): for clause in filters: # new_spec = ldf._intent.copy() # new_spec.remove_column_from_spec(new_spec.attribute) - temp_vis = Vis(ldf.current_vis[0]._inferred_intent.copy(), source=ldf, title="Overall", score=0,) + temp_vis = Vis( + ldf.current_vis[0]._inferred_intent.copy(), + source=ldf, + title="Overall", + score=0, + ) temp_vis.remove_filter_from_spec(clause.value) output.append(temp_vis) diff --git a/lux/action/row_group.py b/lux/action/row_group.py index 7407ab5b..5555c01f 100644 --- a/lux/action/row_group.py +++ b/lux/action/row_group.py @@ -50,7 +50,13 @@ def row_group(ldf): # rowdf.cardinality["index"]=len(rowdf) # if isinstance(ldf.columns,pd.DatetimeIndex): # rowdf.data_type_lookup[dim_name]="temporal" - vis = Vis([dim_name, lux.Clause(row.name, data_model="measure", aggregation=None),], rowdf,) + vis = Vis( + [ + dim_name, + lux.Clause(row.name, data_model="measure", aggregation=None), + ], + rowdf, + ) collection.append(vis) vlst = VisList(collection) # Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data diff --git a/lux/core/frame.py b/lux/core/frame.py index f5fe5ca5..4514ed4d 100644 --- a/lux/core/frame.py +++ b/lux/core/frame.py @@ -496,7 +496,8 @@ def exported(self) -> Union[Dict[str, VisList], VisList]: exported_vis = VisList( list( map( - self._recommendation[export_action].__getitem__, exported_vis_lst[export_action], + self._recommendation[export_action].__getitem__, + exported_vis_lst[export_action], ) ) ) @@ -566,7 +567,8 @@ def _repr_html_(self): self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex") button = widgets.Button( - description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"), + description="Toggle Pandas/Lux", + layout=widgets.Layout(width="140px", top="5px"), ) self.output = widgets.Output() display(button, self.output) diff --git a/lux/core/series.py b/lux/core/series.py index 7cea6639..3c717642 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -154,7 +154,8 @@ def __repr__(self): # box = widgets.Box(layout=widgets.Layout(display='inline')) button = widgets.Button( - description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"), + description="Toggle Pandas/Lux", + layout=widgets.Layout(width="140px", top="5px"), ) ldf.output = widgets.Output() # box.children = [button,output] diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py index 842e0fb9..e51b8c09 100644 --- a/lux/core/sqltable.py +++ b/lux/core/sqltable.py @@ -84,7 +84,8 @@ def set_SQL_table(self, t_name): error_str = str(error) if f'relation "{t_name}" does not exist' in error_str: warnings.warn( - f"\nThe table '{t_name}' does not exist in your database./", stacklevel=2, + f"\nThe table '{t_name}' does not exist in your database./", + stacklevel=2, ) def _repr_html_(self): diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py index 8983876a..e5534821 100644 --- a/lux/executor/PandasExecutor.py +++ b/lux/executor/PandasExecutor.py @@ -216,7 +216,10 @@ def execute_aggregate(vis: Vis, isFiltered=True): } ) vis._vis_data = vis.data.merge( - df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"], + df, + on=[columns[0], columns[1]], + how="right", + suffixes=["", "_right"], ) for col in columns[2:]: vis.data[col] = vis.data[col].fillna(0) # Triggers __setitem__ @@ -364,7 +367,10 @@ def execute_2D_binning(vis: Vis): if color_attr.data_type == "nominal": # Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0]) result = groups.agg( - [("count", "count"), (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),] + [ + ("count", "count"), + (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]), + ] ).reset_index() elif color_attr.data_type == "quantitative" or color_attr.data_type == "temporal": # Compute the average of all values in the bin diff --git a/lux/executor/SQLExecutor.py b/lux/executor/SQLExecutor.py index 93f010c0..abf43f06 100644 --- a/lux/executor/SQLExecutor.py +++ b/lux/executor/SQLExecutor.py @@ -38,7 +38,8 @@ def execute_sampling(tbl: LuxSQLTable): SAMPLE_FRAC = 0.2 length_query = pandas.read_sql( - "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), lux.config.SQLconnection, + "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), + lux.config.SQLconnection, ) limit = int(list(length_query["length"])[0]) * SAMPLE_FRAC tbl._sampled = pandas.read_sql( @@ -122,7 +123,8 @@ def add_quotes(var_name): required_variables = ",".join(required_variables) row_count = list( pandas.read_sql( - f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", lux.config.SQLconnection, + f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", + lux.config.SQLconnection, )["count"] )[0] if row_count > lux.config.sampling_cap: @@ -227,42 +229,48 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True): # generates query for colored barchart case if has_color: if agg_func == "mean": - agg_query = 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - tbl.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, + agg_query = ( + 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "sum": - agg_query = 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - tbl.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, + agg_query = ( + 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "max": - agg_query = 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - tbl.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, + agg_query = ( + 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) @@ -322,7 +330,10 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True): } ) view._vis_data = view._vis_data.merge( - df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"], + df, + on=[columns[0], columns[1]], + how="right", + suffixes=["", "_right"], ) for col in columns[2:]: view._vis_data[col] = view._vis_data[col].fillna(0) # Triggers __setitem__ @@ -391,7 +402,10 @@ def execute_binning(view: Vis, tbl: LuxSQLTable): upper_edges = ",".join(upper_edges) view_filter, filter_vars = SQLExecutor.execute_filter(view) bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format( - bin_attribute.attribute, "{" + upper_edges + "}", tbl.table_name, where_clause, + bin_attribute.attribute, + "{" + upper_edges + "}", + tbl.table_name, + where_clause, ) bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection) @@ -406,11 +420,13 @@ def execute_binning(view: Vis, tbl: LuxSQLTable): else: bin_centers = np.array([(attr_min + attr_min + bin_width) / 2]) bin_centers = np.append( - bin_centers, np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0), + bin_centers, + np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0), ) if attr_type == int: bin_centers = np.append( - bin_centers, math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2), + bin_centers, + math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2), ) else: bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2) @@ -551,7 +567,10 @@ def execute_filter(view: Vis): else: where_clause.append("AND") where_clause.extend( - ['"' + str(a.attribute) + '"', "IS NOT NULL",] + [ + '"' + str(a.attribute) + '"', + "IS NOT NULL", + ] ) if where_clause == []: @@ -668,7 +687,10 @@ def get_cardinality(self, tbl: LuxSQLTable): card_query = 'SELECT Count(Distinct("{}")) FROM {} WHERE "{}" IS NOT NULL'.format( attr, tbl.table_name, attr ) - card_data = pandas.read_sql(card_query, lux.config.SQLconnection,) + card_data = pandas.read_sql( + card_query, + lux.config.SQLconnection, + ) cardinality[attr] = list(card_data["count"])[0] tbl.cardinality = cardinality @@ -691,7 +713,10 @@ def get_unique_values(self, tbl: LuxSQLTable): unique_query = 'SELECT Distinct("{}") FROM {} WHERE "{}" IS NOT NULL'.format( attr, tbl.table_name, attr ) - unique_data = pandas.read_sql(unique_query, lux.config.SQLconnection,) + unique_data = pandas.read_sql( + unique_query, + lux.config.SQLconnection, + ) unique_vals[attr] = list(unique_data[attr]) tbl.unique_values = unique_vals diff --git a/lux/interestingness/interestingness.py b/lux/interestingness/interestingness.py index a95efea9..32b0f1db 100644 --- a/lux/interestingness/interestingness.py +++ b/lux/interestingness/interestingness.py @@ -198,7 +198,11 @@ def weighted_correlation(x, y, w): def deviation_from_overall( - vis: Vis, ldf: LuxDataFrame, filter_specs: list, msr_attribute: str, exclude_nan: bool = True, + vis: Vis, + ldf: LuxDataFrame, + filter_specs: list, + msr_attribute: str, + exclude_nan: bool = True, ) -> int: """ Difference in bar chart/histogram shape from overall chart diff --git a/lux/interestingness/similarity.py b/lux/interestingness/similarity.py index 017b97b5..8d810909 100644 --- a/lux/interestingness/similarity.py +++ b/lux/interestingness/similarity.py @@ -68,7 +68,8 @@ def interpolate(vis, length): yVals[count - 1] + (interpolated_x - xVals[count - 1]) / x_diff * yDiff ) vis.data = pd.DataFrame( - list(zip(interpolated_x_vals, interpolated_y_vals)), columns=[xAxis, yAxis], + list(zip(interpolated_x_vals, interpolated_y_vals)), + columns=[xAxis, yAxis], ) diff --git a/lux/processor/Compiler.py b/lux/processor/Compiler.py index 85c49cfb..96d95d10 100644 --- a/lux/processor/Compiler.py +++ b/lux/processor/Compiler.py @@ -281,7 +281,10 @@ def line_or_bar_or_geo(ldf, dimension: Clause, measure: Clause): # ShowMe logic + additional heuristics # count_col = Clause( attribute="count()", data_model="measure") count_col = Clause( - attribute="Record", aggregation="count", data_model="measure", data_type="quantitative", + attribute="Record", + aggregation="count", + data_model="measure", + data_type="quantitative", ) auto_channel = {} if ndim == 0 and nmsr == 1: @@ -472,7 +475,9 @@ def populate_wildcard_options(_inferred_intent: List[Clause], ldf: LuxDataFrame) options = ldf.unique_values[attr] specInd = _inferred_intent.index(clause) _inferred_intent[specInd] = Clause( - attribute=clause.attribute, filter_op="=", value=list(options), + attribute=clause.attribute, + filter_op="=", + value=list(options), ) else: options.extend(convert_to_list(clause.value)) diff --git a/lux/vislib/altair/AltairRenderer.py b/lux/vislib/altair/AltairRenderer.py index deba1592..2c1ab206 100644 --- a/lux/vislib/altair/AltairRenderer.py +++ b/lux/vislib/altair/AltairRenderer.py @@ -128,7 +128,8 @@ def create_vis(self, vis, standalone=True): found_variable = "df" if standalone: chart.code = chart.code.replace( - "placeholder_variable", f"pd.DataFrame({str(vis.data.to_dict())})", + "placeholder_variable", + f"pd.DataFrame({str(vis.data.to_dict())})", ) else: # TODO: Placeholder (need to read dynamically via locals()) diff --git a/lux/vislib/altair/Choropleth.py b/lux/vislib/altair/Choropleth.py index c2675883..bf71b010 100644 --- a/lux/vislib/altair/Choropleth.py +++ b/lux/vislib/altair/Choropleth.py @@ -68,7 +68,9 @@ def initialize_chart(self): points = ( alt.Chart(geo_map) .mark_geoshape() - .encode(color=f"{y_attr_abv}:Q",) + .encode( + color=f"{y_attr_abv}:Q", + ) .transform_lookup(lookup="id", from_=alt.LookupData(self.data, x_attr_abv, [y_attr_abv])) .project(type=map_type) .properties( diff --git a/lux/vislib/altair/Heatmap.py b/lux/vislib/altair/Heatmap.py index 2f743b32..f83a3bbb 100644 --- a/lux/vislib/altair/Heatmap.py +++ b/lux/vislib/altair/Heatmap.py @@ -71,7 +71,10 @@ def initialize_chart(self): ), y2=alt.Y2("yBinEnd"), opacity=alt.Opacity( - "count", type="quantitative", scale=alt.Scale(type="log"), legend=None, + "count", + type="quantitative", + scale=alt.Scale(type="log"), + legend=None, ), ) ) diff --git a/tests/test_action.py b/tests/test_action.py index 2cb49e86..893c13b7 100644 --- a/tests/test_action.py +++ b/tests/test_action.py @@ -184,7 +184,8 @@ def test_year_filter_value(global_var): lambda vis: len( list( filter( - lambda clause: clause.value != "" and clause.attribute == "Year", vis._intent, + lambda clause: clause.value != "" and clause.attribute == "Year", + vis._intent, ) ) ) @@ -214,10 +215,16 @@ def test_similarity(global_var): ranked_list = df.recommendation["Similarity"] japan_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan", ranked_list,) + filter( + lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan", + ranked_list, + ) )[0] europe_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe", ranked_list,) + filter( + lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe", + ranked_list, + ) )[0] assert japan_vis.score > europe_vis.score df.clear_intent() @@ -240,10 +247,16 @@ def test_similarity2(): ranked_list = df.recommendation["Similarity"] morrisville_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville", ranked_list,) + filter( + lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville", + ranked_list, + ) )[0] watertown_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown", ranked_list,) + filter( + lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown", + ranked_list, + ) )[0] assert morrisville_vis.score > watertown_vis.score diff --git a/tests/test_compiler.py b/tests/test_compiler.py index d30b824f..28e80f8a 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -168,7 +168,10 @@ def test_underspecified_vis_collection_zval(global_var): # check if the number of charts is correct df = pytest.car_df vlst = VisList( - [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),], + [ + lux.Clause(attribute="Origin", filter_op="=", value="?"), + lux.Clause(attribute="MilesPerGal"), + ], df, ) assert len(vlst) == 3 @@ -182,7 +185,10 @@ def test_underspecified_vis_collection_zval(global_var): lux.config.set_SQL_connection(connection) sql_df = lux.LuxSQLTable(table_name="cars") vlst = VisList( - [lux.Clause(attribute="origin", filter_op="=", value="?"), lux.Clause(attribute="milespergal"),], + [ + lux.Clause(attribute="origin", filter_op="=", value="?"), + lux.Clause(attribute="milespergal"), + ], sql_df, ) assert len(vlst) == 3 @@ -279,7 +285,10 @@ def test_specified_channel_enforced_vis_collection(global_var): df = pytest.car_df # change pandas dtype for the column "Year" to datetype df["Year"] = pd.to_datetime(df["Year"], format="%Y") - visList = VisList([lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], df,) + visList = VisList( + [lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], + df, + ) for vis in visList: check_attribute_on_channel(vis, "MilesPerGal", "x") @@ -295,13 +304,22 @@ def test_autoencoding_scatter(global_var): check_attribute_on_channel(vis, "Weight", "y") # Partial channel specified - vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,) + vis = Vis( + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight"), + ], + df, + ) check_attribute_on_channel(vis, "MilesPerGal", "y") check_attribute_on_channel(vis, "Weight", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),], + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight", channel="x"), + ], df, ) check_attribute_on_channel(vis, "MilesPerGal", "y") @@ -321,7 +339,8 @@ def test_autoencoding_scatter(global_var): lux.config.set_SQL_connection(connection) sql_df = lux.LuxSQLTable(table_name="cars") visList = VisList( - [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")], sql_df, + [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")], + sql_df, ) for vis in visList: check_attribute_on_channel(vis, "milespergal", "x") @@ -339,13 +358,22 @@ def test_autoencoding_scatter(): check_attribute_on_channel(vis, "Weight", "y") # Partial channel specified - vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,) + vis = Vis( + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight"), + ], + df, + ) check_attribute_on_channel(vis, "MilesPerGal", "y") check_attribute_on_channel(vis, "Weight", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),], + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight", channel="x"), + ], df, ) check_attribute_on_channel(vis, "MilesPerGal", "y") @@ -370,14 +398,21 @@ def test_autoencoding_scatter(): # Partial channel specified vis = Vis( - [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight"),], sql_df, + [ + lux.Clause(attribute="milespergal", channel="y"), + lux.Clause(attribute="weight"), + ], + sql_df, ) check_attribute_on_channel(vis, "milespergal", "y") check_attribute_on_channel(vis, "weight", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight", channel="x"),], + [ + lux.Clause(attribute="milespergal", channel="y"), + lux.Clause(attribute="weight", channel="x"), + ], sql_df, ) check_attribute_on_channel(vis, "milespergal", "y") @@ -429,13 +464,22 @@ def test_autoencoding_line_chart(global_var): check_attribute_on_channel(vis, "Acceleration", "y") # Partial channel specified - vis = Vis([lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration"),], df,) + vis = Vis( + [ + lux.Clause(attribute="Year", channel="y"), + lux.Clause(attribute="Acceleration"), + ], + df, + ) check_attribute_on_channel(vis, "Year", "y") check_attribute_on_channel(vis, "Acceleration", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration", channel="x"),], + [ + lux.Clause(attribute="Year", channel="y"), + lux.Clause(attribute="Acceleration", channel="x"), + ], df, ) check_attribute_on_channel(vis, "Year", "y") @@ -461,14 +505,21 @@ def test_autoencoding_line_chart(global_var): # Partial channel specified vis = Vis( - [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration"),], sql_df, + [ + lux.Clause(attribute="year", channel="y"), + lux.Clause(attribute="acceleration"), + ], + sql_df, ) check_attribute_on_channel(vis, "year", "y") check_attribute_on_channel(vis, "acceleration", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration", channel="x"),], + [ + lux.Clause(attribute="year", channel="y"), + lux.Clause(attribute="acceleration", channel="x"), + ], sql_df, ) check_attribute_on_channel(vis, "year", "y") @@ -577,7 +628,10 @@ def test_populate_options(global_var): assert list_equal(list(col_set), list(df.columns)) df.set_intent( - [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="MilesPerGal"),] + [ + lux.Clause(attribute="?", data_model="measure"), + lux.Clause(attribute="MilesPerGal"), + ] ) df._repr_html_() col_set = set() @@ -585,7 +639,8 @@ def test_populate_options(global_var): for clause in specOptions: col_set.add(clause.attribute) assert list_equal( - list(col_set), ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"], + list(col_set), + ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"], ) df.clear_intent() @@ -601,7 +656,10 @@ def test_populate_options(global_var): assert list_equal(list(col_set), list(sql_df.columns)) sql_df.set_intent( - [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="milespergal"),] + [ + lux.Clause(attribute="?", data_model="measure"), + lux.Clause(attribute="milespergal"), + ] ) sql_df._repr_html_() col_set = set() @@ -609,7 +667,8 @@ def test_populate_options(global_var): for clause in specOptions: col_set.add(clause.attribute) assert list_equal( - list(col_set), ["acceleration", "weight", "horsepower", "milespergal", "displacement"], + list(col_set), + ["acceleration", "weight", "horsepower", "milespergal", "displacement"], ) @@ -619,7 +678,10 @@ def test_remove_all_invalid(global_var): df["Year"] = pd.to_datetime(df["Year"], format="%Y") # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): df.set_intent( - [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Origin"),] + [ + lux.Clause(attribute="Origin", filter_op="=", value="USA"), + lux.Clause(attribute="Origin"), + ] ) df._repr_html_() assert len(df.current_vis) == 0 @@ -631,7 +693,10 @@ def test_remove_all_invalid(global_var): sql_df = lux.LuxSQLTable(table_name="cars") # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): sql_df.set_intent( - [lux.Clause(attribute="origin", filter_op="=", value="USA"), lux.Clause(attribute="origin"),] + [ + lux.Clause(attribute="origin", filter_op="=", value="USA"), + lux.Clause(attribute="origin"), + ] ) sql_df._repr_html_() assert len(sql_df.current_vis) == 0 diff --git a/tests/test_dates.py b/tests/test_dates.py index 48d514b9..dc530fc7 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -44,7 +44,10 @@ def test_period_selection(global_var): ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq="A") ldf.set_intent( - [lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), lux.Clause(attribute="Year"),] + [ + lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), + lux.Clause(attribute="Year"), + ] ) lux.config.executor.execute(ldf.current_vis, ldf) diff --git a/tests/test_interestingness.py b/tests/test_interestingness.py index 1ef929e3..7e6036f9 100644 --- a/tests/test_interestingness.py +++ b/tests/test_interestingness.py @@ -64,7 +64,10 @@ def test_interestingness_1_0_1(global_var): df["Year"] = pd.to_datetime(df["Year"], format="%Y") df.set_intent( - [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Cylinders"),] + [ + lux.Clause(attribute="Origin", filter_op="=", value="USA"), + lux.Clause(attribute="Cylinders"), + ] ) df._repr_html_() assert df.current_vis[0].score == 0 @@ -121,7 +124,10 @@ def test_interestingness_0_1_1(global_var): df["Year"] = pd.to_datetime(df["Year"], format="%Y") df.set_intent( - [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),] + [ + lux.Clause(attribute="Origin", filter_op="=", value="?"), + lux.Clause(attribute="MilesPerGal"), + ] ) df._repr_html_() assert interestingness(df.recommendation["Current Vis"][0], df) != None diff --git a/tests/test_pandas_coverage.py b/tests/test_pandas_coverage.py index e30b0aac..21014f60 100644 --- a/tests/test_pandas_coverage.py +++ b/tests/test_pandas_coverage.py @@ -173,7 +173,8 @@ def test_groupby_agg_big(global_var): assert len(new_df.cardinality) == 8 year_vis = list( filter( - lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Column Groups"], + lambda vis: vis.get_attr_by_attr_name("Year") != [], + new_df.recommendation["Column Groups"], ) )[0] assert year_vis.mark == "bar" @@ -181,7 +182,10 @@ def test_groupby_agg_big(global_var): new_df = new_df.T new_df._repr_html_() year_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Row Groups"],) + filter( + lambda vis: vis.get_attr_by_attr_name("Year") != [], + new_df.recommendation["Row Groups"], + ) )[0] assert year_vis.mark == "bar" assert year_vis.get_attr_by_channel("x")[0].attribute == "Year" diff --git a/tests/test_parser.py b/tests/test_parser.py index d274a1f4..333977aa 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -121,6 +121,7 @@ def test_validator_invalid_attribute(global_var): df = pytest.college_df with pytest.raises(KeyError, match="'blah'"): with pytest.warns( - UserWarning, match="The input attribute 'blah' does not exist in the DataFrame.", + UserWarning, + match="The input attribute 'blah' does not exist in the DataFrame.", ): df.intent = ["blah"] diff --git a/tests/test_vis.py b/tests/test_vis.py index 2a088ddb..4514be42 100644 --- a/tests/test_vis.py +++ b/tests/test_vis.py @@ -153,7 +153,10 @@ def test_vis_list_custom_title_override(global_var): vcLst = [] for attribute in ["Sport", "Year", "Height", "HostRegion", "SportType"]: - vis = Vis([lux.Clause("Weight"), lux.Clause(attribute)], title="overriding dummy title",) + vis = Vis( + [lux.Clause("Weight"), lux.Clause(attribute)], + title="overriding dummy title", + ) vcLst.append(vis) vlist = VisList(vcLst, df) for v in vlist: