From 12804485d2aab09a2b8a5a82a2307bebc7d98201 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Fri, 25 Apr 2025 17:05:45 +0100 Subject: [PATCH] fix: Hashability of extended types --- src/firebolt/common/_types.py | 9 ++++ tests/unit/common/test_types.py | 79 +++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index dba92eb2325..d3653f93ea1 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -87,6 +87,9 @@ class ExtendedType: def is_valid_type(type_: Any) -> bool: return type_ in _col_types or isinstance(type_, ExtendedType) + # Remember to override this method in subclasses + # if __eq__ is overridden + # https://docs.python.org/3/reference/datamodel.html#object.__hash__ def __hash__(self) -> int: return hash(str(self)) @@ -110,6 +113,8 @@ def __eq__(self, other: object) -> bool: return NotImplemented return other.subtype == self.subtype + __hash__ = ExtendedType.__hash__ + class DECIMAL(ExtendedType): """Class for holding `decimal` value information in Firebolt DB.""" @@ -129,6 +134,8 @@ def __eq__(self, other: object) -> bool: return NotImplemented return other.precision == self.precision and other.scale == self.scale + __hash__ = ExtendedType.__hash__ + class STRUCT(ExtendedType): __name__ = "Struct" @@ -146,6 +153,8 @@ def __str__(self) -> str: def __eq__(self, other: Any) -> bool: return isinstance(other, STRUCT) and other.fields == self.fields + __hash__ = ExtendedType.__hash__ + NULLABLE_SUFFIX = "null" diff --git a/tests/unit/common/test_types.py b/tests/unit/common/test_types.py index 62f007105db..e63a5ebcd2a 100644 --- a/tests/unit/common/test_types.py +++ b/tests/unit/common/test_types.py @@ -1,5 +1,6 @@ from dataclasses import fields +from firebolt.common._types import ARRAY, DECIMAL, STRUCT from firebolt.common.row_set.types import Column @@ -15,3 +16,81 @@ def test_columns_supports_indexing(): ) for i, field in enumerate(fields(column)): assert getattr(column, field.name) == column[i] + + +def test_array_is_hashable(): + """Test that ARRAY type is hashable and can be used in dictionaries and sets.""" + # Create ARRAY types + array_of_int = ARRAY(int) + array_of_str = ARRAY(str) + array_of_array = ARRAY(ARRAY(int)) + + # Test hash function works + assert isinstance(hash(array_of_int), int) + assert isinstance(hash(array_of_str), int) + assert isinstance(hash(array_of_array), int) + + # Test equality with same hash values + assert hash(array_of_int) == hash(ARRAY(int)) + assert hash(array_of_str) == hash(ARRAY(str)) + + # Test usage in dictionary + d = {array_of_int: "array_of_int", array_of_str: "array_of_str"} + assert d[array_of_int] == "array_of_int" + assert d[ARRAY(int)] == "array_of_int" + + # Test usage in set + s = {array_of_int, array_of_str, array_of_array, ARRAY(int)} + assert len(s) == 3 # array_of_int and ARRAY(int) are equal + + +def test_decimal_is_hashable(): + """Test that DECIMAL type is hashable and can be used in dictionaries and sets.""" + # Create DECIMAL types + dec1 = DECIMAL(10, 2) + dec2 = DECIMAL(5, 0) + dec3 = DECIMAL(10, 2) # Same as dec1 + + # Test hash function works + assert isinstance(hash(dec1), int) + assert isinstance(hash(dec2), int) + + # Test equality with same hash values + assert hash(dec1) == hash(dec3) + assert dec1 == dec3 + + # Test usage in dictionary + d = {dec1: "dec1", dec2: "dec2"} + assert d[dec1] == "dec1" + assert d[DECIMAL(10, 2)] == "dec1" + + # Test usage in set + s = {dec1, dec2, dec3} + assert len(s) == 2 # dec1 and dec3 are the same + + +def test_struct_is_hashable(): + """Test that STRUCT type is hashable and can be used in dictionaries and sets.""" + # Create STRUCT types + struct1 = STRUCT({"name": str, "age": int}) + struct2 = STRUCT({"value": DECIMAL(10, 2)}) + struct3 = STRUCT({"name": str, "age": int}) # Same as struct1 + nested_struct = STRUCT({"person": struct1, "balance": float}) + + # Test hash function works + assert isinstance(hash(struct1), int) + assert isinstance(hash(struct2), int) + assert isinstance(hash(nested_struct), int) + + # Test equality with same hash values + assert hash(struct1) == hash(struct3) + assert struct1 == struct3 + + # Test usage in dictionary + d = {struct1: "struct1", struct2: "struct2"} + assert d[struct1] == "struct1" + assert d[STRUCT({"name": str, "age": int})] == "struct1" + + # Test usage in set + s = {struct1, struct2, struct3, nested_struct} + assert len(s) == 3 # struct1 and struct3 are the same