diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/async_db/_types.py
index 7a3d17bd2a7..1881a07fe10 100644
--- a/src/firebolt/async_db/_types.py
+++ b/src/firebolt/async_db/_types.py
@@ -9,11 +9,15 @@
from sqlparse import parse as parse_sql # type: ignore
from sqlparse.sql import ( # type: ignore
+ Comment,
Comparison,
Statement,
Token,
TokenList,
)
+from sqlparse.tokens import Comparison as ComparisonType # type: ignore
+from sqlparse.tokens import Newline # type: ignore
+from sqlparse.tokens import Whitespace # type: ignore
from sqlparse.tokens import Token as TokenType # type: ignore
try:
@@ -349,11 +353,18 @@ def statement_to_set(statement: Statement) -> Optional[SetParameter]:
Return `None` if it's not a `SET` command.
"""
# Filter out meaningless tokens like Punctuation and Whitespaces
+ skip_types = [Whitespace, Newline]
tokens = [
token
for token in statement.tokens
- if token.ttype == TokenType.Keyword or isinstance(token, Comparison)
+ if token.ttype not in skip_types and not isinstance(token, Comment)
]
+ # Trim tail punctuation
+ right_idx = len(tokens) - 1
+ while str(tokens[right_idx]) == ";":
+ right_idx -= 1
+
+ tokens = tokens[: right_idx + 1]
# Check if it's a SET statement by checking if it starts with set
if (
@@ -362,13 +373,35 @@ def statement_to_set(statement: Statement) -> Optional[SetParameter]:
and tokens[0].value.lower() == "set"
):
# Check if set statement has a valid format
- if len(tokens) != 2 or not isinstance(tokens[1], Comparison):
- raise InterfaceError(
- f"Invalid set statement format: {statement_to_sql(statement)},"
- " expected SET = "
+ if len(tokens) == 2 and isinstance(tokens[1], Comparison):
+ return SetParameter(
+ statement_to_sql(tokens[1].left),
+ statement_to_sql(tokens[1].right).strip("'"),
)
- return SetParameter(
- statement_to_sql(tokens[1].left), statement_to_sql(tokens[1].right)
+ # Or if at least there is a comparison
+ cmp_idx = next(
+ (
+ i
+ for i, token in enumerate(tokens)
+ if token.ttype == ComparisonType or isinstance(token, Comparison)
+ ),
+ None,
+ )
+ if cmp_idx:
+ left_tokens, right_tokens = tokens[1:cmp_idx], tokens[cmp_idx + 1 :]
+ if isinstance(tokens[cmp_idx], Comparison):
+ left_tokens = left_tokens + [tokens[cmp_idx].left]
+ right_tokens = [tokens[cmp_idx].right] + right_tokens
+
+ if left_tokens and right_tokens:
+ return SetParameter(
+ "".join(statement_to_sql(t) for t in left_tokens),
+ "".join(statement_to_sql(t) for t in right_tokens).strip("'"),
+ )
+
+ raise InterfaceError(
+ f"Invalid set statement format: {statement_to_sql(statement)},"
+ " expected SET = "
)
return None
diff --git a/tests/unit/async_db/test_typing_format.py b/tests/unit/async_db/test_typing_format.py
index 4027b5dd1ff..a5e68e07aed 100644
--- a/tests/unit/async_db/test_typing_format.py
+++ b/tests/unit/async_db/test_typing_format.py
@@ -177,6 +177,17 @@ def test_split_format_error() -> None:
(to_statement("set a = b"), SetParameter("a", "b")),
(to_statement("set a=b"), SetParameter("a", "b")),
(to_statement("set \t\na = \t\n b ;"), SetParameter("a", "b")),
+ (to_statement("set /*comment*/a=b"), SetParameter("a", "b")),
+ (to_statement("set a='some 'string'"), SetParameter("a", "some 'string")),
+ (
+ to_statement(
+ 'set query_parameters={"name":"param1","value":"Hello, world!"}'
+ ),
+ SetParameter(
+ "query_parameters", '{"name":"param1","value":"Hello, world!"}'
+ ),
+ ),
+ (to_statement("UPDATE t SET a=50 WHERE a>b"), None),
],
)
def test_statement_to_set(statement: Statement, result: Optional[SetParameter]) -> None:
@@ -189,7 +200,6 @@ def test_statement_to_set(statement: Statement, result: Optional[SetParameter])
(to_statement("set"), InterfaceError),
(to_statement("set a"), InterfaceError),
(to_statement("set a ="), InterfaceError),
- (to_statement("set a = '"), InterfaceError),
],
)
def test_statement_to_set_errors(statement: Statement, error: Exception) -> None: