diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index d2e4838ed8..7fcdc15a50 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -41,7 +41,7 @@ def __init__( task_config=task_config, **kwargs, ) - self._query_template = query_template.replace("\n", "\\n").replace("\t", "\\t") + self._query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip() @property def query_template(self) -> str: diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 0284440da3..8e7d8b3b29 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -92,14 +92,13 @@ def __init__( container_image=container_image or DefaultImages.default_image(), executor_type=SQLite3TaskExecutor, task_type=self._SQLITE_TASK_TYPE, + # Sanitize query by removing the newlines at the end of the query. Keep in mind + # that the query can be a multiline string. query_template=query_template, inputs=inputs, outputs=outputs, **kwargs, ) - # Sanitize query by removing the newlines at the end of the query. Keep in mind - # that the query can be a multiline string. - self._query_template = query_template.replace("\n", " ") @property def output_columns(self) -> typing.Optional[typing.List[str]]: diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index a012e38d99..672f4a19ad 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -70,7 +70,7 @@ def test_local_exec(): ) assert len(snowflake_task.interface.inputs) == 1 - assert snowflake_task.query_template == "select 1\\n" + assert snowflake_task.query_template == "select 1" assert len(snowflake_task.interface.outputs) == 1 # will not run locally @@ -86,4 +86,4 @@ def test_sql_template(): custom where column = 1""", output_schema_type=FlyteSchema, ) - assert snowflake_task.query_template == "select 1 from\\t\\n custom where column = 1" + assert snowflake_task.query_template == "select 1 from custom where column = 1" diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index 6d20027b2a..7537a3a1de 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -70,7 +70,23 @@ def test_task_schema(sql_server): assert df is not None -def test_workflow(sql_server): +@pytest.mark.parametrize( + "query_template", + [ + "select * from tracks limit {{.inputs.limit}}", + """ + select * from tracks + limit {{.inputs.limit}} + """, + """select * from tracks + limit {{.inputs.limit}} + """, + """ + select * from tracks + limit {{.inputs.limit}}""", + ], +) +def test_workflow(sql_server, query_template): @task def my_task(df: pandas.DataFrame) -> int: return len(df[df.columns[0]]) @@ -84,7 +100,7 @@ def my_task(df: pandas.DataFrame) -> int: sql_task = SQLAlchemyTask( "test", - query_template="select * from tracks limit {{.inputs.limit}}", + query_template=query_template, inputs=kwtypes(limit=int), task_config=SQLAlchemyConfig(uri=sql_server), ) diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index ef7ea491e6..40fc94a3d2 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -119,14 +119,14 @@ def test_task_serialization(): select * from tracks limit {{.inputs.limit}}""", - " select * from tracks limit {{.inputs.limit}}", + "select * from tracks limit {{.inputs.limit}}", ), ( """ \ select * \ from tracks \ limit {{.inputs.limit}}""", - " select * from tracks limit {{.inputs.limit}}", + "select * from tracks limit {{.inputs.limit}}", ), ("select * from abc", "select * from abc"), ],