From 3da2721de1d7f5bf562964919137c44fb3e83f27 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Mon, 14 Nov 2022 12:13:52 +0100 Subject: [PATCH 1/3] Clean up and test SQLTransforms --- lumen/tests/transforms/test_sql.py | 85 ++++++++++++++++++++++++++++++ lumen/transforms/sql.py | 56 +++++++++----------- 2 files changed, 109 insertions(+), 32 deletions(-) create mode 100644 lumen/tests/transforms/test_sql.py diff --git a/lumen/tests/transforms/test_sql.py b/lumen/tests/transforms/test_sql.py new file mode 100644 index 00000000..6d20f497 --- /dev/null +++ b/lumen/tests/transforms/test_sql.py @@ -0,0 +1,85 @@ +import datetime as dt + +from lumen.transforms.sql import ( + SQLColumns, SQLDistinct, SQLFilter, SQLGroupBy, + SQLLimit, SQLMinMax +) + +def test_sql_group_by_single_column(): + assert ( + SQLGroupBy.apply_to('SELECT * FROM TABLE', by=['A'], aggregates={'AVG': 'B'}) == + """SELECT\n A, AVG(B) AS B\nFROM ( SELECT * FROM TABLE )\nGROUP BY A""" + ) + +def test_sql_group_by_multi_columns(): + assert ( + SQLGroupBy.apply_to('SELECT * FROM TABLE', by=['A'], aggregates={'AVG': ['B', 'C']}) == + """SELECT\n A, AVG(B) AS B, AVG(C) AS C\nFROM ( SELECT * FROM TABLE )\nGROUP BY A""" + ) + +def test_sql_limit(): + assert ( + SQLLimit.apply_to('SELECT * FROM TABLE', limit=10) == + """SELECT\n *\nFROM ( SELECT * FROM TABLE )\nLIMIT 10""" + ) + +def test_sql_columns(): + assert ( + SQLColumns.apply_to('SELECT * FROM TABLE', columns=['A', 'B']) == + """SELECT\n A, B\nFROM ( SELECT * FROM TABLE )""" + ) + +def test_sql_distinct(): + assert ( + SQLDistinct.apply_to('SELECT * FROM TABLE', columns=['A', 'B']) == + """SELECT DISTINCT\n A, B\nFROM ( SELECT * FROM TABLE )""" + ) + +def test_sql_min_max(): + assert ( + SQLMinMax.apply_to('SELECT * FROM TABLE', columns=['A', 'B']) == + """SELECT\n MIN(A) as A_min, MAX(A) as A_max, MIN(B) as B_min, MAX(B) as B_max\nFROM ( SELECT * FROM TABLE )""" + ) + +def test_sql_filter_none(): + assert ( + SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', None)]) == + """SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A IS NULL )""" + ) + +def test_sql_filter_scalar(): + assert ( + SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', 1)]) == + """SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A = 1 )""" + ) + + +def test_sql_filter_isin(): + assert ( + SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', ['A', 'B', 'C'])]) == + """SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A IN ('A', 'B', 'C') )""" + ) + +def test_sql_filter_datetime(): + assert ( + SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', dt.datetime(2017, 4, 14))]) == + """SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A = '2017-04-14 00:00:00' )""" + ) + +def test_sql_filter_date(): + assert ( + SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', dt.date(2017, 4, 14))]) == + """SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A BETWEEN '2017-04-14 00:00:00' AND '2017-04-14 23:59:59' )""" + ) + +def test_sql_filter_date_range(): + assert ( + SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', (dt.date(2017, 2, 22), dt.date(2017, 4, 14)))]) == + """SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A BETWEEN '2017-02-22 00:00:00' AND '2017-04-14 23:59:59' )""" + ) + +def test_sql_filter_datetime_range(): + assert ( + SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', (dt.datetime(2017, 2, 22), dt.datetime(2017, 4, 14)))]) == + """SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A BETWEEN '2017-02-22 00:00:00' AND '2017-04-14 00:00:00' )""" + ) diff --git a/lumen/transforms/sql.py b/lumen/transforms/sql.py index a009c5f0..a2286d09 100644 --- a/lumen/transforms/sql.py +++ b/lumen/transforms/sql.py @@ -1,4 +1,5 @@ import datetime as dt +import textwrap import numpy as np import param @@ -47,6 +48,11 @@ def apply(self, sql_in): """ return sql_in + @classmethod + def _render_template(cls, template, **params): + template = textwrap.dedent(template).lstrip() + return Template(template, trim_blocks=True, lstrip_blocks=True).render(**params) + class SQLGroupBy(SQLTransform): """ @@ -64,18 +70,17 @@ class SQLGroupBy(SQLTransform): def apply(self, sql_in): template = """ SELECT - {{by_cols}}, - {{aggs}} + {{by_cols}}, {{aggs}} FROM ( {{sql_in}} ) - GROUP BY {{by_cols}} - """ + GROUP BY {{by_cols}}""" by_cols = ', '.join(self.by) - aggs = ', '.join([ - f'{agg}({col}) AS {col}' for agg, col in self.aggregates.items() - ]) - return Template(template, trim_blocks=True, lstrip_blocks=True).render( - by_cols=by_cols, aggs=aggs, sql_in=sql_in - ) + aggs = [] + for agg, cols in self.aggregates.items(): + if isinstance(cols, str): + cols = [cols] + for col in cols: + aggs.append(f'{agg}({col}) AS {col}') + return self._render_template(template, by_cols=by_cols, aggs=', '.join(aggs), sql_in=sql_in) class SQLLimit(SQLTransform): @@ -94,9 +99,7 @@ def apply(self, sql_in): FROM ( {{sql_in}} ) LIMIT {{limit}} """ - return Template(template, trim_blocks=True, lstrip_blocks=True).render( - limit=self.limit, sql_in=sql_in - ) + return self._render_template(template, sql_in=sql_in, limit=self.limit) class SQLDistinct(SQLTransform): @@ -109,11 +112,8 @@ def apply(self, sql_in): template = """ SELECT DISTINCT {{columns}} - FROM ( {{sql_in}} ) - """ - return Template(template, trim_blocks=True, lstrip_blocks=True).render( - columns=', '.join(self.columns), sql_in=sql_in - ) + FROM ( {{sql_in}} )""" + return self._render_template(template, sql_in=sql_in, columns=', '.join(self.columns)) class SQLMinMax(SQLTransform): @@ -130,11 +130,8 @@ def apply(self, sql_in): template = """ SELECT {{columns}} - FROM ( {{sql_in}} ) - """ - return Template(template, trim_blocks=True, lstrip_blocks=True).render( - columns=', '.join(aggs), sql_in=sql_in - ) + FROM ( {{sql_in}} )""" + return self._render_template(template, sql_in=sql_in, columns=', '.join(aggs)) class SQLColumns(SQLTransform): @@ -149,9 +146,7 @@ def apply(self, sql_in): {{columns}} FROM ( {{sql_in}} ) """ - return Template(template, trim_blocks=True, lstrip_blocks=True).render( - columns=', '.join(self.columns), sql_in=sql_in - ) + return self._render_template(template, sql_in=sql_in, columns=', '.join(self.columns)) class SQLFilter(SQLTransform): @@ -172,7 +167,7 @@ def _range_filter(cls, col, v1, v2): if isinstance(v1, dt.date) and not isinstance(v1, dt.datetime): start += ' 00:00:00' if isinstance(v2, dt.date) and not isinstance(v2, dt.datetime): - end += ' 00:00:00' + end += ' 23:59:59' return f'{col} BETWEEN {start!r} AND {end!r}' def apply(self, sql_in): @@ -220,8 +215,5 @@ def apply(self, sql_in): SELECT * FROM ( {{sql_in}} ) - WHERE ( {{conditions}} ) - """ - return Template(template, trim_blocks=True, lstrip_blocks=True).render( - conditions=' AND '.join(conditions), sql_in=sql_in - ) + WHERE ( {{conditions}} )""" + return self._render_template(template, sql_in=sql_in, conditions=' AND '.join(conditions)) From ee3e1e098dbbc450d53c618f926347da98346b87 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Mon, 14 Nov 2022 12:14:50 +0100 Subject: [PATCH 2/3] Error if Pipeline.transforms given SQLTransform --- lumen/pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lumen/pipeline.py b/lumen/pipeline.py index 82ebe40d..42168520 100644 --- a/lumen/pipeline.py +++ b/lumen/pipeline.py @@ -104,6 +104,8 @@ class Pipeline(Component): def __init__(self, *, source, table, **params): if 'schema' not in params: params['schema'] = source.get_schema(table) + if any(isinstance(t, SQLTransform) for t in params.get('transforms', [])): + raise TypeError('Pipeline.transforms must be regular Transform components, not SQLTransform.') super().__init__(source=source, table=table, **params) self._update_widget = pn.Param(self.param['update'], widgets={'update': {'button_type': 'success'}})[0] self._init_callbacks() From 110d03b54ecf36aaf5ae8cb1ba3a1db93109a4c9 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Mon, 14 Nov 2022 12:27:50 +0100 Subject: [PATCH 3/3] Apply isort fixes --- lumen/tests/transforms/test_sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lumen/tests/transforms/test_sql.py b/lumen/tests/transforms/test_sql.py index 6d20f497..819d4723 100644 --- a/lumen/tests/transforms/test_sql.py +++ b/lumen/tests/transforms/test_sql.py @@ -1,10 +1,10 @@ import datetime as dt from lumen.transforms.sql import ( - SQLColumns, SQLDistinct, SQLFilter, SQLGroupBy, - SQLLimit, SQLMinMax + SQLColumns, SQLDistinct, SQLFilter, SQLGroupBy, SQLLimit, SQLMinMax, ) + def test_sql_group_by_single_column(): assert ( SQLGroupBy.apply_to('SELECT * FROM TABLE', by=['A'], aggregates={'AVG': 'B'}) ==