# Queries

> Compose the final SQL query.


In [None]:
#| default_exp queries
#| hide
from nbdev.showdoc import *
from fastcore.test import test_eq
from fastcore.foundation import patch
# allow multiple output from one cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [None]:
#|export
from pikaQ.utils import execute
from pikaQ.terms import Field
import pikaQ.functions as fn

In [None]:
#| export
class Table:
    """A class with star and as_ methods to be used as a Table/AliasesQuery in a query. Any other attribute will be treated as a field."""
    def __init__(self, name) -> None:
        self.name = self.alias = name
        self.get_sql = self.execute

    def as_(self, alias):
        self.alias = alias
        return self
    
    def star(self):
        return Field(f"{self.alias}.*")

    def __getattr__(self, __name: str):
        if __name.startswith('__'):
            raise AttributeError
        else:
            # if an alias is assigned, always use it as the table name when accessing a field
            return Field(f"{self.alias}.{__name}")

    def field(self, name):
        return Field(f"{self.alias}.{name}")

    def quoted_name(self, quote_char):
        name_list = self.name.split('.')
        name = '.'.join([f"{quote_char}{n}{quote_char}" for n in name_list])
        return name
    
    def execute(self, **kwargs):
        q = kwargs.get('quote_char', '') or ''
        quoted_name = self.quoted_name(q)
        if self.alias != self.name:
            return f"{quoted_name} as {q}{self.alias}{q}"
        else:
            return quoted_name


AliasedQuery = Table

In [None]:
vw = Table('vw')
test_eq(vw.as_('a').get_sql(), 'vw as a')
test_eq((vw.column + 2 > 1).get_sql(), 'a.column + 2 > 1')
test_eq(vw.star().get_sql(), 'a.*')

tbl = Table('PROD_BI.PRES.tbl')
test_eq(tbl.get_sql(), 'PROD_BI.PRES.tbl')
test_eq(tbl.as_('t').get_sql(), 'PROD_BI.PRES.tbl as t')
test_eq(tbl.as_('t').get_sql(quote_char='"'), '"PROD_BI"."PRES"."tbl" as "t"')
test_eq((tbl.column + 2 > 1).get_sql(), 't.column + 2 > 1')
test_eq((tbl.column + 2 > 1).get_sql(quote_char='"'), '"t"."column" + 2 > 1')

In [None]:
#| export
class QueryBase:
    """An empty query class to be inherited by all query classes. A convenient tool to make all query objects belong to this same class."""
    pass


class Joiner(QueryBase):
    """Join clause has to be followed by an on clause"""
    def __init__(self, select_query, query, how=None) -> None:
        self.select_query = select_query
        self.query = query
        self.how = how

    def on(self, condition):
        l = self.select_query.dic.get('join', [])
        l.append((self.query, self.how, condition))
        self.select_query.dic['join'] = l
        return self.select_query


class Selector(QueryBase):
    """Select clause could be followed by a distinct clause"""
    def __init__(self, select_query, *args) -> None:
        self.select_query = select_query
        self.select_query.dic['select'] = {"columns": list(args)}
        # add parse_select method to the select_query object
        self.select_query.parse_select = self.parse_select

    def __getattr__(self, __name: str):
        # selector inherits all methods from the SelectQuery object
        # so that the distinct method is optional
        # executing some of these methods can get back to the select_query object
        return getattr(self.select_query, __name)

    def parse_select(self, **kwargs):
        key = 'select'
        dic = self.dic[key]
        columns = dic['columns']
        columns = [execute(column, **kwargs) for column in columns]
        if dic.get('distinct'):
            return f"{key} distinct " + ', '.join(columns)
        else:
            return f"{key} " + ', '.join(columns)

    def distinct(self):
        self.select_query.dic['select']['distinct'] = True
        return self.select_query


class SelectQuery(QueryBase):
    """The class to construct a select query. It returns a `Joiner` object when called with the method join. It returns a `Selector` object when called with the method select."""
    keys_simple = ['from', 'groupby', 'orderby', 'where', 'having', 'limit']
    # the order to put together the final sql query
    sql_keys = ['with', 'select', 'from', 'join', 'where', 'groupby', 'having', 'orderby', 'limit']
    key_translation = {
        'groupby': 'group by',
        'orderby': 'order by'
    }

    def __init__(self) -> None:
        # the dictionary to store the queries for each clause
        self.dic = {}
        self.get_sql = self.execute

    def join(self, query, how=None):
        return Joiner(self, query, how)

    def select(self, *args):
        return Selector(self, *args)

    def __getattr__(self, __name):
        # Construct the corresponding method functions
        # remove the leading and trailing underscores
        __name = __name.strip('_').lower()
        # get the string after the first underscore
        key = '_'.join(__name.split('_')[1:])

        if __name.startswith('parse'):
            # Construct parse_xxx functions
            def inner(**kwargs):
                if key in ['from', 'where', 'having', 'limit']:
                    q = self.dic[key]
                    if isinstance(q, QueryBase):
                        return f"{self.key_translation.get(key, key)} ({execute(q, **kwargs)})"
                    else:
                        return f"{self.key_translation.get(key, key)} {execute(q, **kwargs)}"
                elif key in ['groupby', 'orderby']:
                    args = self.dic[key]
                    args = [execute(arg, **kwargs) for arg in args]
                    return f"{self.key_translation.get(key, key)} " + ', '.join(args)
                elif key == 'with':
                    s = self.dic[key]
                    qq = [f"{a} as (\n{execute(q, **kwargs)})" for q, a in s]
                    return f"{key} " + "\n\n, ".join(qq) + "\n"
                elif key == 'join':
                    l = self.dic[key]
                    parsed = []
                    for q, how, cond in l:
                        if how:
                            how = how + ' '
                        else:
                            how = ''
                        if isinstance(q, QueryBase):
                            sub_q = f"({execute(q, **kwargs)})"
                        else:
                            sub_q = execute(q, **kwargs)
                        parsed.append(f"{how}join {sub_q} on {execute(cond, **kwargs)}")
                    return '\n'.join(parsed)
                else:
                    raise AttributeError(f"parsing function for {key} is not defined!")
        else:
            # Construct other method functions
            if __name in ['from', 'where', 'having', 'limit']:
                def inner(q):
                    self.dic[__name] = q
                    return self
            elif __name in ['groupby', 'orderby']:
                def inner(*args):
                    self.dic[__name] = args
                    return self
            elif __name == 'with':
                def inner(query, alias):
                    l = self.dic.get('with', [])
                    l.append((query, alias))
                    self.dic['with'] = l
                    return self
            else:
                raise AttributeError(f"{__name} is not a valid attribute!")

        return inner

    def execute(self, **kwargs):
        dic_sql = {}

        for k, v in self.dic.items():
            dic_sql[k] = getattr(self, f"parse_{k}")(**kwargs)

        if missing_keys := [key for key in dic_sql if key not in self.sql_keys]:
            raise ValueError(f"The order of {missing_keys} in the final SQL are not specified in self.sql_keys!")

        sql = '\n'.join([dic_sql[key] for key in self.sql_keys if key in dic_sql])
        return sql
    
    def union(self, query):
        if not isinstance(query, QueryBase):
            raise TypeError(f"{query} is not an instance of QueryBase!")
        return UnionQuery(self, query, union_type='UNION')

    def __add__(self, query):
        return self.union(query)

    def union_all(self, query):
        if not isinstance(query, QueryBase):
            raise TypeError(f"{query} is not an instance of QueryBase!")
        return UnionQuery(self, query, union_type='UNION ALL')

    def __mul__(self, query):
        return self.union_all(query)


class UnionQuery(QueryBase):
    """The class to construct a union query."""
    def __init__(self, q1, q2, union_type='UNION') -> None:
        self.q1 = q1
        self.q2 = q2
        self.union_type = union_type

    def execute(self, **kwargs):
        return f"{execute(self.q1, **kwargs)} {self.union_type} {execute(self.q2, **kwargs)}"


class Query(QueryBase):
    """The main class to construct a query. It returns a `SelectQuery` object when called with the classmethod from_ or with_. One can extend this class to add more classmethods to construct different types of queries."""
    q = SelectQuery

    @classmethod
    def from_(cls, query):
        q = cls.q()
        q.dic['from'] = query
        return q

    @classmethod
    def with_(cls, query, alias):
        q = cls.q()
        q.dic['with'] = [(query, alias)]
        return q

In [None]:
q0 = (Query
      .from_(Table('tbl').as_('a'))
      .select('col1')
      .where((Field('col2')-100>2) & (Field('col3')/9<=1))
      .limit(100)
)
test_eq(q0.get_sql(), 
"""select col1
from tbl as a
where col2 - 100 > 2 and col3 / 9 <= 1
limit 100""")

qj = (Query
     .with_(Query.from_(Table('tbl').as_('a')).select('col1'), 's')     
     .with_(q0, 'm')
     .from_('s')
     .join('m').on('s.col1=m.col1')
     .where(Field('col1')>=10)
     .select('s.col1', 'm.col2', 'm.col3')
)
test_eq(qj.get_sql(), 
"""with s as (
select col1
from tbl as a)

, m as (
select col1
from tbl as a
where col2 - 100 > 2 and col3 / 9 <= 1
limit 100)

select s.col1, m.col2, m.col3
from s
join m on s.col1=m.col1
where col1 >= 10""")

qd = (Query
     .from_('s')
     .where(Field('col1')>=10)
     .select('col1', 'col2', 'col3').distinct()
     .orderby('col1', 'col2')
)
test_eq(qd.get_sql(),
"""select distinct col1, col2, col3
from s
where col1 >= 10
order by col1, col2""")

In [None]:
#|export
class Exists:
    """Exists statement"""
    def __init__(self, query:Query) -> None:
        super().__init__()
        self.query = query
        self.get_sql = self.execute

    def execute(self, **kwargs):
        return f"EXISTS ({self.query.get_sql(**kwargs)})"

In [None]:
show_doc(Exists)

---

[source](https://github.com/feynlee/pikaQ/blob/master/pikaQ/queries.py#L236){target="_blank" style="float:right; font-size:smaller"}

### Exists

>      Exists (query:__main__.Query)

Exists statement

In [None]:
tbl = Table('tbl')
test_eq(Exists(Query.from_(tbl)
               .select(tbl.col1)
               .where(tbl.col2>10)).get_sql(quote_char='"'), 
        'EXISTS (select "tbl"."col1"\nfrom "tbl"\nwhere "tbl"."col2" > 10)')
test_eq(Query.from_(tbl)
        .select(tbl.col1)
        .where(
            Exists(Query.from_(tbl).
                   select(tbl.col1)
                   .where(tbl.col2>10))
        ).get_sql(quote_char='"'),
        'select "tbl"."col1"\nfrom "tbl"\nwhere EXISTS (select "tbl"."col1"\nfrom "tbl"\nwhere "tbl"."col2" > 10)')

## Extending Query

We can extend `SelectQuery` to support more complex queries. 
There are 3 things we need to implement:
1. A method to record all necessary information for the SQL clause. Let's call it `custom` for example.
2. A method to generate the SQL clause str. This method's name has to be `parse_` + the previous method's name (`parse_custom`).
3. The class variable `sql_keys` has to be overwritten to include the new method's name in the appropriate position. The order of the keys in this list determines the order of the SQL clauses in the final SQL str.

For example, for Snowflake SQL, to generate the [PIVOT](https://docs.snowflake.com/en/sql-reference/constructs/pivot) clause, we need to implement both a `pivot` method and a `parse_pivot` method, and also add `pivot` into `sql_keys` at the right place.


In [None]:
class SFSelectQuery(SelectQuery):
    # the order to put together the final sql query
    sql_keys = ['with', 'select', 'from', 'join', 'pivot', 'where', 'groupby', 'having', 'orderby', 'limit']

    def __init__(self) -> None:
        super().__init__()

    def pivot(self, agg_func, pivot_col, value_col, pivot_values, alias):
        self.dic['pivot'] = {
            'agg_func': agg_func,
            'pivot_col': pivot_col,
            'value_col': value_col,
            'pivot_values': pivot_values,
            'alias': alias
        }
        return self

    def parse_pivot(self, **kwargs):
        dialect = kwargs.get('dialect', None)
        if dialect == 'snowflake':
            dic = self.dic['pivot']
            agg_func = dic['agg_func']
            pivot_col = dic['pivot_col']
            value_col = dic['value_col']
            pivot_values = dic['pivot_values']
            alias = dic['alias']
            pivot_values = ', '.join([f"'{v}'" for v in pivot_values])
            return f"pivot({execute(agg_func(value_col), **kwargs)} for {execute(pivot_col, **kwargs)} in ({pivot_values})) as {alias}"
        else:
            raise NotImplementedError(f"dialect {dialect} not implemented")
    


To use this new `SFSelectQuery` in `Query`, we simply assign it to the class variable `q`.
And we can now use `.pivot` in our query.

In [None]:
Query.q = SFSelectQuery

vw = Table('vw')
print(
    Query
    .from_(vw)
    .pivot(fn.Sum, vw.amount, vw.month, ['JAN', 'FEB', 'MAR', 'APR'], 'p')
    .where(vw.column > 1).select(vw.star()).get_sql(dialect='snowflake')
)

select vw.*
from vw
pivot(SUM(vw.month) for vw.amount in ('JAN', 'FEB', 'MAR', 'APR')) as p
where vw.column > 1


In [None]:
#|hide
import nbdev; nbdev.nbdev_export()