Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sql): add LazyTbl.last_select to simplify queries #449

Merged
merged 2 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions siuba/dply/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2447,12 +2447,14 @@ def tbl(src, *args, **kwargs):
You can analyze a mock table

>>> from sqlalchemy import create_mock_engine
>>> from siuba import _

>>> mock_engine = create_mock_engine("postgresql:///", lambda *args, **kwargs: None)
>>> tbl_mock = tbl(mock_engine, "some_table", columns = ["a", "b", "c"])
>>> q = tbl_mock >> count() >> show_query() # doctest: +NORMALIZE_WHITESPACE
SELECT count(*) AS n
FROM (SELECT some_table.a AS a, some_table.b AS b, some_table.c AS c
FROM some_table) AS anon_1 ORDER BY n DESC

>>> q = tbl_mock >> count(_.a) >> show_query() # doctest: +NORMALIZE_WHITESPACE
SELECT some_table_1.a, count(*) AS n
FROM some_table AS some_table_1 GROUP BY some_table_1.a ORDER BY n DESC
"""

return src
Expand Down
18 changes: 0 additions & 18 deletions siuba/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,21 +190,3 @@ def simplify_sel(sel):

return clone_el


@contextmanager
def _use_simple_names():
from sqlalchemy import sql
from sqlalchemy.ext.compiler import compiles, deregister

get_col_name = lambda el, *args, **kwargs: str(el.element.name)
get_lab_name = lambda el, *args, **kwargs: str(el.element.name)
get_col_name = lambda el, *args, **kwargs: str(el.name)
compiles(sql.compiler._CompileLabel)(get_lab_name)
compiles(sql.elements.ColumnClause)(get_col_name)
compiles(sql.schema.Column)(get_col_name)
try:
yield 1
except:
pass
finally:
deregister(sql.compiler._CompileLabel)
69 changes: 41 additions & 28 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
_sql_add_columns,
_sql_with_only_columns,
_sql_simplify_select,
_use_simple_names,
MockConnection
)

Expand Down Expand Up @@ -297,7 +296,7 @@ def __init__(
self.tbl = self._create_table(tbl, columns, self.source)

# important states the query can be in (e.g. grouped)
self.ops = [self.tbl.select()] if ops is None else ops
self.ops = [self.tbl] if ops is None else ops

self.group_by = group_by
self.order_by = order_by
Expand Down Expand Up @@ -340,8 +339,21 @@ def get_ordered_col_names(self):


@property
def last_op(self):
return self.ops[-1] if len(self.ops) else None
def last_op(self) -> "sql.Table | sql.Select":
last_op = self.ops[-1]

if last_op is None:
raise TypeError()

return last_op

@property
def last_select(self):
last_op = self.last_op
if not isinstance(last_op, sql.selectable.SelectBase):
return last_op.select()

return last_op

@staticmethod
def _create_table(tbl, columns = None, source = None):
Expand Down Expand Up @@ -385,7 +397,7 @@ def _create_table(tbl, columns = None, source = None):

def _get_preview(self):
# need to make prev op a cte, so we don't override any previous limit
new_sel = self.last_op.alias().select().limit(5)
new_sel = self.last_select.limit(5)
tbl_small = self.append_op(new_sel)
return collect(tbl_small)

Expand Down Expand Up @@ -450,13 +462,12 @@ def _show_query(tbl, simplify = False, return_table = True):

if simplify:
# try to strip table names and labels where unnecessary
simple_sel = _sql_simplify_select(tbl.last_op)
simple_sel = _sql_simplify_select(tbl.last_select)

with _use_simple_names():
explained = compile_query(simple_sel)
explained = compile_query(simple_sel)
else:
# use a much more verbose query
explained = compile_query(tbl.last_op)
explained = compile_query(tbl.last_select)

if return_table:
print(str(explained))
Expand All @@ -483,13 +494,13 @@ def _collect(__data, as_df = True):
if _is_dialect_duckdb(__data.source):
# TODO: can be removed once next release of duckdb fixes:
# https://github.com/duckdb/duckdb/issues/2972
query = __data.last_op
query = __data.last_select
compiled = query.compile(
dialect = __data.source.dialect,
compile_kwargs = {"literal_binds": True}
)
else:
compiled = __data.last_op
compiled = __data.last_select

# execute query ----

Expand Down Expand Up @@ -519,8 +530,8 @@ def _select(__data, *args, **kwargs):
"Using kwargs in select not currently supported. "
"Use _.newname == _.oldname instead"
)
last_op = __data.last_op
columns = {c.key: c for c in last_op.inner_columns}
last_sel = __data.last_select
columns = {c.key: c for c in last_sel.inner_columns}

# same as for DataFrame
colnames = Series(list(columns))
Expand All @@ -541,7 +552,7 @@ def _select(__data, *args, **kwargs):
col_list.append(col if v is None else col.label(v))

return __data.append_op(
last_op.with_only_columns(col_list),
last_sel.with_only_columns(col_list),
group_by = group_keys
)

Expand Down Expand Up @@ -610,7 +621,10 @@ def _mutate(__data, **kwargs):
# TODO: verify it can follow a renaming select

# track labeled columns in set
sel = __data.last_op
if not len(kwargs):
return __data.append_op(__data.last_op)

sel = __data.last_select

# evaluate each call
for colname, func in kwargs.items():
Expand Down Expand Up @@ -664,7 +678,7 @@ def _transmute(__data, **kwargs):
# transmute keeps grouping cols, and any defined in kwargs
cols_to_keep = ordered_union(__data.group_by, kwargs)

sel = f_mutate(__data, **kwargs).last_op
sel = f_mutate(__data, **kwargs).last_select

columns = lift_inner_cols(sel)
sel_stripped = sel.with_only_columns([columns[k] for k in cols_to_keep])
Expand All @@ -679,8 +693,8 @@ def _arrange(__data, *args):
# and handle when new columns are named the same as order by vars.
# see: https://dba.stackexchange.com/q/82930

last_op = __data.last_op
cols = lift_inner_cols(last_op)
last_sel = __data.last_select
cols = lift_inner_cols(last_sel)


new_calls = []
Expand All @@ -700,7 +714,7 @@ def _arrange(__data, *args):
sort_cols = _create_order_by_clause(cols, *new_calls)

order_by = __data.order_by + tuple(new_calls)
return __data.append_op(last_op.order_by(*sort_cols), order_by = order_by)
return __data.append_op(last_sel.order_by(*sort_cols), order_by = order_by)


# TODO: consolidate / pull expr handling funcs into own file?
Expand Down Expand Up @@ -746,8 +760,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
)
arg_names.append(name)

tbl_inner = mutate(__data, **kwargs)
sel_inner = tbl_inner.last_op
sel_inner = mutate(__data, **kwargs).last_op
group_cols = arg_names + list(kwargs)

# create outer select ----
Expand All @@ -756,7 +769,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
inner_cols = sel_inner_cte.columns

# apply any group vars from a group_by verb call first
tbl_group_cols = [inner_cols[k] for k in tbl_inner.group_by]
tbl_group_cols = [inner_cols[k] for k in __data.group_by]
count_group_cols = [inner_cols[k] for k in group_cols]

# combine with any defined in the count verb call
Expand All @@ -769,7 +782,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
.group_by(*outer_group_cols)

# count is like summarize, so removes order_by
return tbl_inner.append_op(
return __data.append_op(
sel_outer.order_by(count_col.desc()),
order_by = tuple()
)
Expand All @@ -778,7 +791,7 @@ def _count(__data, *args, sort = False, wt = None, **kwargs):
@add_count.register(LazyTbl)
def _add_count(__data, *args, wt = None, sort = False, **kwargs):
counts = count(__data, *args, wt = wt, sort = sort, **kwargs)
by = list(c.name for c in counts.last_op.inner_columns)[:-1]
by = list(c.name for c in counts.last_select.inner_columns)[:-1]

return inner_join(__data, counts, by = by)

Expand All @@ -789,7 +802,7 @@ def _summarize(__data, **kwargs):
# what if windowed mutate or filter has been done?
# - filter is fine, since it uses a CTE
# - need to detect any window functions...
old_sel = __data.last_op._clone()
old_sel = __data.last_select._clone()

new_calls = {}
for k, expr in kwargs.items():
Expand Down Expand Up @@ -1136,7 +1149,7 @@ def _create_join_conds(left_sel, right_sel, on):

@head.register(LazyTbl)
def _head(__data, n = 5):
sel = __data.last_op
sel = __data.last_select

return __data.append_op(sel.limit(n))

Expand All @@ -1145,7 +1158,7 @@ def _head(__data, n = 5):

@rename.register(LazyTbl)
def _rename(__data, **kwargs):
sel = __data.last_op
sel = __data.last_select
columns = lift_inner_cols(sel)

# old_keys uses dict as ordered set
Expand All @@ -1172,7 +1185,7 @@ def _distinct(__data, *args, _keep_all = False, **kwargs):
if (args or kwargs) and _keep_all:
raise NotImplementedError("Distinct with variables specified in sql requires _keep_all = False")

inner_sel = mutate(__data, **kwargs).last_op if kwargs else __data.last_op
inner_sel = mutate(__data, **kwargs).last_select if kwargs else __data.last_select

# TODO: this is copied from the df distinct version
# cols dict below is used as ordered set
Expand Down
6 changes: 3 additions & 3 deletions siuba/tests/test_verb_show_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def test_show_query_basic_simplify(df_tiny):
q = df_tiny >> mutate(a = _.x.mean()) >> show_query(return_table = False, simplify=True)

assert rename_source(str(q), df_tiny) == """\
SELECT *, avg(x) OVER () AS a
SELECT *, avg(SRC_TBL.x) OVER () AS a
FROM SRC_TBL"""

def test_show_query_complex_simplify(df_wide):
q = df_wide >> mutate(a = _.x.mean(), b = _.a.mean())
res = q >> show_query(return_table = False, simplify=True)

assert rename_source(str(res), df_wide) == """\
SELECT *, avg(a) OVER () AS b
FROM (SELECT *, avg(x) OVER () AS a
SELECT *, avg(anon_1.a) OVER () AS b
FROM (SELECT *, avg(SRC_TBL.x) OVER () AS a
FROM SRC_TBL) AS anon_1"""