Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions src/firebolt/async_db/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@

from sqlparse import parse as parse_sql # type: ignore
from sqlparse.sql import ( # type: ignore
Comment,
Comparison,
Statement,
Token,
TokenList,
)
from sqlparse.tokens import Comparison as ComparisonType # type: ignore
from sqlparse.tokens import Newline # type: ignore
from sqlparse.tokens import Whitespace # type: ignore
from sqlparse.tokens import Token as TokenType # type: ignore

try:
Expand Down Expand Up @@ -349,11 +353,18 @@ def statement_to_set(statement: Statement) -> Optional[SetParameter]:
Return `None` if it's not a `SET` command.
"""
# Filter out meaningless tokens like Punctuation and Whitespaces
skip_types = [Whitespace, Newline]
tokens = [
token
for token in statement.tokens
if token.ttype == TokenType.Keyword or isinstance(token, Comparison)
if token.ttype not in skip_types and not isinstance(token, Comment)
]
# Trim tail punctuation
right_idx = len(tokens) - 1
while str(tokens[right_idx]) == ";":
right_idx -= 1

tokens = tokens[: right_idx + 1]

# Check if it's a SET statement by checking if it starts with set
if (
Expand All @@ -362,13 +373,35 @@ def statement_to_set(statement: Statement) -> Optional[SetParameter]:
and tokens[0].value.lower() == "set"
):
# Check if set statement has a valid format
if len(tokens) != 2 or not isinstance(tokens[1], Comparison):
raise InterfaceError(
f"Invalid set statement format: {statement_to_sql(statement)},"
" expected SET <param> = <value>"
if len(tokens) == 2 and isinstance(tokens[1], Comparison):
return SetParameter(
statement_to_sql(tokens[1].left),
statement_to_sql(tokens[1].right).strip("'"),
)
return SetParameter(
statement_to_sql(tokens[1].left), statement_to_sql(tokens[1].right)
# Or if at least there is a comparison
cmp_idx = next(
(
i
for i, token in enumerate(tokens)
if token.ttype == ComparisonType or isinstance(token, Comparison)
),
None,
)
if cmp_idx:
left_tokens, right_tokens = tokens[1:cmp_idx], tokens[cmp_idx + 1 :]
if isinstance(tokens[cmp_idx], Comparison):
left_tokens = left_tokens + [tokens[cmp_idx].left]
right_tokens = [tokens[cmp_idx].right] + right_tokens

if left_tokens and right_tokens:
return SetParameter(
"".join(statement_to_sql(t) for t in left_tokens),
"".join(statement_to_sql(t) for t in right_tokens).strip("'"),
)

raise InterfaceError(
f"Invalid set statement format: {statement_to_sql(statement)},"
" expected SET <param> = <value>"
)
return None

Expand Down
12 changes: 11 additions & 1 deletion tests/unit/async_db/test_typing_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,17 @@ def test_split_format_error() -> None:
(to_statement("set a = b"), SetParameter("a", "b")),
(to_statement("set a=b"), SetParameter("a", "b")),
(to_statement("set \t\na = \t\n b ;"), SetParameter("a", "b")),
(to_statement("set /*comment*/a=b"), SetParameter("a", "b")),
(to_statement("set a='some 'string'"), SetParameter("a", "some 'string")),
(
to_statement(
'set query_parameters={"name":"param1","value":"Hello, world!"}'
),
SetParameter(
"query_parameters", '{"name":"param1","value":"Hello, world!"}'
),
),
(to_statement("UPDATE t SET a=50 WHERE a>b"), None),
],
)
def test_statement_to_set(statement: Statement, result: Optional[SetParameter]) -> None:
Expand All @@ -189,7 +200,6 @@ def test_statement_to_set(statement: Statement, result: Optional[SetParameter])
(to_statement("set"), InterfaceError),
(to_statement("set a"), InterfaceError),
(to_statement("set a ="), InterfaceError),
(to_statement("set a = '"), InterfaceError),
],
)
def test_statement_to_set_errors(statement: Statement, error: Exception) -> None:
Expand Down