From b1b56d7a9f8b39db3f4bfe9958af07f3a3754eb6 Mon Sep 17 00:00:00 2001
From: Sophia Huang <6860749+sophiahhuang@users.noreply.github.com>
Date: Wed, 17 Mar 2021 12:02:48 +0800
Subject: [PATCH 1/4] Test Commit of SQLTable
---
lux/core/sqltable.py | 16 ++++++++++++++--
1 file changed, 14 insertions(+), 2 deletions(-)
diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py
index 3884f2cd..a6e31dab 100644
--- a/lux/core/sqltable.py
+++ b/lux/core/sqltable.py
@@ -109,7 +109,7 @@ def set_SQL_table(self, t_name):
)
def _repr_html_(self):
- from IPython.display import display
+ from IPython.display import HTML, display
from IPython.display import clear_output
import ipywidgets as widgets
@@ -143,7 +143,19 @@ def _repr_html_(self):
)
self.output = widgets.Output()
lux.config.executor.execute_preview(self)
- display(button, self.output)
+ notification = HTML("""
+ Note: data displayed is just a preview of the database table. No Pandas functionality on this data is supported.""")
+
+ display(button, notification, self.output)
def on_button_clicked(b):
with self.output:
From 34b5c45885f82fa701f8e603d02dc3eeb81cdc28 Mon Sep 17 00:00:00 2001
From: Sophia Huang <6860749+sophiahhuang@users.noreply.github.com>
Date: Sat, 20 Mar 2021 01:32:43 +0800
Subject: [PATCH 2/4] Updated LuxSQLTable notification
- Text -> [Toggle Table/Lux]
- Caption's appearance with preview table only
- Caption has table name
---
lux/_config/config.py | 21 ++---
lux/action/generalize.py | 7 +-
lux/action/row_group.py | 8 +-
lux/core/frame.py | 6 +-
lux/core/series.py | 3 +-
lux/core/sqltable.py | 26 ++----
lux/executor/PandasExecutor.py | 10 +--
lux/executor/SQLExecutor.py | 100 +++++++++--------------
lux/interestingness/interestingness.py | 6 +-
lux/interestingness/similarity.py | 3 +-
lux/processor/Compiler.py | 9 +--
lux/vislib/altair/AltairRenderer.py | 3 +-
lux/vislib/altair/Choropleth.py | 4 +-
lux/vislib/altair/Heatmap.py | 5 +-
tests/test_action.py | 23 ++----
tests/test_compiler.py | 105 +++++--------------------
tests/test_dates.py | 5 +-
tests/test_interestingness.py | 10 +--
tests/test_pandas_coverage.py | 8 +-
tests/test_parser.py | 3 +-
tests/test_vis.py | 5 +-
21 files changed, 98 insertions(+), 272 deletions(-)
diff --git a/lux/_config/config.py b/lux/_config/config.py
index 12db034a..65c778ef 100644
--- a/lux/_config/config.py
+++ b/lux/_config/config.py
@@ -55,8 +55,7 @@ def topk(self, k: Union[int, bool]):
self._topk = k
else:
warnings.warn(
- "Parameter to lux.config.topk must be an integer or a boolean.",
- stacklevel=2,
+ "Parameter to lux.config.topk must be an integer or a boolean.", stacklevel=2,
)
@property
@@ -99,8 +98,7 @@ def pandas_fallback(self, fallback: bool) -> None:
self._pandas_fallback = fallback
else:
warnings.warn(
- "The flag for Pandas fallback must be a boolean.",
- stacklevel=2,
+ "The flag for Pandas fallback must be a boolean.", stacklevel=2,
)
@property
@@ -120,8 +118,7 @@ def interestingness_fallback(self, fallback: bool) -> None:
self._interestingness_fallback = fallback
else:
warnings.warn(
- "The flag for interestingness fallback must be a boolean.",
- stacklevel=2,
+ "The flag for interestingness fallback must be a boolean.", stacklevel=2,
)
@property
@@ -147,8 +144,7 @@ def sampling_cap(self, sample_number: int) -> None:
self._sampling_cap = sample_number
else:
warnings.warn(
- "The cap on the number samples must be an integer.",
- stacklevel=2,
+ "The cap on the number samples must be an integer.", stacklevel=2,
)
@property
@@ -176,8 +172,7 @@ def sampling_start(self, sample_number: int) -> None:
self._sampling_start = sample_number
else:
warnings.warn(
- "The sampling starting point must be an integer.",
- stacklevel=2,
+ "The sampling starting point must be an integer.", stacklevel=2,
)
@property
@@ -202,8 +197,7 @@ def sampling(self, sample_flag: bool) -> None:
self._sampling_flag = sample_flag
else:
warnings.warn(
- "The flag for sampling must be a boolean.",
- stacklevel=2,
+ "The flag for sampling must be a boolean.", stacklevel=2,
)
@property
@@ -228,8 +222,7 @@ def heatmap(self, heatmap_flag: bool) -> None:
self._heatmap_flag = heatmap_flag
else:
warnings.warn(
- "The flag for enabling/disabling heatmaps must be a boolean.",
- stacklevel=2,
+ "The flag for enabling/disabling heatmaps must be a boolean.", stacklevel=2,
)
@property
diff --git a/lux/action/generalize.py b/lux/action/generalize.py
index 385891d4..84504132 100644
--- a/lux/action/generalize.py
+++ b/lux/action/generalize.py
@@ -78,12 +78,7 @@ def generalize(ldf):
for clause in filters:
# new_spec = ldf._intent.copy()
# new_spec.remove_column_from_spec(new_spec.attribute)
- temp_vis = Vis(
- ldf.current_vis[0]._inferred_intent.copy(),
- source=ldf,
- title="Overall",
- score=0,
- )
+ temp_vis = Vis(ldf.current_vis[0]._inferred_intent.copy(), source=ldf, title="Overall", score=0,)
temp_vis.remove_filter_from_spec(clause.value)
output.append(temp_vis)
diff --git a/lux/action/row_group.py b/lux/action/row_group.py
index 5555c01f..7407ab5b 100644
--- a/lux/action/row_group.py
+++ b/lux/action/row_group.py
@@ -50,13 +50,7 @@ def row_group(ldf):
# rowdf.cardinality["index"]=len(rowdf)
# if isinstance(ldf.columns,pd.DatetimeIndex):
# rowdf.data_type_lookup[dim_name]="temporal"
- vis = Vis(
- [
- dim_name,
- lux.Clause(row.name, data_model="measure", aggregation=None),
- ],
- rowdf,
- )
+ vis = Vis([dim_name, lux.Clause(row.name, data_model="measure", aggregation=None),], rowdf,)
collection.append(vis)
vlst = VisList(collection)
# Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data
diff --git a/lux/core/frame.py b/lux/core/frame.py
index 5f584d4c..94e8f406 100644
--- a/lux/core/frame.py
+++ b/lux/core/frame.py
@@ -491,8 +491,7 @@ def exported(self) -> Union[Dict[str, VisList], VisList]:
exported_vis = VisList(
list(
map(
- self._recommendation[export_action].__getitem__,
- exported_vis_lst[export_action],
+ self._recommendation[export_action].__getitem__, exported_vis_lst[export_action],
)
)
)
@@ -562,8 +561,7 @@ def _repr_html_(self):
self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex")
button = widgets.Button(
- description="Toggle Pandas/Lux",
- layout=widgets.Layout(width="140px", top="5px"),
+ description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"),
)
self.output = widgets.Output()
display(button, self.output)
diff --git a/lux/core/series.py b/lux/core/series.py
index 83b2c42b..f5fd31ca 100644
--- a/lux/core/series.py
+++ b/lux/core/series.py
@@ -154,8 +154,7 @@ def __repr__(self):
# box = widgets.Box(layout=widgets.Layout(display='inline'))
button = widgets.Button(
- description="Toggle Pandas/Lux",
- layout=widgets.Layout(width="140px", top="5px"),
+ description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"),
)
ldf.output = widgets.Output()
# box.children = [button,output]
diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py
index a6e31dab..e2469675 100644
--- a/lux/core/sqltable.py
+++ b/lux/core/sqltable.py
@@ -104,8 +104,7 @@ def set_SQL_table(self, t_name):
error_str = str(error)
if f'relation "{t_name}" does not exist' in error_str:
warnings.warn(
- f"\nThe table '{t_name}' does not exist in your database./",
- stacklevel=2,
+ f"\nThe table '{t_name}' does not exist in your database./", stacklevel=2,
)
def _repr_html_(self):
@@ -138,24 +137,12 @@ def _repr_html_(self):
self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex")
button = widgets.Button(
- description="Toggle Data Preview/Lux",
- layout=widgets.Layout(width="200px", top="5px"),
+ description="Toggle Table/Lux",
+ layout=widgets.Layout(width="200px", top="6px", bottom="6px"),
)
self.output = widgets.Output()
lux.config.executor.execute_preview(self)
- notification = HTML("""
- Note: data displayed is just a preview of the database table. No Pandas functionality on this data is supported.""")
-
- display(button, notification, self.output)
+ display(button, self.output)
def on_button_clicked(b):
with self.output:
@@ -163,7 +150,10 @@ def on_button_clicked(b):
self._toggle_pandas_display = not self._toggle_pandas_display
clear_output()
if self._toggle_pandas_display:
- display(self._sampled.display_pandas())
+ notification = widgets.Label(
+ value="Preview of the database table: " + self.table_name
+ )
+ display(notification, self._sampled.display_pandas())
else:
# b.layout.display = "none"
display(self._widget)
diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py
index e5534821..8983876a 100644
--- a/lux/executor/PandasExecutor.py
+++ b/lux/executor/PandasExecutor.py
@@ -216,10 +216,7 @@ def execute_aggregate(vis: Vis, isFiltered=True):
}
)
vis._vis_data = vis.data.merge(
- df,
- on=[columns[0], columns[1]],
- how="right",
- suffixes=["", "_right"],
+ df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"],
)
for col in columns[2:]:
vis.data[col] = vis.data[col].fillna(0) # Triggers __setitem__
@@ -367,10 +364,7 @@ def execute_2D_binning(vis: Vis):
if color_attr.data_type == "nominal":
# Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0])
result = groups.agg(
- [
- ("count", "count"),
- (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),
- ]
+ [("count", "count"), (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),]
).reset_index()
elif color_attr.data_type == "quantitative" or color_attr.data_type == "temporal":
# Compute the average of all values in the bin
diff --git a/lux/executor/SQLExecutor.py b/lux/executor/SQLExecutor.py
index fd063c0a..96112bd2 100644
--- a/lux/executor/SQLExecutor.py
+++ b/lux/executor/SQLExecutor.py
@@ -38,8 +38,7 @@ def execute_sampling(tbl: LuxSQLTable):
SAMPLE_FRAC = 0.2
length_query = pandas.read_sql(
- "SELECT COUNT(*) as length FROM {}".format(tbl.table_name),
- lux.config.SQLconnection,
+ "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), lux.config.SQLconnection,
)
limit = int(list(length_query["length"])[0]) * SAMPLE_FRAC
tbl._sampled = pandas.read_sql(
@@ -123,8 +122,7 @@ def add_quotes(var_name):
required_variables = ",".join(required_variables)
row_count = list(
pandas.read_sql(
- f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}",
- lux.config.SQLconnection,
+ f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", lux.config.SQLconnection,
)["count"]
)[0]
if row_count > lux.config.sampling_cap:
@@ -229,48 +227,42 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True):
# generates query for colored barchart case
if has_color:
if agg_func == "mean":
- agg_query = (
- 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
- groupby_attr.attribute,
- color_attr.attribute,
- measure_attr.attribute,
- measure_attr.attribute,
- tbl.table_name,
- where_clause,
- groupby_attr.attribute,
- color_attr.attribute,
- )
+ agg_query = 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
+ groupby_attr.attribute,
+ color_attr.attribute,
+ measure_attr.attribute,
+ measure_attr.attribute,
+ tbl.table_name,
+ where_clause,
+ groupby_attr.attribute,
+ color_attr.attribute,
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "sum":
- agg_query = (
- 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
- groupby_attr.attribute,
- color_attr.attribute,
- measure_attr.attribute,
- measure_attr.attribute,
- tbl.table_name,
- where_clause,
- groupby_attr.attribute,
- color_attr.attribute,
- )
+ agg_query = 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
+ groupby_attr.attribute,
+ color_attr.attribute,
+ measure_attr.attribute,
+ measure_attr.attribute,
+ tbl.table_name,
+ where_clause,
+ groupby_attr.attribute,
+ color_attr.attribute,
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "max":
- agg_query = (
- 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
- groupby_attr.attribute,
- color_attr.attribute,
- measure_attr.attribute,
- measure_attr.attribute,
- tbl.table_name,
- where_clause,
- groupby_attr.attribute,
- color_attr.attribute,
- )
+ agg_query = 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
+ groupby_attr.attribute,
+ color_attr.attribute,
+ measure_attr.attribute,
+ measure_attr.attribute,
+ tbl.table_name,
+ where_clause,
+ groupby_attr.attribute,
+ color_attr.attribute,
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
@@ -330,10 +322,7 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True):
}
)
view._vis_data = view._vis_data.merge(
- df,
- on=[columns[0], columns[1]],
- how="right",
- suffixes=["", "_right"],
+ df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"],
)
for col in columns[2:]:
view._vis_data[col] = view._vis_data[col].fillna(0) # Triggers __setitem__
@@ -402,10 +391,7 @@ def execute_binning(view: Vis, tbl: LuxSQLTable):
upper_edges = ",".join(upper_edges)
view_filter, filter_vars = SQLExecutor.execute_filter(view)
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
- bin_attribute.attribute,
- "{" + upper_edges + "}",
- tbl.table_name,
- where_clause,
+ bin_attribute.attribute, "{" + upper_edges + "}", tbl.table_name, where_clause,
)
bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection)
@@ -420,13 +406,11 @@ def execute_binning(view: Vis, tbl: LuxSQLTable):
else:
bin_centers = np.array([(attr_min + attr_min + bin_width) / 2])
bin_centers = np.append(
- bin_centers,
- np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
+ bin_centers, np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
)
if attr_type == int:
bin_centers = np.append(
- bin_centers,
- math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
+ bin_centers, math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
)
else:
bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2)
@@ -567,10 +551,7 @@ def execute_filter(view: Vis):
else:
where_clause.append("AND")
where_clause.extend(
- [
- '"' + str(a.attribute) + '"',
- "IS NOT NULL",
- ]
+ ['"' + str(a.attribute) + '"', "IS NOT NULL",]
)
if where_clause == []:
@@ -649,8 +630,7 @@ def compute_stats(self, tbl: LuxSQLTable):
tbl.unique_values = {}
tbl._min_max = {}
length_query = pandas.read_sql(
- "SELECT COUNT(*) as length FROM {}".format(tbl.table_name),
- lux.config.SQLconnection,
+ "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), lux.config.SQLconnection,
)
tbl.length = list(length_query["length"])[0]
@@ -687,10 +667,7 @@ def get_cardinality(self, tbl: LuxSQLTable):
card_query = 'SELECT Count(Distinct("{}")) FROM {} WHERE "{}" IS NOT NULL'.format(
attr, tbl.table_name, attr
)
- card_data = pandas.read_sql(
- card_query,
- lux.config.SQLconnection,
- )
+ card_data = pandas.read_sql(card_query, lux.config.SQLconnection,)
cardinality[attr] = list(card_data["count"])[0]
tbl.cardinality = cardinality
@@ -713,10 +690,7 @@ def get_unique_values(self, tbl: LuxSQLTable):
unique_query = 'SELECT Distinct("{}") FROM {} WHERE "{}" IS NOT NULL'.format(
attr, tbl.table_name, attr
)
- unique_data = pandas.read_sql(
- unique_query,
- lux.config.SQLconnection,
- )
+ unique_data = pandas.read_sql(unique_query, lux.config.SQLconnection,)
unique_vals[attr] = list(unique_data[attr])
tbl.unique_values = unique_vals
diff --git a/lux/interestingness/interestingness.py b/lux/interestingness/interestingness.py
index 32b0f1db..a95efea9 100644
--- a/lux/interestingness/interestingness.py
+++ b/lux/interestingness/interestingness.py
@@ -198,11 +198,7 @@ def weighted_correlation(x, y, w):
def deviation_from_overall(
- vis: Vis,
- ldf: LuxDataFrame,
- filter_specs: list,
- msr_attribute: str,
- exclude_nan: bool = True,
+ vis: Vis, ldf: LuxDataFrame, filter_specs: list, msr_attribute: str, exclude_nan: bool = True,
) -> int:
"""
Difference in bar chart/histogram shape from overall chart
diff --git a/lux/interestingness/similarity.py b/lux/interestingness/similarity.py
index 8d810909..017b97b5 100644
--- a/lux/interestingness/similarity.py
+++ b/lux/interestingness/similarity.py
@@ -68,8 +68,7 @@ def interpolate(vis, length):
yVals[count - 1] + (interpolated_x - xVals[count - 1]) / x_diff * yDiff
)
vis.data = pd.DataFrame(
- list(zip(interpolated_x_vals, interpolated_y_vals)),
- columns=[xAxis, yAxis],
+ list(zip(interpolated_x_vals, interpolated_y_vals)), columns=[xAxis, yAxis],
)
diff --git a/lux/processor/Compiler.py b/lux/processor/Compiler.py
index 96d95d10..85c49cfb 100644
--- a/lux/processor/Compiler.py
+++ b/lux/processor/Compiler.py
@@ -281,10 +281,7 @@ def line_or_bar_or_geo(ldf, dimension: Clause, measure: Clause):
# ShowMe logic + additional heuristics
# count_col = Clause( attribute="count()", data_model="measure")
count_col = Clause(
- attribute="Record",
- aggregation="count",
- data_model="measure",
- data_type="quantitative",
+ attribute="Record", aggregation="count", data_model="measure", data_type="quantitative",
)
auto_channel = {}
if ndim == 0 and nmsr == 1:
@@ -475,9 +472,7 @@ def populate_wildcard_options(_inferred_intent: List[Clause], ldf: LuxDataFrame)
options = ldf.unique_values[attr]
specInd = _inferred_intent.index(clause)
_inferred_intent[specInd] = Clause(
- attribute=clause.attribute,
- filter_op="=",
- value=list(options),
+ attribute=clause.attribute, filter_op="=", value=list(options),
)
else:
options.extend(convert_to_list(clause.value))
diff --git a/lux/vislib/altair/AltairRenderer.py b/lux/vislib/altair/AltairRenderer.py
index 2c1ab206..deba1592 100644
--- a/lux/vislib/altair/AltairRenderer.py
+++ b/lux/vislib/altair/AltairRenderer.py
@@ -128,8 +128,7 @@ def create_vis(self, vis, standalone=True):
found_variable = "df"
if standalone:
chart.code = chart.code.replace(
- "placeholder_variable",
- f"pd.DataFrame({str(vis.data.to_dict())})",
+ "placeholder_variable", f"pd.DataFrame({str(vis.data.to_dict())})",
)
else:
# TODO: Placeholder (need to read dynamically via locals())
diff --git a/lux/vislib/altair/Choropleth.py b/lux/vislib/altair/Choropleth.py
index bf71b010..c2675883 100644
--- a/lux/vislib/altair/Choropleth.py
+++ b/lux/vislib/altair/Choropleth.py
@@ -68,9 +68,7 @@ def initialize_chart(self):
points = (
alt.Chart(geo_map)
.mark_geoshape()
- .encode(
- color=f"{y_attr_abv}:Q",
- )
+ .encode(color=f"{y_attr_abv}:Q",)
.transform_lookup(lookup="id", from_=alt.LookupData(self.data, x_attr_abv, [y_attr_abv]))
.project(type=map_type)
.properties(
diff --git a/lux/vislib/altair/Heatmap.py b/lux/vislib/altair/Heatmap.py
index f83a3bbb..2f743b32 100644
--- a/lux/vislib/altair/Heatmap.py
+++ b/lux/vislib/altair/Heatmap.py
@@ -71,10 +71,7 @@ def initialize_chart(self):
),
y2=alt.Y2("yBinEnd"),
opacity=alt.Opacity(
- "count",
- type="quantitative",
- scale=alt.Scale(type="log"),
- legend=None,
+ "count", type="quantitative", scale=alt.Scale(type="log"), legend=None,
),
)
)
diff --git a/tests/test_action.py b/tests/test_action.py
index 893c13b7..2cb49e86 100644
--- a/tests/test_action.py
+++ b/tests/test_action.py
@@ -184,8 +184,7 @@ def test_year_filter_value(global_var):
lambda vis: len(
list(
filter(
- lambda clause: clause.value != "" and clause.attribute == "Year",
- vis._intent,
+ lambda clause: clause.value != "" and clause.attribute == "Year", vis._intent,
)
)
)
@@ -215,16 +214,10 @@ def test_similarity(global_var):
ranked_list = df.recommendation["Similarity"]
japan_vis = list(
- filter(
- lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan",
- ranked_list,
- )
+ filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan", ranked_list,)
)[0]
europe_vis = list(
- filter(
- lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe",
- ranked_list,
- )
+ filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe", ranked_list,)
)[0]
assert japan_vis.score > europe_vis.score
df.clear_intent()
@@ -247,16 +240,10 @@ def test_similarity2():
ranked_list = df.recommendation["Similarity"]
morrisville_vis = list(
- filter(
- lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville",
- ranked_list,
- )
+ filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville", ranked_list,)
)[0]
watertown_vis = list(
- filter(
- lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown",
- ranked_list,
- )
+ filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown", ranked_list,)
)[0]
assert morrisville_vis.score > watertown_vis.score
diff --git a/tests/test_compiler.py b/tests/test_compiler.py
index 28e80f8a..d30b824f 100644
--- a/tests/test_compiler.py
+++ b/tests/test_compiler.py
@@ -168,10 +168,7 @@ def test_underspecified_vis_collection_zval(global_var):
# check if the number of charts is correct
df = pytest.car_df
vlst = VisList(
- [
- lux.Clause(attribute="Origin", filter_op="=", value="?"),
- lux.Clause(attribute="MilesPerGal"),
- ],
+ [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),],
df,
)
assert len(vlst) == 3
@@ -185,10 +182,7 @@ def test_underspecified_vis_collection_zval(global_var):
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
vlst = VisList(
- [
- lux.Clause(attribute="origin", filter_op="=", value="?"),
- lux.Clause(attribute="milespergal"),
- ],
+ [lux.Clause(attribute="origin", filter_op="=", value="?"), lux.Clause(attribute="milespergal"),],
sql_df,
)
assert len(vlst) == 3
@@ -285,10 +279,7 @@ def test_specified_channel_enforced_vis_collection(global_var):
df = pytest.car_df
# change pandas dtype for the column "Year" to datetype
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
- visList = VisList(
- [lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")],
- df,
- )
+ visList = VisList([lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], df,)
for vis in visList:
check_attribute_on_channel(vis, "MilesPerGal", "x")
@@ -304,22 +295,13 @@ def test_autoencoding_scatter(global_var):
check_attribute_on_channel(vis, "Weight", "y")
# Partial channel specified
- vis = Vis(
- [
- lux.Clause(attribute="MilesPerGal", channel="y"),
- lux.Clause(attribute="Weight"),
- ],
- df,
- )
+ vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,)
check_attribute_on_channel(vis, "MilesPerGal", "y")
check_attribute_on_channel(vis, "Weight", "x")
# Full channel specified
vis = Vis(
- [
- lux.Clause(attribute="MilesPerGal", channel="y"),
- lux.Clause(attribute="Weight", channel="x"),
- ],
+ [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),],
df,
)
check_attribute_on_channel(vis, "MilesPerGal", "y")
@@ -339,8 +321,7 @@ def test_autoencoding_scatter(global_var):
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
visList = VisList(
- [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")],
- sql_df,
+ [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")], sql_df,
)
for vis in visList:
check_attribute_on_channel(vis, "milespergal", "x")
@@ -358,22 +339,13 @@ def test_autoencoding_scatter():
check_attribute_on_channel(vis, "Weight", "y")
# Partial channel specified
- vis = Vis(
- [
- lux.Clause(attribute="MilesPerGal", channel="y"),
- lux.Clause(attribute="Weight"),
- ],
- df,
- )
+ vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,)
check_attribute_on_channel(vis, "MilesPerGal", "y")
check_attribute_on_channel(vis, "Weight", "x")
# Full channel specified
vis = Vis(
- [
- lux.Clause(attribute="MilesPerGal", channel="y"),
- lux.Clause(attribute="Weight", channel="x"),
- ],
+ [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),],
df,
)
check_attribute_on_channel(vis, "MilesPerGal", "y")
@@ -398,21 +370,14 @@ def test_autoencoding_scatter():
# Partial channel specified
vis = Vis(
- [
- lux.Clause(attribute="milespergal", channel="y"),
- lux.Clause(attribute="weight"),
- ],
- sql_df,
+ [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight"),], sql_df,
)
check_attribute_on_channel(vis, "milespergal", "y")
check_attribute_on_channel(vis, "weight", "x")
# Full channel specified
vis = Vis(
- [
- lux.Clause(attribute="milespergal", channel="y"),
- lux.Clause(attribute="weight", channel="x"),
- ],
+ [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight", channel="x"),],
sql_df,
)
check_attribute_on_channel(vis, "milespergal", "y")
@@ -464,22 +429,13 @@ def test_autoencoding_line_chart(global_var):
check_attribute_on_channel(vis, "Acceleration", "y")
# Partial channel specified
- vis = Vis(
- [
- lux.Clause(attribute="Year", channel="y"),
- lux.Clause(attribute="Acceleration"),
- ],
- df,
- )
+ vis = Vis([lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration"),], df,)
check_attribute_on_channel(vis, "Year", "y")
check_attribute_on_channel(vis, "Acceleration", "x")
# Full channel specified
vis = Vis(
- [
- lux.Clause(attribute="Year", channel="y"),
- lux.Clause(attribute="Acceleration", channel="x"),
- ],
+ [lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration", channel="x"),],
df,
)
check_attribute_on_channel(vis, "Year", "y")
@@ -505,21 +461,14 @@ def test_autoencoding_line_chart(global_var):
# Partial channel specified
vis = Vis(
- [
- lux.Clause(attribute="year", channel="y"),
- lux.Clause(attribute="acceleration"),
- ],
- sql_df,
+ [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration"),], sql_df,
)
check_attribute_on_channel(vis, "year", "y")
check_attribute_on_channel(vis, "acceleration", "x")
# Full channel specified
vis = Vis(
- [
- lux.Clause(attribute="year", channel="y"),
- lux.Clause(attribute="acceleration", channel="x"),
- ],
+ [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration", channel="x"),],
sql_df,
)
check_attribute_on_channel(vis, "year", "y")
@@ -628,10 +577,7 @@ def test_populate_options(global_var):
assert list_equal(list(col_set), list(df.columns))
df.set_intent(
- [
- lux.Clause(attribute="?", data_model="measure"),
- lux.Clause(attribute="MilesPerGal"),
- ]
+ [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="MilesPerGal"),]
)
df._repr_html_()
col_set = set()
@@ -639,8 +585,7 @@ def test_populate_options(global_var):
for clause in specOptions:
col_set.add(clause.attribute)
assert list_equal(
- list(col_set),
- ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"],
+ list(col_set), ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"],
)
df.clear_intent()
@@ -656,10 +601,7 @@ def test_populate_options(global_var):
assert list_equal(list(col_set), list(sql_df.columns))
sql_df.set_intent(
- [
- lux.Clause(attribute="?", data_model="measure"),
- lux.Clause(attribute="milespergal"),
- ]
+ [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="milespergal"),]
)
sql_df._repr_html_()
col_set = set()
@@ -667,8 +609,7 @@ def test_populate_options(global_var):
for clause in specOptions:
col_set.add(clause.attribute)
assert list_equal(
- list(col_set),
- ["acceleration", "weight", "horsepower", "milespergal", "displacement"],
+ list(col_set), ["acceleration", "weight", "horsepower", "milespergal", "displacement"],
)
@@ -678,10 +619,7 @@ def test_remove_all_invalid(global_var):
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
# with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"):
df.set_intent(
- [
- lux.Clause(attribute="Origin", filter_op="=", value="USA"),
- lux.Clause(attribute="Origin"),
- ]
+ [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Origin"),]
)
df._repr_html_()
assert len(df.current_vis) == 0
@@ -693,10 +631,7 @@ def test_remove_all_invalid(global_var):
sql_df = lux.LuxSQLTable(table_name="cars")
# with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"):
sql_df.set_intent(
- [
- lux.Clause(attribute="origin", filter_op="=", value="USA"),
- lux.Clause(attribute="origin"),
- ]
+ [lux.Clause(attribute="origin", filter_op="=", value="USA"), lux.Clause(attribute="origin"),]
)
sql_df._repr_html_()
assert len(sql_df.current_vis) == 0
diff --git a/tests/test_dates.py b/tests/test_dates.py
index dc530fc7..48d514b9 100644
--- a/tests/test_dates.py
+++ b/tests/test_dates.py
@@ -44,10 +44,7 @@ def test_period_selection(global_var):
ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq="A")
ldf.set_intent(
- [
- lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]),
- lux.Clause(attribute="Year"),
- ]
+ [lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), lux.Clause(attribute="Year"),]
)
lux.config.executor.execute(ldf.current_vis, ldf)
diff --git a/tests/test_interestingness.py b/tests/test_interestingness.py
index 7e6036f9..1ef929e3 100644
--- a/tests/test_interestingness.py
+++ b/tests/test_interestingness.py
@@ -64,10 +64,7 @@ def test_interestingness_1_0_1(global_var):
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
df.set_intent(
- [
- lux.Clause(attribute="Origin", filter_op="=", value="USA"),
- lux.Clause(attribute="Cylinders"),
- ]
+ [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Cylinders"),]
)
df._repr_html_()
assert df.current_vis[0].score == 0
@@ -124,10 +121,7 @@ def test_interestingness_0_1_1(global_var):
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
df.set_intent(
- [
- lux.Clause(attribute="Origin", filter_op="=", value="?"),
- lux.Clause(attribute="MilesPerGal"),
- ]
+ [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),]
)
df._repr_html_()
assert interestingness(df.recommendation["Current Vis"][0], df) != None
diff --git a/tests/test_pandas_coverage.py b/tests/test_pandas_coverage.py
index 21014f60..e30b0aac 100644
--- a/tests/test_pandas_coverage.py
+++ b/tests/test_pandas_coverage.py
@@ -173,8 +173,7 @@ def test_groupby_agg_big(global_var):
assert len(new_df.cardinality) == 8
year_vis = list(
filter(
- lambda vis: vis.get_attr_by_attr_name("Year") != [],
- new_df.recommendation["Column Groups"],
+ lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Column Groups"],
)
)[0]
assert year_vis.mark == "bar"
@@ -182,10 +181,7 @@ def test_groupby_agg_big(global_var):
new_df = new_df.T
new_df._repr_html_()
year_vis = list(
- filter(
- lambda vis: vis.get_attr_by_attr_name("Year") != [],
- new_df.recommendation["Row Groups"],
- )
+ filter(lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Row Groups"],)
)[0]
assert year_vis.mark == "bar"
assert year_vis.get_attr_by_channel("x")[0].attribute == "Year"
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 333977aa..d274a1f4 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -121,7 +121,6 @@ def test_validator_invalid_attribute(global_var):
df = pytest.college_df
with pytest.raises(KeyError, match="'blah'"):
with pytest.warns(
- UserWarning,
- match="The input attribute 'blah' does not exist in the DataFrame.",
+ UserWarning, match="The input attribute 'blah' does not exist in the DataFrame.",
):
df.intent = ["blah"]
diff --git a/tests/test_vis.py b/tests/test_vis.py
index 55e10b39..0f210d82 100644
--- a/tests/test_vis.py
+++ b/tests/test_vis.py
@@ -154,10 +154,7 @@ def test_vis_list_custom_title_override(global_var):
vcLst = []
for attribute in ["Sport", "Year", "Height", "HostRegion", "SportType"]:
- vis = Vis(
- [lux.Clause("Weight"), lux.Clause(attribute)],
- title="overriding dummy title",
- )
+ vis = Vis([lux.Clause("Weight"), lux.Clause(attribute)], title="overriding dummy title",)
vcLst.append(vis)
vlist = VisList(vcLst, df)
for v in vlist:
From d68b16d712a8203157a5a00ef6b0f8fcd7c50eaf Mon Sep 17 00:00:00 2001
From: Sophia Huang <6860749+sophiahhuang@users.noreply.github.com>
Date: Sat, 20 Mar 2021 02:10:55 +0800
Subject: [PATCH 3/4] Remove Out of Date LuxSQLTableNotice
---
lux/core/sqltable.py | 12 +-----------
1 file changed, 1 insertion(+), 11 deletions(-)
diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py
index 6e557a0a..842e0fb9 100644
--- a/lux/core/sqltable.py
+++ b/lux/core/sqltable.py
@@ -122,17 +122,7 @@ def _repr_html_(self):
)
self.output = widgets.Output()
lux.config.executor.execute_preview(self)
- notice = HTML(
- """
- Please note, the data shown here is just a preview of the database table. You will be unable to perform Pandas functionality on this data."""
- )
-
- display(button, notice, self.output)
+ display(button, self.output)
def on_button_clicked(b):
with self.output:
From f190a55d4e26bec8766ab7f178b0298bb8fca329 Mon Sep 17 00:00:00 2001
From: 19thyneb
Date: Fri, 19 Mar 2021 11:29:55 -0700
Subject: [PATCH 4/4] Black Reformatting
---
lux/_config/config.py | 21 +++--
lux/action/generalize.py | 7 +-
lux/action/row_group.py | 8 +-
lux/core/frame.py | 6 +-
lux/core/series.py | 3 +-
lux/core/sqltable.py | 3 +-
lux/executor/PandasExecutor.py | 10 ++-
lux/executor/SQLExecutor.py | 97 ++++++++++++++---------
lux/interestingness/interestingness.py | 6 +-
lux/interestingness/similarity.py | 3 +-
lux/processor/Compiler.py | 9 ++-
lux/vislib/altair/AltairRenderer.py | 3 +-
lux/vislib/altair/Choropleth.py | 4 +-
lux/vislib/altair/Heatmap.py | 5 +-
tests/test_action.py | 23 ++++--
tests/test_compiler.py | 105 ++++++++++++++++++++-----
tests/test_dates.py | 5 +-
tests/test_interestingness.py | 10 ++-
tests/test_pandas_coverage.py | 8 +-
tests/test_parser.py | 3 +-
tests/test_vis.py | 5 +-
21 files changed, 254 insertions(+), 90 deletions(-)
diff --git a/lux/_config/config.py b/lux/_config/config.py
index 65c778ef..12db034a 100644
--- a/lux/_config/config.py
+++ b/lux/_config/config.py
@@ -55,7 +55,8 @@ def topk(self, k: Union[int, bool]):
self._topk = k
else:
warnings.warn(
- "Parameter to lux.config.topk must be an integer or a boolean.", stacklevel=2,
+ "Parameter to lux.config.topk must be an integer or a boolean.",
+ stacklevel=2,
)
@property
@@ -98,7 +99,8 @@ def pandas_fallback(self, fallback: bool) -> None:
self._pandas_fallback = fallback
else:
warnings.warn(
- "The flag for Pandas fallback must be a boolean.", stacklevel=2,
+ "The flag for Pandas fallback must be a boolean.",
+ stacklevel=2,
)
@property
@@ -118,7 +120,8 @@ def interestingness_fallback(self, fallback: bool) -> None:
self._interestingness_fallback = fallback
else:
warnings.warn(
- "The flag for interestingness fallback must be a boolean.", stacklevel=2,
+ "The flag for interestingness fallback must be a boolean.",
+ stacklevel=2,
)
@property
@@ -144,7 +147,8 @@ def sampling_cap(self, sample_number: int) -> None:
self._sampling_cap = sample_number
else:
warnings.warn(
- "The cap on the number samples must be an integer.", stacklevel=2,
+ "The cap on the number samples must be an integer.",
+ stacklevel=2,
)
@property
@@ -172,7 +176,8 @@ def sampling_start(self, sample_number: int) -> None:
self._sampling_start = sample_number
else:
warnings.warn(
- "The sampling starting point must be an integer.", stacklevel=2,
+ "The sampling starting point must be an integer.",
+ stacklevel=2,
)
@property
@@ -197,7 +202,8 @@ def sampling(self, sample_flag: bool) -> None:
self._sampling_flag = sample_flag
else:
warnings.warn(
- "The flag for sampling must be a boolean.", stacklevel=2,
+ "The flag for sampling must be a boolean.",
+ stacklevel=2,
)
@property
@@ -222,7 +228,8 @@ def heatmap(self, heatmap_flag: bool) -> None:
self._heatmap_flag = heatmap_flag
else:
warnings.warn(
- "The flag for enabling/disabling heatmaps must be a boolean.", stacklevel=2,
+ "The flag for enabling/disabling heatmaps must be a boolean.",
+ stacklevel=2,
)
@property
diff --git a/lux/action/generalize.py b/lux/action/generalize.py
index 84504132..385891d4 100644
--- a/lux/action/generalize.py
+++ b/lux/action/generalize.py
@@ -78,7 +78,12 @@ def generalize(ldf):
for clause in filters:
# new_spec = ldf._intent.copy()
# new_spec.remove_column_from_spec(new_spec.attribute)
- temp_vis = Vis(ldf.current_vis[0]._inferred_intent.copy(), source=ldf, title="Overall", score=0,)
+ temp_vis = Vis(
+ ldf.current_vis[0]._inferred_intent.copy(),
+ source=ldf,
+ title="Overall",
+ score=0,
+ )
temp_vis.remove_filter_from_spec(clause.value)
output.append(temp_vis)
diff --git a/lux/action/row_group.py b/lux/action/row_group.py
index 7407ab5b..5555c01f 100644
--- a/lux/action/row_group.py
+++ b/lux/action/row_group.py
@@ -50,7 +50,13 @@ def row_group(ldf):
# rowdf.cardinality["index"]=len(rowdf)
# if isinstance(ldf.columns,pd.DatetimeIndex):
# rowdf.data_type_lookup[dim_name]="temporal"
- vis = Vis([dim_name, lux.Clause(row.name, data_model="measure", aggregation=None),], rowdf,)
+ vis = Vis(
+ [
+ dim_name,
+ lux.Clause(row.name, data_model="measure", aggregation=None),
+ ],
+ rowdf,
+ )
collection.append(vis)
vlst = VisList(collection)
# Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data
diff --git a/lux/core/frame.py b/lux/core/frame.py
index f5fe5ca5..4514ed4d 100644
--- a/lux/core/frame.py
+++ b/lux/core/frame.py
@@ -496,7 +496,8 @@ def exported(self) -> Union[Dict[str, VisList], VisList]:
exported_vis = VisList(
list(
map(
- self._recommendation[export_action].__getitem__, exported_vis_lst[export_action],
+ self._recommendation[export_action].__getitem__,
+ exported_vis_lst[export_action],
)
)
)
@@ -566,7 +567,8 @@ def _repr_html_(self):
self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex")
button = widgets.Button(
- description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"),
+ description="Toggle Pandas/Lux",
+ layout=widgets.Layout(width="140px", top="5px"),
)
self.output = widgets.Output()
display(button, self.output)
diff --git a/lux/core/series.py b/lux/core/series.py
index 7cea6639..3c717642 100644
--- a/lux/core/series.py
+++ b/lux/core/series.py
@@ -154,7 +154,8 @@ def __repr__(self):
# box = widgets.Box(layout=widgets.Layout(display='inline'))
button = widgets.Button(
- description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"),
+ description="Toggle Pandas/Lux",
+ layout=widgets.Layout(width="140px", top="5px"),
)
ldf.output = widgets.Output()
# box.children = [button,output]
diff --git a/lux/core/sqltable.py b/lux/core/sqltable.py
index 842e0fb9..e51b8c09 100644
--- a/lux/core/sqltable.py
+++ b/lux/core/sqltable.py
@@ -84,7 +84,8 @@ def set_SQL_table(self, t_name):
error_str = str(error)
if f'relation "{t_name}" does not exist' in error_str:
warnings.warn(
- f"\nThe table '{t_name}' does not exist in your database./", stacklevel=2,
+ f"\nThe table '{t_name}' does not exist in your database./",
+ stacklevel=2,
)
def _repr_html_(self):
diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py
index 8983876a..e5534821 100644
--- a/lux/executor/PandasExecutor.py
+++ b/lux/executor/PandasExecutor.py
@@ -216,7 +216,10 @@ def execute_aggregate(vis: Vis, isFiltered=True):
}
)
vis._vis_data = vis.data.merge(
- df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"],
+ df,
+ on=[columns[0], columns[1]],
+ how="right",
+ suffixes=["", "_right"],
)
for col in columns[2:]:
vis.data[col] = vis.data[col].fillna(0) # Triggers __setitem__
@@ -364,7 +367,10 @@ def execute_2D_binning(vis: Vis):
if color_attr.data_type == "nominal":
# Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0])
result = groups.agg(
- [("count", "count"), (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),]
+ [
+ ("count", "count"),
+ (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),
+ ]
).reset_index()
elif color_attr.data_type == "quantitative" or color_attr.data_type == "temporal":
# Compute the average of all values in the bin
diff --git a/lux/executor/SQLExecutor.py b/lux/executor/SQLExecutor.py
index 93f010c0..abf43f06 100644
--- a/lux/executor/SQLExecutor.py
+++ b/lux/executor/SQLExecutor.py
@@ -38,7 +38,8 @@ def execute_sampling(tbl: LuxSQLTable):
SAMPLE_FRAC = 0.2
length_query = pandas.read_sql(
- "SELECT COUNT(*) as length FROM {}".format(tbl.table_name), lux.config.SQLconnection,
+ "SELECT COUNT(*) as length FROM {}".format(tbl.table_name),
+ lux.config.SQLconnection,
)
limit = int(list(length_query["length"])[0]) * SAMPLE_FRAC
tbl._sampled = pandas.read_sql(
@@ -122,7 +123,8 @@ def add_quotes(var_name):
required_variables = ",".join(required_variables)
row_count = list(
pandas.read_sql(
- f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", lux.config.SQLconnection,
+ f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}",
+ lux.config.SQLconnection,
)["count"]
)[0]
if row_count > lux.config.sampling_cap:
@@ -227,42 +229,48 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True):
# generates query for colored barchart case
if has_color:
if agg_func == "mean":
- agg_query = 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
- groupby_attr.attribute,
- color_attr.attribute,
- measure_attr.attribute,
- measure_attr.attribute,
- tbl.table_name,
- where_clause,
- groupby_attr.attribute,
- color_attr.attribute,
+ agg_query = (
+ 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
+ groupby_attr.attribute,
+ color_attr.attribute,
+ measure_attr.attribute,
+ measure_attr.attribute,
+ tbl.table_name,
+ where_clause,
+ groupby_attr.attribute,
+ color_attr.attribute,
+ )
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "sum":
- agg_query = 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
- groupby_attr.attribute,
- color_attr.attribute,
- measure_attr.attribute,
- measure_attr.attribute,
- tbl.table_name,
- where_clause,
- groupby_attr.attribute,
- color_attr.attribute,
+ agg_query = (
+ 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
+ groupby_attr.attribute,
+ color_attr.attribute,
+ measure_attr.attribute,
+ measure_attr.attribute,
+ tbl.table_name,
+ where_clause,
+ groupby_attr.attribute,
+ color_attr.attribute,
+ )
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "max":
- agg_query = 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
- groupby_attr.attribute,
- color_attr.attribute,
- measure_attr.attribute,
- measure_attr.attribute,
- tbl.table_name,
- where_clause,
- groupby_attr.attribute,
- color_attr.attribute,
+ agg_query = (
+ 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
+ groupby_attr.attribute,
+ color_attr.attribute,
+ measure_attr.attribute,
+ measure_attr.attribute,
+ tbl.table_name,
+ where_clause,
+ groupby_attr.attribute,
+ color_attr.attribute,
+ )
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
@@ -322,7 +330,10 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True):
}
)
view._vis_data = view._vis_data.merge(
- df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"],
+ df,
+ on=[columns[0], columns[1]],
+ how="right",
+ suffixes=["", "_right"],
)
for col in columns[2:]:
view._vis_data[col] = view._vis_data[col].fillna(0) # Triggers __setitem__
@@ -391,7 +402,10 @@ def execute_binning(view: Vis, tbl: LuxSQLTable):
upper_edges = ",".join(upper_edges)
view_filter, filter_vars = SQLExecutor.execute_filter(view)
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
- bin_attribute.attribute, "{" + upper_edges + "}", tbl.table_name, where_clause,
+ bin_attribute.attribute,
+ "{" + upper_edges + "}",
+ tbl.table_name,
+ where_clause,
)
bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection)
@@ -406,11 +420,13 @@ def execute_binning(view: Vis, tbl: LuxSQLTable):
else:
bin_centers = np.array([(attr_min + attr_min + bin_width) / 2])
bin_centers = np.append(
- bin_centers, np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
+ bin_centers,
+ np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
)
if attr_type == int:
bin_centers = np.append(
- bin_centers, math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
+ bin_centers,
+ math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
)
else:
bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2)
@@ -551,7 +567,10 @@ def execute_filter(view: Vis):
else:
where_clause.append("AND")
where_clause.extend(
- ['"' + str(a.attribute) + '"', "IS NOT NULL",]
+ [
+ '"' + str(a.attribute) + '"',
+ "IS NOT NULL",
+ ]
)
if where_clause == []:
@@ -668,7 +687,10 @@ def get_cardinality(self, tbl: LuxSQLTable):
card_query = 'SELECT Count(Distinct("{}")) FROM {} WHERE "{}" IS NOT NULL'.format(
attr, tbl.table_name, attr
)
- card_data = pandas.read_sql(card_query, lux.config.SQLconnection,)
+ card_data = pandas.read_sql(
+ card_query,
+ lux.config.SQLconnection,
+ )
cardinality[attr] = list(card_data["count"])[0]
tbl.cardinality = cardinality
@@ -691,7 +713,10 @@ def get_unique_values(self, tbl: LuxSQLTable):
unique_query = 'SELECT Distinct("{}") FROM {} WHERE "{}" IS NOT NULL'.format(
attr, tbl.table_name, attr
)
- unique_data = pandas.read_sql(unique_query, lux.config.SQLconnection,)
+ unique_data = pandas.read_sql(
+ unique_query,
+ lux.config.SQLconnection,
+ )
unique_vals[attr] = list(unique_data[attr])
tbl.unique_values = unique_vals
diff --git a/lux/interestingness/interestingness.py b/lux/interestingness/interestingness.py
index a95efea9..32b0f1db 100644
--- a/lux/interestingness/interestingness.py
+++ b/lux/interestingness/interestingness.py
@@ -198,7 +198,11 @@ def weighted_correlation(x, y, w):
def deviation_from_overall(
- vis: Vis, ldf: LuxDataFrame, filter_specs: list, msr_attribute: str, exclude_nan: bool = True,
+ vis: Vis,
+ ldf: LuxDataFrame,
+ filter_specs: list,
+ msr_attribute: str,
+ exclude_nan: bool = True,
) -> int:
"""
Difference in bar chart/histogram shape from overall chart
diff --git a/lux/interestingness/similarity.py b/lux/interestingness/similarity.py
index 017b97b5..8d810909 100644
--- a/lux/interestingness/similarity.py
+++ b/lux/interestingness/similarity.py
@@ -68,7 +68,8 @@ def interpolate(vis, length):
yVals[count - 1] + (interpolated_x - xVals[count - 1]) / x_diff * yDiff
)
vis.data = pd.DataFrame(
- list(zip(interpolated_x_vals, interpolated_y_vals)), columns=[xAxis, yAxis],
+ list(zip(interpolated_x_vals, interpolated_y_vals)),
+ columns=[xAxis, yAxis],
)
diff --git a/lux/processor/Compiler.py b/lux/processor/Compiler.py
index 85c49cfb..96d95d10 100644
--- a/lux/processor/Compiler.py
+++ b/lux/processor/Compiler.py
@@ -281,7 +281,10 @@ def line_or_bar_or_geo(ldf, dimension: Clause, measure: Clause):
# ShowMe logic + additional heuristics
# count_col = Clause( attribute="count()", data_model="measure")
count_col = Clause(
- attribute="Record", aggregation="count", data_model="measure", data_type="quantitative",
+ attribute="Record",
+ aggregation="count",
+ data_model="measure",
+ data_type="quantitative",
)
auto_channel = {}
if ndim == 0 and nmsr == 1:
@@ -472,7 +475,9 @@ def populate_wildcard_options(_inferred_intent: List[Clause], ldf: LuxDataFrame)
options = ldf.unique_values[attr]
specInd = _inferred_intent.index(clause)
_inferred_intent[specInd] = Clause(
- attribute=clause.attribute, filter_op="=", value=list(options),
+ attribute=clause.attribute,
+ filter_op="=",
+ value=list(options),
)
else:
options.extend(convert_to_list(clause.value))
diff --git a/lux/vislib/altair/AltairRenderer.py b/lux/vislib/altair/AltairRenderer.py
index deba1592..2c1ab206 100644
--- a/lux/vislib/altair/AltairRenderer.py
+++ b/lux/vislib/altair/AltairRenderer.py
@@ -128,7 +128,8 @@ def create_vis(self, vis, standalone=True):
found_variable = "df"
if standalone:
chart.code = chart.code.replace(
- "placeholder_variable", f"pd.DataFrame({str(vis.data.to_dict())})",
+ "placeholder_variable",
+ f"pd.DataFrame({str(vis.data.to_dict())})",
)
else:
# TODO: Placeholder (need to read dynamically via locals())
diff --git a/lux/vislib/altair/Choropleth.py b/lux/vislib/altair/Choropleth.py
index c2675883..bf71b010 100644
--- a/lux/vislib/altair/Choropleth.py
+++ b/lux/vislib/altair/Choropleth.py
@@ -68,7 +68,9 @@ def initialize_chart(self):
points = (
alt.Chart(geo_map)
.mark_geoshape()
- .encode(color=f"{y_attr_abv}:Q",)
+ .encode(
+ color=f"{y_attr_abv}:Q",
+ )
.transform_lookup(lookup="id", from_=alt.LookupData(self.data, x_attr_abv, [y_attr_abv]))
.project(type=map_type)
.properties(
diff --git a/lux/vislib/altair/Heatmap.py b/lux/vislib/altair/Heatmap.py
index 2f743b32..f83a3bbb 100644
--- a/lux/vislib/altair/Heatmap.py
+++ b/lux/vislib/altair/Heatmap.py
@@ -71,7 +71,10 @@ def initialize_chart(self):
),
y2=alt.Y2("yBinEnd"),
opacity=alt.Opacity(
- "count", type="quantitative", scale=alt.Scale(type="log"), legend=None,
+ "count",
+ type="quantitative",
+ scale=alt.Scale(type="log"),
+ legend=None,
),
)
)
diff --git a/tests/test_action.py b/tests/test_action.py
index 2cb49e86..893c13b7 100644
--- a/tests/test_action.py
+++ b/tests/test_action.py
@@ -184,7 +184,8 @@ def test_year_filter_value(global_var):
lambda vis: len(
list(
filter(
- lambda clause: clause.value != "" and clause.attribute == "Year", vis._intent,
+ lambda clause: clause.value != "" and clause.attribute == "Year",
+ vis._intent,
)
)
)
@@ -214,10 +215,16 @@ def test_similarity(global_var):
ranked_list = df.recommendation["Similarity"]
japan_vis = list(
- filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan", ranked_list,)
+ filter(
+ lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Japan",
+ ranked_list,
+ )
)[0]
europe_vis = list(
- filter(lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe", ranked_list,)
+ filter(
+ lambda vis: vis.get_attr_by_attr_name("Origin")[0].value == "Europe",
+ ranked_list,
+ )
)[0]
assert japan_vis.score > europe_vis.score
df.clear_intent()
@@ -240,10 +247,16 @@ def test_similarity2():
ranked_list = df.recommendation["Similarity"]
morrisville_vis = list(
- filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville", ranked_list,)
+ filter(
+ lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Morrisville",
+ ranked_list,
+ )
)[0]
watertown_vis = list(
- filter(lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown", ranked_list,)
+ filter(
+ lambda vis: vis.get_attr_by_attr_name("City")[0].value == "Watertown",
+ ranked_list,
+ )
)[0]
assert morrisville_vis.score > watertown_vis.score
diff --git a/tests/test_compiler.py b/tests/test_compiler.py
index d30b824f..28e80f8a 100644
--- a/tests/test_compiler.py
+++ b/tests/test_compiler.py
@@ -168,7 +168,10 @@ def test_underspecified_vis_collection_zval(global_var):
# check if the number of charts is correct
df = pytest.car_df
vlst = VisList(
- [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),],
+ [
+ lux.Clause(attribute="Origin", filter_op="=", value="?"),
+ lux.Clause(attribute="MilesPerGal"),
+ ],
df,
)
assert len(vlst) == 3
@@ -182,7 +185,10 @@ def test_underspecified_vis_collection_zval(global_var):
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
vlst = VisList(
- [lux.Clause(attribute="origin", filter_op="=", value="?"), lux.Clause(attribute="milespergal"),],
+ [
+ lux.Clause(attribute="origin", filter_op="=", value="?"),
+ lux.Clause(attribute="milespergal"),
+ ],
sql_df,
)
assert len(vlst) == 3
@@ -279,7 +285,10 @@ def test_specified_channel_enforced_vis_collection(global_var):
df = pytest.car_df
# change pandas dtype for the column "Year" to datetype
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
- visList = VisList([lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")], df,)
+ visList = VisList(
+ [lux.Clause(attribute="?"), lux.Clause(attribute="MilesPerGal", channel="x")],
+ df,
+ )
for vis in visList:
check_attribute_on_channel(vis, "MilesPerGal", "x")
@@ -295,13 +304,22 @@ def test_autoencoding_scatter(global_var):
check_attribute_on_channel(vis, "Weight", "y")
# Partial channel specified
- vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,)
+ vis = Vis(
+ [
+ lux.Clause(attribute="MilesPerGal", channel="y"),
+ lux.Clause(attribute="Weight"),
+ ],
+ df,
+ )
check_attribute_on_channel(vis, "MilesPerGal", "y")
check_attribute_on_channel(vis, "Weight", "x")
# Full channel specified
vis = Vis(
- [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),],
+ [
+ lux.Clause(attribute="MilesPerGal", channel="y"),
+ lux.Clause(attribute="Weight", channel="x"),
+ ],
df,
)
check_attribute_on_channel(vis, "MilesPerGal", "y")
@@ -321,7 +339,8 @@ def test_autoencoding_scatter(global_var):
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
visList = VisList(
- [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")], sql_df,
+ [lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")],
+ sql_df,
)
for vis in visList:
check_attribute_on_channel(vis, "milespergal", "x")
@@ -339,13 +358,22 @@ def test_autoencoding_scatter():
check_attribute_on_channel(vis, "Weight", "y")
# Partial channel specified
- vis = Vis([lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight"),], df,)
+ vis = Vis(
+ [
+ lux.Clause(attribute="MilesPerGal", channel="y"),
+ lux.Clause(attribute="Weight"),
+ ],
+ df,
+ )
check_attribute_on_channel(vis, "MilesPerGal", "y")
check_attribute_on_channel(vis, "Weight", "x")
# Full channel specified
vis = Vis(
- [lux.Clause(attribute="MilesPerGal", channel="y"), lux.Clause(attribute="Weight", channel="x"),],
+ [
+ lux.Clause(attribute="MilesPerGal", channel="y"),
+ lux.Clause(attribute="Weight", channel="x"),
+ ],
df,
)
check_attribute_on_channel(vis, "MilesPerGal", "y")
@@ -370,14 +398,21 @@ def test_autoencoding_scatter():
# Partial channel specified
vis = Vis(
- [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight"),], sql_df,
+ [
+ lux.Clause(attribute="milespergal", channel="y"),
+ lux.Clause(attribute="weight"),
+ ],
+ sql_df,
)
check_attribute_on_channel(vis, "milespergal", "y")
check_attribute_on_channel(vis, "weight", "x")
# Full channel specified
vis = Vis(
- [lux.Clause(attribute="milespergal", channel="y"), lux.Clause(attribute="weight", channel="x"),],
+ [
+ lux.Clause(attribute="milespergal", channel="y"),
+ lux.Clause(attribute="weight", channel="x"),
+ ],
sql_df,
)
check_attribute_on_channel(vis, "milespergal", "y")
@@ -429,13 +464,22 @@ def test_autoencoding_line_chart(global_var):
check_attribute_on_channel(vis, "Acceleration", "y")
# Partial channel specified
- vis = Vis([lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration"),], df,)
+ vis = Vis(
+ [
+ lux.Clause(attribute="Year", channel="y"),
+ lux.Clause(attribute="Acceleration"),
+ ],
+ df,
+ )
check_attribute_on_channel(vis, "Year", "y")
check_attribute_on_channel(vis, "Acceleration", "x")
# Full channel specified
vis = Vis(
- [lux.Clause(attribute="Year", channel="y"), lux.Clause(attribute="Acceleration", channel="x"),],
+ [
+ lux.Clause(attribute="Year", channel="y"),
+ lux.Clause(attribute="Acceleration", channel="x"),
+ ],
df,
)
check_attribute_on_channel(vis, "Year", "y")
@@ -461,14 +505,21 @@ def test_autoencoding_line_chart(global_var):
# Partial channel specified
vis = Vis(
- [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration"),], sql_df,
+ [
+ lux.Clause(attribute="year", channel="y"),
+ lux.Clause(attribute="acceleration"),
+ ],
+ sql_df,
)
check_attribute_on_channel(vis, "year", "y")
check_attribute_on_channel(vis, "acceleration", "x")
# Full channel specified
vis = Vis(
- [lux.Clause(attribute="year", channel="y"), lux.Clause(attribute="acceleration", channel="x"),],
+ [
+ lux.Clause(attribute="year", channel="y"),
+ lux.Clause(attribute="acceleration", channel="x"),
+ ],
sql_df,
)
check_attribute_on_channel(vis, "year", "y")
@@ -577,7 +628,10 @@ def test_populate_options(global_var):
assert list_equal(list(col_set), list(df.columns))
df.set_intent(
- [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="MilesPerGal"),]
+ [
+ lux.Clause(attribute="?", data_model="measure"),
+ lux.Clause(attribute="MilesPerGal"),
+ ]
)
df._repr_html_()
col_set = set()
@@ -585,7 +639,8 @@ def test_populate_options(global_var):
for clause in specOptions:
col_set.add(clause.attribute)
assert list_equal(
- list(col_set), ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"],
+ list(col_set),
+ ["Acceleration", "Weight", "Horsepower", "MilesPerGal", "Displacement"],
)
df.clear_intent()
@@ -601,7 +656,10 @@ def test_populate_options(global_var):
assert list_equal(list(col_set), list(sql_df.columns))
sql_df.set_intent(
- [lux.Clause(attribute="?", data_model="measure"), lux.Clause(attribute="milespergal"),]
+ [
+ lux.Clause(attribute="?", data_model="measure"),
+ lux.Clause(attribute="milespergal"),
+ ]
)
sql_df._repr_html_()
col_set = set()
@@ -609,7 +667,8 @@ def test_populate_options(global_var):
for clause in specOptions:
col_set.add(clause.attribute)
assert list_equal(
- list(col_set), ["acceleration", "weight", "horsepower", "milespergal", "displacement"],
+ list(col_set),
+ ["acceleration", "weight", "horsepower", "milespergal", "displacement"],
)
@@ -619,7 +678,10 @@ def test_remove_all_invalid(global_var):
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
# with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"):
df.set_intent(
- [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Origin"),]
+ [
+ lux.Clause(attribute="Origin", filter_op="=", value="USA"),
+ lux.Clause(attribute="Origin"),
+ ]
)
df._repr_html_()
assert len(df.current_vis) == 0
@@ -631,7 +693,10 @@ def test_remove_all_invalid(global_var):
sql_df = lux.LuxSQLTable(table_name="cars")
# with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"):
sql_df.set_intent(
- [lux.Clause(attribute="origin", filter_op="=", value="USA"), lux.Clause(attribute="origin"),]
+ [
+ lux.Clause(attribute="origin", filter_op="=", value="USA"),
+ lux.Clause(attribute="origin"),
+ ]
)
sql_df._repr_html_()
assert len(sql_df.current_vis) == 0
diff --git a/tests/test_dates.py b/tests/test_dates.py
index 48d514b9..dc530fc7 100644
--- a/tests/test_dates.py
+++ b/tests/test_dates.py
@@ -44,7 +44,10 @@ def test_period_selection(global_var):
ldf["Year"] = pd.DatetimeIndex(ldf["Year"]).to_period(freq="A")
ldf.set_intent(
- [lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]), lux.Clause(attribute="Year"),]
+ [
+ lux.Clause(attribute=["Horsepower", "Weight", "Acceleration"]),
+ lux.Clause(attribute="Year"),
+ ]
)
lux.config.executor.execute(ldf.current_vis, ldf)
diff --git a/tests/test_interestingness.py b/tests/test_interestingness.py
index 1ef929e3..7e6036f9 100644
--- a/tests/test_interestingness.py
+++ b/tests/test_interestingness.py
@@ -64,7 +64,10 @@ def test_interestingness_1_0_1(global_var):
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
df.set_intent(
- [lux.Clause(attribute="Origin", filter_op="=", value="USA"), lux.Clause(attribute="Cylinders"),]
+ [
+ lux.Clause(attribute="Origin", filter_op="=", value="USA"),
+ lux.Clause(attribute="Cylinders"),
+ ]
)
df._repr_html_()
assert df.current_vis[0].score == 0
@@ -121,7 +124,10 @@ def test_interestingness_0_1_1(global_var):
df["Year"] = pd.to_datetime(df["Year"], format="%Y")
df.set_intent(
- [lux.Clause(attribute="Origin", filter_op="=", value="?"), lux.Clause(attribute="MilesPerGal"),]
+ [
+ lux.Clause(attribute="Origin", filter_op="=", value="?"),
+ lux.Clause(attribute="MilesPerGal"),
+ ]
)
df._repr_html_()
assert interestingness(df.recommendation["Current Vis"][0], df) != None
diff --git a/tests/test_pandas_coverage.py b/tests/test_pandas_coverage.py
index e30b0aac..21014f60 100644
--- a/tests/test_pandas_coverage.py
+++ b/tests/test_pandas_coverage.py
@@ -173,7 +173,8 @@ def test_groupby_agg_big(global_var):
assert len(new_df.cardinality) == 8
year_vis = list(
filter(
- lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Column Groups"],
+ lambda vis: vis.get_attr_by_attr_name("Year") != [],
+ new_df.recommendation["Column Groups"],
)
)[0]
assert year_vis.mark == "bar"
@@ -181,7 +182,10 @@ def test_groupby_agg_big(global_var):
new_df = new_df.T
new_df._repr_html_()
year_vis = list(
- filter(lambda vis: vis.get_attr_by_attr_name("Year") != [], new_df.recommendation["Row Groups"],)
+ filter(
+ lambda vis: vis.get_attr_by_attr_name("Year") != [],
+ new_df.recommendation["Row Groups"],
+ )
)[0]
assert year_vis.mark == "bar"
assert year_vis.get_attr_by_channel("x")[0].attribute == "Year"
diff --git a/tests/test_parser.py b/tests/test_parser.py
index d274a1f4..333977aa 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -121,6 +121,7 @@ def test_validator_invalid_attribute(global_var):
df = pytest.college_df
with pytest.raises(KeyError, match="'blah'"):
with pytest.warns(
- UserWarning, match="The input attribute 'blah' does not exist in the DataFrame.",
+ UserWarning,
+ match="The input attribute 'blah' does not exist in the DataFrame.",
):
df.intent = ["blah"]
diff --git a/tests/test_vis.py b/tests/test_vis.py
index 2a088ddb..4514be42 100644
--- a/tests/test_vis.py
+++ b/tests/test_vis.py
@@ -153,7 +153,10 @@ def test_vis_list_custom_title_override(global_var):
vcLst = []
for attribute in ["Sport", "Year", "Height", "HostRegion", "SportType"]:
- vis = Vis([lux.Clause("Weight"), lux.Clause(attribute)], title="overriding dummy title",)
+ vis = Vis(
+ [lux.Clause("Weight"), lux.Clause(attribute)],
+ title="overriding dummy title",
+ )
vcLst.append(vis)
vlist = VisList(vcLst, df)
for v in vlist: