# Terms

> Relevant small pieces in a SQL query.

In [None]:
#| default_exp terms
#| hide
from nbdev.showdoc import *
from fastcore.test import *
# allow multiple output from one cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [None]:
#|export
import inspect
from functools import partial
from fastcore.foundation import patch
from pikaQ.utils import exec

## Field

In [None]:
#| export
class FieldBase:
    """Collection of methods to convert to ArithmeticExpression and Criteria"""
    def __add__(self, other):
        return ArithmeticExpression(self, '+', other)

    def __radd__(self, other):
        return ArithmeticExpression(other, '+', self)

    def __sub__(self, other):
        return ArithmeticExpression(self, '-', other)

    def __rsub__(self, other):
        return ArithmeticExpression(other, '-', self)

    def __mul__(self, other):
        return ArithmeticExpression(self, '*', other)
        
    def __rmul__(self, other):
        return ArithmeticExpression(other, '*', self)

    def __truediv__(self, other):
        return ArithmeticExpression(self, '/', other)

    def __rtruediv__(self, other):
        return ArithmeticExpression(other, '/', self)

    def __gt__(self, other):
        return Criteria(self, '>', other)

    def __ge__(self, other):
        return Criteria(self, '>=', other)
        
    def __eq__(self, other):
        return Criteria(self, '==', other)

    def __ne__(self, other):
        return Criteria(self, '!=', other)

    def __le__(self, other):
        return Criteria(self, '<=', other)

    def __lt__(self, other):
        return Criteria(self, '<', other)


class Field(FieldBase):
    def __init__(self, name) -> None: 
        self.name = name
        self.alias = None
        self.get_sql = self.exec

    def as_(self, alias):
        self.alias = alias
        return self

    def exec(self, **kwargs):
        if self.alias:
            return f"{exec(self.name, **kwargs)} AS {self.alias}"
        else:
            return f"{exec(self.name, **kwargs)}"
    

class ArithmeticExpression(FieldBase):
    add_order = ["+", "-"]

    def __init__(self, this, op, other) -> None:
        super().__init__()
        self.this, self.op, self.other = this, op, other
        self.get_sql = self.exec

    def left_needs_parens(self, left_op, curr_op) -> bool:
        """
        Returns true if the expression on the left of the current operator needs to be enclosed in parentheses.
        :param current_op:
            The current operator.
        :param left_op:
            The highest level operator of the left expression.
        """
        if left_op is None or curr_op in self.add_order:
            # If the left expression is a single item.
            # or if the current operator is '+' or '-'.
            return False
        
        # The current operator is '*' or '/'. 
        # If the left operator is '+' or '-', we need to add parentheses:
        # e.g. (A + B) / ..., (A - B) / ...
        # Otherwise, no parentheses are necessary:
        # e.g. A * B / ..., A / B / ...
        return left_op in self.add_order

    def right_needs_parens(self, curr_op, right_op) -> bool:
        """
        Returns true if the expression on the right of the current operator needs to be enclosed in parentheses.
        :param current_op:
            The current operator.
        :param right_op:
            The highest level operator of the right expression.
        """
        if right_op is None:
            # If the right expression is a single item.
            return False
        # If the right operator is '+' or '-', we always add parentheses:
        # e.g. ... - (A + B), ... - (A - B), ... + (A + B)
        # Otherwise, no parentheses are necessary:
        # e.g. ... - A / B, ... - A * B
        return right_op in self.add_order

    def exec(self, **kwargs):
        if self.this.__class__ is ArithmeticExpression:
            this = self.this.exec(**kwargs)
            this = f"({this})" if self.left_needs_parens(self.this.op, self.op) else this
        elif getattr(self.this, 'exec', None):
            this = self.this.exec(**kwargs)
        else:
            this = str(self.this)

        if self.other.__class__ is ArithmeticExpression:
            other = self.other.exec(**kwargs)
            other = f"({other})" if self.right_needs_parens(self.op, self.other.op) else other
        elif getattr(self.other, 'exec', None):
            other = self.other.exec(**kwargs)
        else:
            other = str(self.other)
        return f"{this} {self.op} {other}"


class Criteria:
    compose_ops = ('and', 'or')

    def __init__(self, this, op, other) -> None:
        super().__init__()
        self.this, self.op, self.other = this, op, other
        self.add_parentheses = False
        self.get_sql = self.exec

    def compose_criteria(self, op, other):
        """Add parentheses when operator in criteria is different from op"""
        if other.__class__ is Criteria:
            if self.op in self.compose_ops and op in self.compose_ops and self.op != op:
                self.add_parentheses = True
            if other.op in self.compose_ops and op in self.compose_ops and other.op != op:
                other.add_parentheses = True
        return Criteria(self, op, other)

    @staticmethod
    def resolve(obj, **kwargs):
        if obj.__class__ is Criteria:
            obj_c = obj.exec(**kwargs)
            obj_c = f"({obj_c})" if obj.add_parentheses == True else obj_c
        else:
            obj_c = exec(obj, **kwargs)
        return obj_c
    
    def exec(self, **kwargs):
        this = self.resolve(self.this, **kwargs)
        other = self.resolve(self.other, **kwargs)
        return f"{this} {self.op} {other}"

    def __and__(self, __o):
        return self.compose_criteria('and', __o)

    def __or__(self, __o):
        return self.compose_criteria('or', __o)



In [None]:
a = Field('a')
b = Field('b')
test_eq((a+1<13).get_sql(), 'a + 1 < 13')
test_eq(((a + 1)/3).get_sql(), '(a + 1) / 3')
test_eq(((a + 1)/(b - 4)).get_sql(), '(a + 1) / (b - 4)')
test_eq(((a + 1>2) & ((b-1<10) | (b>23)) ).get_sql(), 
        'a + 1 > 2 and (b - 1 < 10 or b > 23)')
test_eq((((a + 1>2) & (b-1<=10)) | (b>100)).get_sql(), '(a + 1 > 2 and b - 1 <= 10) or b > 100')

## Custom Functions

In [None]:
#| export
def kwargs_func(func, *args, **kwargs):
    "Allow arbitrary kwargs. Only pass those kwargs that are specified in func to func."
    sig = inspect.signature(func)
    param = sig.parameters
    func_kwargs = {k:v for k, v in param.items() if v.default!=inspect._empty}
    kwargs = {k:v for k, v in kwargs.items() if k in func_kwargs}
    return func(*args, **kwargs)


class CustomFunction(FieldBase):
    def __init__(self, func_name, arg_names) -> None:
        super().__init__()
        self.func_name = func_name
        self.arg_names = arg_names
        self.get_sql = self.exec
    
    def __call__(self, *args):
        if len(args) != len(self.arg_names):
            raise ValueError(f"The number of args provided {len(args)} is not the same as the number of args expected by this function ({len(self.arg_names)})!")
        def func(*args):
            return f"{self.func_name}({', '.join(args)})"

        self.func = func
        self.args = args
        return self

    def exec(self, **kwargs):
        args = [exec(arg, **kwargs) for arg in self.args]
        return self.func(*args)


class DelayedFunc(FieldBase):
    """Delay the execution of stored function until exec is run."""
    def __init__(self, 
                 func, 
                 args,
                 kwargs, 
                 window_func=True # whether this function is a window function
                 ) -> None: 
        self.func = func
        self.args = args
        self.kwargs = kwargs
        self.window_func = window_func

    def exec(self, **kwargs):
        """keyword arguments can be overwritten with any provided new kwargs."""
        self.kwargs.update(kwargs)
        # recursively resolve all delayed functions
        args = (exec(arg, **self.kwargs) for arg in self.args) 
        return self.func(*args, **self.kwargs)


def custom_func(func=None, window_func=False, dialect=None):
    """return Field"""
    if func is None: 
        return partial(custom_func, window_func=window_func, dialect=dialect)
    else:
        if dialect is None:
            def wrapper(*args, **kwargs):
                dlf = DelayedFunc(func, args, kwargs, window_func=window_func)
                return dlf
        else:
            # get previously defined func
            func_name = func.__name__
            ori_func = globals().get(func_name)
            def wrapper(*args, **kwargs):
                def new_func(*args, **kwargs):
                    # if the dialect is different from the one defined for this new function, fall back to the original func
                    if kwargs['dialect'] != dialect and ori_func:
                        f = ori_func().func
                    else:
                        f = func
                    return kwargs_func(f, *args, **kwargs)

                # make new delayed function
                dlf = DelayedFunc(new_func, args, kwargs, window_func=window_func)
                return dlf
        return wrapper

`CustomFunction` is a convenient class to create a custom SQL function with the name, and positional arguments, if you don't need the function to be parsed differently for different dialects.

In [None]:
date_diff = CustomFunction('DATE_DIFF', ['interval', 'start_date', 'end_date'])
test_eq(date_diff('month', Field('date1'), Field('date2')).get_sql(), 'DATE_DIFF(month, date1, date2)')

For functions that need to be parsed differently for different dialects, you can use the `custom_func` decorator.

In [None]:
@custom_func
def add_months(column, num, dialect='sql'):
    if dialect=='sql':
        return f'DATE_ADD(month, {num}, {column})'
    elif dialect=='snowflake':
        return f'MONTH_ADD({column}, {num})'


test_eq((add_months("col1", 3)-2 > 2).get_sql(), 'DATE_ADD(month, 3, col1) - 2 > 2')
test_eq((add_months("col1", 3)-2 > 2).get_sql(dialect='snowflake'), 'MONTH_ADD(col1, 3) - 2 > 2')

You can even overwrite or extend the existing functions by using the `custom_func` decorator with specified `dialects` keyword argument.

In [None]:
@custom_func(dialect='athena')
def add_months(column, num):
    return f"DATE_ADD('month', {num}, {column})"

test_eq((add_months("col1", 5)).exec(dialect='athena'),
        "DATE_ADD('month', 5, col1)")

## Window Clause

In [None]:
#| export
class OverClause:
    def __init__(self, expression) -> None:
        if not (isinstance(expression, ArithmeticExpression) 
                or 
                (isinstance(expression, DelayedFunc) 
                 and 
                 expression.window_func==True
                )
                or type(expression) is str):
            raise ValueError(f"Expression has to be of the ArithmeticExpression type or a Window Function (DelayedFunc with window_func=True)!!")
        self.expr = expression
        self.alias = None
        self.rows_flag = False
        self.range_flag = False
        self.d = {}

    def over(self, q):
        self.d['PARTITION BY'] = q
        return self

    def orderby(self, q):
        self.d['ORDER BY'] = q
        return self

    def _check_rows_or_range(self):
        if self.rows_flag==True: 
            raise ValueError(f"ROWS already set!")
        if self.range_flag==True:
            raise ValueError(f"RANGE already set!")

    def rows(self, start, end):
        self._check_rows_or_range()
        self.rows = True
        self.d['ROWS'] = (start, end)
        return self

    def range(self, start, end):
        self._check_rows_or_range()
        self.range = True
        self.d['RANGE'] = (start, end)
        return self

    def as_(self, alias):
        self.alias = alias
        return self

    def _resolve_over_statement(self, **kwargs):
        sql = []
        for k in ['PARTITION BY', 'ORDER BY', 'ROWS', 'RANGE']:
            if k in self.d:
                if k in ['PARTITION BY', 'ORDER BY']:
                    rslvd = f"{k} {exec(self.d[k], **kwargs)}"
                else:
                    start, end = self.d[k]
                    rslvd = f"{k} BETWEEN {exec(start, **kwargs)} AND {exec(end, **kwargs)}"
                sql.append(rslvd)
        return f"{exec(self.expr, **kwargs)} OVER ({' '.join(sql)})"

    def exec(self, **kwargs):
        sql = self._resolve_over_statement(**kwargs)
        if self.alias:
            return f"{sql} AS {self.alias}"
        else:
            return sql


class Preceding:
    def __init__(self, N=None) -> None:
        self.N = N

    def exec(self, **kwargs):
        if self.N:
            return f"{self.N} PRECEDING"
        else:
            return "UNBOUNDED PRECEDING"


class Following:
    def __init__(self, N=None) -> None:
        self.N = N

    def exec(self, **kwargs):
        if self.N:
            return f"{self.N} FOLLOWING"
        else:
            return "UNBOUNDED FOLLOWING"


CURRENT_ROW = "CURRENT_ROW"

In [None]:
#| export
@patch
def over(self:ArithmeticExpression, partition_by):
    return OverClause(self).over(partition_by)

@patch
def over(self:DelayedFunc, partition_by):
    if self.window_func==True:
        return OverClause(self).over(partition_by)
    else:
        raise ValueError(f"This function is not a window function!!")


In [None]:
test_eq(
    OverClause('SUM(col1)').over('col0').orderby('col2').rows(Preceding(3), CURRENT_ROW).exec(), 
    'SUM(col1) OVER (PARTITION BY col0 ORDER BY col2 ROWS BETWEEN 3 PRECEDING AND CURRENT_ROW)')

test_eq(
    ((Field('col2')+2)/10).over('col1').orderby('col3').range(Preceding(), Following(2)).exec(),
    '(col2 + 2) / 10 OVER (PARTITION BY col1 ORDER BY col3 RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)'
)

## Case

In [None]:
#| export
class Case:
    def __init__(self) -> None:
        self.dp = []
        self.alias = None

    def check_prev(self, statement):
        if self.dp:
            prev = self.dp[0][0]
            if prev == statement:
                return True
        return False

    def when(self, q, then):
        if self.check_prev('ELSE'):
            raise ValueError(f"'WHEN' can not follow 'ELSE'!")
        self.dp.append(('WHEN', q, then))
        return self

    def else_(self, q):
        self.dp.append(('ELSE', q))
        return self

    def _as(self, alias):
        self.alias = alias
        return self

    def exec(self, **kwargs):
        sql = ["CASE"]
        for item in self.dp:
            if item[0] == 'WHEN':
                q_resolved = f"WHEN {exec(item[1], **kwargs)} THEN {exec(item[2], **kwargs)}"
            else:
                q_resolved = f"ELSE {exec(item[1], **kwargs)}"
            sql.append(q_resolved)
        if self.alias:
            sql.append(f"END AS {self.alias}")
        else:
            sql.append("END")
        return '\n'.join(sql)

In [None]:
test_eq(Case().when(Field('column1')>3, True).else_(False).exec(), 
        'CASE\nWHEN column1 > 3 THEN True\nELSE False\nEND')
test_eq(Case().when(Field('column1')>3, 1).when(Field('column1')<1, -1).else_(0).exec(), 
        'CASE\nWHEN column1 > 3 THEN 1\nWHEN column1 < 1 THEN -1\nELSE 0\nEND')

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()