From aa919c6a5a38e3af67457245c06fd87f3d646880 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Fri, 30 Aug 2024 13:20:04 +0200 Subject: [PATCH 1/2] Fix missing EOL when formatting SQL files --- src/databricks/labs/lsql/dashboards.py | 3 ++- tests/unit/test_dashboards.py | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/databricks/labs/lsql/dashboards.py b/src/databricks/labs/lsql/dashboards.py index fb9ce2c1..ad29f8e0 100644 --- a/src/databricks/labs/lsql/dashboards.py +++ b/src/databricks/labs/lsql/dashboards.py @@ -429,6 +429,7 @@ def format(content: str, *, max_text_width: int = 120, normalize_case: bool = Tr normalize_case : bool, optional (default: True) If the query identifiers should be normalized to lower case """ + has_eol = content.endswith("\n") try: parsed_query = sqlglot.parse(content, dialect=_SQL_DIALECT) except sqlglot.ParseError: @@ -453,7 +454,7 @@ def format(content: str, *, max_text_width: int = 120, normalize_case: bool = Tr if "$" in content: # replace ${x} with $x, because we use it in UCX view definitions for now formatted_query = re.sub(r"\${(\w+)}", r"$\1", formatted_query) - return formatted_query + return formatted_query + ("\n" if has_eol else "") def _get_abstract_syntax_tree(self) -> sqlglot.Expression | None: try: diff --git a/tests/unit/test_dashboards.py b/tests/unit/test_dashboards.py index 205729f8..f6868be2 100644 --- a/tests/unit/test_dashboards.py +++ b/tests/unit/test_dashboards.py @@ -832,6 +832,11 @@ def test_query_formats(query, query_formatted): assert QueryTile.format(query) == query_formatted +def test_query_format_preserves_eol(): + assert not QueryTile.format("SELECT x, y FROM a, b").endswith("\n") + assert QueryTile.format("SELECT x, y FROM a, b\n").endswith("\n") + + def test_query_formats_no_normalize(): query = """ select a.request_params.clusterId, a.request_params.notebookId From ec9012d391ee8d007acca2432a45f2123df90a48 Mon Sep 17 00:00:00 2001 From: Eric Vergnaud Date: Fri, 30 Aug 2024 17:30:41 +0200 Subject: [PATCH 2/2] fix failing tests --- tests/unit/test_dashboards.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_dashboards.py b/tests/unit/test_dashboards.py index f6868be2..1b825fa4 100644 --- a/tests/unit/test_dashboards.py +++ b/tests/unit/test_dashboards.py @@ -829,7 +829,7 @@ def test_query_tile_keeps_original_query(tmp_path): ], ) def test_query_formats(query, query_formatted): - assert QueryTile.format(query) == query_formatted + assert QueryTile.format(query).strip() == query_formatted.strip() def test_query_format_preserves_eol():