diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/async_db/_types.py index 7a3d17bd2a7..1881a07fe10 100644 --- a/src/firebolt/async_db/_types.py +++ b/src/firebolt/async_db/_types.py @@ -9,11 +9,15 @@ from sqlparse import parse as parse_sql # type: ignore from sqlparse.sql import ( # type: ignore + Comment, Comparison, Statement, Token, TokenList, ) +from sqlparse.tokens import Comparison as ComparisonType # type: ignore +from sqlparse.tokens import Newline # type: ignore +from sqlparse.tokens import Whitespace # type: ignore from sqlparse.tokens import Token as TokenType # type: ignore try: @@ -349,11 +353,18 @@ def statement_to_set(statement: Statement) -> Optional[SetParameter]: Return `None` if it's not a `SET` command. """ # Filter out meaningless tokens like Punctuation and Whitespaces + skip_types = [Whitespace, Newline] tokens = [ token for token in statement.tokens - if token.ttype == TokenType.Keyword or isinstance(token, Comparison) + if token.ttype not in skip_types and not isinstance(token, Comment) ] + # Trim tail punctuation + right_idx = len(tokens) - 1 + while str(tokens[right_idx]) == ";": + right_idx -= 1 + + tokens = tokens[: right_idx + 1] # Check if it's a SET statement by checking if it starts with set if ( @@ -362,13 +373,35 @@ def statement_to_set(statement: Statement) -> Optional[SetParameter]: and tokens[0].value.lower() == "set" ): # Check if set statement has a valid format - if len(tokens) != 2 or not isinstance(tokens[1], Comparison): - raise InterfaceError( - f"Invalid set statement format: {statement_to_sql(statement)}," - " expected SET = " + if len(tokens) == 2 and isinstance(tokens[1], Comparison): + return SetParameter( + statement_to_sql(tokens[1].left), + statement_to_sql(tokens[1].right).strip("'"), ) - return SetParameter( - statement_to_sql(tokens[1].left), statement_to_sql(tokens[1].right) + # Or if at least there is a comparison + cmp_idx = next( + ( + i + for i, token in enumerate(tokens) + if token.ttype == ComparisonType or isinstance(token, Comparison) + ), + None, + ) + if cmp_idx: + left_tokens, right_tokens = tokens[1:cmp_idx], tokens[cmp_idx + 1 :] + if isinstance(tokens[cmp_idx], Comparison): + left_tokens = left_tokens + [tokens[cmp_idx].left] + right_tokens = [tokens[cmp_idx].right] + right_tokens + + if left_tokens and right_tokens: + return SetParameter( + "".join(statement_to_sql(t) for t in left_tokens), + "".join(statement_to_sql(t) for t in right_tokens).strip("'"), + ) + + raise InterfaceError( + f"Invalid set statement format: {statement_to_sql(statement)}," + " expected SET = " ) return None diff --git a/tests/unit/async_db/test_typing_format.py b/tests/unit/async_db/test_typing_format.py index 4027b5dd1ff..a5e68e07aed 100644 --- a/tests/unit/async_db/test_typing_format.py +++ b/tests/unit/async_db/test_typing_format.py @@ -177,6 +177,17 @@ def test_split_format_error() -> None: (to_statement("set a = b"), SetParameter("a", "b")), (to_statement("set a=b"), SetParameter("a", "b")), (to_statement("set \t\na = \t\n b ;"), SetParameter("a", "b")), + (to_statement("set /*comment*/a=b"), SetParameter("a", "b")), + (to_statement("set a='some 'string'"), SetParameter("a", "some 'string")), + ( + to_statement( + 'set query_parameters={"name":"param1","value":"Hello, world!"}' + ), + SetParameter( + "query_parameters", '{"name":"param1","value":"Hello, world!"}' + ), + ), + (to_statement("UPDATE t SET a=50 WHERE a>b"), None), ], ) def test_statement_to_set(statement: Statement, result: Optional[SetParameter]) -> None: @@ -189,7 +200,6 @@ def test_statement_to_set(statement: Statement, result: Optional[SetParameter]) (to_statement("set"), InterfaceError), (to_statement("set a"), InterfaceError), (to_statement("set a ="), InterfaceError), - (to_statement("set a = '"), InterfaceError), ], ) def test_statement_to_set_errors(statement: Statement, error: Exception) -> None: