diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 45c23da0ae..0284440da3 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -97,6 +97,9 @@ def __init__( outputs=outputs, **kwargs, ) + # Sanitize query by removing the newlines at the end of the query. Keep in mind + # that the query can be a multiline string. + self._query_template = query_template.replace("\n", " ") @property def output_columns(self) -> typing.Optional[typing.List[str]]: diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index cc0f2deee1..ef7ea491e6 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -1,4 +1,5 @@ import pandas +import pytest from flytekit import kwtypes, task, workflow from flytekit.configuration import DefaultImages @@ -108,3 +109,36 @@ def test_task_serialization(): sql_task._container_image = image tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS) assert tt.container.image == image + + +@pytest.mark.parametrize( + "query_template, expected_query", + [ + ( + """ +select * +from tracks +limit {{.inputs.limit}}""", + " select * from tracks limit {{.inputs.limit}}", + ), + ( + """ \ +select * \ +from tracks \ +limit {{.inputs.limit}}""", + " select * from tracks limit {{.inputs.limit}}", + ), + ("select * from abc", "select * from abc"), + ], +) +def test_query_sanitization(query_template, expected_query): + sql_task = SQLite3Task( + "test", + query_template=query_template, + inputs=kwtypes(limit=int), + task_config=SQLite3Config( + uri=EXAMPLE_DB, + compressed=True, + ), + ) + assert sql_task.query_template == expected_query