Skip to content

Commit

Permalink
feat: tensor coersion (#1588)
Browse files Browse the repository at this point in the history
Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
  • Loading branch information
samsja committed May 31, 2023
1 parent 5d0e24c commit 5e74fcc
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 4 deletions.
14 changes: 14 additions & 0 deletions docarray/typing/tensor/abstract_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,17 @@ def _docarray_to_json_compatible(self):
:return: a representation of the tensor compatible with orjson
"""
return self

@classmethod
@abc.abstractmethod
def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
"""Create a `tensor from a numpy array
PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
"""
...

@abc.abstractmethod
def _docarray_to_ndarray(self) -> np.ndarray:
"""cast itself to a numpy array"""
...
19 changes: 17 additions & 2 deletions docarray/typing/tensor/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ def validate(
return cls._docarray_from_native(value)
elif isinstance(value, NdArray):
return cast(T, value)
elif isinstance(value, AbstractTensor):
return cls._docarray_from_native(value._docarray_to_ndarray())
elif torch_available and isinstance(value, torch.Tensor):
return cls._docarray_from_native(value.detach().cpu().numpy())
elif tf_available and isinstance(value, tf.Tensor):
return cls._docarray_from_native(value.numpy())
elif tf_available and isinstance(value, TensorFlowTensor):
return cls._docarray_from_native(value.tensor.numpy())
elif isinstance(value, list) or isinstance(value, tuple):
try:
arr_from_list: np.ndarray = np.asarray(value)
Expand Down Expand Up @@ -219,3 +219,18 @@ def get_comp_backend() -> 'NumpyCompBackend':
def __class_getitem__(cls, item: Any, *args, **kwargs):
# see here for mypy bug: https://github.com/python/mypy/issues/14123
return AbstractTensor.__class_getitem__.__func__(cls, item) # type: ignore

@classmethod
def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
"""Create a `tensor from a numpy array
PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
"""
return cls._docarray_from_native(value)

def _docarray_to_ndarray(self) -> np.ndarray:
"""Create a `tensor from a numpy array
PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
"""
return self.unwrap()
23 changes: 22 additions & 1 deletion docarray/typing/tensor/tensorflow_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from docarray.base_doc.base_node import BaseNode
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.misc import import_library, is_torch_available

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand All @@ -17,6 +17,9 @@
else:
tf = import_library('tensorflow', raise_error=True)

torch_available = is_torch_available()
if torch_available:
import torch

T = TypeVar('T', bound='TensorFlowTensor')
ShapeT = TypeVar('ShapeT')
Expand Down Expand Up @@ -202,6 +205,12 @@ def validate(
return cast(T, value)
elif isinstance(value, tf.Tensor):
return cls._docarray_from_native(value)
elif isinstance(value, np.ndarray):
return cls._docarray_from_ndarray(value)
elif isinstance(value, AbstractTensor):
return cls._docarray_from_ndarray(value._docarray_to_ndarray())
elif torch_available and isinstance(value, torch.Tensor):
return cls._docarray_from_native(value.detach().cpu().numpy())
else:
try:
arr: tf.Tensor = tf.constant(value)
Expand Down Expand Up @@ -320,3 +329,15 @@ def unwrap(self) -> tf.Tensor:

def __len__(self) -> int:
return len(self.tensor)

@classmethod
def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
"""Create a `tensor from a numpy array
PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
"""
return cls.from_ndarray(value)

def _docarray_to_ndarray(self) -> np.ndarray:
"""cast itself to a numpy array"""
return self.tensor.numpy()
23 changes: 22 additions & 1 deletion docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from docarray.base_doc.base_node import BaseNode
from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.misc import import_library, is_tf_available

if TYPE_CHECKING:
import torch
Expand All @@ -18,6 +18,9 @@
else:
torch = import_library('torch', raise_error=True)

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

T = TypeVar('T', bound='TorchTensor')
ShapeT = TypeVar('ShapeT')
Expand Down Expand Up @@ -123,6 +126,12 @@ def validate(
return cast(T, value)
elif isinstance(value, torch.Tensor):
return cls._docarray_from_native(value)
elif isinstance(value, AbstractTensor):
return cls._docarray_from_ndarray(value._docarray_to_ndarray())
elif tf_available and isinstance(value, tf.Tensor):
return cls._docarray_from_ndarray(value.numpy())
elif isinstance(value, np.ndarray):
return cls._docarray_from_ndarray(value)
else:
try:
arr: torch.Tensor = torch.tensor(value)
Expand Down Expand Up @@ -240,3 +249,15 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
torch.Tensor if t in docarray_torch_tensors else t for t in types
)
return super().__torch_function__(func, types_, args, kwargs)

@classmethod
def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
"""Create a `tensor from a numpy array
PS: this function is different from `from_ndarray` because it is private under the docarray namesapce.
This allows us to avoid breaking change if one day we introduce a Tensor backend with a `from_ndarray` method.
"""
return cls.from_ndarray(value)

def _docarray_to_ndarray(self) -> np.ndarray:
"""cast itself to a numpy array"""
return self.detach().cpu().numpy()
50 changes: 50 additions & 0 deletions tests/units/typing/tensor/test_tensor_coercion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import pytest
import torch
from pydantic import parse_obj_as

from docarray.typing import NdArray, TorchTensor
from docarray.utils._internal.misc import is_tf_available

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf

from docarray.typing import TensorFlowTensor
else:

### This is needed to fake the import of tensorflow when it is not installed
class TfNotInstalled:
def zeros(self, *args, **kwargs):
return 0

class TensorFlowTensor:
def _docarray_from_native(self, *args, **kwargs):
return 0

tf = TfNotInstalled()


pure_tensor_to_test = [
np.zeros((3, 224, 224)),
torch.zeros(3, 224, 224),
tf.zeros((3, 224, 224)),
]

docarray_tensor_to_test = [
NdArray._docarray_from_native(np.zeros((3, 224, 224))),
TorchTensor._docarray_from_native(torch.zeros(3, 224, 224)),
TensorFlowTensor._docarray_from_native(tf.zeros((3, 224, 224))),
]


@pytest.mark.tensorflow
@pytest.mark.parametrize('tensor', pure_tensor_to_test + docarray_tensor_to_test)
@pytest.mark.parametrize('tensor_cls', [NdArray, TorchTensor, TensorFlowTensor])
def test_torch_tensor_coerse(tensor_cls, tensor):
t = parse_obj_as(tensor_cls, tensor)
assert isinstance(t, tensor_cls)

t_numpy = t._docarray_to_ndarray()
assert t_numpy.shape == (3, 224, 224)
assert (t_numpy == np.zeros((3, 224, 224))).all()

0 comments on commit 5e74fcc

Please sign in to comment.