Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/firebolt/async_db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
ARRAY,
BINARY,
DATETIME,
DATETIME64,
DECIMAL,
NUMBER,
ROWID,
STRING,
Expand Down
81 changes: 72 additions & 9 deletions src/firebolt/async_db/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,24 @@
try:
from ciso8601 import parse_datetime # type: ignore
except ImportError:
parse_datetime = datetime.fromisoformat # type: ignore
# Unfortunately, there seems to be no support for optional bits in strptime
def parse_datetime(date_string: str) -> datetime: # type: ignore
format = "%Y-%m-%d %H:%M:%S.%f"
# fromisoformat doesn't support milliseconds
if "." in date_string:
return datetime.strptime(date_string, format)
return datetime.fromisoformat(date_string)


from firebolt.common.exception import DataError, NotSupportedError
from firebolt.common.util import cached_property

_NoneType = type(None)
_col_types = (int, float, str, datetime, date, bool, list, _NoneType)
_col_types = (int, float, str, datetime, date, bool, list, Decimal, _NoneType)
# duplicating this since 3.7 can't unpack Union
ColType = Union[int, float, str, datetime, date, bool, list, _NoneType]
ColType = Union[int, float, str, datetime, date, bool, list, Decimal, _NoneType]
RawColType = Union[int, float, str, bool, list, _NoneType]
ParameterType = Union[int, float, str, datetime, date, bool, Sequence]
ParameterType = Union[int, float, str, datetime, date, bool, Decimal, Sequence]

# These definitions are required by PEP-249
Date = date
Expand Down Expand Up @@ -78,9 +84,9 @@ class ARRAY:

_prefix = "Array("

def __init__(self, subtype: Union[type, ARRAY]):
def __init__(self, subtype: Union[type, ARRAY, DECIMAL, DATETIME64]):
assert (subtype in _col_types and subtype is not list) or isinstance(
subtype, ARRAY
subtype, (ARRAY, DECIMAL, DATETIME64)
), f"Invalid array subtype: {str(subtype)}"
self.subtype = subtype

Expand All @@ -93,6 +99,41 @@ def __eq__(self, other: object) -> bool:
return other.subtype == self.subtype


class DECIMAL:
"""Class for holding imformation about decimal value in firebolt db."""

_prefix = "Decimal("

def __init__(self, precision: int, scale: int):
self.precision = precision
self.scale = scale

def __str__(self) -> str:
return f"Decimal({self.precision}, {self.scale})"

def __eq__(self, other: object) -> bool:
if not isinstance(other, DECIMAL):
return NotImplemented
return other.precision == self.precision and other.scale == self.scale


class DATETIME64:
"""Class for holding imformation about datetime64 value in firebolt db."""

_prefix = "DateTime64("

def __init__(self, precision: int):
self.precision = precision

def __str__(self) -> str:
return f"DateTime64({self.precision})"

def __eq__(self, other: object) -> bool:
if not isinstance(other, DATETIME64):
return NotImplemented
return other.precision == self.precision


NULLABLE_PREFIX = "Nullable("


Expand Down Expand Up @@ -122,6 +163,7 @@ class _InternalType(Enum):

# DATE
Date = "Date"
Date32 = "Date32"

# DATETIME, TIMESTAMP
DateTime = "DateTime"
Expand All @@ -145,20 +187,38 @@ def python_type(self) -> type:
_InternalType.Float64: float,
_InternalType.String: str,
_InternalType.Date: date,
_InternalType.Date32: date,
_InternalType.DateTime: datetime,
# For simplicity, this could happen only during 'select null' query
_InternalType.Nothing: str,
}
return types[self]


def parse_type(raw_type: str) -> Union[type, ARRAY]:
def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL, DATETIME64]:
"""Parse typename, provided by query metadata into python type."""
if not isinstance(raw_type, str):
raise DataError(f"Invalid typename {str(raw_type)}: str expected")
# Handle arrays
if raw_type.startswith(ARRAY._prefix) and raw_type.endswith(")"):
return ARRAY(parse_type(raw_type[len(ARRAY._prefix) : -1]))
# Handle decimal
if raw_type.startswith(DECIMAL._prefix) and raw_type.endswith(")"):
try:
prec_scale = raw_type[len(DECIMAL._prefix) : -1].split(",")
precision, scale = int(prec_scale[0]), int(prec_scale[1])
except (ValueError, IndexError):
pass
else:
return DECIMAL(precision, scale)
# Handle detetime64
if raw_type.startswith(DATETIME64._prefix) and raw_type.endswith(")"):
try:
precision = int(raw_type[len(DATETIME64._prefix) : -1])
except (ValueError, IndexError):
pass
else:
return DATETIME64(precision)
# Handle nullable
if raw_type.startswith(NULLABLE_PREFIX) and raw_type.endswith(")"):
return parse_type(raw_type[len(NULLABLE_PREFIX) : -1])
Expand All @@ -173,7 +233,7 @@ def parse_type(raw_type: str) -> Union[type, ARRAY]:

def parse_value(
value: RawColType,
ctype: Union[type, ARRAY],
ctype: Union[type, ARRAY, DECIMAL, DATETIME64],
) -> ColType:
"""Provided raw value and python type, parses first into python value."""
if value is None:
Expand All @@ -186,10 +246,13 @@ def parse_value(
raise DataError(f"Invalid date value {value}: str expected")
assert isinstance(value, str)
return parse_datetime(value).date()
if ctype is datetime:
if ctype is datetime or isinstance(ctype, DATETIME64):
if not isinstance(value, str):
raise DataError(f"Invalid datetime value {value}: str expected")
return parse_datetime(value)
if isinstance(ctype, DECIMAL):
assert isinstance(value, (str, int))
return Decimal(value)
if isinstance(ctype, ARRAY):
assert isinstance(value, list)
return [parse_value(it, ctype.subtype) for it in value]
Expand Down
3 changes: 2 additions & 1 deletion src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def _append_query_data(self, response: Response) -> None:
# Empty response is returned for insert query
if response.headers.get("content-length", "") != "0":
try:
query_data = response.json()
# Skip parsing floats to properly parse them later
query_data = response.json(parse_float=str)
rowcount = int(query_data["rows"])
descriptions = [
Column(
Expand Down
2 changes: 2 additions & 0 deletions src/firebolt/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
ARRAY,
BINARY,
DATETIME,
DATETIME64,
DECIMAL,
NUMBER,
ROWID,
STRING,
Expand Down
29 changes: 21 additions & 8 deletions tests/integration/dbapi/async/test_queries_async.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import date, datetime
from decimal import Decimal
from typing import Any, List

from pytest import mark, raises
Expand Down Expand Up @@ -39,8 +40,11 @@ async def test_select(
all_types_query_response: List[ColType],
) -> None:
"""Select handles all data types properly"""
set_params = {"firebolt_use_decimal": 1}
with connection.cursor() as c:
assert await c.execute(all_types_query) == 1, "Invalid row count returned"
assert (
await c.execute(all_types_query, set_parameters=set_params) == 1
), "Invalid row count returned"
assert c.rowcount == 1, "Invalid rowcount value"
data = await c.fetchall()
assert len(data) == c.rowcount, "Invalid data length"
Expand All @@ -50,13 +54,13 @@ async def test_select(
assert len(await c.fetchall()) == 0, "Redundant data returned by fetchall"

# Different fetch types
await c.execute(all_types_query)
await c.execute(all_types_query, set_parameters=set_params)
assert (
await c.fetchone() == all_types_query_response[0]
), "Invalid fetchone data"
assert await c.fetchone() is None, "Redundant data returned by fetchone"

await c.execute(all_types_query)
await c.execute(all_types_query, set_parameters=set_params)
assert len(await c.fetchmany(0)) == 0, "Invalid data size returned by fetchmany"
data = await c.fetchmany()
assert len(data) == 1, "Invalid data size returned by fetchmany"
Expand Down Expand Up @@ -206,8 +210,12 @@ async def test_empty_query(c: Cursor, query: str) -> None:
async def test_parameterized_query(connection: Connection) -> None:
"""Query parameters are handled properly"""

set_params = {"firebolt_use_decimal": 1}

async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
assert await c.execute(query, params) == -1, "Invalid row count returned"
assert (
await c.execute(query, params, set_params) == -1
), "Invalid row count returned"
assert c.rowcount == -1, "Invalid rowcount value"
assert c.description is None, "Invalid description"
with raises(DataError):
Expand All @@ -223,8 +231,9 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
await c.execute("DROP TABLE IF EXISTS test_tb_async_parameterized")
await c.execute(
"CREATE FACT TABLE test_tb_async_parameterized(i int, f float, s string, sn"
" string null, d date, dt datetime, b bool, a array(int), ss string)"
" primary index i"
" string null, d date, dt datetime, b bool, a array(int), "
"dec decimal(38, 3), ss string) primary index i",
set_parameters=set_params,
)

params = [
Expand All @@ -236,12 +245,13 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
datetime(2022, 1, 1, 1, 1, 1),
True,
[1, 2, 3],
Decimal("123.456"),
]

await test_empty_query(
c,
"INSERT INTO test_tb_async_parameterized VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, '\\?')",
"(?, ?, ?, ?, ?, ?, ?, ?, ?, '\\?')",
params,
)

Expand All @@ -252,7 +262,10 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None:
params[6] = 1

assert (
await c.execute("SELECT * FROM test_tb_async_parameterized") == 1
await c.execute(
"SELECT * FROM test_tb_async_parameterized", set_parameters=set_params
)
== 1
), "Invalid data length in table after parameterized insert"

assert_deep_eq(
Expand Down
34 changes: 27 additions & 7 deletions tests/integration/dbapi/conftest.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
from datetime import date, datetime
from decimal import Decimal
from logging import getLogger
from typing import List

from pytest import fixture

from firebolt.async_db._types import ColType
from firebolt.async_db.cursor import Column
from firebolt.db import ARRAY
from firebolt.db import ARRAY, DATETIME64, DECIMAL

LOGGER = getLogger(__name__)


@fixture
def all_types_query() -> str:
return (
"select 1 as uint8, -1 as int8, 257 as uint16, -257 as int16, 80000 as uint32,"
" -80000 as int32, 30000000000 as uint64, -30000000000 as int64, cast(1.23 AS"
" FLOAT) as float32, 1.2345678901234 as float64, 'text' as \"string\","
" CAST('2021-03-28' AS DATE) as \"date\", CAST('2019-07-31 01:01:01' AS"
' DATETIME) as "datetime", true as "bool",[1,2,3,4] as "array", cast(null as'
" int) as nullable"
"select 1 as uint8, "
"-1 as int8, "
"257 as uint16, "
"-257 as int16, "
"80000 as uint32, "
"-80000 as int32, "
"30000000000 as uint64, "
"-30000000000 as int64, "
"cast(1.23 AS FLOAT) as float32, "
"1.2345678901234 as float64, "
"'text' as \"string\", "
"CAST('2021-03-28' AS DATE) as \"date\", "
"CAST('1860-03-04' AS DATE_EXT) as \"date32\","
"CAST('2019-07-31 01:01:01' AS DATETIME) as \"datetime\", "
"CAST('2019-07-31 01:01:01.1234' AS TIMESTAMP_EXT(4)) as \"datetime64\", "
'true as "bool",'
'[1,2,3,4] as "array", cast(1231232.123459999990457054844258706536 as '
'decimal(38,30)) as "decimal", '
"cast(null as int) as nullable"
)


Expand All @@ -38,9 +52,12 @@ def all_types_query_description() -> List[Column]:
Column("float64", float, None, None, None, None, None),
Column("string", str, None, None, None, None, None),
Column("date", date, None, None, None, None, None),
Column("date32", date, None, None, None, None, None),
Column("datetime", datetime, None, None, None, None, None),
Column("datetime64", DATETIME64(4), None, None, None, None, None),
Column("bool", int, None, None, None, None, None),
Column("array", ARRAY(int), None, None, None, None, None),
Column("decimal", DECIMAL(38, 30), None, None, None, None, None),
Column("nullable", int, None, None, None, None, None),
]

Expand All @@ -61,9 +78,12 @@ def all_types_query_response() -> List[ColType]:
1.23456789012,
"text",
date(2021, 3, 28),
date(1860, 3, 4),
datetime(2019, 7, 31, 1, 1, 1),
datetime(2019, 7, 31, 1, 1, 1, 123400),
1,
[1, 2, 3, 4],
Decimal("1231232.123459999990457054844258706536"),
None,
]
]
Expand Down
Loading