From f190a55d4e26bec8766ab7f178b0298bb8fca329 Mon Sep 17 00:00:00 2001 From: 19thyneb Date: Fri, 19 Mar 2021 11:29:55 -0700 Subject: [PATCH] Black Reformatting --- lux/_config/config.py | 21 +++-- lux/action/generalize.py | 7 +- lux/action/row_group.py | 8 +- lux/core/frame.py | 6 +- lux/core/series.py | 3 +- lux/core/sqltable.py | 3 +- lux/executor/PandasExecutor.py | 10 ++- lux/executor/SQLExecutor.py | 97 ++++++++++++++--------- lux/interestingness/interestingness.py | 6 +- lux/interestingness/similarity.py | 3 +- lux/processor/Compiler.py | 9 ++- lux/vislib/altair/AltairRenderer.py | 3 +- lux/vislib/altair/Choropleth.py | 4 +- lux/vislib/altair/Heatmap.py | 5 +- tests/test_action.py | 23 ++++-- tests/test_compiler.py | 105 ++++++++++++++++++++----- tests/test_dates.py | 5 +- tests/test_interestingness.py | 10 ++- tests/test_pandas_coverage.py | 8 +- tests/test_parser.py | 3 +- tests/test_vis.py | 5 +- 21 files changed, 254 insertions(+), 90 deletions(-) diff --git a/lux/_config/config.py b/lux/_config/config.py index 65c778ef..12db034a 100644 --- a/lux/_config/config.py +++ b/lux/_config/config.py @@ -55,7 +55,8 @@ def topk(self, k: Union[int, bool]): self._topk = k else: warnings.warn( - "Parameter to lux.config.topk must be an integer or a boolean.", stacklevel=2, + "Parameter to lux.config.topk must be an integer or a boolean.", + stacklevel=2, ) @property @@ -98,7 +99,8 @@ def pandas_fallback(self, fallback: bool) -> None: self._pandas_fallback = fallback else: warnings.warn( - "The flag for Pandas fallback must be a boolean.", stacklevel=2, + "The flag for Pandas fallback must be a boolean.", + stacklevel=2, ) @property @@ -118,7 +120,8 @@ def interestingness_fallback(self, fallback: bool) -> None: self._interestingness_fallback = fallback else: warnings.warn( - "The flag for interestingness fallback must be a boolean.", stacklevel=2, + "The flag for interestingness fallback must be a boolean.", + stacklevel=2, ) @property @@ -144,7 +147,8 @@ def sampling_cap(self, sample_number: int) -> None: self._sampling_cap = sample_number else: warnings.warn( - "The cap on the number samples must be an integer.", stacklevel=2, + "The cap on the number samples must be an integer.", + stacklevel=2, ) @property @@ -172,7 +176,8 @@ def sampling_start(self, sample_number: int) -> None: self._sampling_start = sample_number else: warnings.warn( - "The sampling starting point must be an integer.", stacklevel=2, + "The sampling starting point must be an integer.", + stacklevel=2, ) @property @@ -197,7 +202,8 @@ def sampling(self, sample_flag: bool) -> None: self._sampling_flag = sample_flag else: warnings.warn( - "The flag for sampling must be a boolean.", stacklevel=2, + "The flag for sampling must be a boolean.", + stacklevel=2, ) @property @@ -222,7 +228,8 @@ def heatmap(self, heatmap_flag: bool) -> None: self._heatmap_flag = heatmap_flag else: warnings.warn( - "The flag for enabling/disabling heatmaps must be a boolean.", stacklevel=2, + "The flag for enabling/disabling heatmaps must be a boolean.", + stacklevel=2, ) @property diff --git a/lux/action/generalize.py b/lux/action/generalize.py index 84504132..385891d4 100644 --- a/lux/action/generalize.py +++ b/lux/action/generalize.py @@ -78,7 +78,12 @@ def generalize(ldf): for clause in filters: # new_spec = ldf._intent.copy() # new_spec.remove_column_from_spec(new_spec.attribute) - temp_vis = Vis(ldf.current_vis[0]._inferred_intent.copy(), source=ldf, title="Overall", score=0,) + temp_vis = Vis( + ldf.current_vis[0]._inferred_intent.copy(), + source=ldf, + title="Overall", + score=0, + ) temp_vis.remove_filter_from_spec(clause.value) output.append(temp_vis) diff --git a/lux/action/row_group.py b/lux/action/row_group.py index 7407ab5b..5555c01f 100644 --- a/lux/action/row_group.py +++ b/lux/action/row_group.py @@ -50,7 +50,13 @@ def row_group(ldf): # rowdf.cardinality["index"]=len(rowdf) # if isinstance(ldf.columns,pd.DatetimeIndex): # rowdf.data_type_lookup[dim_name]="temporal" - vis = Vis([dim_name, lux.Clause(row.name, data_model="measure", aggregation=None),], rowdf,) + vis = Vis( + [ + dim_name, + lux.Clause(row.name, data_model="measure", aggregation=None), + ], + rowdf, + ) collection.append(vis) vlst = VisList(collection) # Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data diff --git a/lux/core/frame.py b/lux/core/frame.py index f5fe5ca5..4514ed4d 100644 --- a/lux/core/frame.py +++ b/lux/core/frame.py @@ -496,7 +496,8 @@ def exported(self) -> Union[Dict[str, VisList], VisList]: exported_vis = VisList( list( map( - self._recommendation[export_action].__getitem__, exported_vis_lst[export_action], + self._recommendation[export_action].__getitem__, + exported_vis_lst[export_action], ) ) ) @@ -566,7 +567,8 @@ def _repr_html_(self): self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex") button = widgets.Button( - description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"), + description="Toggle Pandas/Lux", + layout=widgets.Layout(width="140px", top="5px"), ) self.output = widgets.Output() display(button, self.output) diff --git a/lux/core/series.py b/lux/core/series.py index 7cea6639..3c717642 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -154,7 +154,8 @@ def __repr__(self): # box = widgets.Box(layout=widgets.Layout(display='inline')) button = widgets.Button( - description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"), + description="Toggle Pandas/Lux", + layout=widgets.Layout(width="140px", top="5px"), ) ldf.output = widgets.Output() # box.children = [button,output] diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py index 842e0fb9..e51b8c09 100644 --- a/lux/core/sqltable.py +++ b/lux/core/sqltable.py @@ -84,7 +84,8 @@ def set_SQL_table(self, t_name): error_str = str(error) if f'relation "{t_name}" does not exist' in error_str: warnings.warn( - f"\nThe table '{t_name}' does not exist in your database./", stacklevel=2, + f"\nThe table '{t_name}' does not exist in your database./", + stacklevel=2, ) def _repr_html_(self): diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py index 8983876a..e5534821 100644 --- a/lux/executor/PandasExecutor.py +++ b/lux/executor/PandasExecutor.py @@ -216,7 +216,10 @@ def execute_aggregate(vis: Vis, isFiltered=True): } ) vis._vis_data = vis.data.merge( - df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"], + df, + on=[columns[0], columns[1]], + how="right", + suffixes=["", "_right"], ) for col in columns[2:]: vis.data[col] = vis.data[col].fillna(0) # Triggers __setitem__ @@ -364,7 +367,10 @@ def execute_2D_binning(vis: Vis): if color_attr.data_type == "nominal": # Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0]) result = groups.agg( - [("count", "count"), (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),] + [ + ("count", "count"), + (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]), + ] ).reset_index() elif color_attr.data_type == "quantitative" or color_attr.data_type == "temporal": # Compute the average of all values in the bin diff --git a/lux/executor/SQLExecutor.py b/lux/executor/SQLExecutor.py index 93f010c0..abf43f06 100644 --- a/lux/executor/SQLExecutor.py +++ b/lux/executor/SQLExecutor.py @@ -38,7 +38,8 @@ def execute_sampling(tbl: LuxSQLTable): SAMPLE_FRAC = 0.2 length_query = pandas.read_sql( - "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), lux.config.SQLconnection, + "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), + lux.config.SQLconnection, ) limit = int(list(length_query["length"])[0]) * SAMPLE_FRAC tbl._sampled = pandas.read_sql( @@ -122,7 +123,8 @@ def add_quotes(var_name): required_variables = ",".join(required_variables) row_count = list( pandas.read_sql( - f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", lux.config.SQLconnection, + f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", + lux.config.SQLconnection, )["count"] )[0] if row_count > lux.config.sampling_cap: @@ -227,42 +229,48 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True): # generates query for colored barchart case if has_color: if agg_func == "mean": - agg_query = 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - tbl.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, + agg_query = ( + 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "sum": - agg_query = 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - tbl.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, + agg_query = ( + 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "max": - agg_query = 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - tbl.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, + agg_query = ( + 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + tbl.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) @@ -322,7 +330,10 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True): } ) view._vis_data = view._vis_data.merge( - df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"], + df, + on=[columns[0], columns[1]], + how="right", + suffixes=["", "_right"], ) for col in columns[2:]: view._vis_data[col] = view._vis_data[col].fillna(0) # Triggers __setitem__ @@ -391,7 +402,10 @@ def execute_binning(view: Vis, tbl: LuxSQLTable): upper_edges = ",".join(upper_edges) view_filter, filter_vars = SQLExecutor.execute_filter(view) bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format( - bin_attribute.attribute, "{" + upper_edges + "}", tbl.table_name, where_clause, + bin_attribute.attribute, + "{" + upper_edges + "}", + tbl.table_name, + where_clause, ) bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection) @@ -406,11 +420,13 @@ def execute_binning(view: Vis, tbl: LuxSQLTable): else: bin_centers = np.array([(attr_min + attr_min + bin_width) / 2]) bin_centers = np.append( - bin_centers, np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0), + bin_centers, + np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0), ) if attr_type == int: bin_centers = np.append( - bin_centers, math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2), + bin_centers, + math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2), ) else: bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2) @@ -551,7 +567,10 @@ def execute_filter(view: Vis): else: where_clause.append("AND") where_clause.extend( - ['"' + str(a.attribute) + '"', "IS NOT NULL",] + [ + '"' + str(a.attribute) + '"', + "IS NOT NULL", + ] ) if where_clause == []: @@ -668,7 +687,10 @@ def get_cardinality(self, tbl: LuxSQLTable): card_query = 'SELECT Count(Distinct("{}")) FROM {} WHERE "{}" IS NOT NULL'.format( attr, tbl.table_name, attr ) - card_data = pandas.read_sql(card_query, lux.config.SQLconnection,) + card_data = pandas.read_sql( + card_query, + lux.config.SQLconnection, + ) cardinality[attr] = list(card_data["count"])[0] tbl.cardinality = cardinality @@ -691,7 +713,10 @@ def get_unique_values(self, tbl: LuxSQLTable): unique_query = 'SELECT Distinct("{}") FROM {} WHERE "{}" IS NOT NULL'.format( attr, tbl.table_name, attr ) - unique_data = pandas.read_sql(unique_query, lux.config.SQLconnection,) + unique_data = pandas.read_sql( + unique_query, + lux.config.SQLconnection, + ) unique_vals[attr] = list(unique_data[attr]) tbl.unique_values = unique_vals diff --git a/lux/interestingness/interestingness.py b/lux/interestingness/interestingness.py index a95efea9..32b0f1db 100644 --- a/lux/interestingness/interestingness.py +++ b/lux/interestingness/interestingness.py @@ -198,7 +198,11 @@ def weighted_correlation(x, y, w): def deviation_from_overall( - vis: Vis, ldf: LuxDataFrame, filter_specs: list, msr_attribute: str, exclude_nan: bool = True, + vis: Vis, + ldf: LuxDataFrame, + filter_specs: list, + msr_attribute: str, + exclude_nan: bool = True, ) -> int: """ Difference in bar chart/histogram shape from overall chart diff --git a/lux/interestingness/similarity.py b/lux/interestingness/similarity.py index 017b97b5..8d810909 100644 --- a/lux/interestingness/similarity.py +++ b/lux/interestingness/similarity.py @@ -68,7 +68,8 @@ def interpolate(vis, length): yVals[count - 1] + (interpolated_x - xVals[count - 1]) / x_diff * yDiff ) vis.data = pd.DataFrame( - list(zip(interpolated_x_vals, interpolated_y_vals)), columns=[xAxis, yAxis], + list(zip(interpolated_x_vals, interpolated_y_vals)), + columns=[xAxis, yAxis], ) diff --git a/lux/processor/Compiler.py b/lux/processor/Compiler.py index 85c49cfb..96d95d10 100644 --- a/lux/processor/Compiler.py +++ b/lux/processor/Compiler.py @@ -281,7 +281,10 @@ def line_or_bar_or_geo(ldf, dimension: Clause, measure: Clause): # ShowMe logic + additional heuristics # count_col = Clause( attribute="count()", data_model="measure") count_col = Clause( - attribute="Record", aggregation="count", data_model="measure", data_type="quantitative", + attribute="Record", + aggregation="count", + data_model="measure", + data_type="quantitative", ) auto_channel = {} if ndim == 0 and nmsr == 1: @@ -472,7 +475,9 @@ def populate_wildcard_options(_inferred_intent: List[Clause], ldf: LuxDataFrame) options = ldf.unique_values[attr] specInd = _inferred_intent.index(clause) _inferred_intent[specInd] = Clause( - attribute=clause.attribute, filter_op="=", value=list(options), + attribute=clause.attribute, + filter_op="=", + value=list(options), ) else: options.extend(convert_to_list(clause.value)) diff --git a/lux/vislib/altair/AltairRenderer.py b/lux/vislib/altair/AltairRenderer.py index deba1592..2c1ab206 100644 --- a/lux/vislib/altair/AltairRenderer.py +++ b/lux/vislib/altair/AltairRenderer.py @@ -128,7 +128,8 @@ def create_vis(self, vis, standalone=True): found_variable = "df" if standalone: chart.code = chart.code.replace( - "placeholder_variable", f"pd.DataFrame({str(vis.data.to_dict())})", + "placeholder_variable", + f"pd.DataFrame({str(vis.data.to_dict())})", ) else: # TODO: Placeholder (need to read dynamically via locals()) diff --git a/lux/vislib/altair/Choropleth.py b/lux/vislib/altair/Choropleth.py index c2675883..bf71b010 100644 --- a/lux/vislib/altair/Choropleth.py +++ b/lux/vislib/altair/Choropleth.py @@ -68,7 +68,9 @@ def initialize_chart(self): points = ( alt.Chart(geo_map) .mark_geoshape() - .encode(color=f"{y_attr_abv}:Q",) + .encode( + color=f"{y_attr_abv}:Q", + ) .transform_lookup(lookup="id", from_=alt.LookupData(self.data, x_attr_abv, [y_attr_abv])) .project(type=map_type) .properties( diff --git a/lux/vislib/altair/Heatmap.py b/lux/vislib/altair/Heatmap.py index 2f743b32..f83a3bbb 100644 --- a/lux/vislib/altair/Heatmap.py +++ b/lux/vislib/altair/Heatmap.py @@ -71,7 +71,10 @@ def initialize_chart(self): ), y2=alt.Y2("yBinEnd"), opacity=alt.Opacity( - "count", type="quantitative", scale=alt.Scale(type="log"), legend=None, + "count", + type="quantitative", + scale=alt.Scale(type="log"), + legend=None, ), ) ) diff --git a/tests/test_action.py b/tests/test_action.py index 2cb49e86..893c13b7 100644 --- a/tests/test_action.py +++ b/tests/test_action.py @@ -184,7 +184,8 @@ def test_year_filter_value(global_var): lambda vis: len( list( filter( - lambda clause: clause.value != "" and clause.attribute == "Year", vis._intent, + lambda clause: clause.value != "" and clause.attribute == "Year", + vis._intent, ) ) ) @@ -214,10 +215,16 @@ def test_similarity(global_var): ranked_list = df.recommendation["Similarity"] japan_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan", ranked_list,) + filter( + lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan", + ranked_list, + ) )[0] europe_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe", ranked_list,) + filter( + lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe", + ranked_list, + ) )[0] assert japan_vis.score > europe_vis.score df.clear_intent() @@ -240,10 +247,16 @@ def test_similarity2(): ranked_list = df.recommendation["Similarity"] morrisville_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville", ranked_list,) + filter( + lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville", + ranked_list, + ) )[0] watertown_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown", ranked_list,) + filter( + lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown", + ranked_list, + ) )[0] assert morrisville_vis.score > watertown_vis.score diff --git a/tests/test_compiler.py b/tests/test_compiler.py index d30b824f..28e80f8a 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -168,7 +168,10 @@ def test_underspecified_vis_collection_zval(global_var): # check if the number of charts is correct df = pytest.car_df vlst = VisList( - [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),], + [ + lux.Clause(attribute="Origin", filter_op="=", value="?"), + lux.Clause(attribute="MilesPerGal"), + ], df, ) assert len(vlst) == 3 @@ -182,7 +185,10 @@ def test_underspecified_vis_collection_zval(global_var): lux.config.set_SQL_connection(connection) sql_df = lux.LuxSQLTable(table_name="cars") vlst = VisList( - [lux.Clause(attribute="origin", filter_op="=", value="?"), lux.Clause(attribute="milespergal"),], + [ + lux.Clause(attribute="origin", filter_op="=", value="?"), + lux.Clause(attribute="milespergal"), + ], sql_df, ) assert len(vlst) == 3 @@ -279,7 +285,10 @@ def test_specified_channel_enforced_vis_collection(global_var): df = pytest.car_df # change pandas dtype for the column "Year" to datetype df["Year"] = pd.to_datetime(df["Year"], format="%Y") - visList = VisList([lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], df,) + visList = VisList( + [lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], + df, + ) for vis in visList: check_attribute_on_channel(vis, "MilesPerGal", "x") @@ -295,13 +304,22 @@ def test_autoencoding_scatter(global_var): check_attribute_on_channel(vis, "Weight", "y") # Partial channel specified - vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,) + vis = Vis( + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight"), + ], + df, + ) check_attribute_on_channel(vis, "MilesPerGal", "y") check_attribute_on_channel(vis, "Weight", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),], + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight", channel="x"), + ], df, ) check_attribute_on_channel(vis, "MilesPerGal", "y") @@ -321,7 +339,8 @@ def test_autoencoding_scatter(global_var): lux.config.set_SQL_connection(connection) sql_df = lux.LuxSQLTable(table_name="cars") visList = VisList( - [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")], sql_df, + [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")], + sql_df, ) for vis in visList: check_attribute_on_channel(vis, "milespergal", "x") @@ -339,13 +358,22 @@ def test_autoencoding_scatter(): check_attribute_on_channel(vis, "Weight", "y") # Partial channel specified - vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,) + vis = Vis( + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight"), + ], + df, + ) check_attribute_on_channel(vis, "MilesPerGal", "y") check_attribute_on_channel(vis, "Weight", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),], + [ + lux.Clause(attribute="MilesPerGal", channel="y"), + lux.Clause(attribute="Weight", channel="x"), + ], df, ) check_attribute_on_channel(vis, "MilesPerGal", "y") @@ -370,14 +398,21 @@ def test_autoencoding_scatter(): # Partial channel specified vis = Vis( - [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight"),], sql_df, + [ + lux.Clause(attribute="milespergal", channel="y"), + lux.Clause(attribute="weight"), + ], + sql_df, ) check_attribute_on_channel(vis, "milespergal", "y") check_attribute_on_channel(vis, "weight", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight", channel="x"),], + [ + lux.Clause(attribute="milespergal", channel="y"), + lux.Clause(attribute="weight", channel="x"), + ], sql_df, ) check_attribute_on_channel(vis, "milespergal", "y") @@ -429,13 +464,22 @@ def test_autoencoding_line_chart(global_var): check_attribute_on_channel(vis, "Acceleration", "y") # Partial channel specified - vis = Vis([lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration"),], df,) + vis = Vis( + [ + lux.Clause(attribute="Year", channel="y"), + lux.Clause(attribute="Acceleration"), + ], + df, + ) check_attribute_on_channel(vis, "Year", "y") check_attribute_on_channel(vis, "Acceleration", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration", channel="x"),], + [ + lux.Clause(attribute="Year", channel="y"), + lux.Clause(attribute="Acceleration", channel="x"), + ], df, ) check_attribute_on_channel(vis, "Year", "y") @@ -461,14 +505,21 @@ def test_autoencoding_line_chart(global_var): # Partial channel specified vis = Vis( - [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration"),], sql_df, + [ + lux.Clause(attribute="year", channel="y"), + lux.Clause(attribute="acceleration"), + ], + sql_df, ) check_attribute_on_channel(vis, "year", "y") check_attribute_on_channel(vis, "acceleration", "x") # Full channel specified vis = Vis( - [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration", channel="x"),], + [ + lux.Clause(attribute="year", channel="y"), + lux.Clause(attribute="acceleration", channel="x"), + ], sql_df, ) check_attribute_on_channel(vis, "year", "y") @@ -577,7 +628,10 @@ def test_populate_options(global_var): assert list_equal(list(col_set), list(df.columns)) df.set_intent( - [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="MilesPerGal"),] + [ + lux.Clause(attribute="?", data_model="measure"), + lux.Clause(attribute="MilesPerGal"), + ] ) df._repr_html_() col_set = set() @@ -585,7 +639,8 @@ def test_populate_options(global_var): for clause in specOptions: col_set.add(clause.attribute) assert list_equal( - list(col_set), ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"], + list(col_set), + ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"], ) df.clear_intent() @@ -601,7 +656,10 @@ def test_populate_options(global_var): assert list_equal(list(col_set), list(sql_df.columns)) sql_df.set_intent( - [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="milespergal"),] + [ + lux.Clause(attribute="?", data_model="measure"), + lux.Clause(attribute="milespergal"), + ] ) sql_df._repr_html_() col_set = set() @@ -609,7 +667,8 @@ def test_populate_options(global_var): for clause in specOptions: col_set.add(clause.attribute) assert list_equal( - list(col_set), ["acceleration", "weight", "horsepower", "milespergal", "displacement"], + list(col_set), + ["acceleration", "weight", "horsepower", "milespergal", "displacement"], ) @@ -619,7 +678,10 @@ def test_remove_all_invalid(global_var): df["Year"] = pd.to_datetime(df["Year"], format="%Y") # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): df.set_intent( - [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Origin"),] + [ + lux.Clause(attribute="Origin", filter_op="=", value="USA"), + lux.Clause(attribute="Origin"), + ] ) df._repr_html_() assert len(df.current_vis) == 0 @@ -631,7 +693,10 @@ def test_remove_all_invalid(global_var): sql_df = lux.LuxSQLTable(table_name="cars") # with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"): sql_df.set_intent( - [lux.Clause(attribute="origin", filter_op="=", value="USA"), lux.Clause(attribute="origin"),] + [ + lux.Clause(attribute="origin", filter_op="=", value="USA"), + lux.Clause(attribute="origin"), + ] ) sql_df._repr_html_() assert len(sql_df.current_vis) == 0 diff --git a/tests/test_dates.py b/tests/test_dates.py index 48d514b9..dc530fc7 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -44,7 +44,10 @@ def test_period_selection(global_var): ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq="A") ldf.set_intent( - [lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), lux.Clause(attribute="Year"),] + [ + lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), + lux.Clause(attribute="Year"), + ] ) lux.config.executor.execute(ldf.current_vis, ldf) diff --git a/tests/test_interestingness.py b/tests/test_interestingness.py index 1ef929e3..7e6036f9 100644 --- a/tests/test_interestingness.py +++ b/tests/test_interestingness.py @@ -64,7 +64,10 @@ def test_interestingness_1_0_1(global_var): df["Year"] = pd.to_datetime(df["Year"], format="%Y") df.set_intent( - [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Cylinders"),] + [ + lux.Clause(attribute="Origin", filter_op="=", value="USA"), + lux.Clause(attribute="Cylinders"), + ] ) df._repr_html_() assert df.current_vis[0].score == 0 @@ -121,7 +124,10 @@ def test_interestingness_0_1_1(global_var): df["Year"] = pd.to_datetime(df["Year"], format="%Y") df.set_intent( - [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),] + [ + lux.Clause(attribute="Origin", filter_op="=", value="?"), + lux.Clause(attribute="MilesPerGal"), + ] ) df._repr_html_() assert interestingness(df.recommendation["Current Vis"][0], df) != None diff --git a/tests/test_pandas_coverage.py b/tests/test_pandas_coverage.py index e30b0aac..21014f60 100644 --- a/tests/test_pandas_coverage.py +++ b/tests/test_pandas_coverage.py @@ -173,7 +173,8 @@ def test_groupby_agg_big(global_var): assert len(new_df.cardinality) == 8 year_vis = list( filter( - lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Column Groups"], + lambda vis: vis.get_attr_by_attr_name("Year") != [], + new_df.recommendation["Column Groups"], ) )[0] assert year_vis.mark == "bar" @@ -181,7 +182,10 @@ def test_groupby_agg_big(global_var): new_df = new_df.T new_df._repr_html_() year_vis = list( - filter(lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Row Groups"],) + filter( + lambda vis: vis.get_attr_by_attr_name("Year") != [], + new_df.recommendation["Row Groups"], + ) )[0] assert year_vis.mark == "bar" assert year_vis.get_attr_by_channel("x")[0].attribute == "Year" diff --git a/tests/test_parser.py b/tests/test_parser.py index d274a1f4..333977aa 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -121,6 +121,7 @@ def test_validator_invalid_attribute(global_var): df = pytest.college_df with pytest.raises(KeyError, match="'blah'"): with pytest.warns( - UserWarning, match="The input attribute 'blah' does not exist in the DataFrame.", + UserWarning, + match="The input attribute 'blah' does not exist in the DataFrame.", ): df.intent = ["blah"] diff --git a/tests/test_vis.py b/tests/test_vis.py index 2a088ddb..4514be42 100644 --- a/tests/test_vis.py +++ b/tests/test_vis.py @@ -153,7 +153,10 @@ def test_vis_list_custom_title_override(global_var): vcLst = [] for attribute in ["Sport", "Year", "Height", "HostRegion", "SportType"]: - vis = Vis([lux.Clause("Weight"), lux.Clause(attribute)], title="overriding dummy title",) + vis = Vis( + [lux.Clause("Weight"), lux.Clause(attribute)], + title="overriding dummy title", + ) vcLst.append(vis) vlist = VisList(vcLst, df) for v in vlist: