# Queries

> Compose the final SQL query.


In [None]:
#| default_exp queries
#| hide
from nbdev.showdoc import *
from fastcore.test import *
# allow multiple output from one cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [None]:
#|export
from pikaQ.utils import execute
from pikaQ.terms import Field

## Table

In [None]:
#| export
class Table:
    def __init__(self, name) -> None:
        self.name = self.alias = name
        self.get_sql = self.execute

    def as_(self, alias):
        self.alias = alias
        return self

    def __getattr__(self, __name: str):
        if __name.startswith('__'):
            raise AttributeError
        else:
            # if an alias is assigned, always use it as the table name when accessing a field
            return Field(f"{self.alias}.{__name}")

    def quoted_name(self, quote_char):
        name_list = self.name.split('.')
        name = '.'.join([f"{quote_char}{n}{quote_char}" for n in name_list])
        return name
    
    def execute(self, **kwargs):
        q = kwargs.get('quote_char', '') or ''
        quoted_name = self.quoted_name(q)
        if self.alias != self.name:
            return f"{quoted_name} as {q}{self.alias}{q}"
        else:
            return quoted_name

In [None]:
vw = Table('vw')
test_eq(vw.as_('a').get_sql(), 'vw as a')
test_eq((vw.column + 2 > 1).get_sql(), 'a.column + 2 > 1')

tbl = Table('PROD_BI.PRES.tbl')
test_eq(tbl.get_sql(), 'PROD_BI.PRES.tbl')
test_eq(tbl.as_('t').get_sql(), 'PROD_BI.PRES.tbl as t')
test_eq(tbl.as_('t').get_sql(quote_char='"'), '"PROD_BI"."PRES"."tbl" as "t"')
test_eq((tbl.column + 2 > 1).get_sql(), 't.column + 2 > 1')
test_eq((tbl.column + 2 > 1).get_sql(quote_char='"'), '"t"."column" + 2 > 1')

## Query

In [None]:
#| export
class Query:
    @classmethod
    def from_(cls, query):
        q = SelectQuery()
        q.dic['from'] = query
        return q

    @classmethod
    def with_(cls, query, alias):
        q = SelectQuery()
        q.dic['with'] = [(query, alias)]
        return q


class QueryBase:
    pass


class Joiner(QueryBase):
    """Join clause has to be followed by an on clause"""
    def __init__(self, select_query, query, how=None) -> None:
        self.select_query = select_query
        self.query = query
        self.how = how

    def on(self, condition):
        l = self.select_query.dic.get('join', [])
        l.append((self.query, self.how, condition))
        self.select_query.dic['join'] = l
        return self.select_query


class Selector(QueryBase):
    """Select clause could be followed by a distinct clause"""
    def __init__(self, select_query, *args) -> None:
        self.select_query = select_query
        self.select_query.dic['select'] = list(args)

    def __getattr__(self, __name: str):
        # selector inherits all methods from the SelectQuery object
        # so that the distinct method is optional
        return getattr(self.select_query, __name)

    def distinct(self):
        self.select_query.dic['distinct'] = True
        return self.select_query


class SelectQuery(QueryBase):
    keys_simple = ['from', 'groupby', 'orderby', 'where', 'having', 'limit']
    # the order to parse different parts of the sql query
    keys_parse = ['with', 'from', 'join', 'where', 'groupby', 'having', 'orderby', 'select', 'limit']
    # the order to put together the final sql query
    keys_sql = ['with', 'select', 'from', 'join', 'where', 'groupby', 'having', 'orderby', 'limit']
    key_translation = {
        'groupby': 'group by',
        'orderby': 'order by'
    }

    def __init__(self) -> None:
        self.dic = {}
        self.get_sql = self.execute

    def join(self, query, how=None):
        return Joiner(self, query, how)

    def select(self, *args):
        return Selector(self, *args)

    def parse(self, key, **kwargs):
        if key in self.keys_simple:
            q = self.dic[key]
            if isinstance(q, QueryBase):
                return f"{self.key_translation.get(key, key)} ({execute(q, **kwargs)})"
            else:
                return f"{self.key_translation.get(key, key)} {execute(q, **kwargs)}"
        elif key == 'with':
            s = self.dic[key]
            qq = [f"{a} as ({execute(q, **kwargs)})" for q, a in s]
            return f"{key} " + ",\n".join(qq)
        elif key == 'join':
            l = self.dic[key]
            parsed = []
            for q, how, cond in l:
                if how:
                    how = how + ' '
                else:
                    how = ''
                if isinstance(q, QueryBase):
                    sub_q = f"({execute(q, **kwargs)})"
                else:
                    sub_q = execute(q, **kwargs)
                parsed.append(f"{how}join {sub_q} on {execute(cond, **kwargs)}")
            return '\n'.join(parsed)
        elif key == 'select':
            args = self.dic[key]
            columns = [execute(arg, **kwargs) for arg in args]
            if self.dic.get('distinct'):
                return f"{key} distinct " + ', '.join(columns)
            else:
                return f"{key} " + ', '.join(columns)
        else:
            raise AttributeError(f"{key} is not a valid attribute!")

    def __getattr__(self, __name):
        # Construct the corresponding method functions
        __name = __name.split('_')[0].lower()

        if __name in self.keys_simple:
            def inner(query):
                self.dic[__name] = query
                return self
        elif __name == 'with':
            def inner(query, alias):
                l = self.dic.get('with', [])
                l.append((query, alias))
                self.dic['with'] = l
                return self
        else:
            raise AttributeError(f"{__name} is not a valid attribute!")

        return inner

    def execute(self, **kwargs):
        dic_sql = {}

        for key in self.keys_parse:
            if key in self.dic:
                dic_sql[key] = self.parse(key, **kwargs)

        sql = '\n'.join([dic_sql[key] for key in self.keys_sql if key in dic_sql])
        return sql
    
    def union(self, query):
        if not isinstance(query, QueryBase):
            raise TypeError(f"{query} is not an instance of QueryBase!")
        return UnionQuery(self, query, union_type='UNION')

    def __add__(self, query):
        return self.union(query)

    def union_all(self, query):
        if not isinstance(query, QueryBase):
            raise TypeError(f"{query} is not an instance of QueryBase!")
        return UnionQuery(self, query, union_type='UNION ALL')

    def __mul__(self, query):
        return self.union_all(query)


class UnionQuery:
    def __init__(self, q1, q2, union_type='UNION') -> None:
        self.q1 = q1
        self.q2 = q2
        self.union_type = union_type

    def execute(self, **kwargs):
        return f"{execute(self.q1, **kwargs)} {self.union_type} {execute(self.q2, **kwargs)}"
    

In [None]:
q0 = (Query
      .from_(Table('tbl').as_('a'))
      .select('col1')
      .where((Field('col2')-100>2) & (Field('col3')/9<=1))
      .limit(100)
)
test_eq(q0.get_sql(), 
"""select col1
from tbl as a
where col2 - 100 > 2 and col3 / 9 <= 1
limit 100""")

qj = (Query
     .with_(Query.from_(Table('tbl').as_('a')).select('col1'), 's')     
     .with_(q0, 'm')
     .from_('s')
     .join('m').on('s.col1=m.col1')
     .where(Field('col1')>=10)
     .select('s.col1', 'm.col2', 'm.col3')
)
test_eq(qj.get_sql(), 
"""with s as (select col1
from tbl as a),
m as (select col1
from tbl as a
where col2 - 100 > 2 and col3 / 9 <= 1
limit 100)
select s.col1, m.col2, m.col3
from s
join m on s.col1=m.col1
where col1 >= 10""")

qd = (Query
     .from_('s')
     .where(Field('col1')>=10)
     .select('col1', 'col2', 'col3').distinct()
)
test_eq(qd.get_sql(),
"""select distinct col1, col2, col3
from s
where col1 >= 10""")

In [None]:
#|export
class Exists:
    """Exists statement"""
    def __init__(self, query) -> None:
        super().__init__()
        self.query = query
        self.get_sql = self.execute

    def execute(self, **kwargs):
        return f"EXISTS ({self.query.get_sql(**kwargs)})"



In [None]:
#|hide
import nbdev; nbdev.nbdev_export()