From c3c8061f3e22e50fb08404b254660006802f42a0 Mon Sep 17 00:00:00 2001 From: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> Date: Tue, 20 Jun 2023 13:41:09 +0200 Subject: [PATCH] fix: docvec equality if tensors are involved (#1663) Signed-off-by: Johannes Messner --- docarray/array/doc_vec/column_storage.py | 10 ++++- docarray/computation/abstract_comp_backend.py | 13 ++++++ docarray/computation/numpy_backend.py | 15 +++++++ docarray/computation/tensorflow_backend.py | 16 ++++++++ docarray/computation/torch_backend.py | 15 +++++++ tests/units/array/stack/test_array_stacked.py | 40 +++++++++++++++++++ 6 files changed, 108 insertions(+), 1 deletion(-) diff --git a/docarray/array/doc_vec/column_storage.py b/docarray/array/doc_vec/column_storage.py index e525c8aee0d..539e9fd42af 100644 --- a/docarray/array/doc_vec/column_storage.py +++ b/docarray/array/doc_vec/column_storage.py @@ -102,7 +102,15 @@ def __eq__(self, other: Any) -> bool: for key_self in col_map_self.keys(): if key_self == 'id': continue - if col_map_self[key_self] != col_map_other[key_self]: + + val1, val2 = col_map_self[key_self], col_map_other[key_self] + if isinstance(val1, AbstractTensor): + values_are_equal = val1.get_comp_backend().equal(val1, val2) + elif isinstance(val2, AbstractTensor): + values_are_equal = val2.get_comp_backend().equal(val1, val2) + else: + values_are_equal = val1 == val2 + if not values_are_equal: return False return True diff --git a/docarray/computation/abstract_comp_backend.py b/docarray/computation/abstract_comp_backend.py index 8e2be24cbfb..afaf4564e61 100644 --- a/docarray/computation/abstract_comp_backend.py +++ b/docarray/computation/abstract_comp_backend.py @@ -157,6 +157,19 @@ def minmax_normalize( """ ... + @classmethod + @abstractmethod + def equal(cls, tensor1: 'TTensor', tensor2: 'TTensor') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first tensor + :param tensor2: the second tensor + :return: True if two tensors are equal, False otherwise. + If one or more of the inputs is not a tensor of this framework, return False. + """ + ... + class Retrieval(ABC, typing.Generic[TTensorRetrieval]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/numpy_backend.py b/docarray/computation/numpy_backend.py index 30d50cc0174..913f42d429e 100644 --- a/docarray/computation/numpy_backend.py +++ b/docarray/computation/numpy_backend.py @@ -111,6 +111,21 @@ def minmax_normalize( return np.clip(r, *((a, b) if a < b else (b, a))) + @classmethod + def equal(cls, tensor1: 'np.ndarray', tensor2: 'np.ndarray') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first array + :param tensor2: the second array + :return: True if two arrays are equal, False otherwise. + If one or more of the inputs is not an ndarray, return False. + """ + are_np_arrays = isinstance(tensor1, np.ndarray) and isinstance( + tensor2, np.ndarray + ) + return are_np_arrays and np.array_equal(tensor1, tensor2) + class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/tensorflow_backend.py b/docarray/computation/tensorflow_backend.py index fc963cdb48b..27609b737e1 100644 --- a/docarray/computation/tensorflow_backend.py +++ b/docarray/computation/tensorflow_backend.py @@ -121,6 +121,22 @@ def minmax_normalize( normalized = tnp.clip(i, *((a, b) if a < b else (b, a))) return cls._cast_output(tf.cast(normalized, tensor.tensor.dtype)) + @classmethod + def equal(cls, tensor1: 'TensorFlowTensor', tensor2: 'TensorFlowTensor') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first tensor + :param tensor2: the second tensor + :return: True if two tensors are equal, False otherwise. + If one or more of the inputs is not a TensorFlowTensor, return False. + """ + t1, t2 = getattr(tensor1, 'tensor', None), getattr(tensor2, 'tensor', None) + if tf.is_tensor(t1) and tf.is_tensor(t2): + # mypy doesn't know that tf.is_tensor implies that t1, t2 are not None + return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) # type: ignore + return False + class Retrieval(AbstractComputationalBackend.Retrieval[TensorFlowTensor]): """ Abstract class for retrieval and ranking functionalities diff --git a/docarray/computation/torch_backend.py b/docarray/computation/torch_backend.py index be6d4ea03fd..97f0abbb3b5 100644 --- a/docarray/computation/torch_backend.py +++ b/docarray/computation/torch_backend.py @@ -113,6 +113,21 @@ def reshape(cls, tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tenso """ return tensor.reshape(shape) + @classmethod + def equal(cls, tensor1: 'torch.Tensor', tensor2: 'torch.Tensor') -> bool: + """ + Check if two tensors are equal. + + :param tensor1: the first tensor + :param tensor2: the second tensor + :return: True if two tensors are equal, False otherwise. + If one or more of the inputs is not a torch.Tensor, return False. + """ + are_torch = isinstance(tensor1, torch.Tensor) and isinstance( + tensor2, torch.Tensor + ) + return are_torch and torch.equal(tensor1, tensor2) + @classmethod def detach(cls, tensor: 'torch.Tensor') -> 'torch.Tensor': """ diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 4976aaddd31..35509338068 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -598,6 +598,46 @@ class Text(BaseDoc): assert da == da2.to_doc_vec() +@pytest.mark.parametrize('tensor_type', [TorchTensor, NdArray]) +def test_doc_vec_equality_tensor(tensor_type): + class Text(BaseDoc): + tens: tensor_type + + da = DocVec[Text]( + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=tensor_type + ) + da2 = DocVec[Text]( + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=tensor_type + ) + assert da == da2 + + da2 = DocVec[Text]( + [Text(tens=[1, 2, 3, 4, 5]) for _ in range(10)], tensor_type=tensor_type + ) + assert da != da2 + + +@pytest.mark.tensorflow +def test_doc_vec_equality_tf(): + from docarray.typing import TensorFlowTensor + + class Text(BaseDoc): + tens: TensorFlowTensor + + da = DocVec[Text]( + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorFlowTensor + ) + da2 = DocVec[Text]( + [Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorFlowTensor + ) + assert da == da2 + + da2 = DocVec[Text]( + [Text(tens=[1, 2, 3, 4, 5]) for _ in range(10)], tensor_type=TensorFlowTensor + ) + assert da != da2 + + def test_doc_vec_nested(batch_nested_doc): batch, Doc, Inner = batch_nested_doc batch2 = DocVec[Doc]([Doc(inner=Inner(hello='hello')) for _ in range(10)])