# core

> Fill in a module description here

In [1]:
#| default_exp core
#| export
import builtins
import types
import sys
import inspect

from functools import partial
from fastcore.basics import *
from fastcore.meta import *
from typing import Union

try: from types import UnionType
except ImportError: UnionType = None

In [2]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *
# allow multiple output from one cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

## Utilities

In [3]:
#| export
def quote_symbol(quote):
    """generate quote symbol to use for tables and columns"""
    if type(quote) == str:
        return quote
    else:
        if quote == True:
            return '"'
        else:
            return ''


def to_sql(l):
    if type(l) == list:
        return '(' + ', '.join([f"{item}" if type(item) in (int, float) else f"'{item}'" for item in l]) + ')'
    else:
        raise ValueError(f"type {type(l)} for l is not implemented!")


def exec(obj, **kwargs):
    if hasattr(obj, 'exec'):
        return obj.exec(**kwargs)
    else:
        return str(obj)

## Delay the Execution of Functions

In [4]:
#| export
def _exec(obj, **kwargs):
    if hasattr(obj, 'exec'):
        return obj.exec(**kwargs)
    else:
        return obj


class DelayedFunc:
    """Delay the execution of stored function until exec is run."""
    def __init__(self, func, args, kwargs, order=None) -> None: 
        store_attr()

    def exec(self, **kwargs):
        """keyword arguments can be overwritten with any provided new kwargs."""
        self.kwargs.update(kwargs)
        # recursively resolve all delayed functions
        args = (_exec(arg, **self.kwargs) for arg in self.args) 
        return self.func(*args, **self.kwargs)


def delayed_func(func):
    """return Delayed function"""
    def wrapper(*args, **kwargs):
        return DelayedFunc(func, args, kwargs)
    return wrapper

## Delay the Executionf of Methods

However, this is not enough for our purpose. We also need the ability to delay all instance methods until `.exec` is called. To expand this functionality, we introduce `DelayedPipeline` and the decorator `@delayed_methods`.

In [5]:
@delayed_func
def add_months(column, num, dialect='sql'):
    if dialect=='sql':
        return f'DATE_ADD(month, {num}, {column})'
    elif dialect=='snowflake':
        return f'MONTH_ADD({column}, {num})'


@delayed_func
def to_date(date, format=None, dialect='sql'):
    if dialect in ('sql', 'snowflake'):
        if format is None:
            return f"TO_DATE('{date}')"
        else:
            return f"TO_DATE('{date}', '{format}')"
    

test_eq(
    add_months(to_date('2020-01-01'), 1).exec(dialect='snowflake'),
    "MONTH_ADD(TO_DATE('2020-01-01'), 1)")
test_eq(
    add_months(to_date('2020-01-01'), 1).exec(dialect='sql'),
    "DATE_ADD(month, 1, TO_DATE('2020-01-01'))")

In [6]:
#| export
# class DelayedMethod:
#     """Delay the execution of a method until exec is run."""

#     def __init__(self,
#                  method, 
#                  _self,  # "self" for the Object the method belongs to
#                  args, 
#                  kwargs, 
#                  order=None
#                  ) -> None:
#         store_attr()

#     def exec(self, **kwargs):
#         """keyword arguments can be overwritten with any provided new kwargs."""
#         self.kwargs.update(kwargs)
#         # recursively resolve all delayed functions
#         args = (_exec(arg, **self.kwargs) for arg in self.args)
#         return self.method(self._self, *args, **self.kwargs)


# def delayed_method(func, order=None):
#     """wrap func in DelayedFunc and append it to self.dp"""
#     def wrapper(self, *args, **kwargs):
#         self.dp.append(DelayedMethod(func, self, args, kwargs, order=order))
#         return self
#     return wrapper


# def no_delay(self, func):
#     def inner(*args, **kwargs):
#         return func(self, *args, **kwargs)
#     return inner


def as_(self, alias):
    if self.res:
        self.alias = alias
        return self.res + f" AS {alias}"
    else:
        raise ValueError(f"self.res is {self.res}, can not append AS statement!")


class DelayedPipeline:
    """Execute delayed methods in self.dp in order when .exec is called."""
    dic = {
        'as_': {
            'order': None,
            'dialect': {
                'sql': as_
            }
        }
    }

    def __init__(self) -> None: 
        self.dp = []  # record all delayed functions
        self.ress = []  # save all intermediate results
        self.alias_allowed=True

    def __getattr__(self, attr):
        def _func(*args, dialect='sql', **kwargs):
            return self.dic[attr]['dialect'][dialect](self, *args, **kwargs)

        def wrapper(*args, **kwargs):
            dlf = DelayedFunc(_func, args, kwargs)
            dlf.order = self.dic[attr]['order']
            self.dp.append(dlf)
            return self
        return wrapper

    @property
    def res(self):
        if self.ress:
            return self.ress[-1]

    def _order_dp(self):
        dp_none = [d for d in self.dp if d.order is None]
        dp_order = [d for d in self.dp if d.order is not None]
        dp_order.sort(key=lambda x: x.order)
        self.dp = dp_order + dp_none

    def exec(self, **kwargs):
        self._order_dp()
        if self.dp:
            for d in self.dp:
                # save all intermediate steps in resolving the sql query
                self.ress.append(d.exec(**kwargs))
            self.dp = [] # clear the pipeline
        return self.res


def eval_type(t):
    "`eval` a type or collection of types, if needed, for annotations in py3.10+"
    if isinstance(t,str):
        if '|' in t: return Union[eval_type(tuple(t.split('|')))]
        return eval(t)
    if isinstance(t,(tuple,list)): return type(t)([eval_type(c) for c in t])
    return t


def union2tuple(t):
    if (getattr(t, '__origin__', None) is Union
        or (UnionType and isinstance(t, UnionType))): return t.__args__
    return t


def patch_to(func, dialect='sql'):
    ann = getattr(func, '__annotations__', None)
    nm = func.__name__
    cls = union2tuple(eval_type(next(iter(ann.values()))))
    if not isinstance(cls, (tuple,list)): cls=(cls,)
    for c_ in cls: 
        dic = getattr(c_, 'dic', None)
        if dic is not None:
            if nm not in dic:
                dic[nm] = {'order': None, 'dialect': {}}
            subdic = dic[nm]['dialect']
            subdic.update({dialect: func})
            dic[nm]['dialect'] = subdic
            c_.dic = dic
        else:
            raise ValueError(f"{c_} does not have class attribute `dic`, and can not be patched by path_method!")
    return func

        
def patch_method(func=None, dialect='sql'):
    if func is None: return partial(patch_method, dialect=dialect)
    return patch_to(func, dialect)

In [7]:
p = DelayedPipeline()
p.ress.append('field')
test_eq(p.as_('f').exec(dialect='sql'), 'field AS f')

In [8]:
@patch_method(dialect='snowflake')
def as_(self: DelayedPipeline, alias):
    self.alias = alias
    return f"{self.res} AS {alias}2"

dic2 = DelayedPipeline.dic
test_eq(dic2['as_']['dialect'].get('snowflake', None), as_)

p = DelayedPipeline()
p.ress.append('field')
test_eq(p.as_('f').exec(dialect='snowflake'), 'field AS f2')

`order_dp` re-orders self.dp by each DelayedFunc's `.order` attribute.
If `.order=None`, they are appended at the end without any re-ordering.

In [9]:
def f(*args, **kwargs):
    return args[0]

a = DelayedFunc(f, ('a',0), {}, order=2)
b = DelayedFunc(f, ('b',0), {}, order=1)
c = DelayedFunc(f, ('c',0), {}, order=None)
d = DelayedFunc(f, ('d',0), {})

p = DelayedPipeline()
p.dp = [a, b, c, d]
p._order_dp()

test_eq(
    [item.args for item in p.dp],
    [('b', 0), ('a', 0), ('c', 0), ('d', 0)]
) 


## Arithmetic Expression

In [10]:
#| export
def _over(self, partition_by):
    return OverClause(self).over(partition_by)


class Field(DelayedPipeline):
    def __init__(self, name=None, window_func=True) -> None: 
        super().__init__()
        self.alias_allowed=True
        if window_func: 
            Field.over=_over
        if name:
            self.alias = name
            self.ress.append(name)
        self.sql = self.res
    
    def __add__(self, other):
        return ArithmeticExpression(self, '+', other)

    def __radd__(self, other):
        return ArithmeticExpression(other, '+', self)

    def __sub__(self, other):
        return ArithmeticExpression(self, '-', other)

    def __rsub__(self, other):
        return ArithmeticExpression(other, '-', self)

    def __mul__(self, other):
        return ArithmeticExpression(self, '*', other)
        
    def __rmul__(self, other):
        return ArithmeticExpression(other, '*', self)

    def __truediv__(self, other):
        return ArithmeticExpression(self, '/', other)

    def __rtruediv__(self, other):
        return ArithmeticExpression(other, '/', self)

    def __gt__(self, other):
        return Criteria(self, '>', other)

    def __ge__(self, other):
        return Criteria(self, '>=', other)
        
    def __eq__(self, other):
        return Criteria(self, '==', other)

    def __ne__(self, other):
        return Criteria(self, '!=', other)

    def __le__(self, other):
        return Criteria(self, '<=', other)

    def __lt__(self, other):
        return Criteria(self, '<', other)


class ArithmeticExpression(Field):
# class ArithmeticExpression:
    add_order = ["+", "-"]

    def __init__(self, this, op, other) -> None:
        super().__init__()
        self.this, self.op, self.other = this, op, other

    def __add__(self, other):
        return ArithmeticExpression(self, '+', other)

    def __radd__(self, other):
        return ArithmeticExpression(other, '+', self)

    def __sub__(self, other):
        return ArithmeticExpression(self, '-', other)

    def __rsub__(self, other):
        return ArithmeticExpression(other, '-', self)

    def __mul__(self, other):
        return ArithmeticExpression(self, '*', other)
        
    def __rmul__(self, other):
        return ArithmeticExpression(other, '*', self)

    def __truediv__(self, other):
        return ArithmeticExpression(self, '/', other)

    def __rtruediv__(self, other):
        return ArithmeticExpression(other, '/', self)
    
    def __gt__(self, other):
        return Criteria(self, '>', other)

    def __ge__(self, other):
        return Criteria(self, '>=', other)
        
    def __eq__(self, other):
        return Criteria(self, '==', other)

    def __ne__(self, other):
        return Criteria(self, '!=', other)

    def __le__(self, other):
        return Criteria(self, '<=', other)

    def __lt__(self, other):
        return Criteria(self, '<', other)

    def left_needs_parens(self, left_op, curr_op) -> bool:
        """
        Returns true if the expression on the left of the current operator needs to be enclosed in parentheses.
        :param current_op:
            The current operator.
        :param left_op:
            The highest level operator of the left expression.
        """
        if left_op is None or curr_op in self.add_order:
            # If the left expression is a single item.
            # or if the current operator is '+' or '-'.
            return False
        
        # The current operator is '*' or '/'. 
        # If the left operator is '+' or '-', we need to add parentheses:
        # e.g. (A + B) / ..., (A - B) / ...
        # Otherwise, no parentheses are necessary:
        # e.g. A * B / ..., A / B / ...
        return left_op in self.add_order

    def right_needs_parens(self, curr_op, right_op) -> bool:
        """
        Returns true if the expression on the right of the current operator needs to be enclosed in parentheses.
        :param current_op:
            The current operator.
        :param right_op:
            The highest level operator of the right expression.
        """
        if right_op is None:
            # If the right expression is a single item.
            return False
        # If the right operator is '+' or '-', we always add parentheses:
        # e.g. ... - (A + B), ... - (A - B), ... + (A + B)
        # Otherwise, no parentheses are necessary:
        # e.g. ... - A / B, ... - A * B
        return right_op in self.add_order

    def exec(self, **kwargs):
        if self.this.__class__ is ArithmeticExpression:
            this = self.this.exec(**kwargs)
            this = f"({this})" if self.left_needs_parens(self.this.op, self.op) else this
        elif getattr(self.this, 'exec', None):
            this = self.this.exec(**kwargs)
        else:
            this = str(self.this)

        if self.other.__class__ is ArithmeticExpression:
            other = self.other.exec(**kwargs)
            other = f"({other})" if self.right_needs_parens(self.op, self.other.op) else other
        elif getattr(self.other, 'exec', None):
            other = self.other.exec(**kwargs)
        else:
            other = str(self.other)
        return f"{this} {self.op} {other}"


class Criteria(Field):
    compose_ops = ('and', 'or')

    def __init__(self, this, op, other) -> None:
        super().__init__()
        self.this, self.op, self.other = this, op, other
        self.add_parentheses = False

    def compose_criteria(self, op, other):
        if other.__class__ is Criteria:
            if self.op in self.compose_ops and op in self.compose_ops and self.op != op:
                self.add_parentheses = True
            if other.op in self.compose_ops and op in self.compose_ops and other.op != op:
                other.add_parentheses = True
        return Criteria(self, op, other)

    @staticmethod
    def resolve(obj, **kwargs):
        if obj.__class__ is Criteria:
            obj_c = obj.exec(**kwargs)
            obj_c = f"({obj_c})" if obj.add_parentheses == True else obj_c
        elif getattr(obj, 'exec', None):
            obj_c = obj.exec(**kwargs)
        else:
            obj_c = str(obj)
        return obj_c
    
    def exec(self, **kwargs):
        this = self.resolve(self.this, **kwargs)
        other = self.resolve(self.other, **kwargs)
        return f"{this} {self.op} {other}"

    def __and__(self, __o):
        return self.compose_criteria('and', __o)

    def __or__(self, __o):
        return self.compose_criteria('or', __o)


class OverClause:
    def __init__(self, expression) -> None:
        if not (isinstance(expression, Field) or type(expression) is str):
            raise ValueError(f"Expression has to be of the Field type (including ArithmeticExpression)!!")
        self.expr = expression
        self.alias = None
        self.rows_flag = False
        self.range_flag = False
        self.d = {}

    def over(self, q):
        self.d['PARTITION BY'] = q
        return self

    def orderby(self, q):
        self.d['ORDER BY'] = q
        return self

    def _check_rows_or_range(self):
        if self.rows_flag==True: 
            raise ValueError(f"ROWS already set!")
        if self.range_flag==True:
            raise ValueError(f"RANGE already set!")

    def rows(self, start, end):
        self._check_rows_or_range()
        self.rows = True
        self.d['ROWS'] = (start, end)
        return self

    def range(self, start, end):
        self._check_rows_or_range()
        self.range = True
        self.d['RANGE'] = (start, end)
        return self

    def as_(self, alias):
        self.alias = alias
        return self

    def _resolve_over_statement(self, **kwargs):
        sql = []
        for k in ['PARTITION BY', 'ORDER BY', 'ROWS', 'RANGE']:
            if k in self.d:
                if k in ['PARTITION BY', 'ORDER BY']:
                    rslvd = f"{k} {exec(self.d[k], **kwargs)}"
                else:
                    start, end = self.d[k]
                    rslvd = f"{k} BETWEEN {exec(start, **kwargs)} AND {exec(end, **kwargs)}"
                sql.append(rslvd)
        return f"{exec(self.expr, **kwargs)} OVER ({' '.join(sql)})"

    def exec(self, **kwargs):
        sql = self._resolve_over_statement(**kwargs)
        if self.alias:
            return f"{sql} AS {self.alias}"
        else:
            return sql


class Preceding:
    def __init__(self, N=None) -> None:
        self.N = N

    def exec(self, **kwargs):
        if self.N:
            return f"{self.N} PRECEDING"
        else:
            return "UNBOUNDED PRECEDING"


class Following:
    def __init__(self, N=None) -> None:
        self.N = N

    def exec(self, **kwargs):
        if self.N:
            return f"{self.N} FOLLOWING"
        else:
            return "UNBOUNDED FOLLOWING"


CURRENT_ROW = "CURRENT_ROW"



In [11]:
a = Field('a')
b = Field('b')
aa = (a + 1>2) & ((b-1<=10) | (b>100))
aa.exec()
(a+1<13).exec()

'a + 1 > 2 and (b - 1 <= 10 or b > 100)'

'a + 1 < 13'

In [12]:
a = Field('a')
b = Field('b')
test_eq(((a + 1)/3).exec(), '(a + 1) / 3')
test_eq(((a + 1)/(b - 4)).exec(), '(a + 1) / (b - 4)')
test_eq(((a + 1>2) & ((b-1<10) | (b>23)) ).exec(), 
        'a + 1 > 2 and (b - 1 < 10 or b > 23)')

In [13]:
test_eq(
    OverClause('SUM(col1)').over('col0').orderby('col2').rows(Preceding(3), CURRENT_ROW).exec(), 
    'SUM(col1) OVER (PARTITION BY col0 ORDER BY col2 ROWS BETWEEN 3 PRECEDING AND CURRENT_ROW)')

test_eq(
    ((Field('col2')+2)/10).over('col1').orderby('col3').range(Preceding(), Following(2)).exec(),
    '(col2 + 2) / 10 OVER (PARTITION BY col1 ORDER BY col3 RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)'
)

## Table

In [14]:
class Table:
    def __init__(self, name) -> None:
        store_attr()
        self.alias = name
        # self.name = self.alias = name

    def as_(self, alias):
        self.alias = alias
        return self

    def __getattr__(self, __name: str):
        if __name.startswith('__'):
            raise AttributeError
            # return super().__getattr__(__name)
        else:
            return Field(f"{self.alias}.{__name}")
    
    def exec(self, **kwargs):
        if self.alias != self.name:
            return f"{self.name} as {self.alias}"
        else:
            return self.name

In [15]:
vw = Table('vw')
vw.as_('a').exec()
(vw.column + 2 > 1).exec()

'vw as a'

'a.column + 2 > 1'

## Functions

In [16]:
def _kwargs_func(func, *args, **kwargs):
    "Allow arbitrary kwargs. Only pass those kwargs that are specified in func to func."
    sig = inspect.signature(func)
    param = sig.parameters
    func_kwargs = {k:v for k, v in param.items() if v.default!=inspect._empty}
    kwargs = {k:v for k, v in kwargs.items() if k in func_kwargs}
    return func(*args, **kwargs)


def custom_func(func=None, window_func=False, dialect=None):
    """return Field"""
    if func is None: 
        return partial(custom_func, window_func=window_func, dialect=dialect)
    else:
        if dialect is None:
            def wrapper(*args, **kwargs):
                dlf = DelayedFunc(func, args, kwargs)
                f = Field(window_func=window_func)
                f.dp.append(dlf)
                return f
        else:
            def wrapper(*args, **kwargs):
                # get the function
                func_name = func.__name__
                ori_func = globals()[func_name] 
                func

                # make new function
                def new_func(*args, **kwargs):
                    if kwargs['dialect'] == dialect:
                        # return func(*args, **kwargs)
                        return _kwargs_func(func, *args, **kwargs)
                    else:
                        func_o = ori_func.dp[0].func
                        return _kwargs_func(func_o, *args, **kwargs)

                # make new delayed function
                dlf = DelayedFunc(new_func, args, kwargs)

                # append the delayed function to Field.dp
                f = Field(window_func=window_func)
                f.dp.append(dlf)
                return f
        return wrapper


class CustomFunction(Field):
    def __init__(self, func_name, args) -> None:
        super().__init__()
        self.func_name = func_name
        self.args = args
    
    def __call__(self, *args):
        if len(args) != len(self.args):
            raise ValueError(f"The number of args provided {len(args)} is not the same as the number of args expected by this function ({len(self.args)})!")
        def func(*args):
            return f"{self.func_name}({', '.join(args)})"

        dlf = DelayedFunc(func, args, {})
        self.dp.append(dlf)
        return self


@custom_func
def add_months(column, num, dialect='sql'):
    if dialect=='sql':
        return f'DATE_ADD(month, {num}, {column})'
    elif dialect=='snowflake':
        return f'MONTH_ADD({column}, {num})'

In [17]:
date_diff = CustomFunction('DATE_DIFF', ['interval', 'start_date', 'end_date'])
test_eq(date_diff('month', Field('date1'), Field('date2')).exec(), 'DATE_DIFF(month, date1, date2)')

In [18]:
test_eq((add_months("col1", 3)-2 > 2).exec(), 'DATE_ADD(month, 3, col1) - 2 > 2')
test_eq((add_months("col1", 3)-2 > 2).exec(dialect='snowflake'), 'MONTH_ADD(col1, 3) - 2 > 2')

In [19]:
def test_func(s, se, ksew=2):
    return 0

import inspect
sig = inspect.signature(test_func)
param = sig.parameters
{k:v for k, v in param.items() if v.default!=inspect._empty}


{'ksew': <Parameter "ksew=2">}

In [20]:
@custom_func(dialect='athena')
def add_months(column, num):
    return f"DATE_ADD('month', {num}, {column})"

(add_months("col1", 5)).exec(dialect='athena')

"DATE_ADD('month', 5, col1)"

## Case

In [21]:
class Case:
    def __init__(self) -> None:
        self.dp = []
        self.alias = None

    def check_prev(self, statement):
        if self.dp:
            prev = self.dp[0][0]
            if prev == statement:
                return True
        return False

    def when(self, q, then):
        if self.check_prev('ELSE'):
            raise ValueError(f"'WHEN' can not follow 'ELSE'!")
        self.dp.append(('WHEN', q, then))
        return self

    def else_(self, q):
        self.dp.append(('ELSE', q))
        return self

    def _as(self, alias):
        self.alias = alias
        return self

    def exec(self, **kwargs):
        sql = ["CASE"]
        for item in self.dp:
            if item[0] == 'WHEN':
                q_resolved = f"WHEN {exec(item[1], **kwargs)} THEN {exec(item[2], **kwargs)}"
            else:
                q_resolved = f"ELSE {exec(item[1], **kwargs)}"
            sql.append(q_resolved)
        if self.alias:
            sql.append(f"END AS {self.alias}")
        else:
            sql.append("END")
        return '\n'.join(sql)


In [22]:
a = Table('tbl')
test_eq(Case().when(a.column1>3, True).else_(False).exec(), 
        'CASE\nWHEN tbl.column1 > 3 THEN True\nELSE False\nEND')
test_eq(Case().when(a.column1>3, 1).when(a.column1<1, -1).else_(0).exec(), 
        'CASE\nWHEN tbl.column1 > 3 THEN 1\nWHEN tbl.column1 < 1 THEN -1\nELSE 0\nEND')

## Query

In [23]:
class Query:
    keys_simple = ['from', 'select', 'group by', 'order by', 'where', 'having', 'join', 'on', 'limit']
    keys_parse = ['with', 'from', 'join', 'on', 'where', 'group by', 'having', 'order by', 'select', 'limit']
    keys_sql = ['with', 'select', 'from', 'join', 'on', 'where', 'group by', 'having', 'order by', 'limit']
    key_translate = {
        'groupby': 'group by',
        'orderby': 'order by'
    }

    def __init__(self) -> None:
        super().__init__()
        self.dic = {}

    @staticmethod
    def _resolve(q, **kwargs):
        if getattr(q, 'exec', None):
            return q.exec(**kwargs)
        else:
            return str(q)

    def __getattr__(self, __name):
        __name = __name.strip('_')

        if __name.startswith('exec'):
            # for exec_{key} methods
            key = __name.split('_')[-1]

            def inner(**kwargs):
                if key in self.keys_simple:
                    q = self.dic[key]
                    if q.__class__ is Query:
                        return f"{self.key_translate.get(key, key)} ({self._resolve(q, **kwargs)})"
                    else:
                        return f"{self.key_translate.get(key, key)} {self._resolve(q, **kwargs)}"
                elif key == 'with':
                    s = self.dic[key]
                    qq = [f"{a} as ({q.exec(**kwargs)})" for q, a in s]
                    return f"{key} " + ",\n".join(qq)
                else:
                    raise AttributeError
        elif __name in self.keys_simple:
            def inner(query):
                self.dic[__name] = query
                return self
        elif __name in ['with']:
            def inner(query, alias):
                l = self.dic.get('with', [])
                l.append((query, alias))
                return self
        else:
            raise AttributeError


        return inner


    def exec(self, **kwargs):
        dic_sql = {}
        # keys = ['with', 'from']
        # keys_sql = ['with', 'from']

        for key in self.keys_parse:
            if key in self.dic:
                dic_sql[key] = getattr(self, f'exec_{key}')(**kwargs)

        sql = '\n'.join([dic_sql[key] for key in self.keys_sql if key in dic_sql])
        return sql
    
    def union(self, query):
        if not isinstance(query, Query):
            raise TypeError(f"{query} needs to be of Query type, but it is {type(query)} instead!")
        return UnionedQuery(self, query, union_type='UNION')

    def __add__(self, query):
        return self.union(query)

    def union_all(self, query):
        if not isinstance(query, Query):
            raise TypeError(f"{query} needs to be of Query type, but it is {type(query)} instead!")
        return UnionedQuery(self, query, union_type='UNION ALL')

    def __mul__(self, query):
        return self.union_all(query)


class UnionedQuery:
    def __init__(self, q1, q2, union_type='UNION') -> None:
        store_attr()

    def exec(self, **kwargs):
        return f"{exec(self.q1, **kwargs)} {self.union_type} {exec(self.q2, **kwargs)}"

In [24]:
q = Query().from_(Table('tbl').as_('a')).select('col1').where((Field('col2')-100>2) & (Field('col3')/9<=1)).limit(100)
test_eq(q.exec(), 'select col1\nfrom tbl as a\nwhere col2 - 100 > 2 and col3 / 9 <= 1\nlimit 100')

q = Query().from_(Table('tbl').as_('b')).select('*').union(Query().from_(Table('tbl2')).select('*').where(Field('col1')*23/12>=4))
test_eq(q.exec(), 'select *\nfrom tbl as b UNION select *\nfrom tbl2\nwhere col1 * 23 / 12 >= 4')

#TODO: 
- table
- query
    - UNION
    - WHERE
    - LIMIT
- case
- analytics

In [25]:
class QueryMeta(type):
    keys_simple = ['from', 'select', 'group by', 'order by', 'where', 'having', 'join', 'on', 'limit']
    # keys in the order to be parsed
    keys_parse = ['with', 'from', 'join', 'on', 'where', 'group by', 'having', 'order by', 'select', 'limit']
    # keys in the order of the final sql statement
    keys_sql = ['with', 'select', 'from', 'join', 'on', 'where', 'group by', 'having', 'order by', 'limit']
    key_translate = {
        'groupby': 'group by',
        'orderby': 'order by'
    }
    dic = {}

    @staticmethod
    def _resolve(q, **kwargs):
        if getattr(q, 'exec', None):
            return q.exec(**kwargs)
        else:
            return str(q)

    def __getattr__(self, __name):
        __name = __name.strip('_')

        if __name.startswith('exec'):
            # for exec_{key} methods
            key = __name.split('_')[-1]

            def inner(**kwargs):
                if key == 'with':
                    s = self.dic[key]
                    qq = [f"{a} as ({q.exec(**kwargs)})" for q, a in s]
                    return f"{key} " + ",\n".join(qq)
                elif key in self.keys_parse:
                    q = self.dic[key]
                    if q.__class__ is Query:
                        return f"{self.key_translate.get(key, key)} ({self._resolve(q, **kwargs)})"
                    else:
                        return f"{self.key_translate.get(key, key)} {self._resolve(q, **kwargs)}"
                else:
                    raise AttributeError
        elif __name in ['with']:
            def inner(query, alias):
                "Append query and alias to dic['with']"
                l = self.dic.get('with', [])
                l.append((query, alias))
                self.dic['with'] = l
                return self
        elif __name in self.keys_parse:
            def inner(query):
                self.dic[__name] = query
                return self
        else:
            raise AttributeError

        return inner

    def exec(self, **kwargs):
        dic_sql = {}
        # keys = ['with', 'from']
        # keys_sql = ['with', 'from']

        for key in self.keys_parse:
            if key in self.dic:
                dic_sql[key] = getattr(self, f'exec_{key}')(**kwargs)

        sql = '\n'.join([dic_sql[key] for key in self.keys_sql if key in dic_sql])
        return sql
    
    def union(self, query):
        if not isinstance(query, QueryMeta):
            raise TypeError(f"{query} needs to be of Query type, but it is {type(query)} instead!")
        return UnionedQuery(self, query, union_type='UNION')

    def __add__(self, query):
        return self.union(query)

    def union_all(self, query):
        if not isinstance(query, QueryMeta):
            raise TypeError(f"{query} needs to be of Query type, but it is {type(query)} instead!")
        return UnionedQuery(self, query, union_type='UNION ALL')

    def __mul__(self, query):
        return self.union_all(query)


class Query(metaclass=QueryMeta):
    pass

In [34]:
# .union(Query.from_(Table('tbl2')).select('*').where(Field('col1')*23/12>=4))
q = Query.from_(Table('tblslsfei').as_('b'))
q.exec()

'select *\nfrom tblslsfei as b\nwhere col1 * 23 / 12 >= 4\nlimit 100'

In [33]:
Query.dic

{'from': <__main__.Table at 0x7fe3008adac0>,
 'select': '*',
 'where': <__main__.Criteria at 0x7fe2f87b7970>,
 'limit': 100}

In [26]:
q = Query.from_(Table('tbl').as_('a')).select('col1').where((Field('col2')-100>2) & (Field('col3')/9<=1)).limit(100)
test_eq(q.exec(), 'select col1\nfrom tbl as a\nwhere col2 - 100 > 2 and col3 / 9 <= 1\nlimit 100')

q = Query.from_(Table('tbl').as_('b')).select('*').union(Query.from_(Table('tbl2')).select('*').where(Field('col1')*23/12>=4))
test_eq(q.exec(), 'select *\nfrom tbl as b UNION select *\nfrom tbl2\nwhere col1 * 23 / 12 >= 4\nlimit 100')

AssertionError: ==:
select *
from tbl2
where col1 * 23 / 12 >= 4
limit 100 UNION select *
from tbl2
where col1 * 23 / 12 >= 4
limit 100
select *
from tbl as b UNION select *
from tbl2
where col1 * 23 / 12 >= 4