diff --git a/labs.yml b/labs.yml index 897f6879..d83c5e9d 100644 --- a/labs.yml +++ b/labs.yml @@ -2,8 +2,16 @@ name: lsql description: Lightweight SQL execution wrapper only on top of Databricks SDK entrypoint: src/databricks/labs/lsql/cli.py +install: + script: src/databricks/labs/lsql/__about__.py min_python: 3.10 commands: + - name: fmt + is_unauthenticated: true + description: Format SQL files in the given folder + flags: + - name: folder + description: The folder with SQL files. By default, the current working directory. - name: create-dashboard description: Create an unpublished dashboard from code, see [docs](./docs/dashboards.md). flags: diff --git a/src/databricks/labs/lsql/cli.py b/src/databricks/labs/lsql/cli.py index acc6d64d..70377a6f 100644 --- a/src/databricks/labs/lsql/cli.py +++ b/src/databricks/labs/lsql/cli.py @@ -5,7 +5,7 @@ from databricks.labs.blueprint.entrypoint import get_logger from databricks.sdk import WorkspaceClient -from databricks.labs.lsql.dashboards import DashboardMetadata, Dashboards +from databricks.labs.lsql.dashboards import DashboardMetadata, Dashboards, QueryTile logger = get_logger(__file__) lsql = App(__file__) @@ -37,5 +37,17 @@ def create_dashboard( print(sdk_dashboard.dashboard_id) +@lsql.command(is_unauthenticated=True) +def fmt(folder: Path = Path.cwd()): + """Format SQL files in a folder""" + logger.debug("Formatting SQL files ...") + folder = Path(folder) + for sql_file in folder.glob("**/*.sql"): + sql = sql_file.read_text() + formatted_sql = QueryTile.format(sql) + sql_file.write_text(formatted_sql) + logger.debug(f"Formatted {sql_file}") + + if __name__ == "__main__": lsql() diff --git a/src/databricks/labs/lsql/dashboards.py b/src/databricks/labs/lsql/dashboards.py index 8796e048..c82af3e6 100644 --- a/src/databricks/labs/lsql/dashboards.py +++ b/src/databricks/labs/lsql/dashboards.py @@ -376,6 +376,34 @@ class QueryTile(Tile): _DIALECT = sqlglot.dialects.Databricks _FILTER_HEIGHT = 1 + @staticmethod + def format(query: str, max_text_width: int = 120) -> str: + try: + parsed_query = sqlglot.parse(query, dialect="databricks") + except sqlglot.ParseError: + return query + statements = [] + for statement in parsed_query: + if statement is None: + continue + # TODO: CASE .. WHEN .. THEN .. formatting is a bit less readable after reformatting. + # See https://github.com/tobymao/sqlglot/issues/3770 + # see https://sqlglot.com/sqlglot/generator.html#Generator + statements.append( + statement.sql( + dialect="databricks", + normalize=True, # normalize identifiers to lowercase + pretty=True, # format the produced SQL string + normalize_functions="upper", # normalize function names to uppercase + max_text_width=max_text_width, # wrap text at 120 characters + ) + ) + formatted_query = ";\n".join(statements) + if "$" in query: + # replace ${x} with $x, because we use it in UCX view definitions for now + formatted_query = re.sub(r"\${(\w+)}", r"$\1", formatted_query) + return formatted_query + def _get_abstract_syntax_tree(self) -> sqlglot.Expression | None: try: return sqlglot.parse_one(self.content, dialect=self._DIALECT) @@ -822,7 +850,7 @@ def save_to_folder(self, dashboard: Dashboard, local_path: Path) -> Dashboard: local_path.mkdir(parents=True, exist_ok=True) dashboard = self._with_better_names(dashboard) for dataset in dashboard.datasets: - query = self._format_query(dataset.query) + query = QueryTile.format(dataset.query) with (local_path / f"{dataset.name}.sql").open("w") as f: f.write(query) for page in dashboard.pages: @@ -830,29 +858,6 @@ def save_to_folder(self, dashboard: Dashboard, local_path: Path) -> Dashboard: yaml.safe_dump(page.as_dict(), f) return dashboard - @staticmethod - def _format_query(query: str) -> str: - try: - parsed_query = sqlglot.parse(query) - except sqlglot.ParseError: - return query - statements = [] - for statement in parsed_query: - if statement is None: - continue - # see https://sqlglot.com/sqlglot/generator.html#Generator - statements.append( - statement.sql( - dialect="databricks", - normalize=True, # normalize identifiers to lowercase - pretty=True, # format the produced SQL string - normalize_functions="upper", # normalize function names to uppercase - max_text_width=80, # wrap text at 80 characters - ) - ) - formatted_query = ";\n".join(statements) - return formatted_query - def deploy_dashboard( self, lakeview_dashboard: Dashboard, diff --git a/tests/integration/dashboards/one_counter/010_counter.sql b/tests/integration/dashboards/one_counter/010_counter.sql index b2ab78e2..df9bd5e5 100644 --- a/tests/integration/dashboards/one_counter/010_counter.sql +++ b/tests/integration/dashboards/one_counter/010_counter.sql @@ -1 +1,2 @@ -SELECT 6217 AS count \ No newline at end of file +SELECT + 6217 AS count \ No newline at end of file diff --git a/tests/integration/dashboards/one_table/databricks_office_locations.sql b/tests/integration/dashboards/one_table/databricks_office_locations.sql index fbe8a2e4..91c72d8b 100644 --- a/tests/integration/dashboards/one_table/databricks_office_locations.sql +++ b/tests/integration/dashboards/one_table/databricks_office_locations.sql @@ -1,13 +1,12 @@ SELECT - Address, - City, - State, + address, + city, + state, `Zip Code`, - Country -FROM -VALUES + country +FROM VALUES ('160 Spear St 15th Floor', 'San Francisco', 'CA', '94105', 'USA'), ('756 W Peachtree St NW, Suite 03W114', 'Atlanta', 'GA', '30308', 'USA'), ('500 108th Ave NE, Suite 1820', 'Bellevue', 'WA', '98004', 'USA'), ('125 High St, Suite 220', 'Boston', 'MA', '02110', 'USA'), - ('2120 University Ave, Suite 722', 'Berkeley', 'CA', '94704', 'USA') AS tab(Address, City, State, `Zip Code`, Country) + ('2120 University Ave, Suite 722', 'Berkeley', 'CA', '94704', 'USA') AS tab(address, city, state, `Zip Code`, country) \ No newline at end of file diff --git a/tests/integration/views/some.sql b/tests/integration/views/some.sql index 6af77981..9f5ce717 100644 --- a/tests/integration/views/some.sql +++ b/tests/integration/views/some.sql @@ -1 +1,4 @@ -SELECT first AS name, 1 AS id FROM $inventory.foo \ No newline at end of file +SELECT + first AS name, + 1 AS id +FROM $inventory.foo \ No newline at end of file diff --git a/tests/unit/queries/counter.sql b/tests/unit/queries/counter.sql index b2ab78e2..df9bd5e5 100644 --- a/tests/unit/queries/counter.sql +++ b/tests/unit/queries/counter.sql @@ -1 +1,2 @@ -SELECT 6217 AS count \ No newline at end of file +SELECT + 6217 AS count \ No newline at end of file diff --git a/tests/unit/test_deployment.py b/tests/unit/test_deployment.py index cd6fcdbb..1e407a6f 100644 --- a/tests/unit/test_deployment.py +++ b/tests/unit/test_deployment.py @@ -17,7 +17,7 @@ def test_deploys_view(): deployment.deploy_view("some", "some.sql") assert mock_backend.queries == [ - "CREATE OR REPLACE VIEW hive_metastore.inventory.some AS SELECT id, name FROM hive_metastore.inventory.something" + "CREATE OR REPLACE VIEW hive_metastore.inventory.some AS SELECT\n id,\n name\nFROM hive_metastore.inventory.something" ] diff --git a/tests/unit/views/some.sql b/tests/unit/views/some.sql index 166a1660..c76c45a0 100644 --- a/tests/unit/views/some.sql +++ b/tests/unit/views/some.sql @@ -1 +1,4 @@ -SELECT id, name FROM $inventory.something \ No newline at end of file +SELECT + id, + name +FROM $inventory.something \ No newline at end of file