# Functions

> Some commonly used SQL Functions.

In [None]:
#| default_exp functions
#| hide
from nbdev.showdoc import *
from fastcore.test import *
# allow multiple output from one cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [None]:
#|export
from fastcore.all import patch, delegates
from pikaQ.utils import execute
from pikaQ.terms import FieldBase, Field, custom_func, OverClause

# Custom Function

In [None]:
#|export
class CustomFunction(FieldBase):
    """A convenient class for creating custom functions."""

    def __init__(self, 
                 func_name: str,    # name of the function 
                 arg_names: list,   # list of arg names
                 window_func=False,  # whether this can be used as a window function
                 distinct_option=False # whether this function can be used with the distinct option
        ) -> None:
        import inspect
        super().__init__()
        self.func_name = func_name
        self.arg_names = arg_names
        self.window_func = window_func
        self.distinct_option = distinct_option
        self.distinct_ = False  # whether the distinct option is used
        self.get_sql = self.execute
        self.__qualname__ = func_name
        self.__signature__ = inspect.Signature(parameters=[inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in arg_names]) 
        self.__doc__ = f"Custom function {func_name} with args {arg_names}"

    def __call__(self, *args):
        if len(args) != len(self.arg_names):
            raise ValueError(f"The number of args provided {len(args)} is not the same as the number of args expected by this function ({len(self.arg_names)})!")
        def func(*args):
            if self.distinct_:
                return f"{self.func_name}(DISTINCT {', '.join(args)})"
            else:
                return f"{self.func_name}({', '.join(args)})"

        self.func = func
        self.args = args
        return self

    def execute(self, **kwargs):
        args = [str(execute(arg, **kwargs)) for arg in self.args]
        return self.func(*args)

    def distinct(self):
        if self.distinct_option==True:
            self.distinct_ = True
            return self
        else:
            raise ValueError(f"This function does not support the distinct option!!")


@patch
def over(self:CustomFunction, partition_by):
    if self.window_func==True:
        return OverClause(self).over(partition_by)
    else:
        raise ValueError(f"This function is not a window function!!")

In [None]:
show_doc(CustomFunction)

---

[source](https://github.com/feynlee/pikaQ/blob/master/pikaQ/functions.py#L15){target="_blank" style="float:right; font-size:smaller"}

### CustomFunction

>      CustomFunction (func_name:str, arg_names:list, window_func=False,
>                      distinct_option=False)

A convenient class for creating custom functions.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| func_name | str |  | name of the function |
| arg_names | list |  | list of arg names |
| window_func | bool | False | whether this can be used as a window function |
| distinct_option | bool | False | whether this function can be used with the distinct option |
| **Returns** | **None** |  |  |

In [None]:
#|hide
#|eval: false
Max = CustomFunction('MAX', ['field'], window_func=True, distinct_option=True)

show_doc(Max)

---

### MAX

>      MAX (field)

Custom function MAX with args ['field']

`CustomFunction` is a convenient class to create a custom SQL function with the name, and positional arguments, if you don't need the function to be parsed differently for different dialects.

In [None]:
date_diff = CustomFunction('DATE_DIFF', ['interval', 'start_date', 'end_date'])
test_eq(date_diff('day', 'start_date', 'end_date').get_sql(), 'DATE_DIFF(day, start_date, end_date)')

In [None]:
date_diff = CustomFunction('DATE_DIFF', ['interval', 'start_date', 'end_date'])
test_eq(date_diff('month', Field('date1'), Field('date2')).get_sql(), 'DATE_DIFF(month, date1, date2)')

## Commonly Used Functions

In [None]:
#|export
Max = CustomFunction('MAX', ['field'], window_func=True, distinct_option=True)
Min = CustomFunction('MIN', ['field'], window_func=True, distinct_option=True)
Sum = CustomFunction('SUM', ['field'], window_func=True, distinct_option=True)
Avg = CustomFunction('AVG', ['field'], window_func=True, distinct_option=True)
Count = CustomFunction('COUNT', ['field'], window_func=True, distinct_option=True)
Abs = CustomFunction('ABS', ['field'])
Round = CustomFunction('ROUND', ['field', 'decimals'])
First = CustomFunction('FIRST', ['field'])
Last = CustomFunction('LAST', ['field'])

@custom_func
def Cast(field, type):
    return f"CAST({field} AS {type})"

@custom_func
def Coalesce(*args):
    return f"COALESCE({', '.join(args)})"

@custom_func
def Concat(*args):
    return f"CONCAT({', '.join(args)})"


# date functions
def convert_date_format(format, dialect='sql'):
    if dialect == 'spark':
        return (format.replace('YYYY', 'yyyy')
                .replace('YY', 'yy')
                .replace('DD', 'dd')
                )
    elif dialect == 'snowflake':
        return (format.replace('HH', 'HH24')
                .replace('mm', 'MI')
        )
    else:
        return format

@custom_func
def Date(expression, format=None, dialect='sql'):
    if format:
        format = convert_date_format(format, dialect)
        return f"DATE('{expression}', '{format}')"
    else:
        return f"DATE('{expression}')"

@custom_func
def AddMonths(date, months, dialect='sql'):
    if dialect in ['snowflake', 'spark']:
        return f"ADD_MONTHS({date}, {months})"
    elif dialect == 'athena':
        return f"date_add('month', {months}, {date})"
    else:
        return f"DATE_ADD({date}, INTERVAL {months} MONTH)"

@custom_func
def DateDiff(interval, start_date, end_date, dialect='sql'):
    if dialect == 'athena':
        return f"date_diff('{interval}', {start_date}, {end_date})"
    else:
        return f"DATEDIFF({interval}, {start_date}, {end_date})"

@custom_func
def DateTrunc(interval, date, dialect='sql'):
    if dialect == 'spark':
        return f"TRUNC({date}, '{interval}')"
    else:
        return f"DATE_TRUNC('{interval}', {date})"

MonthsBetween = CustomFunction('MONTHS_BETWEEN', ['start_date', 'end_date'])

# Window Functions
RowNumber = CustomFunction('ROW_NUMBER', [], window_func=True)
Rank = CustomFunction('RANK', [], window_func=True)
DenseRank = CustomFunction('DENSE_RANK', [], window_func=True)
PercentRank = CustomFunction('PERCENT_RANK', [], window_func=True)
CumeDist = CustomFunction('CUME_DIST', [], window_func=True)
Ntile = CustomFunction('NTILE', ['num_buckets'], window_func=True)
Lag = CustomFunction('LAG', ['field', 'offset'], window_func=True)
Lead = CustomFunction('LEAD', ['field', 'offset'], window_func=True)
FirstValue = CustomFunction('FIRST_VALUE', ['field'], window_func=True)
LastValue = CustomFunction('LAST_VALUE', ['field'], window_func=True)
NthValue = CustomFunction('NTH_VALUE', ['field', 'n'], window_func=True)

In [None]:
show_doc(RowNumber)

---

### ROW_NUMBER

>      ROW_NUMBER ()

Custom function ROW_NUMBER with args []

In [None]:
test_eq(RowNumber().over(Field('col1')).orderby(Field('col2')).get_sql(), 'ROW_NUMBER() OVER (PARTITION BY col1 ORDER BY col2)')
test_eq(Rank().over(Field('col1')).orderby(Field('col2')).get_sql(), 
        'RANK() OVER (PARTITION BY col1 ORDER BY col2)')
test_eq(DenseRank().over(Field('col1')).orderby(Field('col2')).get_sql(),
        'DENSE_RANK() OVER (PARTITION BY col1 ORDER BY col2)')
test_eq(PercentRank().over(Field('col1')).orderby(Field('col2')).get_sql(),
        'PERCENT_RANK() OVER (PARTITION BY col1 ORDER BY col2)')
test_eq(CumeDist().over(Field('col1')).orderby(Field('col2')).get_sql(),
        'CUME_DIST() OVER (PARTITION BY col1 ORDER BY col2)')
test_eq(Ntile(5).over(Field('col1')).orderby(Field('col2')).get_sql(),
        'NTILE(5) OVER (PARTITION BY col1 ORDER BY col2)')
test_eq(Lag(Field('col1'), 1).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'LAG(col1, 1) OVER (PARTITION BY col2 ORDER BY col3)')
test_eq(Lead(Field('col1'), 1).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'LEAD(col1, 1) OVER (PARTITION BY col2 ORDER BY col3)')
test_eq(FirstValue(Field('col1')).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'FIRST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col3)')
test_eq(LastValue(Field('col1')).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'LAST_VALUE(col1) OVER (PARTITION BY col2 ORDER BY col3)')
test_eq(NthValue(Field('col1'), 2).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'NTH_VALUE(col1, 2) OVER (PARTITION BY col2 ORDER BY col3)')
test_eq(Count(Field('col1')).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'COUNT(col1) OVER (PARTITION BY col2 ORDER BY col3)')
test_eq(Sum(Field('col1')).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'SUM(col1) OVER (PARTITION BY col2 ORDER BY col3)')
test_eq(Avg(Field('col1')).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'AVG(col1) OVER (PARTITION BY col2 ORDER BY col3)')
test_eq(Min(Field('col1')).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'MIN(col1) OVER (PARTITION BY col2 ORDER BY col3)')
test_eq(Max(Field('col1')).over(Field('col2')).orderby(Field('col3')).get_sql(),
        'MAX(col1) OVER (PARTITION BY col2 ORDER BY col3)')

In [None]:
test_eq(Count(Field('col')).distinct().get_sql(), 'COUNT(DISTINCT col)')
test_eq(Sum(Field('col')).distinct().get_sql(), 'SUM(DISTINCT col)')

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()