Skip to content

Commit

Permalink
fix: docvec equality if tensors are involved (#1663)
Browse files Browse the repository at this point in the history
Signed-off-by: Johannes Messner <messnerjo@gmail.com>
  • Loading branch information
JohannesMessner committed Jun 20, 2023
1 parent 4e6bf49 commit c3c8061
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 1 deletion.
10 changes: 9 additions & 1 deletion docarray/array/doc_vec/column_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,15 @@ def __eq__(self, other: Any) -> bool:
for key_self in col_map_self.keys():
if key_self == 'id':
continue
if col_map_self[key_self] != col_map_other[key_self]:

val1, val2 = col_map_self[key_self], col_map_other[key_self]
if isinstance(val1, AbstractTensor):
values_are_equal = val1.get_comp_backend().equal(val1, val2)
elif isinstance(val2, AbstractTensor):
values_are_equal = val2.get_comp_backend().equal(val1, val2)
else:
values_are_equal = val1 == val2
if not values_are_equal:
return False
return True

Expand Down
13 changes: 13 additions & 0 deletions docarray/computation/abstract_comp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,19 @@ def minmax_normalize(
"""
...

@classmethod
@abstractmethod
def equal(cls, tensor1: 'TTensor', tensor2: 'TTensor') -> bool:
"""
Check if two tensors are equal.
:param tensor1: the first tensor
:param tensor2: the second tensor
:return: True if two tensors are equal, False otherwise.
If one or more of the inputs is not a tensor of this framework, return False.
"""
...

class Retrieval(ABC, typing.Generic[TTensorRetrieval]):
"""
Abstract class for retrieval and ranking functionalities
Expand Down
15 changes: 15 additions & 0 deletions docarray/computation/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,21 @@ def minmax_normalize(

return np.clip(r, *((a, b) if a < b else (b, a)))

@classmethod
def equal(cls, tensor1: 'np.ndarray', tensor2: 'np.ndarray') -> bool:
"""
Check if two tensors are equal.
:param tensor1: the first array
:param tensor2: the second array
:return: True if two arrays are equal, False otherwise.
If one or more of the inputs is not an ndarray, return False.
"""
are_np_arrays = isinstance(tensor1, np.ndarray) and isinstance(
tensor2, np.ndarray
)
return are_np_arrays and np.array_equal(tensor1, tensor2)

class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]):
"""
Abstract class for retrieval and ranking functionalities
Expand Down
16 changes: 16 additions & 0 deletions docarray/computation/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ def minmax_normalize(
normalized = tnp.clip(i, *((a, b) if a < b else (b, a)))
return cls._cast_output(tf.cast(normalized, tensor.tensor.dtype))

@classmethod
def equal(cls, tensor1: 'TensorFlowTensor', tensor2: 'TensorFlowTensor') -> bool:
"""
Check if two tensors are equal.
:param tensor1: the first tensor
:param tensor2: the second tensor
:return: True if two tensors are equal, False otherwise.
If one or more of the inputs is not a TensorFlowTensor, return False.
"""
t1, t2 = getattr(tensor1, 'tensor', None), getattr(tensor2, 'tensor', None)
if tf.is_tensor(t1) and tf.is_tensor(t2):
# mypy doesn't know that tf.is_tensor implies that t1, t2 are not None
return t1.shape == t2.shape and tf.math.reduce_all(tf.equal(t1, t1)) # type: ignore
return False

class Retrieval(AbstractComputationalBackend.Retrieval[TensorFlowTensor]):
"""
Abstract class for retrieval and ranking functionalities
Expand Down
15 changes: 15 additions & 0 deletions docarray/computation/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ def reshape(cls, tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tenso
"""
return tensor.reshape(shape)

@classmethod
def equal(cls, tensor1: 'torch.Tensor', tensor2: 'torch.Tensor') -> bool:
"""
Check if two tensors are equal.
:param tensor1: the first tensor
:param tensor2: the second tensor
:return: True if two tensors are equal, False otherwise.
If one or more of the inputs is not a torch.Tensor, return False.
"""
are_torch = isinstance(tensor1, torch.Tensor) and isinstance(
tensor2, torch.Tensor
)
return are_torch and torch.equal(tensor1, tensor2)

@classmethod
def detach(cls, tensor: 'torch.Tensor') -> 'torch.Tensor':
"""
Expand Down
40 changes: 40 additions & 0 deletions tests/units/array/stack/test_array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,46 @@ class Text(BaseDoc):
assert da == da2.to_doc_vec()


@pytest.mark.parametrize('tensor_type', [TorchTensor, NdArray])
def test_doc_vec_equality_tensor(tensor_type):
class Text(BaseDoc):
tens: tensor_type

da = DocVec[Text](
[Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=tensor_type
)
da2 = DocVec[Text](
[Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=tensor_type
)
assert da == da2

da2 = DocVec[Text](
[Text(tens=[1, 2, 3, 4, 5]) for _ in range(10)], tensor_type=tensor_type
)
assert da != da2


@pytest.mark.tensorflow
def test_doc_vec_equality_tf():
from docarray.typing import TensorFlowTensor

class Text(BaseDoc):
tens: TensorFlowTensor

da = DocVec[Text](
[Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorFlowTensor
)
da2 = DocVec[Text](
[Text(tens=[1, 2, 3, 4]) for _ in range(10)], tensor_type=TensorFlowTensor
)
assert da == da2

da2 = DocVec[Text](
[Text(tens=[1, 2, 3, 4, 5]) for _ in range(10)], tensor_type=TensorFlowTensor
)
assert da != da2


def test_doc_vec_nested(batch_nested_doc):
batch, Doc, Inner = batch_nested_doc
batch2 = DocVec[Doc]([Doc(inner=Inner(hello='hello')) for _ in range(10)])
Expand Down

0 comments on commit c3c8061

Please sign in to comment.