Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up and test SQLTransforms #384

Merged
merged 3 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lumen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class Pipeline(Component):
def __init__(self, *, source, table, **params):
if 'schema' not in params:
params['schema'] = source.get_schema(table)
if any(isinstance(t, SQLTransform) for t in params.get('transforms', [])):
raise TypeError('Pipeline.transforms must be regular Transform components, not SQLTransform.')
super().__init__(source=source, table=table, **params)
self._update_widget = pn.Param(self.param['update'], widgets={'update': {'button_type': 'success'}})[0]
self._init_callbacks()
Expand Down
85 changes: 85 additions & 0 deletions lumen/tests/transforms/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import datetime as dt

from lumen.transforms.sql import (
SQLColumns, SQLDistinct, SQLFilter, SQLGroupBy, SQLLimit, SQLMinMax,
)


def test_sql_group_by_single_column():
assert (
SQLGroupBy.apply_to('SELECT * FROM TABLE', by=['A'], aggregates={'AVG': 'B'}) ==
"""SELECT\n A, AVG(B) AS B\nFROM ( SELECT * FROM TABLE )\nGROUP BY A"""
)

def test_sql_group_by_multi_columns():
assert (
SQLGroupBy.apply_to('SELECT * FROM TABLE', by=['A'], aggregates={'AVG': ['B', 'C']}) ==
"""SELECT\n A, AVG(B) AS B, AVG(C) AS C\nFROM ( SELECT * FROM TABLE )\nGROUP BY A"""
)

def test_sql_limit():
assert (
SQLLimit.apply_to('SELECT * FROM TABLE', limit=10) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nLIMIT 10"""
)

def test_sql_columns():
assert (
SQLColumns.apply_to('SELECT * FROM TABLE', columns=['A', 'B']) ==
"""SELECT\n A, B\nFROM ( SELECT * FROM TABLE )"""
)

def test_sql_distinct():
assert (
SQLDistinct.apply_to('SELECT * FROM TABLE', columns=['A', 'B']) ==
"""SELECT DISTINCT\n A, B\nFROM ( SELECT * FROM TABLE )"""
)

def test_sql_min_max():
assert (
SQLMinMax.apply_to('SELECT * FROM TABLE', columns=['A', 'B']) ==
"""SELECT\n MIN(A) as A_min, MAX(A) as A_max, MIN(B) as B_min, MAX(B) as B_max\nFROM ( SELECT * FROM TABLE )"""
)

def test_sql_filter_none():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', None)]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A IS NULL )"""
)

def test_sql_filter_scalar():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', 1)]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A = 1 )"""
)


def test_sql_filter_isin():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', ['A', 'B', 'C'])]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A IN ('A', 'B', 'C') )"""
)

def test_sql_filter_datetime():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', dt.datetime(2017, 4, 14))]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A = '2017-04-14 00:00:00' )"""
)

def test_sql_filter_date():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', dt.date(2017, 4, 14))]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A BETWEEN '2017-04-14 00:00:00' AND '2017-04-14 23:59:59' )"""
)

def test_sql_filter_date_range():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', (dt.date(2017, 2, 22), dt.date(2017, 4, 14)))]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A BETWEEN '2017-02-22 00:00:00' AND '2017-04-14 23:59:59' )"""
)

def test_sql_filter_datetime_range():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', (dt.datetime(2017, 2, 22), dt.datetime(2017, 4, 14)))]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A BETWEEN '2017-02-22 00:00:00' AND '2017-04-14 00:00:00' )"""
)
56 changes: 24 additions & 32 deletions lumen/transforms/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime as dt
import textwrap

import numpy as np
import param
Expand Down Expand Up @@ -47,6 +48,11 @@ def apply(self, sql_in):
"""
return sql_in

@classmethod
def _render_template(cls, template, **params):
template = textwrap.dedent(template).lstrip()
return Template(template, trim_blocks=True, lstrip_blocks=True).render(**params)


class SQLGroupBy(SQLTransform):
"""
Expand All @@ -64,18 +70,17 @@ class SQLGroupBy(SQLTransform):
def apply(self, sql_in):
template = """
SELECT
{{by_cols}},
{{aggs}}
{{by_cols}}, {{aggs}}
FROM ( {{sql_in}} )
GROUP BY {{by_cols}}
"""
GROUP BY {{by_cols}}"""
by_cols = ', '.join(self.by)
aggs = ', '.join([
f'{agg}({col}) AS {col}' for agg, col in self.aggregates.items()
])
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
by_cols=by_cols, aggs=aggs, sql_in=sql_in
)
aggs = []
for agg, cols in self.aggregates.items():
if isinstance(cols, str):
cols = [cols]
for col in cols:
aggs.append(f'{agg}({col}) AS {col}')
return self._render_template(template, by_cols=by_cols, aggs=', '.join(aggs), sql_in=sql_in)


class SQLLimit(SQLTransform):
Expand All @@ -94,9 +99,7 @@ def apply(self, sql_in):
FROM ( {{sql_in}} )
LIMIT {{limit}}
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
limit=self.limit, sql_in=sql_in
)
return self._render_template(template, sql_in=sql_in, limit=self.limit)


class SQLDistinct(SQLTransform):
Expand All @@ -109,11 +112,8 @@ def apply(self, sql_in):
template = """
SELECT DISTINCT
{{columns}}
FROM ( {{sql_in}} )
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
columns=', '.join(self.columns), sql_in=sql_in
)
FROM ( {{sql_in}} )"""
return self._render_template(template, sql_in=sql_in, columns=', '.join(self.columns))


class SQLMinMax(SQLTransform):
Expand All @@ -130,11 +130,8 @@ def apply(self, sql_in):
template = """
SELECT
{{columns}}
FROM ( {{sql_in}} )
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
columns=', '.join(aggs), sql_in=sql_in
)
FROM ( {{sql_in}} )"""
return self._render_template(template, sql_in=sql_in, columns=', '.join(aggs))


class SQLColumns(SQLTransform):
Expand All @@ -149,9 +146,7 @@ def apply(self, sql_in):
{{columns}}
FROM ( {{sql_in}} )
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
columns=', '.join(self.columns), sql_in=sql_in
)
return self._render_template(template, sql_in=sql_in, columns=', '.join(self.columns))


class SQLFilter(SQLTransform):
Expand All @@ -172,7 +167,7 @@ def _range_filter(cls, col, v1, v2):
if isinstance(v1, dt.date) and not isinstance(v1, dt.datetime):
start += ' 00:00:00'
if isinstance(v2, dt.date) and not isinstance(v2, dt.datetime):
end += ' 00:00:00'
end += ' 23:59:59'
return f'{col} BETWEEN {start!r} AND {end!r}'

def apply(self, sql_in):
Expand Down Expand Up @@ -220,8 +215,5 @@ def apply(self, sql_in):
SELECT
*
FROM ( {{sql_in}} )
WHERE ( {{conditions}} )
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
conditions=' AND '.join(conditions), sql_in=sql_in
)
WHERE ( {{conditions}} )"""
return self._render_template(template, sql_in=sql_in, conditions=' AND '.join(conditions))