Skip to content
37 changes: 30 additions & 7 deletions src/firebolt/async_db/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def parse_datetime(datetime_string: str) -> datetime:
from firebolt.utils.util import cached_property

_NoneType = type(None)
_col_types = (int, float, str, datetime, date, bool, list, Decimal, _NoneType)
_col_types = (int, float, str, datetime, date, bool, list, Decimal, _NoneType, bytes)
# duplicating this since 3.7 can't unpack Union
ColType = Union[int, float, str, datetime, date, bool, list, Decimal, _NoneType]
ColType = Union[int, float, str, datetime, date, bool, list, Decimal, _NoneType, bytes]
RawColType = Union[int, float, str, bool, list, _NoneType]
ParameterType = Union[int, float, str, datetime, date, bool, Decimal, Sequence]
ParameterType = Union[int, float, str, datetime, date, bool, Decimal, Sequence, bytes]

# These definitions are required by PEP-249
Date = date
Expand All @@ -78,12 +78,13 @@ def TimeFromTicks(t: int) -> None:
TimestampFromTicks = datetime.fromtimestamp


def Binary(value: str) -> str:
"""Convert string to binary for Firebolt DB does nothing."""
return value
def Binary(value: str) -> bytes:
"""Encode a string into UTF-8."""
return value.encode("utf-8")


STRING = BINARY = str
STRING = str
BINARY = bytes
NUMBER = int
DATETIME = datetime
ROWID = int
Expand Down Expand Up @@ -169,6 +170,8 @@ class _InternalType(Enum):

Boolean = "boolean"

Bytea = "bytea"

Nothing = "Nothing"

@cached_property
Expand All @@ -188,6 +191,7 @@ def python_type(self) -> type:
_InternalType.TimestampNtz: datetime,
_InternalType.TimestampTz: datetime,
_InternalType.Boolean: bool,
_InternalType.Bytea: bytes,
# For simplicity, this could happen only during 'select null' query
_InternalType.Nothing: str,
}
Expand Down Expand Up @@ -221,6 +225,18 @@ def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901
return str


BYTEA_PREFIX = "\\x"


def _parse_bytea(str_value: str) -> bytes:
if (
len(str_value) < len(BYTEA_PREFIX)
or str_value[: len(BYTEA_PREFIX)] != BYTEA_PREFIX
):
raise ValueError(f"Invalid bytea value format: {BYTEA_PREFIX} prefix expected")
return bytes.fromhex(str_value[len(BYTEA_PREFIX) :])


def parse_value(
value: RawColType,
ctype: Union[type, ARRAY, DECIMAL],
Expand All @@ -244,6 +260,10 @@ def parse_value(
if not isinstance(value, (bool, int)):
raise DataError(f"Invalid boolean value {value}: bool or int expected")
return bool(value)
if ctype is bytes:
if not isinstance(value, str):
raise DataError(f"Invalid bytea value {value}: str expected")
return _parse_bytea(value)
if isinstance(ctype, DECIMAL):
assert isinstance(value, (str, int))
return Decimal(value)
Expand Down Expand Up @@ -274,6 +294,9 @@ def format_value(value: ParameterType) -> str:
return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'"
elif isinstance(value, date):
return f"'{value.isoformat()}'"
elif isinstance(value, bytes):
# Encode each byte into hex
return "'" + "".join(f"\\x{b:02x}" for b in value) + "'"
if value is None:
return "NULL"
elif isinstance(value, Sequence):
Expand Down
32 changes: 31 additions & 1 deletion tests/integration/dbapi/async/test_queries_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

from pytest import mark, raises

from firebolt.async_db import Connection, Cursor, DataError, OperationalError
from firebolt.async_db import (
Binary,
Connection,
Cursor,
DataError,
OperationalError,
)
from firebolt.async_db._types import ColType, Column
from firebolt.async_db.cursor import QueryStatus

Expand Down Expand Up @@ -487,3 +493,27 @@ async def test_server_side_async_execution_get_status(
# assert (
# type(status) is QueryStatus,
# ), "get_status() did not return a QueryStatus object."


async def test_bytea_roundtrip(
connection: Connection,
) -> None:
"""Inserted and than selected bytea value doesn't get corrupted."""
with connection.cursor() as c:
await c.execute("DROP TABLE IF EXISTS test_bytea_roundtrip")
await c.execute(
"CREATE FACT TABLE test_bytea_roundtrip(id int, b bytea) primary index id"
)

data = "bytea_123\n\tヽ༼ຈل͜ຈ༽ノ"

await c.execute(
"INSERT INTO test_bytea_roundtrip VALUES (1, ?)", (Binary(data),)
)
await c.execute("SELECT b FROM test_bytea_roundtrip")

bytes_data = (await c.fetchone())[0]

assert (
bytes_data.decode("utf-8") == data
), "Invalid bytea data returned after roundtrip"
5 changes: 4 additions & 1 deletion tests/integration/dbapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def all_types_query() -> str:
'true as "boolean", '
"[1,2,3,4] as \"array\", cast('1231232.123459999990457054844258706536' as "
'decimal(38,30)) as "decimal", '
'cast(null as int) as "nullable"'
'cast(null as int) as "nullable", '
"'abc123'::bytea as \"bytea\""
)


Expand Down Expand Up @@ -104,6 +105,7 @@ def all_types_query_description() -> List[Column]:
Column("array", ARRAY(int), None, None, None, None, None),
Column("decimal", DECIMAL(38, 30), None, None, None, None, None),
Column("nullable", int, None, None, None, None, None),
Column("bytea", bytes, None, None, None, None, None),
]


Expand Down Expand Up @@ -142,6 +144,7 @@ def all_types_query_response(timezone_offset_seconds: int) -> List[ColType]:
[1, 2, 3, 4],
Decimal("1231232.123459999990457054844258706536"),
None,
b"abc123",
]
]

Expand Down
23 changes: 23 additions & 0 deletions tests/integration/dbapi/sync/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from firebolt.async_db.cursor import QueryStatus
from firebolt.client.auth import Auth
from firebolt.db import (
Binary,
Connection,
Cursor,
DataError,
Expand Down Expand Up @@ -519,3 +520,25 @@ def run_query():

connection.close()
assert not exceptions


def test_bytea_roundtrip(
connection: Connection,
) -> None:
"""Inserted and than selected bytea value doesn't get corrupted."""
with connection.cursor() as c:
c.execute("DROP TABLE IF EXISTS test_bytea_roundtrip")
c.execute(
"CREATE FACT TABLE test_bytea_roundtrip(id int, b bytea) primary index id"
)

data = "bytea_123\n\tヽ༼ຈل͜ຈ༽ノ"

c.execute("INSERT INTO test_bytea_roundtrip VALUES (1, ?)", (Binary(data),))
c.execute("SELECT b FROM test_bytea_roundtrip")

bytes_data = (c.fetchone())[0]

assert (
bytes_data.decode("utf-8") == data
), "Invalid bytea data returned after roundtrip"
1 change: 1 addition & 0 deletions tests/unit/async_db/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def types_map() -> Dict[str, type]:
"Decimal(38)": str,
"boolean": bool,
"SomeRandomNotExistingType": str,
"bytea": bytes,
}
array_types = {f"array({k})": ARRAY(v) for k, v in base_types.items()}
nullable_arrays = {f"{k} null": v for k, v in array_types.items()}
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/async_db/test_typing_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from sqlparse import parse
from sqlparse.sql import Statement

from firebolt.async_db import DataError, InterfaceError, NotSupportedError
from firebolt.async_db import (
Binary,
DataError,
InterfaceError,
NotSupportedError,
)
from firebolt.async_db._types import (
SetParameter,
format_statement,
Expand Down Expand Up @@ -44,6 +49,8 @@
(("a", "b", "c"), "['a', 'b', 'c']"),
# None
(None, "NULL"),
# Bytea
(b"abc", "'\\x61\\x62\\x63'"),
],
)
def test_format_value(value: str, result: str) -> None:
Expand Down Expand Up @@ -188,3 +195,7 @@ def test_statement_to_set(statement: Statement, result: Optional[SetParameter])
def test_statement_to_set_errors(statement: Statement, error: Exception) -> None:
with raises(error):
statement_to_set(statement)


def test_binary() -> None:
assert Binary("abc") == b"abc"
19 changes: 19 additions & 0 deletions tests/unit/async_db/test_typing_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,22 @@ def test_parse_value_bool() -> None:

with raises(DataError):
parse_value("true", bool)


def test_parse_value_bytes() -> None:
"""parse_value parses all int values correctly."""
assert (
parse_value("\\x616263", bytes) == b"abc"
), "Error parsing bytes: provided str"
assert parse_value(None, bytes) is None, "Error parsing bytes: provided None"

with raises(ValueError):
parse_value("\\xabc", bytes)

# Missing prefix
with raises(ValueError):
parse_value("616263", bytes)

for val in (1, True, Exception()):
with raises(DataError):
parse_value(val, bytes)
4 changes: 4 additions & 0 deletions tests/unit/db_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def query_description() -> List[Column]:
Column("bool", "boolean", None, None, None, None, None),
Column("array", "array(int)", None, None, None, None, None),
Column("decimal", "Decimal(12, 34)", None, None, None, None, None),
Column("bytea", "bytea", None, None, None, None, None),
]


Expand All @@ -54,6 +55,7 @@ def python_query_description() -> List[Column]:
Column("bool", bool, None, None, None, None, None),
Column("array", ARRAY(int), None, None, None, None, None),
Column("decimal", DECIMAL(12, 34), None, None, None, None, None),
Column("bytea", bytes, None, None, None, None, None),
]


Expand All @@ -77,6 +79,7 @@ def query_data() -> List[List[ColType]]:
1,
[1, 2, 3, 4],
"123456789.123456789123456789123456789",
"\\x616263",
]
for i in range(QUERY_ROW_COUNT)
]
Expand All @@ -102,6 +105,7 @@ def python_query_data() -> List[List[ColType]]:
1,
[1, 2, 3, 4],
Decimal("123456789.123456789123456789123456789"),
b"abc",
]
for i in range(QUERY_ROW_COUNT)
]
Expand Down