diff --git a/lux/_config/config.py b/lux/_config/config.py index dda6f5c0..d30858d0 100644 --- a/lux/_config/config.py +++ b/lux/_config/config.py @@ -294,9 +294,10 @@ def set_executor_type(self, exe): self.executor = SQLExecutor() elif exe == "Pandas": from lux.executor.PandasExecutor import PandasExecutor + self.SQLconnection = "" self.executor = PandasExecutor() def warning_format(message, category, filename, lineno, file=None, line=None): - return "%s:%s: %s:%s\n" % (filename, lineno, category.__name__, message) \ No newline at end of file + return "%s:%s: %s:%s\n" % (filename, lineno, category.__name__, message) diff --git a/lux/executor/SQLExecutor.py b/lux/executor/SQLExecutor.py index ca780b55..215d50fe 100644 --- a/lux/executor/SQLExecutor.py +++ b/lux/executor/SQLExecutor.py @@ -89,7 +89,11 @@ def execute_scatter(view: Vis, ldf: LuxDataFrame): ) # SQLExecutor.execute_2D_binning(view, ldf) + def add_quotes(var_name): + return '"' + var_name + '"' + required_variables = attributes | set(filterVars) + required_variables = map(add_quotes, required_variables) required_variables = ",".join(required_variables) row_count = list( pandas.read_sql( @@ -164,7 +168,7 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True): ) # generates query for colored barchart case if has_color: - count_query = "SELECT {}, {}, COUNT({}) FROM {} {} GROUP BY {}, {}".format( + count_query = 'SELECT "{}", "{}", COUNT("{}") FROM {} {} GROUP BY "{}", "{}"'.format( groupby_attr.attribute, color_attr.attribute, groupby_attr.attribute, @@ -178,7 +182,7 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True): view._vis_data = utils.pandas_to_lux(view._vis_data) # generates query for normal barchart case else: - count_query = "SELECT {}, COUNT({}) FROM {} {} GROUP BY {}".format( + count_query = 'SELECT "{}", COUNT("{}") FROM {} {} GROUP BY "{}"'.format( groupby_attr.attribute, groupby_attr.attribute, ldf.table_name, @@ -200,49 +204,55 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True): # generates query for colored barchart case if has_color: if agg_func == "mean": - agg_query = "SELECT {}, {}, AVG({}) as {} FROM {} {} GROUP BY {}, {}".format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - ldf.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, + agg_query = ( + 'SELECT "{}", "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + ldf.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "sum": - agg_query = "SELECT {}, {}, SUM({}) as {} FROM {} {} GROUP BY {}, {}".format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - ldf.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, + agg_query = ( + 'SELECT "{}", "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + ldf.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "max": - agg_query = "SELECT {}, {}, MAX({}) as {} FROM {} {} GROUP BY {}, {}".format( - groupby_attr.attribute, - color_attr.attribute, - measure_attr.attribute, - measure_attr.attribute, - ldf.table_name, - where_clause, - groupby_attr.attribute, - color_attr.attribute, + agg_query = ( + 'SELECT "{}", "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}", "{}"'.format( + groupby_attr.attribute, + color_attr.attribute, + measure_attr.attribute, + measure_attr.attribute, + ldf.table_name, + where_clause, + groupby_attr.attribute, + color_attr.attribute, + ) ) view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) # generates query for normal barchart case else: if agg_func == "mean": - agg_query = "SELECT {}, AVG({}) as {} FROM {} {} GROUP BY {}".format( + agg_query = 'SELECT "{}", AVG("{}") as "{}" FROM {} {} GROUP BY "{}"'.format( groupby_attr.attribute, measure_attr.attribute, measure_attr.attribute, @@ -253,7 +263,7 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True): view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "sum": - agg_query = "SELECT {}, SUM({}) as {} FROM {} {} GROUP BY {}".format( + agg_query = 'SELECT "{}", SUM("{}") as "{}" FROM {} {} GROUP BY "{}"'.format( groupby_attr.attribute, measure_attr.attribute, measure_attr.attribute, @@ -264,7 +274,7 @@ def execute_aggregate(view: Vis, ldf: LuxDataFrame, isFiltered=True): view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection) view._vis_data = utils.pandas_to_lux(view._vis_data) if agg_func == "max": - agg_query = "SELECT {}, MAX({}) as {} FROM {} {} GROUP BY {}".format( + agg_query = 'SELECT "{}", MAX("{}") as "{}" FROM {} {} GROUP BY "{}"'.format( groupby_attr.attribute, measure_attr.attribute, measure_attr.attribute, @@ -367,7 +377,7 @@ def execute_binning(view: Vis, ldf: LuxDataFrame): upper_edges.append(str(curr_edge)) upper_edges = ",".join(upper_edges) view_filter, filter_vars = SQLExecutor.execute_filter(view) - bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket({}, '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format( + bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(\"{}\", '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format( bin_attribute.attribute, "{" + upper_edges + "}", ldf.table_name, @@ -474,15 +484,19 @@ def execute_2D_binning(view: Vis, ldf: LuxDataFrame): bin_where_clause = "WHERE " if c == 0: lower_bound = x_attr_min - lower_bound_clause = x_attribute.attribute + " >= " + "'" + str(lower_bound) + "'" + lower_bound_clause = ( + '"' + x_attribute.attribute + '"' + " >= " + "'" + str(lower_bound) + "'" + ) else: lower_bound = x_upper_edges[c - 1] - lower_bound_clause = x_attribute.attribute + " >= " + "'" + str(lower_bound) + "'" + lower_bound_clause = ( + '"' + x_attribute.attribute + '"' + " >= " + "'" + str(lower_bound) + "'" + ) upper_bound = x_upper_edges[c] - upper_bound_clause = x_attribute.attribute + " < " + "'" + str(upper_bound) + "'" + upper_bound_clause = '"' + x_attribute.attribute + '"' + " < " + "'" + str(upper_bound) + "'" bin_where_clause = bin_where_clause + lower_bound_clause + " AND " + upper_bound_clause - bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket({}, '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format( + bin_count_query = "SELECT width_bucket, COUNT(width_bucket) FROM (SELECT width_bucket(\"{}\", '{}') FROM {} {}) as Buckets GROUP BY width_bucket ORDER BY width_bucket".format( y_attribute.attribute, "{" + y_upper_edges_string + "}", ldf.table_name, @@ -534,7 +548,7 @@ def execute_filter(view: Vis): where_clause.append("AND") where_clause.extend( [ - str(filters[f].attribute), + '"' + str(filters[f].attribute) + '"', str(filters[f].filter_op), "'" + str(filters[f].value) + "'", ] @@ -627,7 +641,7 @@ def compute_stats(self, ldf: LuxDataFrame): for attribute in ldf.columns: if ldf.data_type[attribute] == "quantitative": min_max_query = pandas.read_sql( - "SELECT MIN({}) as min, MAX({}) as max FROM {}".format( + 'SELECT MIN("{}") as min, MAX("{}") as max FROM {}'.format( attribute, attribute, ldf.table_name ), lux.config.SQLconnection, @@ -653,12 +667,12 @@ def get_cardinality(self, ldf: LuxDataFrame): """ cardinality = {} for attr in list(ldf.columns): - card_query = pandas.read_sql( - "SELECT Count(Distinct({})) FROM {}".format(attr, ldf.table_name), + card_query = 'SELECT Count(Distinct("{}")) FROM {}'.format(attr, ldf.table_name) + card_data = pandas.read_sql( + card_query, lux.config.SQLconnection, ) - - cardinality[attr] = list(card_query["count"])[0] + cardinality[attr] = list(card_data["count"])[0] ldf.cardinality = cardinality def get_unique_values(self, ldf: LuxDataFrame): @@ -677,12 +691,12 @@ def get_unique_values(self, ldf: LuxDataFrame): """ unique_vals = {} for attr in list(ldf.columns): - unique_query = pandas.read_sql( - "SELECT Distinct({}) FROM {}".format(attr, ldf.table_name), + unique_query = 'SELECT Distinct("{}") FROM {}'.format(attr, ldf.table_name) + unique_data = pandas.read_sql( + unique_query, lux.config.SQLconnection, ) - - unique_vals[attr] = list(unique_query[attr]) + unique_vals[attr] = list(unique_data[attr]) ldf.unique_values = unique_vals def compute_data_type(self, ldf: LuxDataFrame): @@ -714,7 +728,6 @@ def compute_data_type(self, ldf: LuxDataFrame): datatype = list(pandas.read_sql(datatype_query, lux.config.SQLconnection)["data_type"])[0] sql_dtypes[attr] = datatype - for attr in list(ldf.columns): if str(attr).lower() in ["month", "year"]: data_type[attr] = "temporal" @@ -735,6 +748,7 @@ def compute_data_type(self, ldf: LuxDataFrame): "smallint", "smallserial", "serial", + "double precision", ]: if ldf.cardinality[attr] < 13: data_type[attr] = "nominal" diff --git a/tests/test_pandas_coverage.py b/tests/test_pandas_coverage.py index 83a26d92..3945cb0b 100644 --- a/tests/test_pandas_coverage.py +++ b/tests/test_pandas_coverage.py @@ -679,4 +679,4 @@ def test_read_sas(global_var): df = pd.read_sas(url, format="sas7bdat") df._repr_html_() assert list(df.recommendation.keys()) == ["Correlation", "Distribution", "Temporal"] - assert len(df.data_type) == 6 \ No newline at end of file + assert len(df.data_type) == 6 diff --git a/tests/test_series.py b/tests/test_series.py index 75a93691..62a4697f 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -50,4 +50,4 @@ def test_print_dtypes(global_var): df = pytest.college_df with warnings.catch_warnings(record=True) as w: print(df.dtypes) - assert len(w) == 0, "Warning displayed when printing dtypes" \ No newline at end of file + assert len(w) == 0, "Warning displayed when printing dtypes" diff --git a/tests/test_sql_executor.py b/tests/test_sql_executor.py index 62cdcf3f..41bd2cf3 100644 --- a/tests/test_sql_executor.py +++ b/tests/test_sql_executor.py @@ -146,7 +146,7 @@ def test_filter(): vis = Vis(intent, sql_df) vis._vis_data = sql_df filter_output = SQLExecutor.execute_filter(vis) - assert filter_output[0] == "WHERE origin = 'USA'" + assert filter_output[0] == "WHERE \"origin\" = 'USA'" assert filter_output[1] == ["origin"] @@ -164,7 +164,7 @@ def test_inequalityfilter(): ) vis._vis_data = sql_df filter_output = SQLExecutor.execute_filter(vis) - assert filter_output[0] == "WHERE horsepower > '50'" + assert filter_output[0] == "WHERE \"horsepower\" > '50'" assert filter_output[1] == ["horsepower"] intent = [ @@ -174,7 +174,7 @@ def test_inequalityfilter(): vis = Vis(intent, sql_df) vis._vis_data = sql_df filter_output = SQLExecutor.execute_filter(vis) - assert filter_output[0] == "WHERE horsepower <= '100'" + assert filter_output[0] == "WHERE \"horsepower\" <= '100'" assert filter_output[1] == ["horsepower"]