Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/firebolt/common/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class ExtendedType:
def is_valid_type(type_: Any) -> bool:
return type_ in _col_types or isinstance(type_, ExtendedType)

# Remember to override this method in subclasses
# if __eq__ is overridden
# https://docs.python.org/3/reference/datamodel.html#object.__hash__
def __hash__(self) -> int:
return hash(str(self))

Expand All @@ -110,6 +113,8 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
return other.subtype == self.subtype

__hash__ = ExtendedType.__hash__


class DECIMAL(ExtendedType):
"""Class for holding `decimal` value information in Firebolt DB."""
Expand All @@ -129,6 +134,8 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
return other.precision == self.precision and other.scale == self.scale

__hash__ = ExtendedType.__hash__


class STRUCT(ExtendedType):
__name__ = "Struct"
Expand All @@ -146,6 +153,8 @@ def __str__(self) -> str:
def __eq__(self, other: Any) -> bool:
return isinstance(other, STRUCT) and other.fields == self.fields

__hash__ = ExtendedType.__hash__


NULLABLE_SUFFIX = "null"

Expand Down
79 changes: 79 additions & 0 deletions tests/unit/common/test_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import fields

from firebolt.common._types import ARRAY, DECIMAL, STRUCT
from firebolt.common.row_set.types import Column


Expand All @@ -15,3 +16,81 @@ def test_columns_supports_indexing():
)
for i, field in enumerate(fields(column)):
assert getattr(column, field.name) == column[i]


def test_array_is_hashable():
"""Test that ARRAY type is hashable and can be used in dictionaries and sets."""
# Create ARRAY types
array_of_int = ARRAY(int)
array_of_str = ARRAY(str)
array_of_array = ARRAY(ARRAY(int))

# Test hash function works
assert isinstance(hash(array_of_int), int)
assert isinstance(hash(array_of_str), int)
assert isinstance(hash(array_of_array), int)

# Test equality with same hash values
assert hash(array_of_int) == hash(ARRAY(int))
assert hash(array_of_str) == hash(ARRAY(str))

# Test usage in dictionary
d = {array_of_int: "array_of_int", array_of_str: "array_of_str"}
assert d[array_of_int] == "array_of_int"
assert d[ARRAY(int)] == "array_of_int"

# Test usage in set
s = {array_of_int, array_of_str, array_of_array, ARRAY(int)}
assert len(s) == 3 # array_of_int and ARRAY(int) are equal


def test_decimal_is_hashable():
"""Test that DECIMAL type is hashable and can be used in dictionaries and sets."""
# Create DECIMAL types
dec1 = DECIMAL(10, 2)
dec2 = DECIMAL(5, 0)
dec3 = DECIMAL(10, 2) # Same as dec1

# Test hash function works
assert isinstance(hash(dec1), int)
assert isinstance(hash(dec2), int)

# Test equality with same hash values
assert hash(dec1) == hash(dec3)
assert dec1 == dec3

# Test usage in dictionary
d = {dec1: "dec1", dec2: "dec2"}
assert d[dec1] == "dec1"
assert d[DECIMAL(10, 2)] == "dec1"

# Test usage in set
s = {dec1, dec2, dec3}
assert len(s) == 2 # dec1 and dec3 are the same


def test_struct_is_hashable():
"""Test that STRUCT type is hashable and can be used in dictionaries and sets."""
# Create STRUCT types
struct1 = STRUCT({"name": str, "age": int})
struct2 = STRUCT({"value": DECIMAL(10, 2)})
struct3 = STRUCT({"name": str, "age": int}) # Same as struct1
nested_struct = STRUCT({"person": struct1, "balance": float})

# Test hash function works
assert isinstance(hash(struct1), int)
assert isinstance(hash(struct2), int)
assert isinstance(hash(nested_struct), int)

# Test equality with same hash values
assert hash(struct1) == hash(struct3)
assert struct1 == struct3

# Test usage in dictionary
d = {struct1: "struct1", struct2: "struct2"}
assert d[struct1] == "struct1"
assert d[STRUCT({"name": str, "age": int})] == "struct1"

# Test usage in set
s = {struct1, struct2, struct3, nested_struct}
assert len(s) == 3 # struct1 and struct3 are the same
Loading