# Terms

> Relevant small pieces in a SQL query.

In [None]:
#| default_exp terms
#| hide
from nbdev.showdoc import *
from fastcore.test import *
# allow multiple output from one cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [None]:
#|export
import inspect
from functools import partial
from fastcore.foundation import patch
from pikaQ.utils import execute, delegates

## Value

In [None]:
#| export
class Term:
    def __init__(self) -> None:
        self.alias = None
    
    def as_(self, alias):
        self.alias = alias
        return self


class Value(Term):
    """A simple wrapper for a value that will be used in a query."""
    def __init__(self, value) -> None:
        super().__init__()
        self.value = value
        self.get_sql = self.execute

    def execute(self, **kwargs):
        if type(self.value) is str:
            return f"'{self.value}' as {self.alias}" if self.alias else f"'{self.value}'" 
        elif type(self.value) in (list, tuple):
            return '(' + ', '.join([f"{item}" if type(item) in (int, float) else f"'{item}'" for item in self.value]) + ')'
        elif hasattr(self.value, 'get_sql'):
            return self.value.get_sql(**kwargs)
        else:
            return f"{self.value} as {self.alias}" if self.alias else f"{self.value}"


class NullValue(Term):
    """NULL value for use in queries."""
    def __init__(self) -> None:
        super().__init__()
        self.get_sql = self.execute

    def execute(self, **kwargs):
        return f"NULL as {self.alias}" if self.alias else f"NULL"

In [None]:
test_eq(Value(2).get_sql(), '2')
test_eq(Value('abc').get_sql(), "'abc'")
test_eq(Value(2).as_('col1').get_sql(), '2 as col1')
test_eq(Value('abc').as_('col1').get_sql(), "'abc' as col1")
test_eq(Value([1,2,3]).get_sql(), '(1, 2, 3)')
test_eq(Value(['col1', 'col2', 'col3']).get_sql(), "('col1', 'col2', 'col3')")
test_eq(NullValue().get_sql(), 'NULL')
test_eq(NullValue().as_('col1').get_sql(), 'NULL as col1')
test_eq(Value(Value(2)).get_sql(), '2')

## Field

In [None]:
#| export
class FieldBase(Term):
    """Collection of methods to convert to ArithmeticExpression and Criteria"""
    def __add__(self, other):
        return ArithmeticExpression(self, '+', other)

    def __radd__(self, other):
        return ArithmeticExpression(other, '+', self)

    def __sub__(self, other):
        return ArithmeticExpression(self, '-', other)

    def __rsub__(self, other):
        return ArithmeticExpression(other, '-', self)

    def __mul__(self, other):
        return ArithmeticExpression(self, '*', other)
        
    def __rmul__(self, other):
        return ArithmeticExpression(other, '*', self)

    def __truediv__(self, other):
        return ArithmeticExpression(self, '/', other)

    def __rtruediv__(self, other):
        return ArithmeticExpression(other, '/', self)

    def __gt__(self, other):
        return Criteria(self, '>', other)

    def __ge__(self, other):
        return Criteria(self, '>=', other)
        
    def __eq__(self, other):
        return Criteria(self, '=', other)

    def __ne__(self, other):
        return Criteria(self, '<>', other)

    def __le__(self, other):
        return Criteria(self, '<=', other)

    def __lt__(self, other):
        return Criteria(self, '<', other)

    gt = __gt__
    ge = __ge__
    eq = __eq__
    ne = __ne__
    le = __le__
    lt = __lt__

    def isnull(self):
        return Criteria(self, 'IS', NullValue())

    def notnull(self):
        return Criteria(self, 'IS NOT', NullValue())

    def like(self, other):
        return Criteria(self, 'LIKE', Value(other))
    
    def ilike(self, other):
        return Criteria(self, 'ILIKE', Value(other))

    def not_like(self, other):
        return Criteria(self, 'NOT LIKE', Value(other))

    def not_ilike(self, other):
        return Criteria(self, 'NOT ILIKE', Value(other))
    
    def isin(self, other:list):
        return Criteria(self, 'IN', Value(other))
    
    def notin(self, other):
        return Criteria(self, 'NOT IN', Value(other))

    #TODO: implement new Criteria, and execute for list


class Field(FieldBase):
    """A simple wrapper for a field that will be used in a query. Quotes can be added to the field name by setting the `quote_char` parameter in get_sql() method."""
    def __init__(self, name) -> None: 
        self.name = name
        self.alias = None
        self.get_sql = self.execute

    def quoted_name(self, quote_char):
        name_list = self.name.split('.')
        name = '.'.join([f"{quote_char}{n}{quote_char}" for n in name_list])
        return name

    def execute(self, **kwargs):
        q = kwargs.get('quote_char', '') or ''
        quoted_name = self.quoted_name(q)
        if self.alias:
            return f"{quoted_name} AS {q}{self.alias}{q}"
        else:
            return quoted_name
    

class ArithmeticExpression(FieldBase):
    """Constructor for arithmetic expressions from two terms and automatically adds parentheses in appropriate places."""
    add_order = ["+", "-"]

    def __init__(self, this, op, other) -> None:
        super().__init__()
        self.this, self.op, self.other = this, op, other
        self.alias = None
        self.get_sql = self.execute

    def left_needs_parens(self, left_op, curr_op) -> bool:
        """
        Returns true if the expression on the left of the current operator needs to be enclosed in parentheses.
        :param current_op:
            The current operator.
        :param left_op:
            The highest level operator of the left expression.
        """
        if left_op is None or curr_op in self.add_order:
            # If the left expression is a single item.
            # or if the current operator is '+' or '-'.
            return False
        
        # The current operator is '*' or '/'. 
        # If the left operator is '+' or '-', we need to add parentheses:
        # e.g. (A + B) / ..., (A - B) / ...
        # Otherwise, no parentheses are necessary:
        # e.g. A * B / ..., A / B / ...
        return left_op in self.add_order

    def right_needs_parens(self, curr_op, right_op) -> bool:
        """
        Returns true if the expression on the right of the current operator needs to be enclosed in parentheses.
        :param current_op:
            The current operator.
        :param right_op:
            The highest level operator of the right expression.
        """
        if right_op is None:
            # If the right expression is a single item.
            return False
        # If the right operator is '+' or '-', we always add parentheses:
        # e.g. ... - (A + B), ... - (A - B), ... + (A + B)
        # Otherwise, no parentheses are necessary:
        # e.g. ... - A / B, ... - A * B
        return right_op in self.add_order

    def execute(self, **kwargs):
        q = kwargs.get('quote_char', '') or ''
        if self.this.__class__ is ArithmeticExpression:
            this = self.this.execute(**kwargs)
            this = f"({this})" if self.left_needs_parens(self.this.op, self.op) else this
        elif self.this.__class__ is Field:
            this = self.this.alias or execute(self.this, **kwargs)
        else:
            this = execute(self.this, **kwargs)

        if self.other.__class__ is ArithmeticExpression:
            other = self.other.execute(**kwargs)
            other = f"({other})" if self.right_needs_parens(self.op, self.other.op) else other
        elif self.other.__class__ is Field:
            other = self.other.alias or execute(self.other, **kwargs)
        else:
            other = execute(self.other, **kwargs)

        if self.alias:
            return f"{this} {self.op} {other} AS {q}{self.alias}{q}"
        else:
            return f"{this} {self.op} {other}"


class Criteria:
    """Constructor for criteria from two terms and automatically adds parentheses in appropriate places."""
    compose_ops = ('and', 'or')

    def __init__(self, this, op, other) -> None:
        super().__init__()
        self.this, self.op, self.other = this, op, other
        self.add_parentheses = False
        self.not_ = False
        self.get_sql = self.execute

    def compose_criteria(self, op, other):
        """Add parentheses when operator in criteria is different from op"""
        if other.__class__ is Criteria:
            if self.op in self.compose_ops and op in self.compose_ops and self.op != op:
                self.add_parentheses = True
            if other.op in self.compose_ops and op in self.compose_ops and other.op != op:
                other.add_parentheses = True
        return Criteria(self, op, other)

    @staticmethod
    def resolve(obj, **kwargs):
        if obj.__class__ is Criteria:
            obj_c = obj.execute(**kwargs)
            obj_c = f"({obj_c})" if obj.add_parentheses == True else obj_c
        else:
            obj_c = execute(obj, **kwargs)
        return obj_c
    
    def execute(self, **kwargs):
        this = self.resolve(self.this, **kwargs)
        other = self.resolve(self.other, **kwargs)
        self.result =  f"{this} {self.op} {other}"
        if self.not_:
            return f"NOT ({self.result})"
        else:
            return self.result
    
    def negate(self):
        self.not_ = True
        return self

    def __and__(self, __o):
        return self.compose_criteria('and', __o)

    def __or__(self, __o):
        return self.compose_criteria('or', __o)

In [None]:
a = Field('a')
b = Field('b')
c = Field('c')

# Test Field
test_eq(a.get_sql(), 'a')
test_eq(c.as_('c1').get_sql(), 'c AS c1')

# Test ArithmeticExpression
test_eq((c-1).as_('new_c').get_sql(), 'c1 - 1 AS new_c')
test_eq((a+1<13).get_sql(), 'a + 1 < 13')
test_eq(((a + 1)/3).get_sql(), '(a + 1) / 3')
test_eq(((a + 1)/(b - 4)).get_sql(), '(a + 1) / (b - 4)')
test_eq(((a + 1>2) & ((b-1<10) | (b>23)) ).get_sql(), 
        'a + 1 > 2 and (b - 1 < 10 or b > 23)')
test_eq((((a + 1>2) & (b-1<=10)) | (b>100)).get_sql(), '(a + 1 > 2 and b - 1 <= 10) or b > 100')

# Test Criteria
test_eq(a.eq(b).get_sql(), 'a = b')
test_eq(a.ne(b).get_sql(), 'a <> b')
test_eq(a.gt(b).get_sql(), 'a > b')
test_eq(a.ge(b).get_sql(), 'a >= b')
test_eq(a.lt(b).get_sql(), 'a < b')
test_eq(a.le(b).get_sql(), 'a <= b')
test_eq(a.like(b).get_sql(), 'a LIKE b')
test_eq(a.not_like(b).get_sql(), 'a NOT LIKE b')
test_eq(a.ilike('%what').get_sql(), "a ILIKE '%what'")
test_eq(a.not_ilike('%hh%').get_sql(), "a NOT ILIKE '%hh%'")
test_eq(a.isin([2, 3, 5]).get_sql(), 'a IN (2, 3, 5)')
test_eq(a.notin([2, 3, 5]).get_sql(), 'a NOT IN (2, 3, 5)')
test_eq(a.isnull().get_sql(), 'a IS NULL')
test_eq(a.notnull().get_sql(), 'a IS NOT NULL')

# Test negate
test_eq((a-1>1).negate().get_sql(), 'NOT (a - 1 > 1)')

# Test quoted field
d = Field('tbl.d')
test_eq(a.get_sql(quote_char='"'), '"a"')
test_eq(d.get_sql(quote_char='"'), '"tbl"."d"')
test_eq((d - 1).get_sql(quote_char='"'), '"tbl"."d" - 1')
test_eq((d - 1 > 2).get_sql(quote_char='"'), '"tbl"."d" - 1 > 2')
test_eq((d - 3).as_('new_d').get_sql(quote_char='"'), '"tbl"."d" - 3 AS "new_d"')

## Custom Functions

In [None]:
#| export
def kwargs_func(func, *args, **kwargs):
    "Allow arbitrary kwargs. Only pass those kwargs that are specified in func to func."
    sig = inspect.signature(func)
    param = sig.parameters
    func_kwargs = {k:v for k, v in param.items() if v.default!=inspect._empty}
    kwargs = {k:v for k, v in kwargs.items() if k in func_kwargs}
    return func(*args, **kwargs)


class DelayedFunc(FieldBase):
    """Delay the execution of stored function until exec is run."""
    def __init__(self, 
                 func, 
                 args,
                 kwargs, 
                 window_func=True # whether this function is a window function
                 ) -> None: 
        super().__init__()
        self.func = func
        self.args = args
        self.kwargs = kwargs
        self.window_func = window_func
        self.get_sql = self.execute
    

    def execute(self, **kwargs):
        """keyword arguments can be overwritten with any provided new kwargs."""
        self.kwargs.update(kwargs)
        # recursively resolve all delayed functions
        args = (execute(arg, **self.kwargs) for arg in self.args) 
        func = self.func(*args, **self.kwargs)
        if self.alias:
            return f"{func} AS {self.alias}"
        else:
            return func


def custom_func(func=None, window_func=False, dialect=None):
    """return Field"""
    if func is None: 
        return partial(custom_func, window_func=window_func, dialect=dialect)
    else:
        if dialect is None:
            @delegates(func)
            def wrapper(*args, **kwargs):
                dlf = DelayedFunc(func, args, kwargs, window_func=window_func)
                return dlf
        else:
            # get previously defined func
            func_name = func.__name__
            ori_func = globals().get(func_name)
            @delegates(func)
            def wrapper(*args, **kwargs):
                def new_func(*args, **kwargs):
                    # if the dialect is different from the one defined for this new function, fall back to the original func
                    if kwargs['dialect'] != dialect and ori_func:
                        f = ori_func().func
                    else:
                        f = func
                    return kwargs_func(f, *args, **kwargs)

                # make new delayed function
                dlf = DelayedFunc(new_func, args, kwargs, window_func=window_func)
                return dlf
        return wrapper

For functions that need to be parsed differently for different dialects, you can use the `custom_func` decorator.

In [None]:
@custom_func
def add_months(column, num, dialect='sql'):
    if dialect=='sql':
        return f'DATE_ADD(month, {num}, {column})'
    elif dialect=='snowflake':
        return f'MONTH_ADD({column}, {num})'


test_eq((add_months("col1", 3)-2 > 2).get_sql(), 'DATE_ADD(month, 3, col1) - 2 > 2')
test_eq((add_months("col1", 3)-2 > 2).get_sql(dialect='snowflake'), 'MONTH_ADD(col1, 3) - 2 > 2')

You can even overwrite or extend the existing functions by using the `custom_func` decorator with specified `dialects` keyword argument.

In [None]:
@custom_func(dialect='athena')
def add_months(column, num):
    return f"DATE_ADD('month', {num}, {column})"

test_eq((add_months("col1", 5).as_('new_date')).get_sql(dialect='athena'),
        "DATE_ADD('month', 5, col1) AS new_date")

## Window Clause

In [None]:
#| export
class OverClause:
    """Constructor for OVER clause."""
    def __init__(self, expression) -> None:
        if not (
            (hasattr(expression, 'window_func')
             and 
             expression.window_func==True)
            or 
            (type(expression) is str)
        ):
            raise ValueError(f"Expression has to be of the ArithmeticExpression type or a Window Function (DelayedFunc with window_func=True)!!")
        self.expr = expression
        self.alias = None
        self.rows_flag = False
        self.range_flag = False
        self.d = {}
        self.get_sql = self.execute

    def over(self, q):
        self.d['PARTITION BY'] = q
        return self

    def orderby(self, q):
        self.d['ORDER BY'] = q
        return self

    def _check_rows_or_range(self):
        if self.rows_flag==True: 
            raise ValueError(f"ROWS already set!")
        if self.range_flag==True:
            raise ValueError(f"RANGE already set!")

    def rows(self, start, end):
        self._check_rows_or_range()
        self.rows = True
        self.d['ROWS'] = (start, end)
        return self

    def range(self, start, end):
        self._check_rows_or_range()
        self.range = True
        self.d['RANGE'] = (start, end)
        return self

    def as_(self, alias):
        self.alias = alias
        return self

    def _resolve_over_statement(self, **kwargs):
        sql = []
        for k in ['PARTITION BY', 'ORDER BY', 'ROWS', 'RANGE']:
            if k in self.d:
                if k in ['PARTITION BY', 'ORDER BY']:
                    rslvd = f"{k} {execute(self.d[k], **kwargs)}"
                else:
                    start, end = self.d[k]
                    rslvd = f"{k} BETWEEN {execute(start, **kwargs)} AND {execute(end, **kwargs)}"
                sql.append(rslvd)
        return f"{execute(self.expr, **kwargs)} OVER ({' '.join(sql)})"

    def execute(self, **kwargs):
        sql = self._resolve_over_statement(**kwargs)
        if self.alias:
            return f"{sql} AS {self.alias}"
        else:
            return sql


class Preceding:
    def __init__(self, N=None) -> None:
        self.N = N
        self.get_sql = self.execute

    def execute(self, **kwargs):
        if self.N:
            return f"{self.N} PRECEDING"
        else:
            return "UNBOUNDED PRECEDING"


class Following:
    def __init__(self, N=None) -> None:
        self.N = N
        self.get_sql = self.execute

    def execute(self, **kwargs):
        if self.N:
            return f"{self.N} FOLLOWING"
        else:
            return "UNBOUNDED FOLLOWING"


CURRENT_ROW = "CURRENT_ROW"

In [None]:
#| export
@patch
def over(self:ArithmeticExpression, partition_by):
    return OverClause(self).over(partition_by)

@patch
def over(self:DelayedFunc, partition_by):
    if self.window_func==True:
        return OverClause(self).over(partition_by)
    else:
        raise ValueError(f"This function is not a window function!!")


In [None]:
test_eq(
    OverClause('SUM(col1)').over('col0').orderby('col2').rows(Preceding(3), CURRENT_ROW).get_sql(), 
    'SUM(col1) OVER (PARTITION BY col0 ORDER BY col2 ROWS BETWEEN 3 PRECEDING AND CURRENT_ROW)')

@custom_func(window_func=True)
def LAG(column, offset=1, default=None):
    if default is None:
        return f"LAG({column}, {offset})"
    else:
        return f"LAG({column}, {offset}, {default})"

test_eq(
    LAG(Field('col2'), 1).over('col1').orderby('col3').range(Preceding(), Following(2)).get_sql(),
    'LAG(col2, 1) OVER (PARTITION BY col1 ORDER BY col3 RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)'
)

## Case

In [None]:
#| export
class Case(Term):
    """Constructor for CASE statement."""
    def __init__(self) -> None:
        self.dp = []
        self.alias = None
        self.get_sql = self.execute

    def check_prev(self, statement):
        if self.dp:
            prev = self.dp[0][0]
            if prev == statement:
                return True
        return False

    def when(self, q, then):
        if self.check_prev('ELSE'):
            raise ValueError(f"'WHEN' can not follow 'ELSE'!")
        self.dp.append(('WHEN', q, then))
        return self

    def else_(self, q):
        self.dp.append(('ELSE', q))
        return self

    def execute(self, **kwargs):
        sql = ["CASE"]
        for item in self.dp:
            if item[0] == 'WHEN':
                q_resolved = f"WHEN {execute(item[1], **kwargs)} THEN {execute(item[2], **kwargs)}"
            else:
                q_resolved = f"ELSE {execute(item[1], **kwargs)}"
            sql.append(q_resolved)
        if self.alias:
            sql.append(f"END AS {self.alias}")
        else:
            sql.append("END")
        return '\n'.join(sql)

In [None]:
test_eq(Case().when(Field('column1')>3, True).else_(False).get_sql(), 
        'CASE\nWHEN column1 > 3 THEN True\nELSE False\nEND')
test_eq(Case().when(Field('column1')>3, 1).when(Field('column1')<1, -1).else_(0).as_('b').get_sql(), 
        'CASE\nWHEN column1 > 3 THEN 1\nWHEN column1 < 1 THEN -1\nELSE 0\nEND AS b')

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()