diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index 3efbf7b3926..67ed2b49c20 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -341,12 +341,8 @@ def parse_value( raise DataError(f"Invalid bytea value {value}: str expected") return _parse_bytea(value) if isinstance(ctype, DECIMAL): - if not isinstance(value, (str, int, float)): + if not isinstance(value, (str, int)): raise DataError(f"Invalid decimal value {value}: str or int expected") - if isinstance(value, float): - # Decimal constructor doesn't support float - # so we need to convert it to string first - value = str(value) return Decimal(value) if isinstance(ctype, ARRAY): if not isinstance(value, list): diff --git a/src/firebolt/common/row_set/streaming_common.py b/src/firebolt/common/row_set/streaming_common.py index 7a44f913208..86416d537cb 100644 --- a/src/firebolt/common/row_set/streaming_common.py +++ b/src/firebolt/common/row_set/streaming_common.py @@ -94,7 +94,8 @@ def _next_json_lines_record_from_line( return None try: - record = json.loads(next_line) + # Skip parsing floats to properly parse them later + record = json.loads(next_line, parse_float=str) except json.JSONDecodeError as err: raise OperationalError( f"Invalid JSON line response format: {next_line}" diff --git a/src/firebolt/common/row_set/synchronous/in_memory.py b/src/firebolt/common/row_set/synchronous/in_memory.py index b36b80b80da..ddf0a4f78c9 100644 --- a/src/firebolt/common/row_set/synchronous/in_memory.py +++ b/src/firebolt/common/row_set/synchronous/in_memory.py @@ -43,7 +43,8 @@ def append_response_stream(self, stream: Iterator[bytes]) -> None: self.append_empty_response() else: try: - query_data = json.loads(content) + # Skip parsing floats to properly parse them later + query_data = json.loads(content, parse_float=str) if "errors" in query_data and len(query_data["errors"]) > 0: raise FireboltStructuredError(query_data) diff --git a/tests/integration/dbapi/async/V1/conftest.py b/tests/integration/dbapi/async/V1/conftest.py index b3cd40edacd..895d4b2d568 100644 --- a/tests/integration/dbapi/async/V1/conftest.py +++ b/tests/integration/dbapi/async/V1/conftest.py @@ -1,11 +1,7 @@ -from decimal import Decimal -from typing import List - from pytest import fixture from firebolt.async_db import Connection, connect from firebolt.client.auth.base import Auth -from firebolt.common._types import ColType @fixture @@ -78,14 +74,3 @@ async def connection_no_engine( api_endpoint=api_endpoint, ) as connection: yield connection - - -@fixture -def all_types_query_response_v1( - all_types_query_response: List[List[ColType]], -) -> List[List[ColType]]: - """ - V1 still returns decimals as floats, despite overflow. That's why it's not fully accurate. - """ - all_types_query_response[0][18] = Decimal("1231232.1234599999152123928070068359375") - return all_types_query_response diff --git a/tests/integration/dbapi/async/V1/test_queries_async.py b/tests/integration/dbapi/async/V1/test_queries_async.py index fd32ab3789c..d8d5906c470 100644 --- a/tests/integration/dbapi/async/V1/test_queries_async.py +++ b/tests/integration/dbapi/async/V1/test_queries_async.py @@ -78,7 +78,7 @@ async def test_connect_engine_name( connection_engine_name: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response_v1: List[ColType], + all_types_query_response: List[ColType], timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" @@ -86,7 +86,7 @@ async def test_connect_engine_name( connection_engine_name, all_types_query, all_types_query_description, - all_types_query_response_v1, + all_types_query_response, timezone_name, ) @@ -95,7 +95,7 @@ async def test_connect_no_engine( connection_no_engine: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response_v1: List[ColType], + all_types_query_response: List[ColType], timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" @@ -103,7 +103,7 @@ async def test_connect_no_engine( connection_no_engine, all_types_query, all_types_query_description, - all_types_query_response_v1, + all_types_query_response, timezone_name, ) @@ -112,7 +112,7 @@ async def test_select( connection: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response_v1: List[ColType], + all_types_query_response: List[ColType], timezone_name: str, ) -> None: """Select handles all data types properly.""" @@ -130,7 +130,7 @@ async def test_select( assert c.rowcount == 1, "Invalid rowcount value" data = await c.fetchall() assert len(data) == c.rowcount, "Invalid data length" - assert_deep_eq(data, all_types_query_response_v1, "Invalid data") + assert_deep_eq(data, all_types_query_response, "Invalid data") assert c.description == all_types_query_description, "Invalid description value" assert len(data[0]) == len(c.description), "Invalid description length" assert len(await c.fetchall()) == 0, "Redundant data returned by fetchall" @@ -138,7 +138,7 @@ async def test_select( # Different fetch types await c.execute(all_types_query) assert ( - await c.fetchone() == all_types_query_response_v1[0] + await c.fetchone() == all_types_query_response[0] ), "Invalid fetchone data" assert await c.fetchone() is None, "Redundant data returned by fetchone" @@ -147,7 +147,7 @@ async def test_select( data = await c.fetchmany() assert len(data) == 1, "Invalid data size returned by fetchmany" assert_deep_eq( - data, all_types_query_response_v1, "Invalid data returned by fetchmany" + data, all_types_query_response, "Invalid data returned by fetchmany" ) @@ -328,7 +328,7 @@ async def test_empty_query(c: Cursor, query: str, params: tuple) -> None: datetime(2022, 1, 1, 1, 1, 1), True, [1, 2, 3], - Decimal(123.456), + Decimal("123.456"), ] await test_empty_query( diff --git a/tests/integration/dbapi/sync/V1/conftest.py b/tests/integration/dbapi/sync/V1/conftest.py index c05d4d42ba1..9beea7962f7 100644 --- a/tests/integration/dbapi/sync/V1/conftest.py +++ b/tests/integration/dbapi/sync/V1/conftest.py @@ -1,10 +1,6 @@ -from decimal import Decimal -from typing import List - from pytest import fixture from firebolt.client.auth.base import Auth -from firebolt.common._types import ColType from firebolt.db import Connection, connect @@ -97,14 +93,3 @@ def connection_system_engine( ) yield connection connection.close() - - -@fixture -def all_types_query_response_v1( - all_types_query_response: List[List[ColType]], -) -> List[List[ColType]]: - """ - V1 still returns decimals as floats, despite overflow. That's why it's not fully accurate. - """ - all_types_query_response[0][18] = Decimal("1231232.1234599999152123928070068359375") - return all_types_query_response diff --git a/tests/integration/dbapi/sync/V1/test_queries.py b/tests/integration/dbapi/sync/V1/test_queries.py index aaae6dbdd5d..b331e7ac8c9 100644 --- a/tests/integration/dbapi/sync/V1/test_queries.py +++ b/tests/integration/dbapi/sync/V1/test_queries.py @@ -30,7 +30,7 @@ def test_connect_engine_name( connection_engine_name: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response_v1: List[ColType], + all_types_query_response: List[ColType], timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" @@ -38,7 +38,7 @@ def test_connect_engine_name( connection_engine_name, all_types_query, all_types_query_description, - all_types_query_response_v1, + all_types_query_response, timezone_name, ) @@ -47,7 +47,7 @@ def test_connect_no_engine( connection_no_engine: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response_v1: List[ColType], + all_types_query_response: List[ColType], timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" @@ -55,7 +55,7 @@ def test_connect_no_engine( connection_no_engine, all_types_query, all_types_query_description, - all_types_query_response_v1, + all_types_query_response, timezone_name, ) @@ -64,7 +64,7 @@ def test_select( connection: Connection, all_types_query: str, all_types_query_description: List[Column], - all_types_query_response_v1: List[ColType], + all_types_query_response: List[ColType], timezone_name: str, ) -> None: """Select handles all data types properly.""" @@ -82,14 +82,14 @@ def test_select( assert c.rowcount == 1, "Invalid rowcount value" data = c.fetchall() assert len(data) == c.rowcount, "Invalid data length" - assert_deep_eq(data, all_types_query_response_v1, "Invalid data") + assert_deep_eq(data, all_types_query_response, "Invalid data") assert c.description == all_types_query_description, "Invalid description value" assert len(data[0]) == len(c.description), "Invalid description length" assert len(c.fetchall()) == 0, "Redundant data returned by fetchall" # Different fetch types c.execute(all_types_query) - assert c.fetchone() == all_types_query_response_v1[0], "Invalid fetchone data" + assert c.fetchone() == all_types_query_response[0], "Invalid fetchone data" assert c.fetchone() is None, "Redundant data returned by fetchone" c.execute(all_types_query) @@ -97,7 +97,7 @@ def test_select( data = c.fetchmany() assert len(data) == 1, "Invalid data size returned by fetchmany" assert_deep_eq( - data, all_types_query_response_v1, "Invalid data returned by fetchmany" + data, all_types_query_response, "Invalid data returned by fetchmany" ) @@ -273,7 +273,7 @@ def test_empty_query(c: Cursor, query: str, params: tuple) -> None: datetime(2022, 1, 1, 1, 1, 1), True, [1, 2, 3], - Decimal(123.456), + Decimal("123.456"), ] test_empty_query( diff --git a/tests/unit/common/row_set/asynchronous/test_in_memory.py b/tests/unit/common/row_set/asynchronous/test_in_memory.py index 55492d71589..93fe55b31e7 100644 --- a/tests/unit/common/row_set/asynchronous/test_in_memory.py +++ b/tests/unit/common/row_set/asynchronous/test_in_memory.py @@ -1,4 +1,5 @@ import json +from decimal import Decimal from unittest.mock import MagicMock, patch import pytest @@ -173,6 +174,7 @@ def test_append_empty_response(self, in_memory_rowset): async def test_append_response(self, in_memory_rowset, mock_response): """Test appending a response with data.""" + # Create a proper aclose method async def mock_aclose(): mock_response.is_closed = True @@ -207,6 +209,7 @@ async def test_append_response_empty_content( self, in_memory_rowset, mock_empty_response ): """Test appending a response with empty content.""" + # Create a proper aclose method async def mock_aclose(): mock_empty_response.is_closed = True @@ -226,6 +229,7 @@ async def test_append_response_invalid_json( self, in_memory_rowset, mock_invalid_json_response ): """Test appending a response with invalid JSON.""" + # Create a proper aclose method async def mock_aclose(): mock_invalid_json_response.is_closed = True @@ -245,6 +249,7 @@ async def test_append_response_missing_meta( self, in_memory_rowset, mock_missing_meta_response ): """Test appending a response with missing meta field.""" + # Create a proper aclose method async def mock_aclose(): mock_missing_meta_response.is_closed = True @@ -264,6 +269,7 @@ async def test_append_response_missing_data( self, in_memory_rowset, mock_missing_data_response ): """Test appending a response with missing data field.""" + # Create a proper aclose method async def mock_aclose(): mock_missing_data_response.is_closed = True @@ -281,6 +287,7 @@ async def mock_aclose(): async def test_nextset_no_more_sets(self, in_memory_rowset, mock_response): """Test nextset when there are no more result sets.""" + # Create a proper aclose method async def mock_aclose(): pass @@ -296,6 +303,7 @@ async def test_nextset_with_more_sets(self, in_memory_rowset, mock_response): The implementation seems to add rowsets correctly, but behaves differently than expected when accessing them via nextset. """ + # Create a proper aclose method async def mock_aclose(): pass @@ -322,6 +330,7 @@ async def mock_aclose(): async def test_iteration(self, in_memory_rowset, mock_response): """Test row iteration.""" + # Create a proper aclose method async def mock_aclose(): pass @@ -347,6 +356,7 @@ async def test_iteration_after_nextset(self, in_memory_rowset, mock_response): This test is tricky because in the mock setup, the second row set is actually empty despite us adding the same mock response. """ + # Create a proper aclose method async def mock_aclose(): pass @@ -410,6 +420,7 @@ async def test_empty_rowset_iteration(self, in_memory_rowset): async def test_aclose(self, in_memory_rowset, mock_response): """Test aclose method.""" + # Create a proper aclose method async def mock_aclose(): pass @@ -423,3 +434,23 @@ async def mock_aclose(): # Verify sync close was called mock_close.assert_called_once() + + async def test_append_response_with_decimals( + self, in_memory_rowset: InMemoryAsyncRowSet, mock_decimal_bytes_stream: Response + ): + + await in_memory_rowset.append_response(mock_decimal_bytes_stream) + + # Verify basic properties + assert in_memory_rowset.row_count == 2 + assert len(in_memory_rowset.columns) == 3 + + # Get the row values and check decimal values are equal + rows = [row async for row in in_memory_rowset] + + # Verify the decimal value is correctly parsed + for row in rows: + assert isinstance(row[2], Decimal), "Expected Decimal type" + assert ( + str(row[2]) == "1231232.123459999990457054844258706536" + ), "Decimal value mismatch" diff --git a/tests/unit/common/row_set/asynchronous/test_streaming.py b/tests/unit/common/row_set/asynchronous/test_streaming.py index a5bfe64d213..754ce219913 100644 --- a/tests/unit/common/row_set/asynchronous/test_streaming.py +++ b/tests/unit/common/row_set/asynchronous/test_streaming.py @@ -1,4 +1,5 @@ import json +from decimal import Decimal from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -564,6 +565,7 @@ async def test_iteration_stops_after_response_consumed(self, streaming_rowset): async def test_pop_data_record_from_record_unexpected_end(self): """Test _pop_data_record_from_record behavior with unexpected end of stream.""" + # Create a simple subclass to access protected method directly class TestableStreamingAsyncRowSet(StreamingAsyncRowSet): def pop_data_record_from_record_exposed(self, record): @@ -993,3 +995,28 @@ def patched_reset(): # Internal state should be reset assert streaming_rowset._responses == [] + + async def test_streaming_with_decimals( + self, + streaming_rowset: StreamingAsyncRowSet, + mock_decimal_response_streaming: Response, + ): + + # Add the response to the row set + await streaming_rowset.append_response(mock_decimal_response_streaming) + + # Verify columns are correctly identified + assert len(streaming_rowset.columns) == 3 + + # Get all rows + rows = [row async for row in streaming_rowset] + + # Verify we got the expected number of rows + assert len(rows) == 2 + + # Verify decimal values are correctly parsed in both formats (string and float) + for row in rows: + assert isinstance(row[2], Decimal), "Expected Decimal type" + assert ( + str(row[2]) == "1231232.123459999990457054844258706536" + ), "Decimal value mismatch" diff --git a/tests/unit/common/row_set/conftest.py b/tests/unit/common/row_set/conftest.py new file mode 100644 index 00000000000..3e3b6c87311 --- /dev/null +++ b/tests/unit/common/row_set/conftest.py @@ -0,0 +1,117 @@ +import json +from unittest.mock import MagicMock + +import pytest +from httpx import Response + + +@pytest.fixture +def mock_decimal_response_streaming() -> Response: + """Create a mock Response with decimal data.""" + mock = MagicMock(spec=Response) + + # Create JSON record strings with properly formatted data + start_record_data = { + "message_type": "START", + "result_columns": [ + {"name": "col1", "type": "int"}, + {"name": "col2", "type": "string"}, + {"name": "col3", "type": "numeric(10, 2)"}, + ], + "query_id": "q1", + "query_label": "l1", + "request_id": "r1", + } + + data_record_data = { + "message_type": "DATA", + "data": [ + [1, "one", "1231232.123459999990457054844258706536"], + [2, "two", "1231232.123459999990457054844258706536"], + ], + } + + success_record_data = { + "message_type": "FINISH_SUCCESSFULLY", + "statistics": { + "elapsed": 0.1, + "rows_read": 10, + "bytes_read": 100, + "time_before_execution": 0.01, + "time_to_execute": 0.09, + }, + } + + # Generate the JSON strings + start_record = json.dumps(start_record_data) + data_record = json.dumps(data_record_data) + + # Replace the decimal string with a float to simulate the behavior of FB 1.0 + # for one of the rows + data_record = data_record.replace( + '"1231232.123459999990457054844258706536"', + "1231232.123459999990457054844258706536", + 1, + ) + + success_record = json.dumps(success_record_data) + + mock.iter_lines.return_value = iter([start_record, data_record, success_record]) + + async def async_iter(): + for item in [ + start_record.encode("utf-8"), + data_record.encode("utf-8"), + success_record.encode("utf-8"), + ]: + yield item + + mock.aiter_lines.side_effect = async_iter + mock.is_closed = False + return mock + + +@pytest.fixture +def mock_decimal_bytes_stream() -> Response: + """Create a mock bytes stream with decimal data.""" + mock = MagicMock(spec=Response) + data = iter( + [ + json.dumps( + { + "meta": [ + {"name": "col1", "type": "int"}, + {"name": "col2", "type": "string"}, + {"name": "col3", "type": "Decimal(10, 2)"}, + ], + "data": [ + [1, "one", "1231232.123459999990457054844258706536"], + [2, "two", "1231232.123459999990457054844258706536"], + ], + "statistics": { + "elapsed": 0.1, + "rows_read": 10, + "bytes_read": 100, + "time_before_execution": 0.01, + "time_to_execute": 0.09, + }, + } + ) + # Replace the decimal string with a float to simulate the behavior of FB 1.0 + # for one of the rows + .replace( + '"1231232.123459999990457054844258706536"', + "1231232.123459999990457054844258706536", + 1, + ).encode("utf-8") + ] + ) + mock.iter_bytes.return_value = data + + async def async_iter(): + for item in data: + yield item + + mock.aiter_bytes.side_effect = async_iter + mock.is_closed = False + return mock diff --git a/tests/unit/common/row_set/synchronous/test_in_memory.py b/tests/unit/common/row_set/synchronous/test_in_memory.py index baf22b1204a..09ad02a06be 100644 --- a/tests/unit/common/row_set/synchronous/test_in_memory.py +++ b/tests/unit/common/row_set/synchronous/test_in_memory.py @@ -1,4 +1,5 @@ import json +from decimal import Decimal from unittest.mock import MagicMock import pytest @@ -12,7 +13,7 @@ class TestInMemoryRowSet: """Tests for InMemoryRowSet functionality.""" @pytest.fixture - def in_memory_rowset(self): + def in_memory_rowset(self) -> InMemoryRowSet: """Create a fresh InMemoryRowSet instance.""" return InMemoryRowSet() @@ -174,6 +175,28 @@ def test_append_response_stream(self, in_memory_rowset, mock_bytes_stream): assert len(in_memory_rowset.columns) == 2 assert in_memory_rowset.statistics is not None + def test_append_response_stream_with_decimals( + self, + in_memory_rowset: InMemoryRowSet, + mock_decimal_bytes_stream: Response, + ): + """Test appending a stream with decimal data type.""" + in_memory_rowset.append_response(mock_decimal_bytes_stream) + + assert len(in_memory_rowset._row_sets) == 1 + assert in_memory_rowset.row_count == 2 + assert len(in_memory_rowset.columns) == 3 + + # Get the row values and check decimal values are equal + rows = list(in_memory_rowset) + + # Verify the decimal value is correctly parsed + for row in rows: + assert isinstance(row[2], Decimal), "Expected Decimal type" + assert ( + str(row[2]) == "1231232.123459999990457054844258706536" + ), "Decimal value mismatch" + def test_append_response_stream_multi_chunk( self, in_memory_rowset, mock_multi_chunk_bytes_stream ): diff --git a/tests/unit/common/row_set/synchronous/test_streaming.py b/tests/unit/common/row_set/synchronous/test_streaming.py index ba36764ab56..deae15cd412 100644 --- a/tests/unit/common/row_set/synchronous/test_streaming.py +++ b/tests/unit/common/row_set/synchronous/test_streaming.py @@ -1,4 +1,5 @@ import json +from decimal import Decimal from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -611,6 +612,7 @@ def test_corrupted_json_line(self, streaming_rowset): def test_pop_data_record_from_record_unexpected_end(self): """Test _pop_data_record_from_record behavior with unexpected end of stream.""" + # Create a simple subclass to access protected method directly class TestableStreamingRowSet(StreamingRowSet): def pop_data_record_from_record_exposed(self, record): @@ -989,3 +991,28 @@ def patched_reset(): # Internal state should be reset assert streaming_rowset._responses == [] + + def test_streaming_with_decimals( + self, + streaming_rowset: StreamingRowSet, + mock_decimal_response_streaming: Response, + ): + """Test handling of decimal data types in streaming responses.""" + # Add the response to the row set + streaming_rowset.append_response(mock_decimal_response_streaming) + + # Verify columns are correctly identified + assert len(streaming_rowset.columns) == 3 + + # Get all rows + rows = list(streaming_rowset) + + # Verify we got the expected number of rows + assert len(rows) == 2 + + # Verify decimal values are correctly parsed in both formats (string and float) + for row in rows: + assert isinstance(row[2], Decimal), "Expected Decimal type" + assert ( + str(row[2]) == "1231232.123459999990457054844258706536" + ), "Decimal value mismatch" diff --git a/tests/unit/common/test_typing_parse.py b/tests/unit/common/test_typing_parse.py index 54e3281a969..67289b09c29 100644 --- a/tests/unit/common/test_typing_parse.py +++ b/tests/unit/common/test_typing_parse.py @@ -219,7 +219,6 @@ def test_parse_value_datetime_errors() -> None: [ ("123.456", Decimal("123.456")), (123, Decimal("123")), - (123.456, Decimal("123.456")), (None, None), ], )