Skip to content

Commit

Permalink
fix: validate before (#1806)
Browse files Browse the repository at this point in the history
Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
Co-authored-by: Joan Fontanals <joan.martinez@jina.ai>
  • Loading branch information
samsja and JoanFM committed Sep 27, 2023
1 parent 7209b78 commit 26d776d
Show file tree
Hide file tree
Showing 14 changed files with 153 additions and 86 deletions.
34 changes: 25 additions & 9 deletions docarray/documents/audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union

import numpy as np

from pydantic import Field

from docarray.base_doc import BaseDoc
Expand All @@ -10,6 +9,10 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.audio.audio_tensor import AudioTensor
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand Down Expand Up @@ -121,17 +124,30 @@ class MultiModalDoc(BaseDoc):
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
def _validate(cls, value) -> Dict[str, Any]:
if isinstance(value, str):
value = cls(url=value)
value = dict(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch is not None
and isinstance(value, torch.Tensor)
or (tf is not None and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)
value = dict(tensor=value)

return value

if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
return cls._validate(value)

else:

return super().validate(value)
@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
return super().validate(cls._validate(value))
35 changes: 25 additions & 10 deletions docarray/documents/image.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union

import numpy as np

from pydantic import Field

from docarray.base_doc import BaseDoc
from docarray.typing import AnyEmbedding, ImageBytes, ImageUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.image.image_tensor import ImageTensor
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand Down Expand Up @@ -115,19 +117,32 @@ class MultiModalDoc(BaseDoc):
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
def _validate(cls, value) -> Dict[str, Any]:
if isinstance(value, str):
value = cls(url=value)
value = dict(url=value)
elif (
isinstance(value, (AbstractTensor, np.ndarray))
or (torch is not None and isinstance(value, torch.Tensor))
or (tf is not None and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)
value = dict(tensor=value)
elif isinstance(value, bytes):
value = cls(byte=value)
value = dict(byte=value)

return value

if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
return cls._validate(value)

else:

return super().validate(value)
@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
return super().validate(cls._validate(value))
30 changes: 22 additions & 8 deletions docarray/documents/mesh/mesh_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces
from docarray.typing.tensor.embedding import AnyEmbedding
from docarray.typing.url.url_3d.mesh_url import Mesh3DUrl
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

T = TypeVar('T', bound='Mesh3D')

Expand Down Expand Up @@ -125,11 +128,22 @@ class MultiModalDoc(BaseDoc):
default=None,
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
return super().validate(value)
if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
if isinstance(value, str):
return {'url': value}
return value

else:

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
return super().validate(value)
32 changes: 24 additions & 8 deletions docarray/documents/point_cloud/point_cloud_3d.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union

import numpy as np

from pydantic import Field

from docarray.base_doc import BaseDoc
from docarray.documents.point_cloud.points_and_colors import PointsAndColors
from docarray.typing import AnyEmbedding, PointCloud3DUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand Down Expand Up @@ -130,17 +133,30 @@ class MultiModalDoc(BaseDoc):
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
def _validate(self, value: Union[str, AbstractTensor, Any]) -> Any:
if isinstance(value, str):
value = cls(url=value)
value = {'url': value}
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch is not None
and isinstance(value, torch.Tensor)
or (tf is not None and isinstance(value, tf.Tensor))
):
value = cls(tensors=PointsAndColors(points=value))
value = {'tensors': PointsAndColors(points=value)}

return value

if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
return cls._validate(value)

else:

return super().validate(value)
@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
return super().validate(cls._validate(value))
32 changes: 24 additions & 8 deletions docarray/documents/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from docarray.base_doc import BaseDoc
from docarray.typing import TextUrl
from docarray.typing.tensor.embedding import AnyEmbedding
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

T = TypeVar('T', bound='TextDoc')

Expand Down Expand Up @@ -129,14 +133,26 @@ def __init__(self, text: Optional[str] = None, **kwargs):
kwargs['text'] = text
super().__init__(**kwargs)

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(text=value)
return super().validate(value)
if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, values):
if isinstance(values, str):
return {'text': values}
else:
return values

else:

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(text=value)
return super().validate(value)

def __eq__(self, other: Any) -> bool:
if isinstance(other, str):
Expand Down
34 changes: 25 additions & 9 deletions docarray/documents/video.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union

import numpy as np

from pydantic import Field

from docarray.base_doc import BaseDoc
Expand All @@ -11,6 +10,10 @@
from docarray.typing.tensor.video.video_tensor import VideoTensor
from docarray.typing.url.video_url import VideoUrl
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand Down Expand Up @@ -131,17 +134,30 @@ class MultiModalDoc(BaseDoc):
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
def _validate(cls, value) -> Dict[str, Any]:
if isinstance(value, str):
value = cls(url=value)
value = dict(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch is not None
and isinstance(value, torch.Tensor)
or (tf is not None and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)
value = dict(tensor=value)

return value

if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
return cls._validate(value)

else:

return super().validate(value)
@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
return super().validate(cls._validate(value))
1 change: 0 additions & 1 deletion docarray/typing/tensor/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def _docarray_validate(
return cls._docarray_from_native(arr)
except Exception:
pass # handled below
breakpoint()
raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}')

@classmethod
Expand Down
6 changes: 0 additions & 6 deletions tests/integrations/predefined_document/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from docarray.typing import AudioUrl
from docarray.typing.tensor.audio import AudioNdArray, AudioTorchTensor
from docarray.utils._internal.misc import is_tf_available
from docarray.utils._internal.pydantic import is_pydantic_v2
from tests import TOYDATA_DIR

tf_available = is_tf_available()
Expand Down Expand Up @@ -184,32 +183,27 @@ class MyAudio(AudioDoc):


# Validating predefined docs against url or tensor is not yet working with pydantic v28
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_audio_np():
audio = parse_obj_as(AudioDoc, np.zeros((10, 10, 3)))
assert (audio.tensor == np.zeros((10, 10, 3))).all()


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_audio_torch():
audio = parse_obj_as(AudioDoc, torch.zeros(10, 10, 3))
assert (audio.tensor == torch.zeros(10, 10, 3)).all()


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
@pytest.mark.tensorflow
def test_audio_tensorflow():
audio = parse_obj_as(AudioDoc, tf.zeros((10, 10, 3)))
assert tnp.allclose(audio.tensor.tensor, tf.zeros((10, 10, 3)))


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_audio_bytes():
audio = parse_obj_as(AudioDoc, torch.zeros(10, 10, 3))
audio.bytes_ = audio.tensor.to_bytes()


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_audio_shortcut_doc():
class MyDoc(BaseDoc):
audio: AudioDoc
Expand Down
Loading

0 comments on commit 26d776d

Please sign in to comment.