diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/async_db/_types.py index 73f2c109ede..7a3d17bd2a7 100644 --- a/src/firebolt/async_db/_types.py +++ b/src/firebolt/async_db/_types.py @@ -49,11 +49,11 @@ def parse_datetime(datetime_string: str) -> datetime: from firebolt.utils.util import cached_property _NoneType = type(None) -_col_types = (int, float, str, datetime, date, bool, list, Decimal, _NoneType) +_col_types = (int, float, str, datetime, date, bool, list, Decimal, _NoneType, bytes) # duplicating this since 3.7 can't unpack Union -ColType = Union[int, float, str, datetime, date, bool, list, Decimal, _NoneType] +ColType = Union[int, float, str, datetime, date, bool, list, Decimal, _NoneType, bytes] RawColType = Union[int, float, str, bool, list, _NoneType] -ParameterType = Union[int, float, str, datetime, date, bool, Decimal, Sequence] +ParameterType = Union[int, float, str, datetime, date, bool, Decimal, Sequence, bytes] # These definitions are required by PEP-249 Date = date @@ -78,12 +78,13 @@ def TimeFromTicks(t: int) -> None: TimestampFromTicks = datetime.fromtimestamp -def Binary(value: str) -> str: - """Convert string to binary for Firebolt DB does nothing.""" - return value +def Binary(value: str) -> bytes: + """Encode a string into UTF-8.""" + return value.encode("utf-8") -STRING = BINARY = str +STRING = str +BINARY = bytes NUMBER = int DATETIME = datetime ROWID = int @@ -169,6 +170,8 @@ class _InternalType(Enum): Boolean = "boolean" + Bytea = "bytea" + Nothing = "Nothing" @cached_property @@ -188,6 +191,7 @@ def python_type(self) -> type: _InternalType.TimestampNtz: datetime, _InternalType.TimestampTz: datetime, _InternalType.Boolean: bool, + _InternalType.Bytea: bytes, # For simplicity, this could happen only during 'select null' query _InternalType.Nothing: str, } @@ -221,6 +225,18 @@ def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901 return str +BYTEA_PREFIX = "\\x" + + +def _parse_bytea(str_value: str) -> bytes: + if ( + len(str_value) < len(BYTEA_PREFIX) + or str_value[: len(BYTEA_PREFIX)] != BYTEA_PREFIX + ): + raise ValueError(f"Invalid bytea value format: {BYTEA_PREFIX} prefix expected") + return bytes.fromhex(str_value[len(BYTEA_PREFIX) :]) + + def parse_value( value: RawColType, ctype: Union[type, ARRAY, DECIMAL], @@ -244,6 +260,10 @@ def parse_value( if not isinstance(value, (bool, int)): raise DataError(f"Invalid boolean value {value}: bool or int expected") return bool(value) + if ctype is bytes: + if not isinstance(value, str): + raise DataError(f"Invalid bytea value {value}: str expected") + return _parse_bytea(value) if isinstance(ctype, DECIMAL): assert isinstance(value, (str, int)) return Decimal(value) @@ -274,6 +294,9 @@ def format_value(value: ParameterType) -> str: return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'" elif isinstance(value, date): return f"'{value.isoformat()}'" + elif isinstance(value, bytes): + # Encode each byte into hex + return "'" + "".join(f"\\x{b:02x}" for b in value) + "'" if value is None: return "NULL" elif isinstance(value, Sequence): diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index f1ef2f8746f..7de7f7ef56d 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -4,7 +4,13 @@ from pytest import mark, raises -from firebolt.async_db import Connection, Cursor, DataError, OperationalError +from firebolt.async_db import ( + Binary, + Connection, + Cursor, + DataError, + OperationalError, +) from firebolt.async_db._types import ColType, Column from firebolt.async_db.cursor import QueryStatus @@ -487,3 +493,27 @@ async def test_server_side_async_execution_get_status( # assert ( # type(status) is QueryStatus, # ), "get_status() did not return a QueryStatus object." + + +async def test_bytea_roundtrip( + connection: Connection, +) -> None: + """Inserted and than selected bytea value doesn't get corrupted.""" + with connection.cursor() as c: + await c.execute("DROP TABLE IF EXISTS test_bytea_roundtrip") + await c.execute( + "CREATE FACT TABLE test_bytea_roundtrip(id int, b bytea) primary index id" + ) + + data = "bytea_123\n\tヽ༼ຈل͜ຈ༽ノ" + + await c.execute( + "INSERT INTO test_bytea_roundtrip VALUES (1, ?)", (Binary(data),) + ) + await c.execute("SELECT b FROM test_bytea_roundtrip") + + bytes_data = (await c.fetchone())[0] + + assert ( + bytes_data.decode("utf-8") == data + ), "Invalid bytea data returned after roundtrip" diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index 657449e40d1..1843a22b894 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -75,7 +75,8 @@ def all_types_query() -> str: 'true as "boolean", ' "[1,2,3,4] as \"array\", cast('1231232.123459999990457054844258706536' as " 'decimal(38,30)) as "decimal", ' - 'cast(null as int) as "nullable"' + 'cast(null as int) as "nullable", ' + "'abc123'::bytea as \"bytea\"" ) @@ -104,6 +105,7 @@ def all_types_query_description() -> List[Column]: Column("array", ARRAY(int), None, None, None, None, None), Column("decimal", DECIMAL(38, 30), None, None, None, None, None), Column("nullable", int, None, None, None, None, None), + Column("bytea", bytes, None, None, None, None, None), ] @@ -142,6 +144,7 @@ def all_types_query_response(timezone_offset_seconds: int) -> List[ColType]: [1, 2, 3, 4], Decimal("1231232.123459999990457054844258706536"), None, + b"abc123", ] ] diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index f29946602f7..5d3084c94d5 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -9,6 +9,7 @@ from firebolt.async_db.cursor import QueryStatus from firebolt.client.auth import Auth from firebolt.db import ( + Binary, Connection, Cursor, DataError, @@ -519,3 +520,25 @@ def run_query(): connection.close() assert not exceptions + + +def test_bytea_roundtrip( + connection: Connection, +) -> None: + """Inserted and than selected bytea value doesn't get corrupted.""" + with connection.cursor() as c: + c.execute("DROP TABLE IF EXISTS test_bytea_roundtrip") + c.execute( + "CREATE FACT TABLE test_bytea_roundtrip(id int, b bytea) primary index id" + ) + + data = "bytea_123\n\tヽ༼ຈل͜ຈ༽ノ" + + c.execute("INSERT INTO test_bytea_roundtrip VALUES (1, ?)", (Binary(data),)) + c.execute("SELECT b FROM test_bytea_roundtrip") + + bytes_data = (c.fetchone())[0] + + assert ( + bytes_data.decode("utf-8") == data + ), "Invalid bytea data returned after roundtrip" diff --git a/tests/unit/async_db/conftest.py b/tests/unit/async_db/conftest.py index d7064bfa691..fdca43ab438 100644 --- a/tests/unit/async_db/conftest.py +++ b/tests/unit/async_db/conftest.py @@ -51,6 +51,7 @@ def types_map() -> Dict[str, type]: "Decimal(38)": str, "boolean": bool, "SomeRandomNotExistingType": str, + "bytea": bytes, } array_types = {f"array({k})": ARRAY(v) for k, v in base_types.items()} nullable_arrays = {f"{k} null": v for k, v in array_types.items()} diff --git a/tests/unit/async_db/test_typing_format.py b/tests/unit/async_db/test_typing_format.py index 97bd835b22f..4027b5dd1ff 100644 --- a/tests/unit/async_db/test_typing_format.py +++ b/tests/unit/async_db/test_typing_format.py @@ -6,7 +6,12 @@ from sqlparse import parse from sqlparse.sql import Statement -from firebolt.async_db import DataError, InterfaceError, NotSupportedError +from firebolt.async_db import ( + Binary, + DataError, + InterfaceError, + NotSupportedError, +) from firebolt.async_db._types import ( SetParameter, format_statement, @@ -44,6 +49,8 @@ (("a", "b", "c"), "['a', 'b', 'c']"), # None (None, "NULL"), + # Bytea + (b"abc", "'\\x61\\x62\\x63'"), ], ) def test_format_value(value: str, result: str) -> None: @@ -188,3 +195,7 @@ def test_statement_to_set(statement: Statement, result: Optional[SetParameter]) def test_statement_to_set_errors(statement: Statement, error: Exception) -> None: with raises(error): statement_to_set(statement) + + +def test_binary() -> None: + assert Binary("abc") == b"abc" diff --git a/tests/unit/async_db/test_typing_parse.py b/tests/unit/async_db/test_typing_parse.py index 76921d2101d..fad2c93a693 100644 --- a/tests/unit/async_db/test_typing_parse.py +++ b/tests/unit/async_db/test_typing_parse.py @@ -248,3 +248,22 @@ def test_parse_value_bool() -> None: with raises(DataError): parse_value("true", bool) + + +def test_parse_value_bytes() -> None: + """parse_value parses all int values correctly.""" + assert ( + parse_value("\\x616263", bytes) == b"abc" + ), "Error parsing bytes: provided str" + assert parse_value(None, bytes) is None, "Error parsing bytes: provided None" + + with raises(ValueError): + parse_value("\\xabc", bytes) + + # Missing prefix + with raises(ValueError): + parse_value("616263", bytes) + + for val in (1, True, Exception()): + with raises(DataError): + parse_value(val, bytes) diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 2c7505370f7..b9dccba3cf8 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -32,6 +32,7 @@ def query_description() -> List[Column]: Column("bool", "boolean", None, None, None, None, None), Column("array", "array(int)", None, None, None, None, None), Column("decimal", "Decimal(12, 34)", None, None, None, None, None), + Column("bytea", "bytea", None, None, None, None, None), ] @@ -54,6 +55,7 @@ def python_query_description() -> List[Column]: Column("bool", bool, None, None, None, None, None), Column("array", ARRAY(int), None, None, None, None, None), Column("decimal", DECIMAL(12, 34), None, None, None, None, None), + Column("bytea", bytes, None, None, None, None, None), ] @@ -77,6 +79,7 @@ def query_data() -> List[List[ColType]]: 1, [1, 2, 3, 4], "123456789.123456789123456789123456789", + "\\x616263", ] for i in range(QUERY_ROW_COUNT) ] @@ -102,6 +105,7 @@ def python_query_data() -> List[List[ColType]]: 1, [1, 2, 3, 4], Decimal("123456789.123456789123456789123456789"), + b"abc", ] for i in range(QUERY_ROW_COUNT) ]