Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions src/firebolt/common/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from decimal import Decimal
from enum import Enum
from io import StringIO
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from sqlparse import parse as parse_sql # type: ignore
from sqlparse.sql import ( # type: ignore
Expand Down Expand Up @@ -234,7 +234,7 @@ def python_type(self) -> type:
def split_struct_fields(raw_struct: str) -> List[str]:
"""Split raw struct inner fields string into a list of field definitions.
>>> split_struct_fields("field1 int, field2 struct(field1 int, field2 text)")
["field1 int", "field2 struct(field1 int, field2 text)"]
['field1 int', 'field2 struct(field1 int, field2 text)']
"""
balance = 0 # keep track of the level of nesting, and only split on level 0
separator = ","
Expand All @@ -246,15 +246,38 @@ def split_struct_fields(raw_struct: str) -> List[str]:
elif ch == ")":
balance -= 1
elif ch == separator and balance == 0:
res.append(current.getvalue())
res.append(current.getvalue().strip())
current = StringIO()
continue
current.write(ch)

res.append(current.getvalue())
res.append(current.getvalue().strip())
return res


def split_struct_field(raw_field: str) -> Tuple[str, str]:
"""Split raw struct field into name and type.
>>> split_struct_field("field int")
('field', 'int')
>>> split_struct_field("`with space` text null")
('with space', 'text null')
>>> split_struct_field("s struct(`a b` int)")
('s', 'struct(`a b` int)')
"""
raw_field = raw_field.strip()
second_tick = (
raw_field.find("`", raw_field.find("`") + 1)
if raw_field.startswith("`")
else -1
)
name, type_ = (
(raw_field[: second_tick + 1], raw_field[second_tick + 1 :])
if second_tick != -1
else raw_field.split(" ", 1)
)
return name.strip(" `"), type_.strip()


def parse_type(raw_type: str) -> Union[type, ExtendedType]: # noqa: C901
"""Parse typename provided by query metadata into Python type."""
if not isinstance(raw_type, str):
Expand All @@ -276,7 +299,7 @@ def parse_type(raw_type: str) -> Union[type, ExtendedType]: # noqa: C901
fields_raw = split_struct_fields(raw_type[len(STRUCT._prefix) : -1])
fields = {}
for f in fields_raw:
name, type_ = f.strip().split(" ", 1)
name, type_ = split_struct_field(f)
fields[name.strip()] = parse_type(type_.strip())
return STRUCT(fields)
except ValueError:
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/common/test_typing_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def test_parse_type(types_map: Dict[str, type]) -> None:
), "Invalid type parsing error message"


def test_parse_struct_type_with_spaces() -> None:
parsed = parse_type("struct(`a b` int, s struct(`c d` text))")
assert parsed == STRUCT(
{"a b": int, "s": STRUCT({"c d": str})}
), f"Error parsing struct type with spaces"


@mark.parametrize(
"value,expected,error",
[
Expand Down Expand Up @@ -363,9 +370,9 @@ def test_parse_value_struct(value, expected, type_, error) -> None:
@mark.parametrize(
"value,expected",
[
("a int, b text", ["a int", " b text"]),
("a int, s struct(a int, b text)", ["a int", " s struct(a int, b text)"]),
("a int, b array(struct(a int))", ["a int", " b array(struct(a int))"]),
("a int, b text", ["a int", "b text"]),
("a int, s struct(a int, b text)", ["a int", "s struct(a int, b text)"]),
("a int, b array(struct(a int))", ["a int", "b array(struct(a int))"]),
],
)
def test_split_struct_fields(value, expected) -> None:
Expand Down
Loading