Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug Fixes in Histogram Binning #267

Merged
merged 5 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 62 additions & 38 deletions lux/executor/SQLExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ def execute(view_collection: VisList, ldf: LuxDataFrame):
if view.mark == "":
view.refresh_source(ldf)
elif view.mark == "scatter":
if len(view.get_attr_by_channel("color")) == 1:
where_clause, filterVars = SQLExecutor.execute_filter(view)
length_query = pandas.read_sql(
"SELECT COUNT(*) as length FROM {} {}".format(ldf.table_name, where_clause),
lux.config.SQLconnection,
)
view_data_length = list(length_query["length"])[0]
if len(view.get_attr_by_channel("color")) == 1 or view_data_length <= 10000000:
# NOTE: might want to have a check somewhere to not use categorical variables with greater than some number of categories as a Color variable----------------
has_color = True
SQLExecutor.execute_scatter(view, ldf)
Expand All @@ -49,9 +55,6 @@ def execute(view_collection: VisList, ldf: LuxDataFrame):
SQLExecutor.execute_aggregate(view, ldf)
elif view.mark == "histogram":
SQLExecutor.execute_binning(view, ldf)
# this is weird, somewhere in the SQL executor the lux.config.executor is being set to a PandasExecutor
# temporary fix here
# lux.config.executor = SQLExecutor()

@staticmethod
def execute_scatter(view: Vis, ldf: LuxDataFrame):
Expand Down Expand Up @@ -100,7 +103,7 @@ def add_quotes(var_name):
lux.config.SQLconnection,
)["count"]
)[0]
if row_count > 10000:
if row_count > lux.config.sampling_cap:
query = f"SELECT {required_variables} FROM {ldf.table_name} {where_clause} ORDER BY random() LIMIT 10000"
else:
query = "SELECT {} FROM {} {}".format(required_variables, ldf.table_name, where_clause)
Expand Down Expand Up @@ -352,8 +355,8 @@ def execute_binning(view: Vis, ldf: LuxDataFrame):
bin_attribute = list(filter(lambda x: x.bin_size != 0, view._inferred_intent))[0]

num_bins = bin_attribute.bin_size
attr_min = min(ldf.unique_values[bin_attribute.attribute])
attr_max = max(ldf.unique_values[bin_attribute.attribute])
attr_min = ldf._min_max[bin_attribute.attribute][0]
attr_max = ldf._min_max[bin_attribute.attribute][1]
attr_type = type(ldf.unique_values[bin_attribute.attribute][0])

# get filters if available
Expand All @@ -374,46 +377,49 @@ def execute_binning(view: Vis, ldf: LuxDataFrame):
upper_edges.append(str(curr_edge))
upper_edges = ",".join(upper_edges)
view_filter, filter_vars = SQLExecutor.execute_filter(view)
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(\"{}\", '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
bin_attribute.attribute,
"{" + upper_edges + "}",
ldf.table_name,
where_clause,
)
print(bin_count_query)
bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection)
if not bin_count_data["width_bucket"].isnull().values.any():
# np.histogram breaks if data contain NaN

# counts,binEdges = np.histogram(ldf[bin_attribute.attribute],bins=bin_attribute.bin_size)
# binEdges of size N+1, so need to compute binCenter as the bin location
upper_edges = [float(i) for i in upper_edges.split(",")]
if attr_type == int:
bin_centers = np.array([math.ceil((attr_min + attr_min + bin_width) / 2)])
else:
bin_centers = np.array([(attr_min + attr_min + bin_width) / 2])
bin_centers = np.append(
bin_centers,
np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
)
if attr_type == int:
# counts,binEdges = np.histogram(ldf[bin_attribute.attribute],bins=bin_attribute.bin_size)
# binEdges of size N+1, so need to compute binCenter as the bin location
upper_edges = [float(i) for i in upper_edges.split(",")]
if attr_type == int:
bin_centers = np.array([math.ceil((attr_min + attr_min + bin_width) / 2)])
else:
bin_centers = np.array([(attr_min + attr_min + bin_width) / 2])
bin_centers = np.append(
bin_centers,
math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
)
else:
bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2)

if len(bin_centers) > len(bin_count_data):
bucket_lables = bin_count_data["width_bucket"].unique()
for i in range(0, len(bin_centers)):
if i not in bucket_lables:
bin_count_data = bin_count_data.append(
pandas.DataFrame([[i, 0]], columns=bin_count_data.columns)
)
view._vis_data = pandas.DataFrame(
np.array([bin_centers, list(bin_count_data["count"])]).T,
columns=[bin_attribute.attribute, "Number of Records"],
)
view._vis_data = utils.pandas_to_lux(view.data)
view._vis_data.length = list(length_query["length"])[0]
if attr_type == int:
bin_centers = np.append(
bin_centers,
math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
)
else:
bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2)

if len(bin_centers) > len(bin_count_data):
bucket_lables = bin_count_data["width_bucket"].unique()
for i in range(0, len(bin_centers)):
if i not in bucket_lables:
bin_count_data = bin_count_data.append(
pandas.DataFrame([[i, 0]], columns=bin_count_data.columns)
)
view._vis_data = pandas.DataFrame(
np.array([bin_centers, list(bin_count_data["count"])]).T,
columns=[bin_attribute.attribute, "Number of Records"],
)
view._vis_data = utils.pandas_to_lux(view.data)
view._vis_data.length = list(length_query["length"])[0]

@staticmethod
def execute_2D_binning(view: Vis, ldf: LuxDataFrame):
Expand Down Expand Up @@ -506,7 +512,7 @@ def execute_2D_binning(view: Vis, ldf: LuxDataFrame):
# bin_where_clause,
# )

bin_count_query = 'SELECT width_bucket("{}", {}) as width_bucket, count(*) FROM {} {} GROUP BY width_bucket'.format(
bin_count_query = 'SELECT width_bucket(CAST ("{}" AS FLOAT), {}) as width_bucket, count(*) FROM {} {} GROUP BY width_bucket'.format(
y_attribute.attribute,
str(y_attr_min) + "," + str(y_attr_max) + "," + str(num_bins - 1),
ldf.table_name,
Expand Down Expand Up @@ -565,6 +571,24 @@ def execute_filter(view: Vis):
)
if filters[f].attribute not in filter_vars:
filter_vars.append(filters[f].attribute)

attributes = utils.get_attrs_specs(view._inferred_intent)

# need to ensure that no null values are included in the data
# null values breaks binning queries
for a in attributes:
if a.attribute != "Record":
if where_clause == []:
where_clause.append("WHERE")
else:
where_clause.append("AND")
where_clause.extend(
[
'"' + str(a.attribute) + '"',
"IS NOT NULL",
]
)

if where_clause == []:
return ("", [])
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_underspecified_single_vis(global_var, test_recs):
sql_df.set_intent([lux.Clause(attribute="milespergal"), lux.Clause(attribute="weight")])
test_recs(sql_df, one_vis_actions)
assert len(sql_df.current_vis) == 1
assert sql_df.current_vis[0].mark == "heatmap"
assert sql_df.current_vis[0].mark == "scatter"
for attr in sql_df.current_vis[0]._inferred_intent:
assert attr.data_model == "measure"
for attr in sql_df.current_vis[0]._inferred_intent:
Expand Down
9 changes: 6 additions & 3 deletions tests/test_sql_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def test_filter():
vis = Vis(intent, sql_df)
vis._vis_data = sql_df
filter_output = SQLExecutor.execute_filter(vis)
assert filter_output[0] == "WHERE \"origin\" = 'USA'"
assert (
filter_output[0]
== 'WHERE "origin" = \'USA\' AND "year" IS NOT NULL AND "horsepower" IS NOT NULL'
)
assert filter_output[1] == ["origin"]


Expand All @@ -164,7 +167,7 @@ def test_inequalityfilter():
)
vis._vis_data = sql_df
filter_output = SQLExecutor.execute_filter(vis)
assert filter_output[0] == "WHERE \"horsepower\" > '50'"
assert filter_output[0] == 'WHERE "horsepower" > \'50\' AND "milespergal" IS NOT NULL'
assert filter_output[1] == ["horsepower"]

intent = [
Expand All @@ -174,7 +177,7 @@ def test_inequalityfilter():
vis = Vis(intent, sql_df)
vis._vis_data = sql_df
filter_output = SQLExecutor.execute_filter(vis)
assert filter_output[0] == "WHERE \"horsepower\" <= '100'"
assert filter_output[0] == 'WHERE "horsepower" <= \'100\' AND "milespergal" IS NOT NULL'
assert filter_output[1] == ["horsepower"]


Expand Down