diff --git a/bigtable/google/cloud/bigtable/row_set.py b/bigtable/google/cloud/bigtable/row_set.py index 0d5ae9903473..ab2f15231903 100644 --- a/bigtable/google/cloud/bigtable/row_set.py +++ b/bigtable/google/cloud/bigtable/row_set.py @@ -29,6 +29,27 @@ def __init__(self): self.row_keys = [] self.row_ranges = [] + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + + if len(other.row_keys) != len(self.row_keys): + return False + + if len(other.row_ranges) != len(self.row_ranges): + return False + + if not set(other.row_keys) == set(self.row_keys): + return False + + if not set(other.row_ranges) == set(self.row_ranges): + return False + + return True + + def __ne__(self, other): + return not self == other + def add_row_key(self, row_key): """Add row key to row_keys list. @@ -112,6 +133,32 @@ def __init__(self, start_key=None, end_key=None, self.end_key = end_key self.end_inclusive = end_inclusive + def _key(self): + """A tuple key that uniquely describes this field. + + Used to compute this instance's hashcode and evaluate equality. + + Returns: + Tuple[str]: The contents of this :class:`.RowRange`. + """ + return ( + self.start_key, + self.start_inclusive, + self.end_key, + self.end_inclusive, + ) + + def __hash__(self): + return hash(self._key()) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self._key() == other._key() + + def __ne__(self, other): + return not self == other + def get_range_kwargs(self): """ Convert row range object to dict which can be passed to google.bigtable.v2.RowRange add method. diff --git a/bigtable/tests/unit/test_row_set.py b/bigtable/tests/unit/test_row_set.py index 84640b616f98..990173b376c1 100644 --- a/bigtable/tests/unit/test_row_set.py +++ b/bigtable/tests/unit/test_row_set.py @@ -32,6 +32,115 @@ def test_constructor(self): self.assertEqual([], row_set.row_keys) self.assertEqual([], row_set.row_ranges) + def test__eq__(self): + row_key1 = b"row_key1" + row_key2 = b"row_key1" + row_range1 = RowRange(b"row_key4", b"row_key9") + row_range2 = RowRange(b"row_key4", b"row_key9") + + row_set1 = self._make_one() + row_set2 = self._make_one() + + row_set1.add_row_key(row_key1) + row_set2.add_row_key(row_key2) + row_set1.add_row_range(row_range1) + row_set2.add_row_range(row_range2) + + self.assertEqual(row_set1, row_set2) + + def test__eq__type_differ(self): + row_set1 = self._make_one() + row_set2 = object() + self.assertNotEqual(row_set1, row_set2) + + def test__eq__len_row_keys_differ(self): + row_key1 = b"row_key1" + row_key2 = b"row_key1" + + row_set1 = self._make_one() + row_set2 = self._make_one() + + row_set1.add_row_key(row_key1) + row_set1.add_row_key(row_key2) + row_set2.add_row_key(row_key2) + + self.assertNotEqual(row_set1, row_set2) + + def test__eq__len_row_ranges_differ(self): + row_range1 = RowRange(b"row_key4", b"row_key9") + row_range2 = RowRange(b"row_key4", b"row_key9") + + row_set1 = self._make_one() + row_set2 = self._make_one() + + row_set1.add_row_range(row_range1) + row_set1.add_row_range(row_range2) + row_set2.add_row_range(row_range2) + + self.assertNotEqual(row_set1, row_set2) + + def test__eq__row_keys_differ(self): + row_set1 = self._make_one() + row_set2 = self._make_one() + + row_set1.add_row_key(b"row_key1") + row_set1.add_row_key(b"row_key2") + row_set1.add_row_key(b"row_key3") + row_set2.add_row_key(b"row_key1") + row_set2.add_row_key(b"row_key2") + row_set2.add_row_key(b"row_key4") + + self.assertNotEqual(row_set1, row_set2) + + def test__eq__row_ranges_differ(self): + row_range1 = RowRange(b"row_key4", b"row_key9") + row_range2 = RowRange(b"row_key14", b"row_key19") + row_range3 = RowRange(b"row_key24", b"row_key29") + + row_set1 = self._make_one() + row_set2 = self._make_one() + + row_set1.add_row_range(row_range1) + row_set1.add_row_range(row_range2) + row_set1.add_row_range(row_range3) + row_set2.add_row_range(row_range1) + row_set2.add_row_range(row_range2) + + self.assertNotEqual(row_set1, row_set2) + + def test__ne__(self): + row_key1 = b"row_key1" + row_key2 = b"row_key1" + row_range1 = RowRange(b"row_key4", b"row_key9") + row_range2 = RowRange(b"row_key5", b"row_key9") + + row_set1 = self._make_one() + row_set2 = self._make_one() + + row_set1.add_row_key(row_key1) + row_set2.add_row_key(row_key2) + row_set1.add_row_range(row_range1) + row_set2.add_row_range(row_range2) + + self.assertNotEqual(row_set1, row_set2) + + def test__ne__same_value(self): + row_key1 = b"row_key1" + row_key2 = b"row_key1" + row_range1 = RowRange(b"row_key4", b"row_key9") + row_range2 = RowRange(b"row_key4", b"row_key9") + + row_set1 = self._make_one() + row_set2 = self._make_one() + + row_set1.add_row_key(row_key1) + row_set2.add_row_key(row_key2) + row_set1.add_row_range(row_range1) + row_set2.add_row_range(row_range2) + + comparison_val = (row_set1 != row_set2) + self.assertFalse(comparison_val) + def test_add_row_key(self): row_set = self._make_one() row_set.add_row_key("row_key1") @@ -92,6 +201,56 @@ def test_constructor(self): self.assertTrue(row_range.start_inclusive) self.assertFalse(row_range.end_inclusive) + def test___hash__set_equality(self): + row_range1 = self._make_one('row_key1', 'row_key9') + row_range2 = self._make_one('row_key1', 'row_key9') + set_one = {row_range1, row_range2} + set_two = {row_range1, row_range2} + self.assertEqual(set_one, set_two) + + def test___hash__not_equals(self): + row_range1 = self._make_one('row_key1', 'row_key9') + row_range2 = self._make_one('row_key1', 'row_key19') + set_one = {row_range1} + set_two = {row_range2} + self.assertNotEqual(set_one, set_two) + + def test__eq__(self): + start_key = b"row_key1" + end_key = b"row_key9" + row_range1 = self._make_one(start_key, end_key, + True, False) + row_range2 = self._make_one(start_key, end_key, + True, False) + self.assertEqual(row_range1, row_range2) + + def test___eq__type_differ(self): + start_key = b"row_key1" + end_key = b"row_key9" + row_range1 = self._make_one(start_key, end_key, + True, False) + row_range2 = object() + self.assertNotEqual(row_range1, row_range2) + + def test__ne__(self): + start_key = b"row_key1" + end_key = b"row_key9" + row_range1 = self._make_one(start_key, end_key, + True, False) + row_range2 = self._make_one(start_key, end_key, + False, True) + self.assertNotEqual(row_range1, row_range2) + + def test__ne__same_value(self): + start_key = b"row_key1" + end_key = b"row_key9" + row_range1 = self._make_one(start_key, end_key, + True, False) + row_range2 = self._make_one(start_key, end_key, + True, False) + comparison_val = (row_range1 != row_range2) + self.assertFalse(comparison_val) + def test_get_range_kwargs_closed_open(self): start_key = b"row_key1" end_key = b"row_key9"