diff --git a/tests/integration/dbapi/async/V1/test_queries_async.py b/tests/integration/dbapi/async/V1/test_queries_async.py index 441468ceea..8dfcd1b730 100644 --- a/tests/integration/dbapi/async/V1/test_queries_async.py +++ b/tests/integration/dbapi/async/V1/test_queries_async.py @@ -1,5 +1,7 @@ +import math from datetime import date, datetime from decimal import Decimal +from random import choice from typing import Any, Callable, List from pytest import fixture, mark, raises @@ -169,6 +171,24 @@ async def test_select( ) +async def test_select_inf(connection: Connection) -> None: + with connection.cursor() as c: + await c.execute("SELECT 'inf'::float, '-inf'::float") + data = await c.fetchall() + assert len(data) == 1, "Invalid data size returned by fetchall" + assert data[0][0] == float("inf"), "Invalid data returned by fetchall" + assert data[0][1] == float("-inf"), "Invalid data returned by fetchall" + + +async def test_select_nan(connection: Connection) -> None: + with connection.cursor() as c: + await c.execute("SELECT 'nan'::float, '-nan'::float") + data = await c.fetchall() + assert len(data) == 1, "Invalid data size returned by fetchall" + assert math.isnan(data[0][0]), "Invalid data returned by fetchall" + assert math.isnan(data[0][1]), "Invalid data returned by fetchall" + + @mark.slow @mark.timeout(timeout=1000) async def test_long_query( @@ -485,9 +505,10 @@ async def test_bytea_roundtrip( async def setup_db(connection_no_engine: Connection, use_db_name: str): use_db_name = f"{use_db_name}_async" with connection_no_engine.cursor() as cursor: - await cursor.execute(f"CREATE DATABASE {use_db_name}") + suffix = "".join(choice("0123456789") for _ in range(2)) + await cursor.execute(f"CREATE DATABASE {use_db_name}{suffix}") yield - await cursor.execute(f"DROP DATABASE {use_db_name}") + await cursor.execute(f"DROP DATABASE {use_db_name}{suffix}") @mark.xfail(reason="USE DATABASE is not yet available in 1.0 Firebolt") diff --git a/tests/integration/dbapi/async/V2/test_queries_async.py b/tests/integration/dbapi/async/V2/test_queries_async.py index f915aa30c9..0e7cda8d97 100644 --- a/tests/integration/dbapi/async/V2/test_queries_async.py +++ b/tests/integration/dbapi/async/V2/test_queries_async.py @@ -1,6 +1,8 @@ +import math from datetime import date, datetime from decimal import Decimal from os import environ +from random import choice from typing import Callable, List from pytest import fixture, mark, raises @@ -100,6 +102,24 @@ async def test_select( ) +async def test_select_inf(connection: Connection) -> None: + with connection.cursor() as c: + await c.execute("SELECT 'inf'::float, '-inf'::float") + data = await c.fetchall() + assert len(data) == 1, "Invalid data size returned by fetchall" + assert data[0][0] == float("inf"), "Invalid data returned by fetchall" + assert data[0][1] == float("-inf"), "Invalid data returned by fetchall" + + +async def test_select_nan(connection: Connection) -> None: + with connection.cursor() as c: + await c.execute("SELECT 'nan'::float, '-nan'::float") + data = await c.fetchall() + assert len(data) == 1, "Invalid data size returned by fetchall" + assert math.isnan(data[0][0]), "Invalid data returned by fetchall" + assert math.isnan(data[0][1]), "Invalid data returned by fetchall" + + @mark.slow @mark.timeout(timeout=550) async def test_long_query( @@ -408,13 +428,15 @@ async def test_bytea_roundtrip( ), "Invalid bytea data returned after roundtrip" -@fixture(scope="session") +@fixture async def setup_db(connection_system_engine_no_db: Connection, use_db_name: str): use_db_name = use_db_name + "_async" with connection_system_engine_no_db.cursor() as cursor: - await cursor.execute(f"CREATE DATABASE {use_db_name}") + # randomize the db name to avoid conflicts + suffix = "".join(choice("0123456789") for _ in range(2)) + await cursor.execute(f"CREATE DATABASE {use_db_name}{suffix}") yield - await cursor.execute(f"DROP DATABASE {use_db_name}") + await cursor.execute(f"DROP DATABASE {use_db_name}{suffix}") @mark.xfail("dev" not in environ[API_ENDPOINT_ENV], reason="Only works on dev") diff --git a/tests/integration/dbapi/sync/V1/test_queries.py b/tests/integration/dbapi/sync/V1/test_queries.py index 071eee48ff..1ac61eb262 100644 --- a/tests/integration/dbapi/sync/V1/test_queries.py +++ b/tests/integration/dbapi/sync/V1/test_queries.py @@ -1,3 +1,4 @@ +import math from datetime import date, datetime from decimal import Decimal from threading import Thread @@ -120,6 +121,24 @@ def test_select( ) +def test_select_inf(connection: Connection) -> None: + with connection.cursor() as c: + c.execute("SELECT 'inf'::float, '-inf'::float") + data = c.fetchall() + assert len(data) == 1, "Invalid data size returned by fetchall" + assert data[0][0] == float("inf"), "Invalid data returned by fetchall" + assert data[0][1] == float("-inf"), "Invalid data returned by fetchall" + + +def test_select_nan(connection: Connection) -> None: + with connection.cursor() as c: + c.execute("SELECT 'nan'::float, '-nan'::float") + data = c.fetchall() + assert len(data) == 1, "Invalid data size returned by fetchall" + assert math.isnan(data[0][0]), "Invalid data returned by fetchall" + assert math.isnan(data[0][1]), "Invalid data returned by fetchall" + + @mark.slow @mark.timeout(timeout=1000) def test_long_query( diff --git a/tests/integration/dbapi/sync/V2/test_queries.py b/tests/integration/dbapi/sync/V2/test_queries.py index 819189588a..416c71369b 100644 --- a/tests/integration/dbapi/sync/V2/test_queries.py +++ b/tests/integration/dbapi/sync/V2/test_queries.py @@ -1,3 +1,4 @@ +import math from datetime import date, datetime from decimal import Decimal from os import environ @@ -106,6 +107,24 @@ def test_select( ) +def test_select_inf(connection: Connection) -> None: + with connection.cursor() as c: + c.execute("SELECT 'inf'::float, '-inf'::float") + data = c.fetchall() + assert len(data) == 1, "Invalid data size returned by fetchall" + assert data[0][0] == float("inf"), "Invalid data returned by fetchall" + assert data[0][1] == float("-inf"), "Invalid data returned by fetchall" + + +def test_select_nan(connection: Connection) -> None: + with connection.cursor() as c: + c.execute("SELECT 'nan'::float, '-nan'::float") + data = c.fetchall() + assert len(data) == 1, "Invalid data size returned by fetchall" + assert math.isnan(data[0][0]), "Invalid data returned by fetchall" + assert math.isnan(data[0][1]), "Invalid data returned by fetchall" + + @mark.slow @mark.timeout(timeout=550) def test_long_query( diff --git a/tests/unit/async_db/V1/conftest.py b/tests/unit/async_db/V1/conftest.py index ed8d3f9991..5a6ab99650 100644 --- a/tests/unit/async_db/V1/conftest.py +++ b/tests/unit/async_db/V1/conftest.py @@ -1,12 +1,10 @@ -from datetime import date, datetime from json import loads from re import Pattern, compile -from typing import Dict import httpx from pytest import fixture -from firebolt.async_db import ARRAY, DECIMAL, Connection, Cursor, connect +from firebolt.async_db import Connection, Cursor, connect from firebolt.client import AsyncClientV1 as Client from firebolt.client.auth.base import Auth from firebolt.client.auth.username_password import UsernamePassword @@ -156,31 +154,3 @@ def check_credentials( ) return check_credentials - - -@fixture -def types_map() -> Dict[str, type]: - base_types = { - "int": int, - "long": int, - "float": float, - "double": float, - "text": str, - "date": date, - "pgdate": date, - "timestamp": datetime, - "timestampntz": datetime, - "timestamptz": datetime, - "Nothing null": str, - "Decimal(123, 4)": DECIMAL(123, 4), - "Decimal(38,0)": DECIMAL(38, 0), - # Invalid decimal format - "Decimal(38)": str, - "boolean": bool, - "SomeRandomNotExistingType": str, - "bytea": bytes, - } - array_types = {f"array({k})": ARRAY(v) for k, v in base_types.items()} - nullable_arrays = {f"{k} null": v for k, v in array_types.items()} - nested_arrays = {f"array({k})": ARRAY(v) for k, v in array_types.items()} - return {**base_types, **array_types, **nullable_arrays, **nested_arrays} diff --git a/tests/unit/async_db/V2/conftest.py b/tests/unit/async_db/V2/conftest.py index 3eaa5f12da..d928c0e114 100644 --- a/tests/unit/async_db/V2/conftest.py +++ b/tests/unit/async_db/V2/conftest.py @@ -1,9 +1,6 @@ -from datetime import date, datetime -from typing import Dict - from pytest import fixture -from firebolt.async_db import ARRAY, DECIMAL, Connection, Cursor, connect +from firebolt.async_db import Connection, Cursor, connect from firebolt.client.auth import Auth from tests.unit.db_conftest import * # noqa @@ -33,33 +30,3 @@ async def connection( @fixture async def cursor(connection: Connection) -> Cursor: return connection.cursor() - - -@fixture -def types_map() -> Dict[str, type]: - base_types = { - "int": int, - "long": int, - "float": float, - "double": float, - "text": str, - "date": date, - "date_ext": date, - "pgdate": date, - "timestamp": datetime, - "timestamp_ext": datetime, - "timestampntz": datetime, - "timestamptz": datetime, - "Nothing null": str, - "Decimal(123, 4)": DECIMAL(123, 4), - "Decimal(38,0)": DECIMAL(38, 0), - # Invalid decimal format - "Decimal(38)": str, - "boolean": bool, - "SomeRandomNotExistingType": str, - "bytea": bytes, - } - array_types = {f"array({k})": ARRAY(v) for k, v in base_types.items()} - nullable_arrays = {f"{k} null": v for k, v in array_types.items()} - nested_arrays = {f"array({k})": ARRAY(v) for k, v in array_types.items()} - return {**base_types, **array_types, **nullable_arrays, **nested_arrays} diff --git a/tests/unit/async_db/V2/test_typing_format.py b/tests/unit/async_db/V2/test_typing_format.py deleted file mode 100644 index c04027ee6c..0000000000 --- a/tests/unit/async_db/V2/test_typing_format.py +++ /dev/null @@ -1,211 +0,0 @@ -from datetime import date, datetime, timedelta, timezone -from decimal import Decimal -from typing import List, Optional - -from pytest import mark, raises -from sqlparse import parse -from sqlparse.sql import Statement - -from firebolt.async_db import ( - Binary, - DataError, - InterfaceError, - NotSupportedError, -) -from firebolt.common._types import ( - SetParameter, - format_statement, - format_value, - split_format_sql, - statement_to_set, -) - - -@mark.parametrize( - "value,result", - [ # Strings - ("", "''"), - ("abcd", "'abcd'"), - ("test' OR '1' == '1", "'test\\' OR \\'1\\' == \\'1'"), - ("test\\", "'test\\\\'"), - ("some\0value", "'some\\0value'"), - # Numbers - (1, "1"), - (1.123, "1.123"), - (Decimal("1.123"), "1.123"), - (Decimal(1.123), "1.1229999999999999982236431605997495353221893310546875"), - (True, "true"), - (False, "false"), - # Date, datetime - (date(2022, 1, 10), "'2022-01-10'"), - (datetime(2022, 1, 10, 1, 1, 1), "'2022-01-10 01:01:01'"), - ( - datetime(2022, 1, 10, 1, 1, 1, tzinfo=timezone(timedelta(hours=1))), - "'2022-01-10 00:01:01'", - ), - # List, tuple - ([], "[]"), - ([1, 2, 3], "[1, 2, 3]"), - (("a", "b", "c"), "['a', 'b', 'c']"), - # None - (None, "NULL"), - # Bytea - (b"abc", "'\\x61\\x62\\x63'"), - ], -) -def test_format_value(value: str, result: str) -> None: - assert format_value(value) == result, "Invalid format_value result" - - -def test_format_value_errors() -> None: - with raises(DataError) as exc_info: - format_value(Exception()) - - assert str(exc_info.value) == "unsupported parameter type " - - -def to_statement(sql: str) -> Statement: - return parse(sql)[0] - - -@mark.parametrize( - "statement,params,result", - [ - (to_statement("select * from table"), (), "select * from table"), - ( - to_statement("select * from table where id == ?"), - (1,), - "select * from table where id == 1", - ), - ( - to_statement("select * from table where id == '?'"), - (), - "select * from table where id == '?'", - ), - ( - to_statement("insert into table values (?, ?, '?')"), - (1, "1"), - "insert into table values (1, '1', '?')", - ), - ( - to_statement("select * from t where /*comment ?*/ id == ?"), - ("*/ 1 == 1 or /*",), - "select * from t where /*comment ?*/ id == '*/ 1 == 1 or /*'", - ), - ( - to_statement("select * from t where id == ?"), - ("' or '' == '",), - r"select * from t where id == '\' or \'\' == \''", - ), - ], -) -def test_format_statement(statement: Statement, params: tuple, result: str) -> None: - assert format_statement(statement, params) == result, "Invalid format sql result" - - -def test_format_statement_errors() -> None: - with raises(DataError) as exc_info: - format_statement(to_statement("?"), []) - assert ( - str(exc_info.value) - == "not enough parameters provided for substitution: given 0, found one more" - ), "Invalid not enought parameters error" - - with raises(DataError) as exc_info: - format_statement(to_statement("?"), (1, 2)) - assert ( - str(exc_info.value) - == "too many parameters provided for substitution: given 2, used only 1" - ), "Invalid not enought parameters error" - - -@mark.parametrize( - "query,params,result", - [ - ("", (), [""]), - ("select * from t", (), ["select * from t"]), - ("select * from t;", (), ["select * from t"]), - ("select * from t where id == ?", ((1,),), ["select * from t where id == 1"]), - ("select * from t where id == ?;", ((1,),), ["select * from t where id == 1"]), - ( - "select * from t;insert into t values (1, 2)", - (), - ["select * from t", "insert into t values (1, 2)"], - ), - ( - "insert into t values (1, 2);select * from t;", - (), - ["insert into t values (1, 2)", "select * from t"], - ), - ( - "select * from t where id == ?", - ((1,), (2,)), - ["select * from t where id == 1", "select * from t where id == 2"], - ), - ( - "select * from t; set a = b;", - (), - ["select * from t", SetParameter("a", "b")], - ), - ( - "set \t\na = \t\n b ; set c=d;", - (), - [SetParameter("a", "b"), SetParameter("c", "d")], - ), - ], -) -def test_split_format_sql(query: str, params: tuple, result: List[str]) -> None: - assert ( - split_format_sql(query, params) == result - ), "Invalid split and format sql result" - - -def test_split_format_error() -> None: - with raises(NotSupportedError): - split_format_sql( - "select * from t where id == ?; insert into t values (?, ?)", ((1, 2, 3),) - ) - - with raises(NotSupportedError): - split_format_sql("set a = ?", ((1,),)) - - -@mark.parametrize( - "statement,result", - [ - (to_statement("select 1"), None), - (to_statement("set a = b"), SetParameter("a", "b")), - (to_statement("set a=b"), SetParameter("a", "b")), - (to_statement("set \t\na = \t\n b ;"), SetParameter("a", "b")), - (to_statement("set /*comment*/a=b"), SetParameter("a", "b")), - (to_statement("set a='some 'string'"), SetParameter("a", "some 'string")), - ( - to_statement( - 'set query_parameters={"name":"param1","value":"Hello, world!"}' - ), - SetParameter( - "query_parameters", '{"name":"param1","value":"Hello, world!"}' - ), - ), - (to_statement("UPDATE t SET a=50 WHERE a>b"), None), - ], -) -def test_statement_to_set(statement: Statement, result: Optional[SetParameter]) -> None: - assert statement_to_set(statement) == result, "Invalid statement_to_set output" - - -@mark.parametrize( - "statement,error", - [ - (to_statement("set"), InterfaceError), - (to_statement("set a"), InterfaceError), - (to_statement("set a ="), InterfaceError), - ], -) -def test_statement_to_set_errors(statement: Statement, error: Exception) -> None: - with raises(error): - statement_to_set(statement) - - -def test_binary() -> None: - assert Binary("abc") == b"abc" diff --git a/tests/unit/async_db/V2/test_typing_parse.py b/tests/unit/async_db/V2/test_typing_parse.py deleted file mode 100644 index 76528a96c7..0000000000 --- a/tests/unit/async_db/V2/test_typing_parse.py +++ /dev/null @@ -1,269 +0,0 @@ -from datetime import date, datetime, timedelta, timezone -from decimal import Decimal -from typing import Dict, Optional - -from pytest import mark, raises - -from firebolt.async_db import ( - ARRAY, - DECIMAL, - DateFromTicks, - TimeFromTicks, - TimestampFromTicks, -) -from firebolt.common._types import parse_type, parse_value -from firebolt.utils.exception import DataError, NotSupportedError - - -def test_parse_type(types_map: Dict[str, type]) -> None: - """parse_type function parses all internal types correctly.""" - for type_name, t in types_map.items(): - parsed = parse_type(type_name) - assert ( - parsed == t - ), f"Error parsing type {type_name}: expected {str(t)}, got {str(parsed)}" - - with raises(DataError) as exc_info: - parse_type(1) - - assert ( - str(exc_info.value) == "Invalid typename 1: str expected" - ), "Invalid type parsing error message" - - -def test_parse_value_int() -> None: - """parse_value parses all int values correctly.""" - assert parse_value(1, int) == 1, "Error parsing integer: provided int" - assert parse_value("1", int) == 1, "Error parsing integer: provided str" - assert parse_value(1.1, int) == 1, "Error parsing integer: provided float" - assert parse_value(None, int) is None, "Error parsing integer: provided None" - - with raises(ValueError): - parse_value("a", int) - - for val in ((1,), [1], Exception()): - with raises(TypeError): - parse_value(val, int) - - -def test_parse_value_float() -> None: - """parse_value parses all float values correctly.""" - assert parse_value(1, float) == 1.0, "Error parsing float: provided int" - assert parse_value("1", float) == 1.0, "Error parsing float: provided str" - assert parse_value("1.1", float) == 1.1, "Error parsing float: provided str" - assert parse_value(1.1, float) == 1.1, "Error parsing float: provided float" - assert parse_value(None, float) is None, "Error parsing float: provided None" - - with raises(ValueError): - parse_value("a", float) - - for val in ((1.1,), [1.1], Exception()): - with raises(TypeError): - parse_value(val, float) - - -def test_parse_value_str() -> None: - """parse_value parses all str values correctly.""" - assert parse_value(1, str) == "1", "Error parsing str: provided int" - assert parse_value("a", str) == "a", "Error parsing str: provided str" - assert parse_value(1.1, str) == "1.1", "Error parsing str: provided float" - assert parse_value(("a",), str) == "('a',)", "Error parsing str: provided tuple" - assert parse_value(["a"], str) == "['a']", "Error parsing str: provided list" - assert parse_value(None, str) is None, "Error parsing str: provided None" - - -@mark.parametrize( - "value,expected,case", - [ - ("2021-12-31", date(2021, 12, 31), "str provided"), - ("0001-01-01", date(1, 1, 1), "range low provided"), - ("9999-12-31", date(9999, 12, 31), "range high provided"), - (None, None, "None provided"), - ("2021-12-31 23:59:59", date(2021, 12, 31), "datetime provided"), - ], -) -def test_parse_value_date(value: Optional[str], expected: Optional[date], case: str): - """parse_value parses all date values correctly.""" - assert parse_value(value, date) == expected, f"Error parsing date: {case}" - - -@mark.parametrize( - "value,expected,case", - [ - ( - "2021-12-31 23:59:59.1234", - datetime(2021, 12, 31, 23, 59, 59, 123400), - "str provided", - ), - ( - "0001-01-01 00:00:00.000000", - datetime(1, 1, 1, 0, 0, 0, 0), - "range low provided", - ), - ( - "9999-12-31 23:59:59.999999", - datetime(9999, 12, 31, 23, 59, 59, 999999), - "range high provided", - ), - ( - "2021-12-31 23:59:59.1234-03", - datetime( - 2021, 12, 31, 23, 59, 59, 123400, tzinfo=timezone(timedelta(hours=-3)) - ), - "timezone provided", - ), - ( - "2021-12-31 23:59:59.1234+05:30:12", - datetime( - 2021, - 12, - 31, - 23, - 59, - 59, - 123400, - tzinfo=timezone(timedelta(hours=5, minutes=30, seconds=12)), - ), - "timezone with seconds provided", - ), - (None, None, "None provided"), - ("2021-12-31", datetime(2021, 12, 31), "date provided"), - ], -) -def test_parse_value_datetime( - value: Optional[str], expected: Optional[date], case: str -): - """parse_value parses all date values correctly.""" - assert parse_value(value, datetime) == expected, f"Error parsing datetime: {case}" - - -def test_parse_value_datetime_errors() -> None: - """parse_value parses all date and datetime values correctly.""" - with raises(ValueError): - parse_value("abd", date) - - for value in ([2021, 12, 31], (2021, 12, 31)): - with raises(DataError) as exc_info: - parse_value(value, date) - - assert str(exc_info.value) == f"Invalid date value {value}: str expected" - - # Datetime - assert parse_value("2021-12-31 23:59:59.1234", datetime) == datetime( - 2021, 12, 31, 23, 59, 59, 123400 - ), "Error parsing datetime: str provided" - assert parse_value(None, datetime) is None, "Error parsing datetime: None provided" - - assert parse_value("2021-12-31", datetime) == datetime( - 2021, 12, 31 - ), "Error parsing datetime: date string provided" - - with raises(ValueError): - parse_value("abd", datetime) - - for value in ([2021, 12, 31], (2021, 12, 31)): - with raises(DataError) as exc_info: - parse_value(value, datetime) - - assert str(exc_info.value) == f"Invalid datetime value {value}: str expected" - - -def test_parse_decimal() -> None: - assert parse_value("123.456", DECIMAL(38, 3)) == Decimal( - "123.456" - ), "Error parsing decimal(38, 3): str provided" - assert parse_value(123, DECIMAL(38, 3)) == Decimal( - "123" - ), "Error parsing decimal(38, 3): int provided" - assert ( - parse_value(None, DECIMAL(38, 3)) is None - ), "Error parsing decimal(38, 3): None provided" - - -def test_parse_arrays() -> None: - """parse_value parses all array values correctly.""" - assert parse_value([1, 2], ARRAY(int)) == [ - 1, - 2, - ], "Error parsing array(int): list[int] provided" - assert parse_value([1, "2"], ARRAY(int)) == [ - 1, - 2, - ], "Error parsing array(int): mixed list provided" - assert parse_value(["1", "2"], ARRAY(int)) == [ - 1, - 2, - ], "Error parsing array(int): list[str] provided" - - assert parse_value([1, 2], ARRAY(float)) == [ - 1.0, - 2.0, - ], "Error parsing array(float): list[int] provided" - - assert parse_value(["2021-12-31 23:59:59", "2000-01-01 01:01:01"], ARRAY(str)) == [ - "2021-12-31 23:59:59", - "2000-01-01 01:01:01", - ], "Error parsing array(str): list[str] provided" - - assert parse_value( - ["2021-12-31 23:59:59", "2000-01-01 01:01:01"], ARRAY(datetime) - ) == [ - datetime(2021, 12, 31, 23, 59, 59), - datetime(2000, 1, 1, 1, 1, 1), - ], "Error parsing array(datetime): list[str] provided" - - assert parse_value(["2021-12-31", "2000-01-01"], ARRAY(date)) == [ - date(2021, 12, 31), - date(2000, 1, 1), - ], "Error parsing array(datetime): list[str] provided" - - for t in (int, float, str, date, datetime, ARRAY(int)): - assert ( - parse_value(None, ARRAY(t)) is None - ), f"Error parsing array({str(t)}): None provided" - - -def test_helpers() -> None: - """All provided helper functions work properly.""" - d = date(2021, 12, 31) - dts = datetime(d.year, d.month, d.day).timestamp() - assert DateFromTicks(dts) == d, "Error running DateFromTicks" - - dt = datetime(2021, 12, 31, 23, 59, 59) - assert ( - TimestampFromTicks(datetime.timestamp(dt)) == dt - ), "Error running TimestampFromTicks" - - with raises(NotSupportedError): - TimeFromTicks(0) - - -def test_parse_value_bool() -> None: - """parse_value parses all int values correctly.""" - assert parse_value(True, bool) == True, "Error parsing boolean: provided true" - assert parse_value(False, bool) == False, "Error parsing boolean: provided false" - assert parse_value(2, bool) == True, "Error parsing boolean: provided 2" - assert parse_value(0, bool) == False, "Error parsing boolean: provided 0" - assert parse_value(None, bool) is None, "Error parsing boolean: provided None" - - with raises(DataError): - parse_value("true", bool) - - -def test_parse_value_bytes() -> None: - """parse_value parses all int values correctly.""" - assert ( - parse_value("\\x616263", bytes) == b"abc" - ), "Error parsing bytes: provided str" - assert parse_value(None, bytes) is None, "Error parsing bytes: provided None" - - with raises(ValueError): - parse_value("\\xabc", bytes) - - # Missing prefix - with raises(ValueError): - parse_value("616263", bytes) - - for val in (1, True, Exception()): - with raises(DataError): - parse_value(val, bytes) diff --git a/tests/unit/async_db/V1/test_typing_format.py b/tests/unit/common/test_typing_format.py similarity index 100% rename from tests/unit/async_db/V1/test_typing_format.py rename to tests/unit/common/test_typing_format.py diff --git a/tests/unit/async_db/V1/test_typing_parse.py b/tests/unit/common/test_typing_parse.py similarity index 52% rename from tests/unit/async_db/V1/test_typing_parse.py rename to tests/unit/common/test_typing_parse.py index 76528a96c7..789fab571a 100644 --- a/tests/unit/async_db/V1/test_typing_parse.py +++ b/tests/unit/common/test_typing_parse.py @@ -1,3 +1,4 @@ +import math from datetime import date, datetime, timedelta, timezone from decimal import Decimal from typing import Dict, Optional @@ -31,45 +32,80 @@ def test_parse_type(types_map: Dict[str, type]) -> None: ), "Invalid type parsing error message" -def test_parse_value_int() -> None: +@mark.parametrize( + "value,expected,error", + [ + (1, 1, None), + ("1", 1, None), + (1.1, 1, None), + (None, None, None), + ("a", None, ValueError), + ((1,), None, TypeError), + ([1], None, TypeError), + (Exception(), None, TypeError), + ], +) +def test_parse_value_int(value, expected, error) -> None: """parse_value parses all int values correctly.""" - assert parse_value(1, int) == 1, "Error parsing integer: provided int" - assert parse_value("1", int) == 1, "Error parsing integer: provided str" - assert parse_value(1.1, int) == 1, "Error parsing integer: provided float" - assert parse_value(None, int) is None, "Error parsing integer: provided None" - - with raises(ValueError): - parse_value("a", int) - - for val in ((1,), [1], Exception()): - with raises(TypeError): - parse_value(val, int) + if error: + with raises(error): + parse_value(value, int) + else: + assert ( + parse_value(value, int) == expected + ), f"Error parsing integer: provided {value}, expected {expected}" -def test_parse_value_float() -> None: +@mark.parametrize( + "value,expected,error", + [ + (1, 1.0, None), + ("1", 1.0, None), + ("1.1", 1.1, None), + (1.1, 1.1, None), + (None, None, None), + ("inf", float("inf"), None), + ("-inf", float("-inf"), None), + ("nan", float("nan"), None), + ("-nan", float("nan"), None), + ("a", None, ValueError), + ((1.1,), None, TypeError), + ([1.1], None, TypeError), + (Exception(), None, TypeError), + ], +) +def test_parse_value_float(value, expected, error) -> None: """parse_value parses all float values correctly.""" - assert parse_value(1, float) == 1.0, "Error parsing float: provided int" - assert parse_value("1", float) == 1.0, "Error parsing float: provided str" - assert parse_value("1.1", float) == 1.1, "Error parsing float: provided str" - assert parse_value(1.1, float) == 1.1, "Error parsing float: provided float" - assert parse_value(None, float) is None, "Error parsing float: provided None" - - with raises(ValueError): - parse_value("a", float) + if error: + with raises(error): + parse_value(value, float) + else: + if expected and math.isnan(expected): + assert math.isnan( + parse_value(value, float) + ), f"Error parsing float: provided {value}, expected {expected}" + else: + assert ( + parse_value(value, float) == expected + ), f"Error parsing float: provided {value}, expected {expected}" - for val in ((1.1,), [1.1], Exception()): - with raises(TypeError): - parse_value(val, float) - -def test_parse_value_str() -> None: +@mark.parametrize( + "value,expected", + [ + (1, "1"), + ("a", "a"), + (1.1, "1.1"), + (("a",), "('a',)"), + (["a"], "['a']"), + (None, None), + ], +) +def test_parse_value_str(value, expected) -> None: """parse_value parses all str values correctly.""" - assert parse_value(1, str) == "1", "Error parsing str: provided int" - assert parse_value("a", str) == "a", "Error parsing str: provided str" - assert parse_value(1.1, str) == "1.1", "Error parsing str: provided float" - assert parse_value(("a",), str) == "('a',)", "Error parsing str: provided tuple" - assert parse_value(["a"], str) == "['a']", "Error parsing str: provided list" - assert parse_value(None, str) is None, "Error parsing str: provided None" + assert ( + parse_value(value, str) == expected + ), f"Error parsing str: provided {value}, expected {expected}" @mark.parametrize( @@ -168,59 +204,50 @@ def test_parse_value_datetime_errors() -> None: assert str(exc_info.value) == f"Invalid datetime value {value}: str expected" -def test_parse_decimal() -> None: - assert parse_value("123.456", DECIMAL(38, 3)) == Decimal( - "123.456" - ), "Error parsing decimal(38, 3): str provided" - assert parse_value(123, DECIMAL(38, 3)) == Decimal( - "123" - ), "Error parsing decimal(38, 3): int provided" +@mark.parametrize( + "value,expected", + [ + ("123.456", Decimal("123.456")), + (123, Decimal("123")), + (None, None), + ], +) +def test_parse_decimal(value, expected) -> None: assert ( - parse_value(None, DECIMAL(38, 3)) is None - ), "Error parsing decimal(38, 3): None provided" - - -def test_parse_arrays() -> None: - """parse_value parses all array values correctly.""" - assert parse_value([1, 2], ARRAY(int)) == [ - 1, - 2, - ], "Error parsing array(int): list[int] provided" - assert parse_value([1, "2"], ARRAY(int)) == [ - 1, - 2, - ], "Error parsing array(int): mixed list provided" - assert parse_value(["1", "2"], ARRAY(int)) == [ - 1, - 2, - ], "Error parsing array(int): list[str] provided" - - assert parse_value([1, 2], ARRAY(float)) == [ - 1.0, - 2.0, - ], "Error parsing array(float): list[int] provided" - - assert parse_value(["2021-12-31 23:59:59", "2000-01-01 01:01:01"], ARRAY(str)) == [ - "2021-12-31 23:59:59", - "2000-01-01 01:01:01", - ], "Error parsing array(str): list[str] provided" - - assert parse_value( - ["2021-12-31 23:59:59", "2000-01-01 01:01:01"], ARRAY(datetime) - ) == [ - datetime(2021, 12, 31, 23, 59, 59), - datetime(2000, 1, 1, 1, 1, 1), - ], "Error parsing array(datetime): list[str] provided" - - assert parse_value(["2021-12-31", "2000-01-01"], ARRAY(date)) == [ - date(2021, 12, 31), - date(2000, 1, 1), - ], "Error parsing array(datetime): list[str] provided" - - for t in (int, float, str, date, datetime, ARRAY(int)): - assert ( - parse_value(None, ARRAY(t)) is None - ), f"Error parsing array({str(t)}): None provided" + parse_value(value, DECIMAL(38, 3)) == expected + ), "Error parsing decimal(38, 3): provided {value}, expected {expected}" + + +@mark.parametrize( + "value,expected,type", + [ + ([1, 2], [1, 2], int), + ([1, "2"], [1, 2], int), + (["1", "2"], [1, 2], int), + ([1, 2], [1.0, 2.0], float), + ( + ["2021-12-31 23:59:59", "2000-01-01 00:00:00"], + ["2021-12-31 23:59:59", "2000-01-01 00:00:00"], + str, + ), + ( + ["2021-12-31 23:59:59", "2000-01-01 00:00:00"], + [datetime(2021, 12, 31, 23, 59, 59), datetime(2000, 1, 1, 0, 0, 0)], + datetime, + ), + (["2021-12-31", "2000-01-01"], [date(2021, 12, 31), date(2000, 1, 1)], date), + (None, None, int), + (None, None, float), + (None, None, str), + (None, None, datetime), + (None, None, date), + (None, None, ARRAY(int)), + ], +) +def test_parse_arrays(value, expected, type) -> None: + assert ( + parse_value(value, ARRAY(type)) == expected + ), f"Error parsing array({type}): provided {value}, expected {expected}" def test_helpers() -> None: @@ -238,32 +265,44 @@ def test_helpers() -> None: TimeFromTicks(0) -def test_parse_value_bool() -> None: +@mark.parametrize( + "value,expected,error", + [ + (True, True, None), + (False, False, None), + (2, True, None), + (0, False, None), + (None, None, None), + ("true", None, DataError), + ], +) +def test_parse_value_bool(value, expected, error) -> None: """parse_value parses all int values correctly.""" - assert parse_value(True, bool) == True, "Error parsing boolean: provided true" - assert parse_value(False, bool) == False, "Error parsing boolean: provided false" - assert parse_value(2, bool) == True, "Error parsing boolean: provided 2" - assert parse_value(0, bool) == False, "Error parsing boolean: provided 0" - assert parse_value(None, bool) is None, "Error parsing boolean: provided None" - - with raises(DataError): - parse_value("true", bool) + if error: + with raises(error): + parse_value(value, bool) + else: + assert ( + parse_value(value, bool) == expected + ), f"Error parsing boolean: provided {value}" -def test_parse_value_bytes() -> None: +@mark.parametrize( + "value,expected,error", + [ + ("\\x616263", b"abc", None), + (None, None, None), + ("\\xabc", None, ValueError), + ("616263", None, ValueError), + (1, None, DataError), + ], +) +def test_parse_value_bytes(value, expected, error) -> None: """parse_value parses all int values correctly.""" - assert ( - parse_value("\\x616263", bytes) == b"abc" - ), "Error parsing bytes: provided str" - assert parse_value(None, bytes) is None, "Error parsing bytes: provided None" - - with raises(ValueError): - parse_value("\\xabc", bytes) - - # Missing prefix - with raises(ValueError): - parse_value("616263", bytes) - - for val in (1, True, Exception()): - with raises(DataError): - parse_value(val, bytes) + if error: + with raises(error): + parse_value(value, bytes) + else: + assert ( + parse_value(value, bytes) == expected + ), f"Error parsing bytes: provided {value}" diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 5ace52d183..2afc6df9c6 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -580,3 +580,31 @@ def inner() -> None: httpx_mock.add_callback(insert_query_callback, url=query_url) return inner + + +@fixture +def types_map() -> Dict[str, type]: + base_types = { + "int": int, + "long": int, + "float": float, + "double": float, + "text": str, + "date": date, + "pgdate": date, + "timestamp": datetime, + "timestampntz": datetime, + "timestamptz": datetime, + "Nothing null": str, + "Decimal(123, 4)": DECIMAL(123, 4), + "Decimal(38,0)": DECIMAL(38, 0), + # Invalid decimal format + "Decimal(38)": str, + "boolean": bool, + "SomeRandomNotExistingType": str, + "bytea": bytes, + } + array_types = {f"array({k})": ARRAY(v) for k, v in base_types.items()} + nullable_arrays = {f"{k} null": v for k, v in array_types.items()} + nested_arrays = {f"array({k})": ARRAY(v) for k, v in array_types.items()} + return {**base_types, **array_types, **nullable_arrays, **nested_arrays}