Skip to content

Commit

Permalink
fix: custom types with extra annotation. Closes #598 (#601)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosschroh committed Apr 12, 2024
1 parent 9bd8b45 commit 2766f1f
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 14 deletions.
14 changes: 10 additions & 4 deletions dataclasses_avroschema/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,12 +845,18 @@ def field_factory(
if native_type is None:
native_type = type(None)

if utils.is_annotated(native_type) and native_type not in ALL_TYPES_FIELD_CLASSES:
if utils.is_annotated(native_type):
a_type, *extra_args = get_args(native_type)
field_info = next((arg for arg in extra_args if isinstance(arg, types.FieldInfo)), None)
# it means that it is a custom type defined by us `Int32`, `Float32`,`TimeMicro` or `DateTimeMicro`
# or a known type Annotated with the end user
native_type = a_type

if field_info is not None:
# it means that it is a custom type defined by us `Int32`, `Float32`,`TimeMicro`, `DateTimeMicro`
# confixed or condecimal
native_type = utils.rebuild_annotation(a_type, field_info)

if native_type not in ALL_TYPES_FIELD_CLASSES:
# type Annotated with the end user
native_type = a_type

if native_type in IMMUTABLE_FIELDS_CLASSES:
klass = IMMUTABLE_FIELDS_CLASSES[native_type]
Expand Down
20 changes: 16 additions & 4 deletions dataclasses_avroschema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def __repr__(self) -> str:
return f"FixedFieldInfo(size={self.size}, aliases={self.aliases}, namespace={self.namespace})"


class Int32FieldInfo(FieldInfo): ...


class Float32FieldInfo(FieldInfo): ...


class TimeMicroFieldInfo(FieldInfo): ...


class DateTimeMicro2FieldInfo(FieldInfo): ...


def confixed(
*,
size,
Expand All @@ -67,10 +79,10 @@ def condecimal(*, max_digits: int, decimal_places: int) -> typing.Type[decimal.D
] # type: ignore[return-value]


Int32 = Annotated[int, "Int32"]
Float32 = Annotated[float, "Float32"]
TimeMicro = Annotated[datetime.time, "TimeMicro"]
DateTimeMicro = Annotated[datetime.datetime, "DateTimeMicro"]
Int32 = Annotated[int, Int32FieldInfo()]
Float32 = Annotated[float, Float32FieldInfo()]
TimeMicro = Annotated[datetime.time, TimeMicroFieldInfo()]
DateTimeMicro = Annotated[datetime.datetime, DateTimeMicro2FieldInfo()]

CUSTOM_TYPES = (
Int32,
Expand Down
6 changes: 5 additions & 1 deletion dataclasses_avroschema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing_extensions import Annotated, get_origin

from .types import JsonDict
from .types import FieldInfo, JsonDict

try:
import pydantic # pragma: no cover
Expand Down Expand Up @@ -64,6 +64,10 @@ def is_annotated(a_type: typing.Any) -> bool:
return origin is not None and isinstance(origin, type) and issubclass(origin, Annotated) # type: ignore[arg-type]


def rebuild_annotation(a_type: typing.Any, field_info: FieldInfo) -> typing.Type:
return Annotated[a_type, field_info] # type: ignore[return-value]


def standardize_custom_type(value: typing.Any) -> typing.Any:
if isinstance(value, dict):
return {k: standardize_custom_type(v) for k, v in value.items()}
Expand Down
20 changes: 15 additions & 5 deletions tests/fields/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from typing_extensions import Annotated

from dataclasses_avroschema import types
from dataclasses_avroschema.fields import field_utils

PY_VER = sys.version_info
Expand All @@ -21,6 +22,10 @@
(bytes, field_utils.BYTES),
(None, field_utils.NULL),
(type(None), field_utils.NULL),
(types.Int32, field_utils.INT),
(Annotated[types.Int32, "ExtraAnnotation"], field_utils.INT),
(types.Float32, field_utils.FLOAT),
(Annotated[types.Float32, "ExtraAnnotation"], field_utils.FLOAT),
(Annotated[str, "string"], field_utils.STRING),
(Annotated[int, "integer"], field_utils.LONG),
(Annotated[bool, "boolean"], field_utils.BOOLEAN),
Expand Down Expand Up @@ -83,6 +88,8 @@
(bytes, b"test"),
(None, None),
(type(None), None),
(types.Int32, 10),
(types.Float32, 10.7),
(Annotated[str, "string"], "test"),
(Annotated[int, "int"], 1),
(Annotated[bool, "boolean"], True),
Expand All @@ -105,7 +112,9 @@
LOGICAL_TYPES = (
(datetime.date, field_utils.LOGICAL_DATE, now.date()),
(datetime.time, field_utils.LOGICAL_TIME_MILIS, now.time()),
(types.TimeMicro, field_utils.LOGICAL_TIME_MICROS, now.time()),
(datetime.datetime, field_utils.LOGICAL_DATETIME_MILIS, now),
(types.DateTimeMicro, field_utils.LOGICAL_DATETIME_MICROS, now),
(uuid.UUID, field_utils.LOGICAL_UUID, uuid.uuid4()),
(Annotated[datetime.date, "date"], field_utils.LOGICAL_DATE, now.date()),
(Annotated[datetime.time, "time"], field_utils.LOGICAL_TIME_MILIS, now.time()),
Expand Down Expand Up @@ -399,14 +408,15 @@ def xfail_annotation(typ):
# Represent the logical types
# (python_type, avro_type)
LOGICAL_TYPES = (
(datetime.date, {"type": field_utils.INT, "logicalType": field_utils.DATE}),
(datetime.time, {"type": field_utils.INT, "logicalType": field_utils.TIME_MILLIS}),
(datetime.date, field_utils.LOGICAL_DATE),
(datetime.time, field_utils.LOGICAL_TIME_MILIS),
(types.TimeMicro, field_utils.LOGICAL_TIME_MICROS),
(
datetime.datetime,
{"type": field_utils.LONG, "logicalType": field_utils.TIMESTAMP_MILLIS},
field_utils.LOGICAL_DATETIME_MILIS,
),
(uuid.uuid4, {"type": field_utils.STRING, "logicalType": field_utils.UUID}),
(uuid.UUID, {"type": field_utils.STRING, "logicalType": field_utils.UUID}),
(uuid.uuid4, field_utils.LOGICAL_UUID),
(uuid.UUID, field_utils.LOGICAL_UUID),
)

LOGICAL_TYPES_AND_INVALID_DEFAULTS = (
Expand Down

0 comments on commit 2766f1f

Please sign in to comment.