Skip to content

Commit

Permalink
Merge pull request #267 from thyneb19/Database-Executor
Browse files Browse the repository at this point in the history
Bug Fixes in Histogram Binning
  • Loading branch information
thyneb19 committed Feb 18, 2021
2 parents 4fd9a00 + 3e649c2 commit b985393
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 42 deletions.
100 changes: 62 additions & 38 deletions lux/executor/SQLExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ def execute(view_collection: VisList, ldf: LuxDataFrame):
if view.mark == "":
view.refresh_source(ldf)
elif view.mark == "scatter":
if len(view.get_attr_by_channel("color")) == 1:
where_clause, filterVars = SQLExecutor.execute_filter(view)
length_query = pandas.read_sql(
"SELECT COUNT(*) as length FROM {} {}".format(ldf.table_name, where_clause),
lux.config.SQLconnection,
)
view_data_length = list(length_query["length"])[0]
if len(view.get_attr_by_channel("color")) == 1 or view_data_length <= 10000000:
# NOTE: might want to have a check somewhere to not use categorical variables with greater than some number of categories as a Color variable----------------
has_color = True
SQLExecutor.execute_scatter(view, ldf)
Expand All @@ -49,9 +55,6 @@ def execute(view_collection: VisList, ldf: LuxDataFrame):
SQLExecutor.execute_aggregate(view, ldf)
elif view.mark == "histogram":
SQLExecutor.execute_binning(view, ldf)
# this is weird, somewhere in the SQL executor the lux.config.executor is being set to a PandasExecutor
# temporary fix here
# lux.config.executor = SQLExecutor()

@staticmethod
def execute_scatter(view: Vis, ldf: LuxDataFrame):
Expand Down Expand Up @@ -100,7 +103,7 @@ def add_quotes(var_name):
lux.config.SQLconnection,
)["count"]
)[0]
if row_count > 10000:
if row_count > lux.config.sampling_cap:
query = f"SELECT {required_variables} FROM {ldf.table_name} {where_clause} ORDER BY random() LIMIT 10000"
else:
query = "SELECT {} FROM {} {}".format(required_variables, ldf.table_name, where_clause)
Expand Down Expand Up @@ -352,8 +355,8 @@ def execute_binning(view: Vis, ldf: LuxDataFrame):
bin_attribute = list(filter(lambda x: x.bin_size != 0, view._inferred_intent))[0]

num_bins = bin_attribute.bin_size
attr_min = min(ldf.unique_values[bin_attribute.attribute])
attr_max = max(ldf.unique_values[bin_attribute.attribute])
attr_min = ldf._min_max[bin_attribute.attribute][0]
attr_max = ldf._min_max[bin_attribute.attribute][1]
attr_type = type(ldf.unique_values[bin_attribute.attribute][0])

# get filters if available
Expand All @@ -374,46 +377,49 @@ def execute_binning(view: Vis, ldf: LuxDataFrame):
upper_edges.append(str(curr_edge))
upper_edges = ",".join(upper_edges)
view_filter, filter_vars = SQLExecutor.execute_filter(view)
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(\"{}\", '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(CAST (\"{}\" AS FLOAT), '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
bin_attribute.attribute,
"{" + upper_edges + "}",
ldf.table_name,
where_clause,
)
print(bin_count_query)
bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection)
if not bin_count_data["width_bucket"].isnull().values.any():
# np.histogram breaks if data contain NaN

# counts,binEdges = np.histogram(ldf[bin_attribute.attribute],bins=bin_attribute.bin_size)
# binEdges of size N+1, so need to compute binCenter as the bin location
upper_edges = [float(i) for i in upper_edges.split(",")]
if attr_type == int:
bin_centers = np.array([math.ceil((attr_min + attr_min + bin_width) / 2)])
else:
bin_centers = np.array([(attr_min + attr_min + bin_width) / 2])
bin_centers = np.append(
bin_centers,
np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
)
if attr_type == int:
# counts,binEdges = np.histogram(ldf[bin_attribute.attribute],bins=bin_attribute.bin_size)
# binEdges of size N+1, so need to compute binCenter as the bin location
upper_edges = [float(i) for i in upper_edges.split(",")]
if attr_type == int:
bin_centers = np.array([math.ceil((attr_min + attr_min + bin_width) / 2)])
else:
bin_centers = np.array([(attr_min + attr_min + bin_width) / 2])
bin_centers = np.append(
bin_centers,
math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
np.mean(np.vstack([upper_edges[0:-1], upper_edges[1:]]), axis=0),
)
else:
bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2)

if len(bin_centers) > len(bin_count_data):
bucket_lables = bin_count_data["width_bucket"].unique()
for i in range(0, len(bin_centers)):
if i not in bucket_lables:
bin_count_data = bin_count_data.append(
pandas.DataFrame([[i, 0]], columns=bin_count_data.columns)
)
view._vis_data = pandas.DataFrame(
np.array([bin_centers, list(bin_count_data["count"])]).T,
columns=[bin_attribute.attribute, "Number of Records"],
)
view._vis_data = utils.pandas_to_lux(view.data)
view._vis_data.length = list(length_query["length"])[0]
if attr_type == int:
bin_centers = np.append(
bin_centers,
math.ceil((upper_edges[len(upper_edges) - 1] + attr_max) / 2),
)
else:
bin_centers = np.append(bin_centers, (upper_edges[len(upper_edges) - 1] + attr_max) / 2)

if len(bin_centers) > len(bin_count_data):
bucket_lables = bin_count_data["width_bucket"].unique()
for i in range(0, len(bin_centers)):
if i not in bucket_lables:
bin_count_data = bin_count_data.append(
pandas.DataFrame([[i, 0]], columns=bin_count_data.columns)
)
view._vis_data = pandas.DataFrame(
np.array([bin_centers, list(bin_count_data["count"])]).T,
columns=[bin_attribute.attribute, "Number of Records"],
)
view._vis_data = utils.pandas_to_lux(view.data)
view._vis_data.length = list(length_query["length"])[0]

@staticmethod
def execute_2D_binning(view: Vis, ldf: LuxDataFrame):
Expand Down Expand Up @@ -506,7 +512,7 @@ def execute_2D_binning(view: Vis, ldf: LuxDataFrame):
# bin_where_clause,
# )

bin_count_query = 'SELECT width_bucket("{}", {}) as width_bucket, count(*) FROM {} {} GROUP BY width_bucket'.format(
bin_count_query = 'SELECT width_bucket(CAST ("{}" AS FLOAT), {}) as width_bucket, count(*) FROM {} {} GROUP BY width_bucket'.format(
y_attribute.attribute,
str(y_attr_min) + "," + str(y_attr_max) + "," + str(num_bins - 1),
ldf.table_name,
Expand Down Expand Up @@ -565,6 +571,24 @@ def execute_filter(view: Vis):
)
if filters[f].attribute not in filter_vars:
filter_vars.append(filters[f].attribute)

attributes = utils.get_attrs_specs(view._inferred_intent)

# need to ensure that no null values are included in the data
# null values breaks binning queries
for a in attributes:
if a.attribute != "Record":
if where_clause == []:
where_clause.append("WHERE")
else:
where_clause.append("AND")
where_clause.extend(
[
'"' + str(a.attribute) + '"',
"IS NOT NULL",
]
)

if where_clause == []:
return ("", [])
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_underspecified_single_vis(global_var, test_recs):
sql_df.set_intent([lux.Clause(attribute="milespergal"), lux.Clause(attribute="weight")])
test_recs(sql_df, one_vis_actions)
assert len(sql_df.current_vis) == 1
assert sql_df.current_vis[0].mark == "heatmap"
assert sql_df.current_vis[0].mark == "scatter"
for attr in sql_df.current_vis[0]._inferred_intent:
assert attr.data_model == "measure"
for attr in sql_df.current_vis[0]._inferred_intent:
Expand Down
9 changes: 6 additions & 3 deletions tests/test_sql_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def test_filter():
vis = Vis(intent, sql_df)
vis._vis_data = sql_df
filter_output = SQLExecutor.execute_filter(vis)
assert filter_output[0] == "WHERE \"origin\" = 'USA'"
assert (
filter_output[0]
== 'WHERE "origin" = \'USA\' AND "year" IS NOT NULL AND "horsepower" IS NOT NULL'
)
assert filter_output[1] == ["origin"]


Expand All @@ -164,7 +167,7 @@ def test_inequalityfilter():
)
vis._vis_data = sql_df
filter_output = SQLExecutor.execute_filter(vis)
assert filter_output[0] == "WHERE \"horsepower\" > '50'"
assert filter_output[0] == 'WHERE "horsepower" > \'50\' AND "milespergal" IS NOT NULL'
assert filter_output[1] == ["horsepower"]

intent = [
Expand All @@ -174,7 +177,7 @@ def test_inequalityfilter():
vis = Vis(intent, sql_df)
vis._vis_data = sql_df
filter_output = SQLExecutor.execute_filter(vis)
assert filter_output[0] == "WHERE \"horsepower\" <= '100'"
assert filter_output[0] == 'WHERE "horsepower" <= \'100\' AND "milespergal" IS NOT NULL'
assert filter_output[1] == ["horsepower"]


Expand Down

0 comments on commit b985393

Please sign in to comment.