Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat sql over #412

Merged
merged 7 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions siuba/siu/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,5 +320,6 @@ def exit___call__(self, node):
*node.args[1:],
**node.kwargs
)
return node


4 changes: 2 additions & 2 deletions siuba/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .verbs import LazyTbl, sql_raw, SqlFunctionLookupError
from .translate import SqlColumn, SqlColumnAgg
from .verbs import LazyTbl, sql_raw
from .translate import SqlColumn, SqlColumnAgg, SqlFunctionLookupError

# preceed w/ underscore so it isn't exported by default
# we just want to register the singledispatch funcs
Expand Down
96 changes: 92 additions & 4 deletions siuba/sql/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@

from sqlalchemy import sql

from siuba.siu import FunctionLookupBound
from siuba.siu import FunctionLookupBound, FunctionLookupError


# warning for when sql defaults differ from pandas ============================
import warnings


class SqlFunctionLookupError(FunctionLookupError): pass


class SiubaSqlRuntimeWarning(UserWarning): pass

def warn_arg_default(func_name, arg_name, arg, correct):
Expand Down Expand Up @@ -57,19 +61,32 @@ class SqlColumnAgg(SqlBase): pass
class CustomOverClause(Over):
"""Base class for custom window clauses in SQL translation."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


def set_over(self, group_by, order_by):
raise NotImplementedError()

def has_over(self):
return self.order_by is not None or self.group_by is not None


@classmethod
def func(cls, name):
raise NotImplementedError()





class AggOver(CustomOverClause):
"""Over clause for uses of functions min, max, avg, that return one value.

Note that this class does not set order by, which is how these functions
generally become their cumulative versions.

E.g. mean(x) -> AVG(x) OVER (partition_by <group vars>)
"""

def set_over(self, group_by, order_by = None):
Expand All @@ -90,6 +107,8 @@ class RankOver(CustomOverClause):

Note that in python we might call rank(col), but in SQL the ranking column
is defined using order by.

E.g. rank(y) -> rank() OVER (partition by <group vars> order by y)
"""
def set_over(self, group_by, order_by = None):
crnt_partition = getattr(self.partition_by, 'clauses', tuple())
Expand All @@ -111,6 +130,9 @@ class CumlOver(CustomOverClause):
Note that this class is also currently used for aggregates that might require
ordering, like nth, first, etc..

e.g. cumsum(x) -> SUM(x) OVER (partition by <group vars> order by <order vars>)
e.g. nth(0) -> NTH_VALUE(1) OVER (partition by <group vars> order by <order vars>)

"""
def set_over(self, group_by, order_by):
self.partition_by = group_by
Expand Down Expand Up @@ -164,7 +186,15 @@ def f(codata, col, *args) -> SqlColumn:
return f

def sql_ordered_set(name, is_analytic=False):
# Ordered and theoretical set aggregates
"""Generate function for ordered and hypothetical set aggregates.

Hypothetical-set aggregates take an argument, and return a value for each
element of the argument. For example: rank(2) WITHIN GROUP (order by x).
In this case, the hypothetical ranks 2 relative to x.

Ordered set aggregates are like percentil_cont(.5) WITHIN GROUP (order by x),
which calculates the median of x.
"""
sa_func = getattr(sql.func, name)

if is_analytic:
Expand Down Expand Up @@ -217,15 +247,16 @@ def wrapper(*args, **kwargs):

# Translator =================================================================

from siuba.ops.translate import create_pandas_translator


def extend_base(cls, **kwargs):
"""Register concrete methods onto generic functions for pandas Series methods."""
from siuba.ops import ALL_OPS
for meth_name, f in kwargs.items():
ALL_OPS[meth_name].register(cls, f)


from siuba.ops.translate import create_pandas_translator

# TODO: should inherit from a ITranslate class (w/ abstract translate method)
class SqlTranslator:
"""Translates symbolic column operations to sqlalchemy clauses.
Expand Down Expand Up @@ -262,11 +293,68 @@ def __init__(self, window, aggregate):
self.aggregate = aggregate

def translate(self, expr, window = True):
"""Convert an AST of method chains to an AST of function calls."""

if window:
return self.window.translate(expr)

return self.aggregate.translate(expr)


def shape_call(
self,
call, window = True, str_accessors = False,
verb_name = None, arg_name = None,
):
"""Return a siu Call that creates dialect specific SQL when called."""

from siuba.siu import Call, MetaArg, strip_symbolic, Lazy, str_to_getitem_call
from siuba.siu.visitors import CodataVisitor

call = strip_symbolic(call)

if isinstance(call, Call):
pass
elif str_accessors and isinstance(call, str):
# verbs that can use strings as accessors, like group_by, or
# arrange, need to convert those strings into a getitem call
return str_to_getitem_call(call)
elif isinstance(call, sql.elements.ColumnClause):
return Lazy(call)
elif callable(call):
#TODO: should not happen here
return Call("__call__", call, MetaArg('_'))

else:
# verbs that use literal strings, need to convert them to a call
# that returns a sqlalchemy "literal" object
return Lazy(sql.literal(call))

# raise informative error message if missing translation
try:
# TODO: MC-NOTE -- scaffolding in to verify prior behavior works
shaped_call = self.translate(call, window = window)
if window:
trans = self.window
else:
trans = self.aggregate

# TODO: MC-NOTE - once all sql singledispatch funcs are annotated
# with return types, then switch object back out
# alternatively, could register a bounding class, and remove
# the result type check
v = CodataVisitor(trans.dispatch_cls, object)
return v.visit(shaped_call)

except FunctionLookupError as err:
raise SqlFunctionLookupError.from_verb(
verb_name or "Unknown",
arg_name or "Unknown",
err,
short = True
)


def from_mappings(WinCls, AggCls):
from siuba.ops import ALL_OPS

Expand Down
96 changes: 49 additions & 47 deletions siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from sqlalchemy import sql
import sqlalchemy
from siuba.siu import Call, str_to_getitem_call, Lazy, FunctionLookupError, singledispatch2
from siuba.siu import Call, Lazy, FunctionLookupError, singledispatch2
# TODO: currently needed for select, but can we remove pandas?
from pandas import Series

Expand Down Expand Up @@ -148,10 +148,53 @@ def _get_over_clauses(clause):
return windows


#def track_call_windows(call, columns, group_by, order_by, window_cte = None):
# listener = WindowReplacer(columns, group_by, order_by, window_cte)
# col = listener.enter(call)
# return col, listener.windows, listener.window_cte


def track_call_windows(call, columns, group_by, order_by, window_cte = None):
listener = WindowReplacer(columns, group_by, order_by, window_cte)
col = listener.enter(call)
return col, listener.windows, listener.window_cte
col_expr = call(columns)

crnt_group_by = sql.elements.ClauseList(
*[columns[name] for name in group_by]
)
crnt_order_by = sql.elements.ClauseList(
*_create_order_by_clause(columns, *order_by)
)
return replace_call_windows(col_expr, crnt_group_by, crnt_order_by, window_cte)



def replace_call_windows(col_expr, group_by, order_by, window_cte = None):

if not isinstance(col_expr, sql.elements.ClauseElement):
return col_expr

over_clauses = WindowReplacer._get_over_clauses(col_expr)

for over in over_clauses:
# TODO: shouldn't mutate these over clauses
over.set_over(group_by, order_by)

if len(over_clauses) and window_cte is not None:
# custom name, or parameters like "%(...)s" may nest and break psycopg2
# with columns you can set a key to fix this, but it doesn't seem to
# be an option with labels
name = WindowReplacer._get_unique_name('win', lift_inner_cols(window_cte))
label = col_expr.label(name)

# put into CTE, and return its resulting column, so that subsequent
# operations will refer to the window column on window_cte. Note that
# the operations will use the actual column, so may need to use the
# ClauseAdaptor to make it a reference to the label
window_cte = _sql_add_columns(window_cte, [label])
win_col = lift_inner_cols(window_cte).values()[-1]

return win_col, over_clauses, window_cte

return col_expr, over_clauses, window_cte


def lift_inner_cols(tbl):
Expand Down Expand Up @@ -263,49 +306,7 @@ def shape_call(
call, window = True, str_accessors = False,
verb_name = None, arg_name = None,
):
if isinstance(call, Call):
pass
elif str_accessors and isinstance(call, str):
# verbs that can use strings as accessors, like group_by, or
# arrange, need to convert those strings into a getitem call
return str_to_getitem_call(call)
elif isinstance(call, sql.elements.ColumnClause):
return Lazy(call)
elif callable(call):
#TODO: should not happen here
from siuba.siu import MetaArg
return Call("__call__", call, MetaArg('_'))

else:
# verbs that use literal strings, need to convert them to a call
# that returns a sqlalchemy "literal" object
return Lazy(sql.literal(call))

# raise informative error message if missing translation
try:
# TODO: MC-NOTE -- scaffolding in to verify prior behavior works
from siuba.siu.visitors import CodataVisitor
shaped_call = self.translator.translate(call, window = window)
if window:
trans = self.translator.window
else:
trans = self.translator.aggregate

# TODO: MC-NOTE - once all sql singledispatch funcs are annotated
# with return types, then switch object back out
# alternatively, could register a bounding class, and remove
# the result type check
v = CodataVisitor(trans.dispatch_cls, object)
return v.visit(shaped_call)

except FunctionLookupError as err:
raise SqlFunctionLookupError.from_verb(
verb_name or "Unknown",
arg_name or "Unknown",
err,
short = True
)

return self.translator.shape_call(call, window, str_accessors, verb_name, arg_name)

def track_call_windows(self, call, columns = None, window_cte = None):
"""Returns tuple of (new column expression, list of window exprs)"""
Expand Down Expand Up @@ -460,6 +461,7 @@ def _show_query(tbl, simplify = False):

return tbl


# collect ----------

@collect.register(LazyTbl)
Expand Down
10 changes: 10 additions & 0 deletions siuba/tests/test_sql_verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ def test_lazy_tbl_shape_call_error(db):
assert err.__suppress_context__ == True


# track_call_windows ----------------------------------------------------------

from siuba.sql.verbs import track_call_windows
from siuba.sql.translate import win_over

def test_track_call_windows_basic():
pass




# TODO: remove these old tests? should be redundant ===========================

Expand Down
12 changes: 11 additions & 1 deletion siuba/tests/test_verb_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@ def df(backend):
return backend.load_df(DATA)


def test_group_by_no_add(df):
def test_group_by_two(df):
gdf = group_by(df, _.x, _.y)
assert gdf.group_by == ("x", "y")

def test_group_by_override(df):
gdf = df >> group_by(_.x, _.y) >> group_by(_.g)
assert gdf.group_by == ("g",)

def test_group_by_no_add(df):
# without add argument, group_by overwrites prev grouping
gdf1 = group_by(df, _.x)
gdf2 = group_by(gdf1, _.y)

assert gdf2.group_by == ("y",)

def test_group_by_add(df):
gdf = group_by(df, _.x) >> group_by(_.y, add = True)

Expand All @@ -40,6 +47,9 @@ def test_group_by_ungroup(df):
q2 = q1 >> ungroup()
assert q2.group_by == tuple()

def test_group_by_using_string(df):
gdf = group_by(df, "g") >> summarize(res = _.x.mean())


@pytest.mark.skip("TODO: need to test / validate joins first")
def test_group_by_before_joins(df):
Expand Down