Skip to content

Commit

Permalink
Merge pull request #236 from thyneb19/Database-Executor
Browse files Browse the repository at this point in the history
Fixed Issue with SQLExecutor's Column handling
  • Loading branch information
thyneb19 committed Jan 19, 2021
2 parents 522c616 + 97d1281 commit 395dfd6
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 54 deletions.
3 changes: 2 additions & 1 deletion lux/_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,10 @@ def set_executor_type(self, exe):
self.executor = SQLExecutor()
elif exe == "Pandas":
from lux.executor.PandasExecutor import PandasExecutor

self.SQLconnection = ""
self.executor = PandasExecutor()


def warning_format(message, category, filename, lineno, file=None, line=None):
return "%s:%s: %s:%s\n" % (filename, lineno, category.__name__, message)
return "%s:%s: %s:%s\n" % (filename, lineno, category.__name__, message)
110 changes: 62 additions & 48 deletions lux/executor/SQLExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def execute_scatter(view: Vis, ldf: LuxDataFrame):
)

# SQLExecutor.execute_2D_binning(view, ldf)
def add_quotes(var_name):
return '"' + var_name + '"'

required_variables = attributes | set(filterVars)
required_variables = map(add_quotes, required_variables)
required_variables = ",".join(required_variables)
row_count = list(
pandas.read_sql(
Expand Down Expand Up @@ -164,7 +168,7 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True):
)
# generates query for colored barchart case
if has_color:
count_query = "SELECT {}, {}, COUNT({}) FROM {} {} GROUP BY {}, {}".format(
count_query = 'SELECT "{}", "{}", COUNT("{}") FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
groupby_attr.attribute,
Expand All @@ -178,7 +182,7 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True):
view._vis_data = utils.pandas_to_lux(view._vis_data)
# generates query for normal barchart case
else:
count_query = "SELECT {}, COUNT({}) FROM {} {} GROUP BY {}".format(
count_query = 'SELECT "{}", COUNT("{}") FROM {} {} GROUP BY "{}"'.format(
groupby_attr.attribute,
groupby_attr.attribute,
ldf.table_name,
Expand All @@ -200,49 +204,55 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True):
# generates query for colored barchart case
if has_color:
if agg_func == "mean":
agg_query = "SELECT {}, {}, AVG({}) as {} FROM {} {} GROUP BY {}, {}".format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
ldf.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
agg_query = (
'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
ldf.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
)
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)

view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "sum":
agg_query = "SELECT {}, {}, SUM({}) as {} FROM {} {} GROUP BY {}, {}".format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
ldf.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
agg_query = (
'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
ldf.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
)
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "max":
agg_query = "SELECT {}, {}, MAX({}) as {} FROM {} {} GROUP BY {}, {}".format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
ldf.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
agg_query = (
'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format(
groupby_attr.attribute,
color_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
ldf.table_name,
where_clause,
groupby_attr.attribute,
color_attr.attribute,
)
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
# generates query for normal barchart case
else:
if agg_func == "mean":
agg_query = "SELECT {}, AVG({}) as {} FROM {} {} GROUP BY {}".format(
agg_query = 'SELECT "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}"'.format(
groupby_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
Expand All @@ -253,7 +263,7 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True):
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "sum":
agg_query = "SELECT {}, SUM({}) as {} FROM {} {} GROUP BY {}".format(
agg_query = 'SELECT "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}"'.format(
groupby_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
Expand All @@ -264,7 +274,7 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True):
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
if agg_func == "max":
agg_query = "SELECT {}, MAX({}) as {} FROM {} {} GROUP BY {}".format(
agg_query = 'SELECT "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}"'.format(
groupby_attr.attribute,
measure_attr.attribute,
measure_attr.attribute,
Expand Down Expand Up @@ -367,7 +377,7 @@ def execute_binning(view: Vis, ldf: LuxDataFrame):
upper_edges.append(str(curr_edge))
upper_edges = ",".join(upper_edges)
view_filter, filter_vars = SQLExecutor.execute_filter(view)
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket({}, '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(\"{}\", '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
bin_attribute.attribute,
"{" + upper_edges + "}",
ldf.table_name,
Expand Down Expand Up @@ -474,15 +484,19 @@ def execute_2D_binning(view: Vis, ldf: LuxDataFrame):
bin_where_clause = "WHERE "
if c == 0:
lower_bound = x_attr_min
lower_bound_clause = x_attribute.attribute + " >= " + "'" + str(lower_bound) + "'"
lower_bound_clause = (
'"' + x_attribute.attribute + '"' + " >= " + "'" + str(lower_bound) + "'"
)
else:
lower_bound = x_upper_edges[c - 1]
lower_bound_clause = x_attribute.attribute + " >= " + "'" + str(lower_bound) + "'"
lower_bound_clause = (
'"' + x_attribute.attribute + '"' + " >= " + "'" + str(lower_bound) + "'"
)
upper_bound = x_upper_edges[c]
upper_bound_clause = x_attribute.attribute + " < " + "'" + str(upper_bound) + "'"
upper_bound_clause = '"' + x_attribute.attribute + '"' + " < " + "'" + str(upper_bound) + "'"
bin_where_clause = bin_where_clause + lower_bound_clause + " AND " + upper_bound_clause

bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket({}, '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(\"{}\", '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format(
y_attribute.attribute,
"{" + y_upper_edges_string + "}",
ldf.table_name,
Expand Down Expand Up @@ -534,7 +548,7 @@ def execute_filter(view: Vis):
where_clause.append("AND")
where_clause.extend(
[
str(filters[f].attribute),
'"' + str(filters[f].attribute) + '"',
str(filters[f].filter_op),
"'" + str(filters[f].value) + "'",
]
Expand Down Expand Up @@ -627,7 +641,7 @@ def compute_stats(self, ldf: LuxDataFrame):
for attribute in ldf.columns:
if ldf.data_type[attribute] == "quantitative":
min_max_query = pandas.read_sql(
"SELECT MIN({}) as min, MAX({}) as max FROM {}".format(
'SELECT MIN("{}") as min, MAX("{}") as max FROM {}'.format(
attribute, attribute, ldf.table_name
),
lux.config.SQLconnection,
Expand All @@ -653,12 +667,12 @@ def get_cardinality(self, ldf: LuxDataFrame):
"""
cardinality = {}
for attr in list(ldf.columns):
card_query = pandas.read_sql(
"SELECT Count(Distinct({})) FROM {}".format(attr, ldf.table_name),
card_query = 'SELECT Count(Distinct("{}")) FROM {}'.format(attr, ldf.table_name)
card_data = pandas.read_sql(
card_query,
lux.config.SQLconnection,
)

cardinality[attr] = list(card_query["count"])[0]
cardinality[attr] = list(card_data["count"])[0]
ldf.cardinality = cardinality

def get_unique_values(self, ldf: LuxDataFrame):
Expand All @@ -677,12 +691,12 @@ def get_unique_values(self, ldf: LuxDataFrame):
"""
unique_vals = {}
for attr in list(ldf.columns):
unique_query = pandas.read_sql(
"SELECT Distinct({}) FROM {}".format(attr, ldf.table_name),
unique_query = 'SELECT Distinct("{}") FROM {}'.format(attr, ldf.table_name)
unique_data = pandas.read_sql(
unique_query,
lux.config.SQLconnection,
)

unique_vals[attr] = list(unique_query[attr])
unique_vals[attr] = list(unique_data[attr])
ldf.unique_values = unique_vals

def compute_data_type(self, ldf: LuxDataFrame):
Expand Down Expand Up @@ -714,7 +728,6 @@ def compute_data_type(self, ldf: LuxDataFrame):
datatype = list(pandas.read_sql(datatype_query, lux.config.SQLconnection)["data_type"])[0]

sql_dtypes[attr] = datatype

for attr in list(ldf.columns):
if str(attr).lower() in ["month", "year"]:
data_type[attr] = "temporal"
Expand All @@ -735,6 +748,7 @@ def compute_data_type(self, ldf: LuxDataFrame):
"smallint",
"smallserial",
"serial",
"double precision",
]:
if ldf.cardinality[attr] < 13:
data_type[attr] = "nominal"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pandas_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,4 +679,4 @@ def test_read_sas(global_var):
df = pd.read_sas(url, format="sas7bdat")
df._repr_html_()
assert list(df.recommendation.keys()) == ["Correlation", "Distribution", "Temporal"]
assert len(df.data_type) == 6
assert len(df.data_type) == 6
2 changes: 1 addition & 1 deletion tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ def test_print_dtypes(global_var):
df = pytest.college_df
with warnings.catch_warnings(record=True) as w:
print(df.dtypes)
assert len(w) == 0, "Warning displayed when printing dtypes"
assert len(w) == 0, "Warning displayed when printing dtypes"
6 changes: 3 additions & 3 deletions tests/test_sql_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_filter():
vis = Vis(intent, sql_df)
vis._vis_data = sql_df
filter_output = SQLExecutor.execute_filter(vis)
assert filter_output[0] == "WHERE origin = 'USA'"
assert filter_output[0] == "WHERE \"origin\" = 'USA'"
assert filter_output[1] == ["origin"]


Expand All @@ -164,7 +164,7 @@ def test_inequalityfilter():
)
vis._vis_data = sql_df
filter_output = SQLExecutor.execute_filter(vis)
assert filter_output[0] == "WHERE horsepower > '50'"
assert filter_output[0] == "WHERE \"horsepower\" > '50'"
assert filter_output[1] == ["horsepower"]

intent = [
Expand All @@ -174,7 +174,7 @@ def test_inequalityfilter():
vis = Vis(intent, sql_df)
vis._vis_data = sql_df
filter_output = SQLExecutor.execute_filter(vis)
assert filter_output[0] == "WHERE horsepower <= '100'"
assert filter_output[0] == "WHERE \"horsepower\" <= '100'"
assert filter_output[1] == ["horsepower"]


Expand Down

0 comments on commit 395dfd6

Please sign in to comment.