From 7e0b46aba7c48f7f944c0fca0cb394551b8d60c1 Mon Sep 17 00:00:00 2001 From: rahul2393 Date: Tue, 12 Mar 2024 19:45:35 +0530 Subject: [PATCH] feat: add support of float32 type (#1113) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add support of float32 type * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * incorporate changes * incorporate changes * handle case for infinity * fix build * fix tests --------- Co-authored-by: Owl Bot --- google/cloud/spanner_v1/_helpers.py | 5 ++++ google/cloud/spanner_v1/param_types.py | 1 + google/cloud/spanner_v1/streamed.py | 1 + tests/system/_sample_data.py | 3 ++ tests/system/test_session_api.py | 41 ++++++++++++++++++++++++++ tests/unit/test__helpers.py | 21 +++++++++++++ tests/unit/test_param_types.py | 21 +++++++------ 7 files changed, 82 insertions(+), 11 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index e0e2bfdbd0..d6b10dba18 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -228,6 +228,11 @@ def _parse_value_pb(value_pb, field_type): return float(value_pb.string_value) else: return value_pb.number_value + elif type_code == TypeCode.FLOAT32: + if value_pb.HasField("string_value"): + return float(value_pb.string_value) + else: + return value_pb.number_value elif type_code == TypeCode.DATE: return _date_from_iso8601_date(value_pb.string_value) elif type_code == TypeCode.TIMESTAMP: diff --git a/google/cloud/spanner_v1/param_types.py b/google/cloud/spanner_v1/param_types.py index 0c03f7ecc6..9b1910244d 100644 --- a/google/cloud/spanner_v1/param_types.py +++ b/google/cloud/spanner_v1/param_types.py @@ -26,6 +26,7 @@ BOOL = Type(code=TypeCode.BOOL) INT64 = Type(code=TypeCode.INT64) FLOAT64 = Type(code=TypeCode.FLOAT64) +FLOAT32 = Type(code=TypeCode.FLOAT32) DATE = Type(code=TypeCode.DATE) TIMESTAMP = Type(code=TypeCode.TIMESTAMP) NUMERIC = Type(code=TypeCode.NUMERIC) diff --git a/google/cloud/spanner_v1/streamed.py b/google/cloud/spanner_v1/streamed.py index ac8fc71ce6..d2c2b6216f 100644 --- a/google/cloud/spanner_v1/streamed.py +++ b/google/cloud/spanner_v1/streamed.py @@ -332,6 +332,7 @@ def _merge_struct(lhs, rhs, type_): TypeCode.BYTES: _merge_string, TypeCode.DATE: _merge_string, TypeCode.FLOAT64: _merge_float64, + TypeCode.FLOAT32: _merge_float64, TypeCode.INT64: _merge_string, TypeCode.STRING: _merge_string, TypeCode.STRUCT: _merge_struct, diff --git a/tests/system/_sample_data.py b/tests/system/_sample_data.py index 9c83f42224..d9c269c27f 100644 --- a/tests/system/_sample_data.py +++ b/tests/system/_sample_data.py @@ -90,5 +90,8 @@ def _check_cell_data(found_cell, expected_cell, recurse_into_lists=True): for found_item, expected_item in zip(found_cell, expected_cell): _check_cell_data(found_item, expected_item) + elif isinstance(found_cell, float) and not math.isinf(found_cell): + assert abs(found_cell - expected_cell) < 0.00001 + else: assert found_cell == expected_cell diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 29d196b011..6f1844faa9 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -2216,6 +2216,47 @@ def test_execute_sql_w_float_bindings_transfinite(sessions_database, database_di ) +def test_execute_sql_w_float32_bindings(sessions_database, database_dialect): + pytest.skip("float32 is not yet supported in production.") + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.FLOAT32, + 42.3, + [12.3, 456.0, 7.89], + ) + + +def test_execute_sql_w_float32_bindings_transfinite( + sessions_database, database_dialect +): + pytest.skip("float32 is not yet supported in production.") + key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "neg_inf" + placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" + + # Find -inf + _check_sql_results( + sessions_database, + sql=f"SELECT {placeholder}", + params={key: NEG_INF}, + param_types={key: spanner_v1.param_types.FLOAT32}, + expected=[(NEG_INF,)], + order=False, + ) + + key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "pos_inf" + placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" + # Find +inf + _check_sql_results( + sessions_database, + sql=f"SELECT {placeholder}", + params={key: POS_INF}, + param_types={key: spanner_v1.param_types.FLOAT32}, + expected=[(POS_INF,)], + order=False, + ) + + def test_execute_sql_w_bytes_bindings(sessions_database, database_dialect): _bind_test_helper( sessions_database, diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 0e0ec903a2..cb2372406f 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -466,6 +466,27 @@ def test_w_float_str(self): self.assertEqual(self._callFUT(value_pb, field_type), expected_value) + def test_w_float32(self): + from google.cloud.spanner_v1 import Type, TypeCode + from google.protobuf.struct_pb2 import Value + + VALUE = 3.14159 + field_type = Type(code=TypeCode.FLOAT32) + value_pb = Value(number_value=VALUE) + + self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + + def test_w_float32_str(self): + from google.cloud.spanner_v1 import Type, TypeCode + from google.protobuf.struct_pb2 import Value + + VALUE = "3.14159" + field_type = Type(code=TypeCode.FLOAT32) + value_pb = Value(string_value=VALUE) + expected_value = 3.14159 + + self.assertEqual(self._callFUT(value_pb, field_type), expected_value) + def test_w_date(self): import datetime from google.protobuf.struct_pb2 import Value diff --git a/tests/unit/test_param_types.py b/tests/unit/test_param_types.py index 02f41c1f25..645774d79b 100644 --- a/tests/unit/test_param_types.py +++ b/tests/unit/test_param_types.py @@ -18,9 +18,7 @@ class Test_ArrayParamType(unittest.TestCase): def test_it(self): - from google.cloud.spanner_v1 import Type - from google.cloud.spanner_v1 import TypeCode - from google.cloud.spanner_v1 import param_types + from google.cloud.spanner_v1 import Type, TypeCode, param_types expected = Type( code=TypeCode.ARRAY, array_element_type=Type(code=TypeCode.INT64) @@ -33,15 +31,13 @@ def test_it(self): class Test_Struct(unittest.TestCase): def test_it(self): - from google.cloud.spanner_v1 import Type - from google.cloud.spanner_v1 import TypeCode - from google.cloud.spanner_v1 import StructType - from google.cloud.spanner_v1 import param_types + from google.cloud.spanner_v1 import StructType, Type, TypeCode, param_types struct_type = StructType( fields=[ StructType.Field(name="name", type_=Type(code=TypeCode.STRING)), StructType.Field(name="count", type_=Type(code=TypeCode.INT64)), + StructType.Field(name="float32", type_=Type(code=TypeCode.FLOAT32)), ] ) expected = Type(code=TypeCode.STRUCT, struct_type=struct_type) @@ -50,6 +46,7 @@ def test_it(self): [ param_types.StructField("name", param_types.STRING), param_types.StructField("count", param_types.INT64), + param_types.StructField("float32", param_types.FLOAT32), ] ) @@ -58,10 +55,12 @@ def test_it(self): class Test_JsonbParamType(unittest.TestCase): def test_it(self): - from google.cloud.spanner_v1 import Type - from google.cloud.spanner_v1 import TypeCode - from google.cloud.spanner_v1 import TypeAnnotationCode - from google.cloud.spanner_v1 import param_types + from google.cloud.spanner_v1 import ( + Type, + TypeAnnotationCode, + TypeCode, + param_types, + ) expected = Type( code=TypeCode.JSON,