diff --git a/src/firebolt/async_db/__init__.py b/src/firebolt/async_db/__init__.py index cd21ba20036..b8b5b7b03d0 100644 --- a/src/firebolt/async_db/__init__.py +++ b/src/firebolt/async_db/__init__.py @@ -2,6 +2,8 @@ ARRAY, BINARY, DATETIME, + DATETIME64, + DECIMAL, NUMBER, ROWID, STRING, diff --git a/src/firebolt/async_db/_types.py b/src/firebolt/async_db/_types.py index f0b1d8567c7..fe7d9d193b3 100644 --- a/src/firebolt/async_db/_types.py +++ b/src/firebolt/async_db/_types.py @@ -13,18 +13,24 @@ try: from ciso8601 import parse_datetime # type: ignore except ImportError: - parse_datetime = datetime.fromisoformat # type: ignore + # Unfortunately, there seems to be no support for optional bits in strptime + def parse_datetime(date_string: str) -> datetime: # type: ignore + format = "%Y-%m-%d %H:%M:%S.%f" + # fromisoformat doesn't support milliseconds + if "." in date_string: + return datetime.strptime(date_string, format) + return datetime.fromisoformat(date_string) from firebolt.common.exception import DataError, NotSupportedError from firebolt.common.util import cached_property _NoneType = type(None) -_col_types = (int, float, str, datetime, date, bool, list, _NoneType) +_col_types = (int, float, str, datetime, date, bool, list, Decimal, _NoneType) # duplicating this since 3.7 can't unpack Union -ColType = Union[int, float, str, datetime, date, bool, list, _NoneType] +ColType = Union[int, float, str, datetime, date, bool, list, Decimal, _NoneType] RawColType = Union[int, float, str, bool, list, _NoneType] -ParameterType = Union[int, float, str, datetime, date, bool, Sequence] +ParameterType = Union[int, float, str, datetime, date, bool, Decimal, Sequence] # These definitions are required by PEP-249 Date = date @@ -78,9 +84,9 @@ class ARRAY: _prefix = "Array(" - def __init__(self, subtype: Union[type, ARRAY]): + def __init__(self, subtype: Union[type, ARRAY, DECIMAL, DATETIME64]): assert (subtype in _col_types and subtype is not list) or isinstance( - subtype, ARRAY + subtype, (ARRAY, DECIMAL, DATETIME64) ), f"Invalid array subtype: {str(subtype)}" self.subtype = subtype @@ -93,6 +99,41 @@ def __eq__(self, other: object) -> bool: return other.subtype == self.subtype +class DECIMAL: + """Class for holding imformation about decimal value in firebolt db.""" + + _prefix = "Decimal(" + + def __init__(self, precision: int, scale: int): + self.precision = precision + self.scale = scale + + def __str__(self) -> str: + return f"Decimal({self.precision}, {self.scale})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DECIMAL): + return NotImplemented + return other.precision == self.precision and other.scale == self.scale + + +class DATETIME64: + """Class for holding imformation about datetime64 value in firebolt db.""" + + _prefix = "DateTime64(" + + def __init__(self, precision: int): + self.precision = precision + + def __str__(self) -> str: + return f"DateTime64({self.precision})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DATETIME64): + return NotImplemented + return other.precision == self.precision + + NULLABLE_PREFIX = "Nullable(" @@ -122,6 +163,7 @@ class _InternalType(Enum): # DATE Date = "Date" + Date32 = "Date32" # DATETIME, TIMESTAMP DateTime = "DateTime" @@ -145,6 +187,7 @@ def python_type(self) -> type: _InternalType.Float64: float, _InternalType.String: str, _InternalType.Date: date, + _InternalType.Date32: date, _InternalType.DateTime: datetime, # For simplicity, this could happen only during 'select null' query _InternalType.Nothing: str, @@ -152,13 +195,30 @@ def python_type(self) -> type: return types[self] -def parse_type(raw_type: str) -> Union[type, ARRAY]: +def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL, DATETIME64]: """Parse typename, provided by query metadata into python type.""" if not isinstance(raw_type, str): raise DataError(f"Invalid typename {str(raw_type)}: str expected") # Handle arrays if raw_type.startswith(ARRAY._prefix) and raw_type.endswith(")"): return ARRAY(parse_type(raw_type[len(ARRAY._prefix) : -1])) + # Handle decimal + if raw_type.startswith(DECIMAL._prefix) and raw_type.endswith(")"): + try: + prec_scale = raw_type[len(DECIMAL._prefix) : -1].split(",") + precision, scale = int(prec_scale[0]), int(prec_scale[1]) + except (ValueError, IndexError): + pass + else: + return DECIMAL(precision, scale) + # Handle detetime64 + if raw_type.startswith(DATETIME64._prefix) and raw_type.endswith(")"): + try: + precision = int(raw_type[len(DATETIME64._prefix) : -1]) + except (ValueError, IndexError): + pass + else: + return DATETIME64(precision) # Handle nullable if raw_type.startswith(NULLABLE_PREFIX) and raw_type.endswith(")"): return parse_type(raw_type[len(NULLABLE_PREFIX) : -1]) @@ -173,7 +233,7 @@ def parse_type(raw_type: str) -> Union[type, ARRAY]: def parse_value( value: RawColType, - ctype: Union[type, ARRAY], + ctype: Union[type, ARRAY, DECIMAL, DATETIME64], ) -> ColType: """Provided raw value and python type, parses first into python value.""" if value is None: @@ -186,10 +246,13 @@ def parse_value( raise DataError(f"Invalid date value {value}: str expected") assert isinstance(value, str) return parse_datetime(value).date() - if ctype is datetime: + if ctype is datetime or isinstance(ctype, DATETIME64): if not isinstance(value, str): raise DataError(f"Invalid datetime value {value}: str expected") return parse_datetime(value) + if isinstance(ctype, DECIMAL): + assert isinstance(value, (str, int)) + return Decimal(value) if isinstance(ctype, ARRAY): assert isinstance(value, list) return [parse_value(it, ctype.subtype) for it in value] diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 35cc69d36ad..6a7900f7201 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -182,7 +182,8 @@ def _append_query_data(self, response: Response) -> None: # Empty response is returned for insert query if response.headers.get("content-length", "") != "0": try: - query_data = response.json() + # Skip parsing floats to properly parse them later + query_data = response.json(parse_float=str) rowcount = int(query_data["rows"]) descriptions = [ Column( diff --git a/src/firebolt/db/__init__.py b/src/firebolt/db/__init__.py index 88175c5156a..50baf76555c 100644 --- a/src/firebolt/db/__init__.py +++ b/src/firebolt/db/__init__.py @@ -2,6 +2,8 @@ ARRAY, BINARY, DATETIME, + DATETIME64, + DECIMAL, NUMBER, ROWID, STRING, diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index a14abec7d30..7174b314f0e 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -1,4 +1,5 @@ from datetime import date, datetime +from decimal import Decimal from typing import Any, List from pytest import mark, raises @@ -39,8 +40,11 @@ async def test_select( all_types_query_response: List[ColType], ) -> None: """Select handles all data types properly""" + set_params = {"firebolt_use_decimal": 1} with connection.cursor() as c: - assert await c.execute(all_types_query) == 1, "Invalid row count returned" + assert ( + await c.execute(all_types_query, set_parameters=set_params) == 1 + ), "Invalid row count returned" assert c.rowcount == 1, "Invalid rowcount value" data = await c.fetchall() assert len(data) == c.rowcount, "Invalid data length" @@ -50,13 +54,13 @@ async def test_select( assert len(await c.fetchall()) == 0, "Redundant data returned by fetchall" # Different fetch types - await c.execute(all_types_query) + await c.execute(all_types_query, set_parameters=set_params) assert ( await c.fetchone() == all_types_query_response[0] ), "Invalid fetchone data" assert await c.fetchone() is None, "Redundant data returned by fetchone" - await c.execute(all_types_query) + await c.execute(all_types_query, set_parameters=set_params) assert len(await c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany" data = await c.fetchmany() assert len(data) == 1, "Invalid data size returned by fetchmany" @@ -206,8 +210,12 @@ async def test_empty_query(c: Cursor, query: str) -> None: async def test_parameterized_query(connection: Connection) -> None: """Query parameters are handled properly""" + set_params = {"firebolt_use_decimal": 1} + async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: - assert await c.execute(query, params) == -1, "Invalid row count returned" + assert ( + await c.execute(query, params, set_params) == -1 + ), "Invalid row count returned" assert c.rowcount == -1, "Invalid rowcount value" assert c.description is None, "Invalid description" with raises(DataError): @@ -223,8 +231,9 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: await c.execute("DROP TABLE IF EXISTS test_tb_async_parameterized") await c.execute( "CREATE FACT TABLE test_tb_async_parameterized(i int, f float, s string, sn" - " string null, d date, dt datetime, b bool, a array(int), ss string)" - " primary index i" + " string null, d date, dt datetime, b bool, a array(int), " + "dec decimal(38, 3), ss string) primary index i", + set_parameters=set_params, ) params = [ @@ -236,12 +245,13 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: datetime(2022, 1, 1, 1, 1, 1), True, [1, 2, 3], + Decimal("123.456"), ] await test_empty_query( c, "INSERT INTO test_tb_async_parameterized VALUES " - "(?, ?, ?, ?, ?, ?, ?, ?, '\\?')", + "(?, ?, ?, ?, ?, ?, ?, ?, ?, '\\?')", params, ) @@ -252,7 +262,10 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: params[6] = 1 assert ( - await c.execute("SELECT * FROM test_tb_async_parameterized") == 1 + await c.execute( + "SELECT * FROM test_tb_async_parameterized", set_parameters=set_params + ) + == 1 ), "Invalid data length in table after parameterized insert" assert_deep_eq( diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index 3399eafd3b6..c63fff5eda1 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -1,4 +1,5 @@ from datetime import date, datetime +from decimal import Decimal from logging import getLogger from typing import List @@ -6,7 +7,7 @@ from firebolt.async_db._types import ColType from firebolt.async_db.cursor import Column -from firebolt.db import ARRAY +from firebolt.db import ARRAY, DATETIME64, DECIMAL LOGGER = getLogger(__name__) @@ -14,12 +15,25 @@ @fixture def all_types_query() -> str: return ( - "select 1 as uint8, -1 as int8, 257 as uint16, -257 as int16, 80000 as uint32," - " -80000 as int32, 30000000000 as uint64, -30000000000 as int64, cast(1.23 AS" - " FLOAT) as float32, 1.2345678901234 as float64, 'text' as \"string\"," - " CAST('2021-03-28' AS DATE) as \"date\", CAST('2019-07-31 01:01:01' AS" - ' DATETIME) as "datetime", true as "bool",[1,2,3,4] as "array", cast(null as' - " int) as nullable" + "select 1 as uint8, " + "-1 as int8, " + "257 as uint16, " + "-257 as int16, " + "80000 as uint32, " + "-80000 as int32, " + "30000000000 as uint64, " + "-30000000000 as int64, " + "cast(1.23 AS FLOAT) as float32, " + "1.2345678901234 as float64, " + "'text' as \"string\", " + "CAST('2021-03-28' AS DATE) as \"date\", " + "CAST('1860-03-04' AS DATE_EXT) as \"date32\"," + "CAST('2019-07-31 01:01:01' AS DATETIME) as \"datetime\", " + "CAST('2019-07-31 01:01:01.1234' AS TIMESTAMP_EXT(4)) as \"datetime64\", " + 'true as "bool",' + '[1,2,3,4] as "array", cast(1231232.123459999990457054844258706536 as ' + 'decimal(38,30)) as "decimal", ' + "cast(null as int) as nullable" ) @@ -38,9 +52,12 @@ def all_types_query_description() -> List[Column]: Column("float64", float, None, None, None, None, None), Column("string", str, None, None, None, None, None), Column("date", date, None, None, None, None, None), + Column("date32", date, None, None, None, None, None), Column("datetime", datetime, None, None, None, None, None), + Column("datetime64", DATETIME64(4), None, None, None, None, None), Column("bool", int, None, None, None, None, None), Column("array", ARRAY(int), None, None, None, None, None), + Column("decimal", DECIMAL(38, 30), None, None, None, None, None), Column("nullable", int, None, None, None, None, None), ] @@ -61,9 +78,12 @@ def all_types_query_response() -> List[ColType]: 1.23456789012, "text", date(2021, 3, 28), + date(1860, 3, 4), datetime(2019, 7, 31, 1, 1, 1), + datetime(2019, 7, 31, 1, 1, 1, 123400), 1, [1, 2, 3, 4], + Decimal("1231232.123459999990457054844258706536"), None, ] ] diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index 72ecaebd4fb..b164d7c3dd8 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -1,4 +1,5 @@ from datetime import date, datetime +from decimal import Decimal from typing import Any, List from pytest import mark, raises @@ -38,8 +39,11 @@ def test_select( all_types_query_response: List[ColType], ) -> None: """Select handles all data types properly""" + set_params = {"firebolt_use_decimal": 1} with connection.cursor() as c: - assert c.execute(all_types_query) == 1, "Invalid row count returned" + assert ( + c.execute(all_types_query, set_parameters=set_params) == 1 + ), "Invalid row count returned" assert c.rowcount == 1, "Invalid rowcount value" data = c.fetchall() assert len(data) == c.rowcount, "Invalid data length" @@ -49,11 +53,11 @@ def test_select( assert len(c.fetchall()) == 0, "Redundant data returned by fetchall" # Different fetch types - c.execute(all_types_query) + c.execute(all_types_query, set_parameters=set_params) assert c.fetchone() == all_types_query_response[0], "Invalid fetchone data" assert c.fetchone() is None, "Redundant data returned by fetchone" - c.execute(all_types_query) + c.execute(all_types_query, set_parameters=set_params) assert len(c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany" data = c.fetchmany() assert len(data) == 1, "Invalid data size returned by fetchmany" @@ -194,8 +198,10 @@ def test_empty_query(c: Cursor, query: str) -> None: def test_parameterized_query(connection: Connection) -> None: """Query parameters are handled properly""" + set_params = {"firebolt_use_decimal": 1} + def test_empty_query(c: Cursor, query: str, params: tuple) -> None: - assert c.execute(query, params) == -1, "Invalid row count returned" + assert c.execute(query, params, set_params) == -1, "Invalid row count returned" assert c.rowcount == -1, "Invalid rowcount value" assert c.description is None, "Invalid description" with raises(DataError): @@ -211,8 +217,9 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None: c.execute("DROP TABLE IF EXISTS test_tb_parameterized") c.execute( "CREATE FACT TABLE test_tb_parameterized(i int, f float, s string, sn" - " string null, d date, dt datetime, b bool, a array(int), ss string)" - " primary index i" + " string null, d date, dt datetime, b bool, a array(int), " + "dec decimal(38, 3), ss string) primary index i", + set_parameters=set_params, ) params = [ @@ -224,11 +231,13 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None: datetime(2022, 1, 1, 1, 1, 1), True, [1, 2, 3], + Decimal("123.456"), ] test_empty_query( c, - "INSERT INTO test_tb_parameterized VALUES (?, ?, ?, ?, ?, ?, ?, ?, '\\?')", + "INSERT INTO test_tb_parameterized VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?," + " '\\?')", params, ) @@ -239,7 +248,8 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None: params[6] = 1 assert ( - c.execute("SELECT * FROM test_tb_parameterized") == 1 + c.execute("SELECT * FROM test_tb_parameterized", set_parameters=set_params) + == 1 ), "Invalid data length in table after parameterized insert" assert_deep_eq( diff --git a/tests/unit/async_db/conftest.py b/tests/unit/async_db/conftest.py index f7973810ce6..cf4594ad777 100644 --- a/tests/unit/async_db/conftest.py +++ b/tests/unit/async_db/conftest.py @@ -1,11 +1,19 @@ from datetime import date, datetime +from decimal import Decimal from json import dumps as jdumps from typing import Any, Callable, Dict, List from httpx import URL, Request, Response, codes from pytest import fixture -from firebolt.async_db import ARRAY, Connection, Cursor, connect +from firebolt.async_db import ( + ARRAY, + DATETIME64, + DECIMAL, + Connection, + Cursor, + connect, +) from firebolt.async_db.cursor import JSON_OUTPUT_FORMAT, ColType, Column from firebolt.common.settings import Settings @@ -25,9 +33,12 @@ def query_description() -> List[Column]: Column("float64", "Float64", None, None, None, None, None), Column("string", "String", None, None, None, None, None), Column("date", "Date", None, None, None, None, None), + Column("date32", "Date32", None, None, None, None, None), Column("datetime", "DateTime", None, None, None, None, None), + Column("datetime64", "DateTime64(4)", None, None, None, None, None), Column("bool", "UInt8", None, None, None, None, None), Column("array", "Array(UInt8)", None, None, None, None, None), + Column("decimal", "Decimal(12, 34)", None, None, None, None, None), ] @@ -44,9 +55,12 @@ def python_query_description() -> List[Column]: Column("float64", float, None, None, None, None, None), Column("string", str, None, None, None, None, None), Column("date", date, None, None, None, None, None), + Column("date32", date, None, None, None, None, None), Column("datetime", datetime, None, None, None, None, None), + Column("datetime64", DATETIME64(4), None, None, None, None, None), Column("bool", int, None, None, None, None, None), Column("array", ARRAY(int), None, None, None, None, None), + Column("decimal", DECIMAL(12, 34), None, None, None, None, None), ] @@ -61,12 +75,15 @@ def query_data() -> List[List[ColType]]: 922337203685477580, -922337203685477580, 1, - 1.0387398573, + "1.0387398573", "some text", "2019-07-31", + "1860-01-31", "2019-07-31 01:01:01", + "2020-07-31 01:01:01.1234", 1, [1, 2, 3, 4], + "123456789.123456789123456789123456789", ] for i in range(QUERY_ROW_COUNT) ] @@ -86,9 +103,12 @@ def python_query_data() -> List[List[ColType]]: 1.0387398573, "some text", date(2019, 7, 31), + date(1860, 1, 31), datetime(2019, 7, 31, 1, 1, 1), + datetime(2020, 7, 31, 1, 1, 1, 123400), 1, [1, 2, 3, 4], + Decimal("123456789.123456789123456789123456789"), ] for i in range(QUERY_ROW_COUNT) ] @@ -217,8 +237,14 @@ def types_map() -> Dict[str, type]: "Float64": float, "String": str, "Date": date, + "Date32": date, "DateTime": datetime, + "DateTime64(7)": DATETIME64(7), "Nullable(Nothing)": str, + "Decimal(123, 4)": DECIMAL(123, 4), + "Decimal(38,0)": DECIMAL(38, 0), + # Invalid decimal format + "Decimal(38)": str, "SomeRandomNotExistingType": str, } array_types = {f"Array({k})": ARRAY(v) for k, v in base_types.items()} diff --git a/tests/unit/async_db/test_typing_parse.py b/tests/unit/async_db/test_typing_parse.py index e1618265425..da2008f1d07 100644 --- a/tests/unit/async_db/test_typing_parse.py +++ b/tests/unit/async_db/test_typing_parse.py @@ -1,10 +1,12 @@ from datetime import date, datetime +from decimal import Decimal from typing import Dict from pytest import raises from firebolt.async_db import ( ARRAY, + DECIMAL, DateFromTicks, TimeFromTicks, TimestampFromTicks, @@ -76,6 +78,10 @@ def test_parse_value_datetime() -> None: assert parse_value("2021-12-31", date) == date( 2021, 12, 31 ), "Error parsing date: str provided" + assert parse_value("1860-12-31", date) == date( + 1860, 12, 31 + ), "Error parsing extended date: str provided" + assert parse_value(None, date) is None, "Error parsing date: None provided" assert parse_value("2021-12-31 23:59:59", date) == date( @@ -92,8 +98,8 @@ def test_parse_value_datetime() -> None: assert str(exc_info.value) == f"Invalid date value {value}: str expected" # Datetime - assert parse_value("2021-12-31 23:59:59", datetime) == datetime( - 2021, 12, 31, 23, 59, 59 + assert parse_value("2021-12-31 23:59:59.1234", datetime) == datetime( + 2021, 12, 31, 23, 59, 59, 123400 ), "Error parsing datetime: str provided" assert parse_value(None, datetime) is None, "Error parsing datetime: None provided" @@ -111,6 +117,18 @@ def test_parse_value_datetime() -> None: assert str(exc_info.value) == f"Invalid datetime value {value}: str expected" +def test_parse_decimal() -> None: + assert parse_value("123.456", DECIMAL(38, 3)) == Decimal( + "123.456" + ), "Error parsing decimal(38, 3): str provided" + assert parse_value(123, DECIMAL(38, 3)) == Decimal( + "123" + ), "Error parsing decimal(38, 3): int provided" + assert ( + parse_value(None, DECIMAL(38, 3)) is None + ), "Error parsing decimal(38, 3): None provided" + + def test_parse_arrays() -> None: """parse_value parses all array values correctly.""" assert parse_value([1, 2], ARRAY(int)) == [