Skip to content

Commit

Permalink
Black Reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
thyneb19 committed Mar 19, 2021
1 parent d68b16d commit f190a55
Show file tree
Hide file tree
Showing 21 changed files with 254 additions and 90 deletions.
21 changes: 14 additions & 7 deletions lux/_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def topk(self, k: Union[int, bool]):
self._topk = k
else:
warnings.warn(
"Parameter to lux.config.topk must be an integer or a boolean.", stacklevel=2,
"Parameter to lux.config.topk must be an integer or a boolean.",
stacklevel=2,
)

@property
Expand Down Expand Up @@ -98,7 +99,8 @@ def pandas_fallback(self, fallback: bool) -> None:
self._pandas_fallback = fallback
else:
warnings.warn(
"The flag for Pandas fallback must be a boolean.", stacklevel=2,
"The flag for Pandas fallback must be a boolean.",
stacklevel=2,
)

@property
Expand All @@ -118,7 +120,8 @@ def interestingness_fallback(self, fallback: bool) -> None:
self._interestingness_fallback = fallback
else:
warnings.warn(
"The flag for interestingness fallback must be a boolean.", stacklevel=2,
"The flag for interestingness fallback must be a boolean.",
stacklevel=2,
)

@property
Expand All @@ -144,7 +147,8 @@ def sampling_cap(self, sample_number: int) -> None:
self._sampling_cap = sample_number
else:
warnings.warn(
"The cap on the number samples must be an integer.", stacklevel=2,
"The cap on the number samples must be an integer.",
stacklevel=2,
)

@property
Expand Down Expand Up @@ -172,7 +176,8 @@ def sampling_start(self, sample_number: int) -> None:
self._sampling_start = sample_number
else:
warnings.warn(
"The sampling starting point must be an integer.", stacklevel=2,
"The sampling starting point must be an integer.",
stacklevel=2,
)

@property
Expand All @@ -197,7 +202,8 @@ def sampling(self, sample_flag: bool) -> None:
self._sampling_flag = sample_flag
else:
warnings.warn(
"The flag for sampling must be a boolean.", stacklevel=2,
"The flag for sampling must be a boolean.",
stacklevel=2,
)

@property
Expand All @@ -222,7 +228,8 @@ def heatmap(self, heatmap_flag: bool) -> None:
self._heatmap_flag = heatmap_flag
else:
warnings.warn(
"The flag for enabling/disabling heatmaps must be a boolean.", stacklevel=2,
"The flag for enabling/disabling heatmaps must be a boolean.",
stacklevel=2,
)

@property
Expand Down
7 changes: 6 additions & 1 deletion lux/action/generalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ def generalize(ldf):
for clause in filters:
# new_spec = ldf._intent.copy()
# new_spec.remove_column_from_spec(new_spec.attribute)
temp_vis = Vis(ldf.current_vis[0]._inferred_intent.copy(), source=ldf, title="Overall", score=0,)
temp_vis = Vis(
ldf.current_vis[0]._inferred_intent.copy(),
source=ldf,
title="Overall",
score=0,
)
temp_vis.remove_filter_from_spec(clause.value)
output.append(temp_vis)

Expand Down
8 changes: 7 additions & 1 deletion lux/action/row_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ def row_group(ldf):
# rowdf.cardinality["index"]=len(rowdf)
# if isinstance(ldf.columns,pd.DatetimeIndex):
# rowdf.data_type_lookup[dim_name]="temporal"
vis = Vis([dim_name, lux.Clause(row.name, data_model="measure", aggregation=None),], rowdf,)
vis = Vis(
[
dim_name,
lux.Clause(row.name, data_model="measure", aggregation=None),
],
rowdf,
)
collection.append(vis)
vlst = VisList(collection)
# Note that we are not computing interestingness score here because we want to preserve the arrangement of the aggregated data
Expand Down
6 changes: 4 additions & 2 deletions lux/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,8 @@ def exported(self) -> Union[Dict[str, VisList], VisList]:
exported_vis = VisList(
list(
map(
self._recommendation[export_action].__getitem__, exported_vis_lst[export_action],
self._recommendation[export_action].__getitem__,
exported_vis_lst[export_action],
)
)
)
Expand Down Expand Up @@ -566,7 +567,8 @@ def _repr_html_(self):
self._widget.observe(self.set_intent_on_click, names="selectedIntentIndex")

button = widgets.Button(
description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"),
description="Toggle Pandas/Lux",
layout=widgets.Layout(width="140px", top="5px"),
)
self.output = widgets.Output()
display(button, self.output)
Expand Down
3 changes: 2 additions & 1 deletion lux/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def __repr__(self):

# box = widgets.Box(layout=widgets.Layout(display='inline'))
button = widgets.Button(
description="Toggle Pandas/Lux", layout=widgets.Layout(width="140px", top="5px"),
description="Toggle Pandas/Lux",
layout=widgets.Layout(width="140px", top="5px"),
)
ldf.output = widgets.Output()
# box.children = [button,output]
Expand Down
3 changes: 2 additions & 1 deletion lux/core/sqltable.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def set_SQL_table(self, t_name):
error_str = str(error)
if f'relation "{t_name}" does not exist' in error_str:
warnings.warn(
f"\nThe table '{t_name}' does not exist in your database./", stacklevel=2,
f"\nThe table '{t_name}' does not exist in your database./",
stacklevel=2,
)

def _repr_html_(self):
Expand Down
10 changes: 8 additions & 2 deletions lux/executor/PandasExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ def execute_aggregate(vis: Vis, isFiltered=True):
}
)
vis._vis_data = vis.data.merge(
df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"],
df,
on=[columns[0], columns[1]],
how="right",
suffixes=["", "_right"],
)
for col in columns[2:]:
vis.data[col] = vis.data[col].fillna(0) # Triggers __setitem__
Expand Down Expand Up @@ -364,7 +367,10 @@ def execute_2D_binning(vis: Vis):
if color_attr.data_type == "nominal":
# Compute mode and count. Mode aggregates each cell by taking the majority vote for the category variable. In cases where there is ties across categories, pick the first item (.iat[0])
result = groups.agg(
[("count", "count"), (color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),]
[
("count", "count"),
(color_attr.attribute, lambda x: pd.Series.mode(x).iat[0]),
]
).reset_index()
elif color_attr.data_type == "quantitative" or color_attr.data_type == "temporal":
# Compute the average of all values in the bin
Expand Down
97 changes: 61 additions & 36 deletions lux/executor/SQLExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def execute_sampling(tbl: LuxSQLTable):
SAMPLE_FRAC = 0.2

length_query = pandas.read_sql(
"SELECT COUNT(*) as length FROM {}".format(tbl.table_name), lux.config.SQLconnection,
"SELECT COUNT(*) as length FROM {}".format(tbl.table_name),
lux.config.SQLconnection,
)
limit = int(list(length_query["length"])[0]) * SAMPLE_FRAC
tbl._sampled = pandas.read_sql(
Expand Down Expand Up @@ -122,7 +123,8 @@ def add_quotes(var_name):
required_variables = ",".join(required_variables)
row_count = list(
pandas.read_sql(
f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}", lux.config.SQLconnection,
f"SELECT COUNT(*) FROM {tbl.table_name} {where_clause}",
lux.config.SQLconnection,
)["count"]
)[0]
if row_count > lux.config.sampling_cap:
Expand Down Expand Up @@ -227,42 +229,48 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True):
# generates query for colored barchart case
if has_color:
if agg_func == "mean":
agg_query = 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
tbl.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
agg_query = (
'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
tbl.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
)
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)

view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "sum":
agg_query = 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
tbl.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
agg_query = (
'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
tbl.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
)
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "max":
agg_query = 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
tbl.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
agg_query = (
'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
tbl.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
)
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
Expand Down Expand Up @@ -322,7 +330,10 @@ def execute_aggregate(view: Vis, tbl: LuxSQLTable, isFiltered=True):
}
)
view._vis_data = view._vis_data.merge(
df, on=[columns[0], columns[1]], how="right", suffixes=["", "_right"],
df,
on=[columns[0], columns[1]],
how="right",
suffixes=["", "_right"],
)
for col in columns[2:]:
view._vis_data[col] = view._vis_data[col].fillna(0) # Triggers __setitem__
Expand Down Expand Up @@ -391,7 +402,10 @@ def execute_binning(view: Vis, tbl: LuxSQLTable):
upper_edges = ",".join(upper_edges)
view_filter, filter_vars = SQLExecutor.execute_filter(view)
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
bin_attribute.attribute, "{" + upper_edges + "}", tbl.table_name, where_clause,
bin_attribute.attribute,
"{" + upper_edges + "}",
tbl.table_name,
where_clause,
)

bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection)
Expand All @@ -406,11 +420,13 @@ def execute_binning(view: Vis, tbl: LuxSQLTable):
else:
bin_centers = np.array([(attr_min + attr_min + bin_width) / 2])
bin_centers = np.append(
bin_centers, np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
bin_centers,
np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
)
if attr_type == int:
bin_centers = np.append(
bin_centers, math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
bin_centers,
math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
)
else:
bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2)
Expand Down Expand Up @@ -551,7 +567,10 @@ def execute_filter(view: Vis):
else:
where_clause.append("AND")
where_clause.extend(
['"' + str(a.attribute) + '"', "IS NOT NULL",]
[
'"' + str(a.attribute) + '"',
"IS NOT NULL",
]
)

if where_clause == []:
Expand Down Expand Up @@ -668,7 +687,10 @@ def get_cardinality(self, tbl: LuxSQLTable):
card_query = 'SELECT Count(Distinct("{}")) FROM {} WHERE "{}" IS NOT NULL'.format(
attr, tbl.table_name, attr
)
card_data = pandas.read_sql(card_query, lux.config.SQLconnection,)
card_data = pandas.read_sql(
card_query,
lux.config.SQLconnection,
)
cardinality[attr] = list(card_data["count"])[0]
tbl.cardinality = cardinality

Expand All @@ -691,7 +713,10 @@ def get_unique_values(self, tbl: LuxSQLTable):
unique_query = 'SELECT Distinct("{}") FROM {} WHERE "{}" IS NOT NULL'.format(
attr, tbl.table_name, attr
)
unique_data = pandas.read_sql(unique_query, lux.config.SQLconnection,)
unique_data = pandas.read_sql(
unique_query,
lux.config.SQLconnection,
)
unique_vals[attr] = list(unique_data[attr])
tbl.unique_values = unique_vals

Expand Down
6 changes: 5 additions & 1 deletion lux/interestingness/interestingness.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def weighted_correlation(x, y, w):


def deviation_from_overall(
vis: Vis, ldf: LuxDataFrame, filter_specs: list, msr_attribute: str, exclude_nan: bool = True,
vis: Vis,
ldf: LuxDataFrame,
filter_specs: list,
msr_attribute: str,
exclude_nan: bool = True,
) -> int:
"""
Difference in bar chart/histogram shape from overall chart
Expand Down
3 changes: 2 additions & 1 deletion lux/interestingness/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def interpolate(vis, length):
yVals[count - 1] + (interpolated_x - xVals[count - 1]) / x_diff * yDiff
)
vis.data = pd.DataFrame(
list(zip(interpolated_x_vals, interpolated_y_vals)), columns=[xAxis, yAxis],
list(zip(interpolated_x_vals, interpolated_y_vals)),
columns=[xAxis, yAxis],
)


Expand Down
9 changes: 7 additions & 2 deletions lux/processor/Compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,10 @@ def line_or_bar_or_geo(ldf, dimension: Clause, measure: Clause):
# ShowMe logic + additional heuristics
# count_col = Clause( attribute="count()", data_model="measure")
count_col = Clause(
attribute="Record", aggregation="count", data_model="measure", data_type="quantitative",
attribute="Record",
aggregation="count",
data_model="measure",
data_type="quantitative",
)
auto_channel = {}
if ndim == 0 and nmsr == 1:
Expand Down Expand Up @@ -472,7 +475,9 @@ def populate_wildcard_options(_inferred_intent: List[Clause], ldf: LuxDataFrame)
options = ldf.unique_values[attr]
specInd = _inferred_intent.index(clause)
_inferred_intent[specInd] = Clause(
attribute=clause.attribute, filter_op="=", value=list(options),
attribute=clause.attribute,
filter_op="=",
value=list(options),
)
else:
options.extend(convert_to_list(clause.value))
Expand Down
3 changes: 2 additions & 1 deletion lux/vislib/altair/AltairRenderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def create_vis(self, vis, standalone=True):
found_variable = "df"
if standalone:
chart.code = chart.code.replace(
"placeholder_variable", f"pd.DataFrame({str(vis.data.to_dict())})",
"placeholder_variable",
f"pd.DataFrame({str(vis.data.to_dict())})",
)
else:
# TODO: Placeholder (need to read dynamically via locals())
Expand Down
Loading

0 comments on commit f190a55

Please sign in to comment.