Skip to content

Commit

Permalink
Added query parameter to Vis objects
Browse files Browse the repository at this point in the history
To make Lux' SQLExecutor more transparent, added a query parameter to Vis objects so that users can see what query was used to gather the data for that visualization.

Updated syntax in SQLExecutor tests to reflect the LuxSQLTable Changes
  • Loading branch information
thyneb19 committed Mar 13, 2021
1 parent 04591af commit fa917fb
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 106 deletions.
6 changes: 6 additions & 0 deletions lux/executor/SQLExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def add_quotes(var_name):
query = f"SELECT {required_variables} FROM {lst.table_name} {where_clause} ORDER BY random() LIMIT 10000"
else:
query = "SELECT {} FROM {} {}".format(required_variables, lst.table_name, where_clause)
view._query = query
data = pandas.read_sql(query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(data)
view._vis_data.length = list(length_query["length"])[0]
Expand Down Expand Up @@ -218,6 +219,7 @@ def execute_aggregate(view: Vis, lst: LuxSQLTable, isFiltered=True):
view._vis_data = view._vis_data.rename(columns={"count": "Record"})
view._vis_data = utils.pandas_to_lux(view._vis_data)
view._vis_data.length = list(length_query["length"])[0]
view._query = count_query
# aggregate barchart case, need aggregate data (mean, sum, max) for each group
else:
where_clause, filterVars = SQLExecutor.execute_filter(view)
Expand Down Expand Up @@ -309,6 +311,7 @@ def execute_aggregate(view: Vis, lst: LuxSQLTable, isFiltered=True):
)
view._vis_data = pandas.read_sql(agg_query, lux.config.SQLconnection)
view._vis_data = utils.pandas_to_lux(view._vis_data)
view._query = agg_query
result_vals = list(view._vis_data[groupby_attr.attribute])
# create existing group by attribute combinations if color is specified
# this is needed to check what combinations of group_by_attr and color_attr values have a non-zero number of elements in them
Expand Down Expand Up @@ -407,6 +410,7 @@ def execute_binning(view: Vis, lst: LuxSQLTable):
lst.table_name,
where_clause,
)
view._query = bin_count_query

bin_count_data = pandas.read_sql(bin_count_query, lux.config.SQLconnection)
if not bin_count_data["width_bucket"].isnull().values.any():
Expand Down Expand Up @@ -497,6 +501,8 @@ def execute_2D_binning(view: Vis, lst: LuxSQLTable):
where_clause,
)

view._query = bin_count_query

# data = pandas.read_sql(bin_count_query, lux.config.SQLconnection)

data = pandas.read_sql(bin_count_query, lux.config.SQLconnection)
Expand Down
5 changes: 5 additions & 0 deletions lux/vis/Vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, intent, source=None, title="", score=0.0):
self._mark = ""
self._min_max = {}
self._postbin = None
self._query = ""
self.title = title
self.score = score
self.refresh_source(self._source)
Expand Down Expand Up @@ -95,6 +96,10 @@ def min_max(self):
def intent(self):
return self._intent

@property
def query(self):
return self._query

@intent.setter
def intent(self, intent: List[Clause]) -> None:
self.set_intent(intent)
Expand Down
5 changes: 5 additions & 0 deletions lux/vislib/altair/AltairRenderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def create_vis(self, vis, standalone=True):
"placeholder_variable",
f"pd.DataFrame({str(vis.data.to_dict())})",
)
elif lux.config.executor.name == "SQLExecutor":
chart.code = chart.code.replace(
"placeholder_variable",
f"pd.read_sql({str(vis._query)}, lux.config.SQLconnection)",
)
else:
# TODO: Placeholder (need to read dynamically via locals())
chart.code = chart.code.replace("placeholder_variable", found_variable)
Expand Down
120 changes: 60 additions & 60 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def test_underspecified_no_vis(global_var, test_recs):
# test for sql executor
connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_lst = lux.LuxSQLTable(table_name="cars")

test_recs(sql_df, no_vis_actions)
assert len(sql_df.current_vis) == 0
test_recs(sql_lst, no_vis_actions)
assert len(sql_lst.current_vis) == 0

# test only one filter context case.
sql_df.set_intent([lux.Clause(attribute="origin", filter_op="=", value="USA")])
test_recs(sql_df, no_vis_actions)
assert len(sql_df.current_vis) == 0
sql_lst.set_intent([lux.Clause(attribute="origin", filter_op="=", value="USA")])
test_recs(sql_lst, no_vis_actions)
assert len(sql_lst.current_vis) == 0


def test_underspecified_single_vis(global_var, test_recs):
Expand All @@ -62,14 +62,14 @@ def test_underspecified_single_vis(global_var, test_recs):

connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_df.set_intent([lux.Clause(attribute="milespergal"), lux.Clause(attribute="weight")])
test_recs(sql_df, one_vis_actions)
assert len(sql_df.current_vis) == 1
assert sql_df.current_vis[0].mark == "scatter"
for attr in sql_df.current_vis[0]._inferred_intent:
sql_lst = lux.LuxSQLTable(table_name="cars")
sql_lst.set_intent([lux.Clause(attribute="milespergal"), lux.Clause(attribute="weight")])
test_recs(sql_lst, one_vis_actions)
assert len(sql_lst.current_vis) == 1
assert sql_lst.current_vis[0].mark == "scatter"
for attr in sql_lst.current_vis[0]._inferred_intent:
assert attr.data_model == "measure"
for attr in sql_df.current_vis[0]._inferred_intent:
for attr in sql_lst.current_vis[0]._inferred_intent:
assert attr.data_type == "quantitative"


Expand Down Expand Up @@ -117,12 +117,12 @@ def test_set_intent_as_vis(global_var, test_recs):

connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_df._repr_html_()
vis = sql_df.recommendation["Correlation"][0]
sql_df.intent = vis
sql_df._repr_html_()
test_recs(sql_df, ["Enhance", "Filter", "Generalize"])
sql_lst = lux.LuxSQLTable(table_name="cars")
sql_lst._repr_html_()
vis = sql_lst.recommendation["Correlation"][0]
sql_lst.intent = vis
sql_lst._repr_html_()
test_recs(sql_lst, ["Enhance", "Filter", "Generalize"])


@pytest.fixture
Expand Down Expand Up @@ -152,14 +152,14 @@ def test_parse(global_var):

connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
vlst = VisList([lux.Clause("origin=?"), lux.Clause(attribute="milespergal")], sql_df)
sql_lst = lux.LuxSQLTable(table_name="cars")
vlst = VisList([lux.Clause("origin=?"), lux.Clause(attribute="milespergal")], sql_lst)
assert len(vlst) == 3

connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
vlst = VisList([lux.Clause("origin=?"), lux.Clause("milespergal")], sql_df)
sql_lst = lux.LuxSQLTable(table_name="cars")
vlst = VisList([lux.Clause("origin=?"), lux.Clause("milespergal")], sql_lst)
assert len(vlst) == 3


Expand All @@ -183,13 +183,13 @@ def test_underspecified_vis_collection_zval(global_var):

connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_lst = lux.LuxSQLTable(table_name="cars")
vlst = VisList(
[
lux.Clause(attribute="origin", filter_op="=", value="?"),
lux.Clause(attribute="milespergal"),
],
sql_df,
sql_lst,
)
assert len(vlst) == 3

Expand Down Expand Up @@ -223,26 +223,26 @@ def test_sort_bar(global_var):

connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_lst = lux.LuxSQLTable(table_name="cars")
vis = Vis(
[
lux.Clause(attribute="acceleration", data_model="measure", data_type="quantitative"),
lux.Clause(attribute="origin", data_model="dimension", data_type="nominal"),
],
sql_df,
sql_lst,
)
assert vis.mark == "bar"
assert vis._inferred_intent[1].sort == ""

connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_lst = lux.LuxSQLTable(table_name="cars")
vis = Vis(
[
lux.Clause(attribute="acceleration", data_model="measure", data_type="quantitative"),
lux.Clause(attribute="name", data_model="dimension", data_type="nominal"),
],
sql_df,
sql_lst,
)
assert vis.mark == "bar"
assert vis._inferred_intent[1].sort == "ascending"
Expand Down Expand Up @@ -337,10 +337,10 @@ def test_autoencoding_scatter(global_var):

connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_lst = lux.LuxSQLTable(table_name="cars")
visList = VisList(
[lux.Clause(attribute="?"), lux.Clause(attribute="milespergal", channel="x")],
sql_df,
sql_lst,
)
for vis in visList:
check_attribute_on_channel(vis, "milespergal", "x")
Expand Down Expand Up @@ -391,8 +391,8 @@ def test_autoencoding_scatter():
# test for sql executor
connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
vis = Vis([lux.Clause(attribute="milespergal"), lux.Clause(attribute="weight")], sql_df)
sql_lst = lux.LuxSQLTable(table_name="cars")
vis = Vis([lux.Clause(attribute="milespergal"), lux.Clause(attribute="weight")], sql_lst)
check_attribute_on_channel(vis, "milespergal", "x")
check_attribute_on_channel(vis, "weight", "y")

Expand All @@ -402,7 +402,7 @@ def test_autoencoding_scatter():
lux.Clause(attribute="milespergal", channel="y"),
lux.Clause(attribute="weight"),
],
sql_df,
sql_lst,
)
check_attribute_on_channel(vis, "milespergal", "y")
check_attribute_on_channel(vis, "weight", "x")
Expand All @@ -413,14 +413,14 @@ def test_autoencoding_scatter():
lux.Clause(attribute="milespergal", channel="y"),
lux.Clause(attribute="weight", channel="x"),
],
sql_df,
sql_lst,
)
check_attribute_on_channel(vis, "milespergal", "y")
check_attribute_on_channel(vis, "weight", "x")
# Duplicate channel specified
with pytest.raises(ValueError):
# Should throw error because there should not be columns with the same channel specified
sql_df.set_intent(
sql_lst.set_intent(
[
lux.Clause(attribute="milespergal", channel="x"),
lux.Clause(attribute="weight", channel="x"),
Expand All @@ -445,11 +445,11 @@ def test_autoencoding_histogram(global_var):
# test for sql executor
connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
vis = Vis([lux.Clause(attribute="milespergal", channel="y")], sql_df)
sql_lst = lux.LuxSQLTable(table_name="cars")
vis = Vis([lux.Clause(attribute="milespergal", channel="y")], sql_lst)
check_attribute_on_channel(vis, "milespergal", "y")

vis = Vis([lux.Clause(attribute="milespergal", channel="x")], sql_df)
vis = Vis([lux.Clause(attribute="milespergal", channel="x")], sql_lst)
assert vis.get_attr_by_channel("x")[0].attribute == "milespergal"
assert vis.get_attr_by_channel("y")[0].attribute == "Record"

Expand Down Expand Up @@ -498,8 +498,8 @@ def test_autoencoding_line_chart(global_var):
# test for sql executor
connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
vis = Vis([lux.Clause(attribute="year"), lux.Clause(attribute="acceleration")], sql_df)
sql_lst = lux.LuxSQLTable(table_name="cars")
vis = Vis([lux.Clause(attribute="year"), lux.Clause(attribute="acceleration")], sql_lst)
check_attribute_on_channel(vis, "year", "x")
check_attribute_on_channel(vis, "acceleration", "y")

Expand All @@ -509,7 +509,7 @@ def test_autoencoding_line_chart(global_var):
lux.Clause(attribute="year", channel="y"),
lux.Clause(attribute="acceleration"),
],
sql_df,
sql_lst,
)
check_attribute_on_channel(vis, "year", "y")
check_attribute_on_channel(vis, "acceleration", "x")
Expand All @@ -520,14 +520,14 @@ def test_autoencoding_line_chart(global_var):
lux.Clause(attribute="year", channel="y"),
lux.Clause(attribute="acceleration", channel="x"),
],
sql_df,
sql_lst,
)
check_attribute_on_channel(vis, "year", "y")
check_attribute_on_channel(vis, "acceleration", "x")

with pytest.raises(ValueError):
# Should throw error because there should not be columns with the same channel specified
sql_df.set_intent(
sql_lst.set_intent(
[
lux.Clause(attribute="year", channel="x"),
lux.Clause(attribute="acceleration", channel="x"),
Expand All @@ -553,13 +553,13 @@ def test_autoencoding_color_line_chart(global_var):
# test for sql executor
connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_lst = lux.LuxSQLTable(table_name="cars")
intent = [
lux.Clause(attribute="year"),
lux.Clause(attribute="acceleration"),
lux.Clause(attribute="origin"),
]
vis = Vis(intent, sql_df)
vis = Vis(intent, sql_lst)
check_attribute_on_channel(vis, "year", "x")
check_attribute_on_channel(vis, "acceleration", "y")
check_attribute_on_channel(vis, "origin", "color")
Expand Down Expand Up @@ -593,14 +593,14 @@ def test_autoencoding_color_scatter_chart(global_var):
# test for sql executor
connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_lst = lux.LuxSQLTable(table_name="cars")
vis = Vis(
[
lux.Clause(attribute="horsepower"),
lux.Clause(attribute="acceleration"),
lux.Clause(attribute="origin"),
],
sql_df,
sql_lst,
)
check_attribute_on_channel(vis, "origin", "color")

Expand All @@ -610,7 +610,7 @@ def test_autoencoding_color_scatter_chart(global_var):
lux.Clause(attribute="acceleration", channel="color"),
lux.Clause(attribute="origin"),
],
sql_df,
sql_lst,
)
check_attribute_on_channel(vis, "acceleration", "color")

Expand Down Expand Up @@ -647,23 +647,23 @@ def test_populate_options(global_var):
# test for sql executor
connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_df.set_intent([lux.Clause(attribute="?"), lux.Clause(attribute="milespergal")])
sql_lst = lux.LuxSQLTable(table_name="cars")
sql_lst.set_intent([lux.Clause(attribute="?"), lux.Clause(attribute="milespergal")])
col_set = set()
for specOptions in Compiler.populate_wildcard_options(sql_df._intent, sql_df)["attributes"]:
for specOptions in Compiler.populate_wildcard_options(sql_lst._intent, sql_lst)["attributes"]:
for clause in specOptions:
col_set.add(clause.attribute)
assert list_equal(list(col_set), list(sql_df.columns))
assert list_equal(list(col_set), list(sql_lst.columns))

sql_df.set_intent(
sql_lst.set_intent(
[
lux.Clause(attribute="?", data_model="measure"),
lux.Clause(attribute="milespergal"),
]
)
sql_df._repr_html_()
sql_lst._repr_html_()
col_set = set()
for specOptions in Compiler.populate_wildcard_options(sql_df._intent, sql_df)["attributes"]:
for specOptions in Compiler.populate_wildcard_options(sql_lst._intent, sql_lst)["attributes"]:
for clause in specOptions:
col_set.add(clause.attribute)
assert list_equal(
Expand All @@ -690,16 +690,16 @@ def test_remove_all_invalid(global_var):
# test for sql executor
connection = psycopg2.connect("host=localhost dbname=postgres user=postgres password=lux")
lux.config.set_SQL_connection(connection)
sql_df = lux.LuxSQLTable(table_name="cars")
sql_lst = lux.LuxSQLTable(table_name="cars")
# with pytest.warns(UserWarning,match="duplicate attribute specified in the intent"):
sql_df.set_intent(
sql_lst.set_intent(
[
lux.Clause(attribute="origin", filter_op="=", value="USA"),
lux.Clause(attribute="origin"),
]
)
sql_df._repr_html_()
assert len(sql_df.current_vis) == 0
sql_lst._repr_html_()
assert len(sql_lst.current_vis) == 0


def list_equal(l1, l2):
Expand Down
Loading

0 comments on commit fa917fb

Please sign in to comment.