diff --git a/README.md b/README.md index 1cdbceee..ca311e64 100644 --- a/README.md +++ b/README.md @@ -329,8 +329,9 @@ User.fake() ## Features -* [X] Primitive types: int, long, double, float, boolean, string and null support -* [X] Complex types: enum, array, map, fixed, unions and records support +* [x] Primitive types: int, long, double, float, boolean, string and null support +* [x] Complex types: enum, array, map, fixed, unions and records support +* [x] `typing.Annotated` supported * [x] Logical Types: date, time (millis and micro), datetime (millis and micro), uuid support * [X] Schema relations (oneToOne, oneToMany) * [X] Recursive Schemas diff --git a/dataclasses_avroschema/field_utils.py b/dataclasses_avroschema/field_utils.py index 575383da..365ffeab 100644 --- a/dataclasses_avroschema/field_utils.py +++ b/dataclasses_avroschema/field_utils.py @@ -91,7 +91,17 @@ } # excluding tuple because is a container -PYTHON_INMUTABLE_TYPES = (str, int, types.Int32, types.Float32, bool, float, bytes, type(None)) +PYTHON_INMUTABLE_TYPES = ( + str, + int, + types.Int32, + types.Float32, + bool, + float, + bytes, + type(None), +) + PYTHON_PRIMITIVE_CONTAINERS = (list, tuple, dict) PYTHON_LOGICAL_TYPES = ( diff --git a/dataclasses_avroschema/fields.py b/dataclasses_avroschema/fields.py index fb4d62ae..c3e64094 100644 --- a/dataclasses_avroschema/fields.py +++ b/dataclasses_avroschema/fields.py @@ -16,6 +16,7 @@ import inflect from faker import Faker from pytz import utc +from typing_extensions import get_args from dataclasses_avroschema import schema_generator, serialization, types, utils @@ -49,6 +50,7 @@ class BaseField: type: typing.Any # store the python primitive type default: typing.Any parent: typing.Any + field_info: typing.Optional[types.FieldInfo] = None metadata: typing.Optional[typing.Mapping] = None model_metadata: typing.Optional[utils.SchemaMetadata] = None @@ -122,11 +124,12 @@ def get_default_value(self) -> typing.Any: return self.default def validate_default(self) -> bool: + a_type = self.type msg = f"Invalid default type. Default should be {self.type}" - if getattr(self.type, "__metadata__", [None])[0] in types.CUSTOM_TYPES: - assert isinstance(self.default, self.type.__origin__) - else: - assert isinstance(self.default, self.type), msg + if utils.is_annotated(self.type): + a_type, _ = get_args(self.type) + + assert isinstance(self.default, a_type), msg return True @@ -929,12 +932,21 @@ def field_factory( *, default: typing.Any = dataclasses.MISSING, default_factory: typing.Any = dataclasses.MISSING, - metadata: typing.Optional[typing.Mapping] = None, + metadata: typing.Optional[typing.Dict[str, typing.Any]] = None, model_metadata: typing.Optional[utils.SchemaMetadata] = None, ) -> FieldType: if metadata is None: metadata = {} + field_info = None + + if utils.is_annotated(native_type): + a_type, *extra_args = get_args(native_type) + field_info = next((arg for arg in extra_args if isinstance(arg, types.FieldInfo)), None) + + if field_info is None: + native_type = a_type + if native_type in field_utils.PYTHON_INMUTABLE_TYPES: klass = INMUTABLE_FIELDS_CLASSES[native_type] return klass( @@ -944,6 +956,7 @@ def field_factory( metadata=metadata, model_metadata=model_metadata, parent=parent, + field_info=field_info, ) elif utils.is_self_referenced(native_type): return SelfReferenceField( @@ -953,6 +966,7 @@ def field_factory( metadata=metadata, model_metadata=model_metadata, parent=parent, + field_info=field_info, ) elif native_type is types.Fixed: return FixedField( @@ -962,6 +976,7 @@ def field_factory( metadata=metadata, model_metadata=model_metadata, parent=parent, + field_info=field_info, ) elif native_type in field_utils.PYTHON_LOGICAL_TYPES: klass = LOGICAL_TYPES_FIELDS_CLASSES[native_type] # type: ignore @@ -973,6 +988,7 @@ def field_factory( metadata=metadata, model_metadata=model_metadata, parent=parent, + field_info=field_info, ) elif isinstance(native_type, GenericAlias): # type: ignore origin = native_type.__origin__ @@ -1003,6 +1019,7 @@ def field_factory( default_factory=default_factory, model_metadata=model_metadata, parent=parent, + field_info=field_info, ) elif inspect.isclass(native_type) and issubclass(native_type, enum.Enum): return EnumField( @@ -1012,6 +1029,7 @@ def field_factory( metadata=metadata, model_metadata=model_metadata, parent=parent, + field_info=field_info, ) elif types.UnionType is not None and isinstance(native_type, types.UnionType): # we need to check whether types.UnionType because it works only in @@ -1026,6 +1044,7 @@ def field_factory( default_factory=default_factory, model_metadata=model_metadata, parent=parent, + field_info=field_info, ) elif inspect.isclass(native_type) and issubclass(native_type, schema_generator.AvroModel): return RecordField( @@ -1035,6 +1054,7 @@ def field_factory( metadata=metadata, model_metadata=model_metadata, parent=parent, + field_info=field_info, ) else: msg = ( diff --git a/dataclasses_avroschema/schema_definition.py b/dataclasses_avroschema/schema_definition.py index 103cf65f..3f9dcf1b 100644 --- a/dataclasses_avroschema/schema_definition.py +++ b/dataclasses_avroschema/schema_definition.py @@ -68,7 +68,7 @@ def parse_fields(self) -> typing.List[FieldType]: dataclass_field.type, default=dataclass_field.default, default_factory=dataclass_field.default_factory, # type: ignore # TODO: resolve mypy - metadata=dataclass_field.metadata, + metadata=dict(dataclass_field.metadata), model_metadata=self.metadata, parent=self.parent, ) @@ -100,7 +100,7 @@ def parse_faust_fields(self) -> typing.List[FieldType]: dataclass_field.type, default=default, default_factory=default_factory, - metadata=metadata, + metadata=dict(metadata), model_metadata=self.metadata, parent=self.parent, ) diff --git a/dataclasses_avroschema/types.py b/dataclasses_avroschema/types.py index 0c7e5510..44150897 100644 --- a/dataclasses_avroschema/types.py +++ b/dataclasses_avroschema/types.py @@ -28,6 +28,21 @@ __all__ = CUSTOM_TYPES +class FieldInfo: + def __init__(self, **kwargs) -> None: + self.type = kwargs.get("type") + self.max_digits = kwargs.get("max_digits") + self.decimal_places = kwargs.get("decimal_places") + + @property + def metadata(self): + return { + "type": self.type, + "max_digits": self.max_digits, + "decimal_places": self.decimal_places, + } + + class MissingSentinel(typing.Generic[T]): """ Class to detect when a field is not initialized @@ -80,7 +95,7 @@ def __repr__(self) -> str: return f"Decimal('{self.default}')" -Int32 = Annotated[int, "Int32"] -Float32 = Annotated[float, "Float32"] -TimeMicro = Annotated[datetime.time, "TimeMicro"] -DateTimeMicro = Annotated[datetime.datetime, "DateTimeMicro"] +Int32 = Annotated[int, FieldInfo(type="Int32")] +Float32 = Annotated[float, FieldInfo(type="Float32")] +TimeMicro = Annotated[datetime.time, FieldInfo(type="TimeMicro")] +DateTimeMicro = Annotated[datetime.datetime, FieldInfo(type="DateTimeMicro")] diff --git a/dataclasses_avroschema/utils.py b/dataclasses_avroschema/utils.py index b5c511a8..4008d3ce 100644 --- a/dataclasses_avroschema/utils.py +++ b/dataclasses_avroschema/utils.py @@ -3,6 +3,7 @@ from datetime import datetime from pytz import utc +from typing_extensions import Annotated, get_origin from .types import JsonDict @@ -69,6 +70,11 @@ def is_self_referenced(a_type: type) -> bool: ) +def is_annotated(a_type: typing.Any) -> bool: + origin = get_origin(a_type) + return origin is not None and isinstance(origin, type) and issubclass(origin, Annotated) # type: ignore[arg-type] + + @dataclasses.dataclass class SchemaMetadata: schema_name: typing.Optional[str] = None diff --git a/docs/fields_specification.md b/docs/fields_specification.md index e8895ce3..45e256f4 100644 --- a/docs/fields_specification.md +++ b/docs/fields_specification.md @@ -137,6 +137,62 @@ Python Type | Avro Type | Logical Type | | uuid.uuid4 | string | uuid | | uuid.UUID | string | uuid | +## typing.Annotated + +All the types can be [Annotated](https://docs.python.org/3/library/typing.html#typing.Annotated) so `metadata` can be added to the fields. This library will use the `python type` to generate the `avro field` and it will ignore the extra `metadata`. + +```python title="Annotated" +import dataclasses +import enum +import typing + +from dataclasses_avroschema import AvroModel + + +class FavoriteColor(str, enum.Enum): + BLUE = "BLUE" + YELLOW = "YELLOW" + GREEN = "GREEN" + + +@dataclasses.dataclass +class UserAdvance(AvroModel): + name: typing.Annotated[str, "string"] + age: typing.Annotated[int, "integer"] + pets: typing.List[typing.Annotated[str, "string"]] + accounts: typing.Dict[str, typing.Annotated[int, "integer"]] + favorite_colors: typing.Annotated[FavoriteColor, "a color enum"] + has_car: typing.Annotated[bool, "boolean"] = False + country: str = "Argentina" + address: typing.Optional[typing.Annotated[str, "string"]] = None + + class Meta: + schema_doc = False + + +UserAdvance.avro_schema() +``` + +resulting in + +```json +{ + "type": "record", + "name": "UserAdvance", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "long"}, + {"name": "pets", "type": {"type": "array", "items": "string", "name": "pet"}}, + {"name": "accounts", "type": {"type": "map", "values": "long", "name": "account"}}, + {"name": "favorite_colors", "type": {"type": "enum", "name": "FavoriteColor", "symbols": ["BLUE", "YELLOW", "GREEN"]}}, + {"name": "has_car", "type": "boolean", "default": false}, + {"name": "country", "type": "string", "default": "Argentina"}, + {"name": "address", "type": ["null", "string"], "default": null}] +}' +``` + +*(This script is complete, it should run "as is")* + ## Adding Custom Field-level Attributes You may want to add field-level attributes which are not automatically populated according to the typing semantics @@ -149,8 +205,6 @@ to all fields such as `"name"` and others are specific to the datatype (e.g. `ar In order to add custom fields, you can use the `field` descriptor of the built-in `dataclasses` package and provide a `dict` of key-value pairs to the `metadata` parameter as in `dataclasses.field(metadata={'doc': 'foo'})`. -### Examples - ```python title="Adding a doc attribute to fields" from dataclasses import dataclass, field from dataclasses_avroschema import AvroModel, types diff --git a/tests/conftest.py b/tests/conftest.py index befaa37c..dbcaa48d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import typing import pytest +from typing_extensions import Annotated from dataclasses_avroschema import AvroModel, types @@ -154,7 +155,7 @@ class UserAdvance(AvroModel): favorite_colors: color_enum has_car: bool = False country: str = "Argentina" - address: str = None + address: typing.Optional[str] = None user_type: typing.Union[int, user_type_enum] = -1 md5: types.Fixed = types.Fixed(16) @@ -164,6 +165,27 @@ class Meta: return UserAdvance +@pytest.fixture +def user_advance_dataclass_with_union_enum_with_annotated(color_enum: type, user_type_enum: type): + @dataclasses.dataclass + class UserAdvance(AvroModel): + name: Annotated[str, "string"] + age: Annotated[int, "integer"] + pets: typing.List[Annotated[str, "string"]] + accounts: typing.Dict[str, Annotated[int, "integer"]] + favorite_colors: Annotated[color_enum, "a color enum"] + has_car: Annotated[bool, "boolean"] = False + country: str = "Argentina" + address: typing.Optional[Annotated[str, "string"]] = None + user_type: typing.Union[Annotated[int, "integer"], user_type_enum] = -1 + md5: types.Fixed = types.Fixed(16) + + class Meta: + schema_doc = False + + return UserAdvance + + @pytest.fixture def user_advance_dataclass_with_sub_record_and_enum(color_enum: type, user_type_enum: type): @dataclasses.dataclass diff --git a/tests/fields/consts.py b/tests/fields/consts.py index c0482609..c0eafd6d 100644 --- a/tests/fields/consts.py +++ b/tests/fields/consts.py @@ -4,6 +4,7 @@ import uuid import pytest +from typing_extensions import Annotated from dataclasses_avroschema import field_utils @@ -17,6 +18,11 @@ (bool, field_utils.BOOLEAN), (float, field_utils.DOUBLE), (bytes, field_utils.BYTES), + (Annotated[str, "string"], field_utils.STRING), + (Annotated[int, "integer"], field_utils.LONG), + (Annotated[bool, "boolean"], field_utils.BOOLEAN), + (Annotated[float, "float"], field_utils.DOUBLE), + (Annotated[bytes, "bytes"], field_utils.BYTES), ) PRIMITIVE_TYPES_AND_DEFAULTS = ( @@ -25,6 +31,11 @@ (bool, True), (float, 10.4), (bytes, b"test"), + (Annotated[str, "string"], "test"), + (Annotated[int, "int"], 1), + (Annotated[bool, "boolen"], True), + (Annotated[float, "float"], 10.4), + (Annotated[bytes, "bytes"], b"test"), ) PRIMITIVE_TYPES_AND_INVALID_DEFAULTS = ( @@ -35,19 +46,15 @@ (bytes, "test"), ) -LIST_TYPE_AND_ITEMS_TYPE = ( - (str, "string"), - (int, "long"), - (bool, "boolean"), - (float, "double"), - (bytes, "bytes"), -) - LOGICAL_TYPES = ( (datetime.date, field_utils.LOGICAL_DATE, now.date()), (datetime.time, field_utils.LOGICAL_TIME_MILIS, now.time()), (datetime.datetime, field_utils.LOGICAL_DATETIME_MILIS, now), - (uuid.uuid4, field_utils.LOGICAL_UUID, uuid.uuid4()), + (uuid.UUID, field_utils.LOGICAL_UUID, uuid.uuid4()), + (Annotated[datetime.date, "date"], field_utils.LOGICAL_DATE, now.date()), + (Annotated[datetime.time, "time"], field_utils.LOGICAL_TIME_MILIS, now.time()), + (Annotated[datetime.datetime, "datetime"], field_utils.LOGICAL_DATETIME_MILIS, now), + (Annotated[uuid.UUID, "uuid"], field_utils.LOGICAL_UUID, uuid.uuid4()), ) UNION_PRIMITIVE_ELEMENTS = ( @@ -75,6 +82,10 @@ typing.Union[str, float, int, bool], (field_utils.STRING, field_utils.DOUBLE, field_utils.LONG, field_utils.BOOLEAN), ), + ( + typing.Union[Annotated[str, "string"], int], + (field_utils.STRING, field_utils.LONG), + ), ) UNION_PRIMITIVE_ELEMENTS_DEFAULTS = ( @@ -95,6 +106,7 @@ (field_utils.BOOLEAN, field_utils.STRING, field_utils.DOUBLE, field_utils.LONG), False, ), + (typing.Union[Annotated[str, "string"], int], (field_utils.STRING, field_utils.LONG), "test"), ) UNION_WITH_ARRAY = ( @@ -111,9 +123,13 @@ (field_utils.LOGICAL_DATETIME_MILIS, field_utils.LOGICAL_DATETIME_MILIS), ), ( - typing.Union[typing.List[uuid.uuid4], bytes], + typing.Union[typing.List[uuid.UUID], bytes], (field_utils.LOGICAL_UUID, field_utils.BYTES), ), + ( + typing.Union[typing.List[Annotated[int, "integer"]], str], + (field_utils.LONG, field_utils.STRING), + ), ) UNION_WITH_MAP = ( @@ -130,9 +146,13 @@ (field_utils.LOGICAL_DATETIME_MILIS, field_utils.LOGICAL_DATETIME_MILIS), ), ( - typing.Union[typing.Dict[str, uuid.uuid4], bytes], + typing.Union[typing.Dict[str, uuid.UUID], bytes], (field_utils.LOGICAL_UUID, field_utils.BYTES), ), + ( + typing.Union[typing.Dict[str, Annotated[int, "integer"]], str], + (field_utils.LONG, field_utils.STRING), + ), ) OPTIONAL_UNION_COMPLEX_TYPES = ( diff --git a/tests/fields/test_complex_types.py b/tests/fields/test_complex_types.py index 6e24ece0..1e69902a 100644 --- a/tests/fields/test_complex_types.py +++ b/tests/fields/test_complex_types.py @@ -5,8 +5,9 @@ import pytest from faker import Faker +from typing_extensions import get_args -from dataclasses_avroschema import AvroModel, exceptions, field_utils, fields, types +from dataclasses_avroschema import AvroModel, exceptions, field_utils, fields, types, utils from . import consts @@ -59,11 +60,20 @@ def test_sequence_type(sequence, python_primitive_type, python_type_str): assert expected == field.to_dict() + +@pytest.mark.parametrize("sequence, primitive_type,python_type_str", consts.SEQUENCES_AND_TYPES) +def test_sequence_type_with_default(sequence, primitive_type, python_type_str): + name = "an_array_field" + python_type = sequence[primitive_type] + field = fields.AvroField(name, python_type, default=dataclasses.MISSING) + if python_type_str == field_utils.BYTES: values = [b"hola", b"hi"] default = ["hola", "hi"] else: - values = default = faker.pylist(2, True, python_primitive_type) + if utils.is_annotated(primitive_type): + primitive_type, _ = get_args(primitive_type) + values = default = faker.pylist(2, True, primitive_type) field = fields.AvroField(name, python_type, default=default, default_factory=lambda: values) @@ -77,16 +87,16 @@ def test_sequence_type(sequence, python_primitive_type, python_type_str): @pytest.mark.parametrize( - "sequence,python_primitive_type,python_type_str,value", + "sequence,primitive_type,python_type_str,value", consts.SEQUENCES_LOGICAL_TYPES, ) -def test_sequence_with_logical_type(sequence, python_primitive_type, python_type_str, value): +def test_sequence_with_logical_type(sequence, primitive_type, python_type_str, value): """ When the type is List, the Avro field type should be array with the items attribute present. """ name = "an_array_field" - python_type = sequence[python_primitive_type] + python_type = sequence[primitive_type] field = fields.AvroField(name, python_type, default=dataclasses.MISSING) expected = { @@ -112,7 +122,7 @@ def test_sequence_with_logical_type(sequence, python_primitive_type, python_type expected = { "name": name, "type": {"type": "array", "name": name, "items": python_type_str}, - "default": [fields.LOGICAL_TYPES_FIELDS_CLASSES[python_primitive_type].to_avro(value) for value in values], + "default": field.get_default_value(), } assert expected == field.to_dict() @@ -148,14 +158,14 @@ def test_sequence_with_union_type(union, items, default): assert expected == field.to_dict() -@pytest.mark.parametrize("mapping,python_primitive_type,python_type_str", consts.MAPPING_AND_TYPES) -def test_mapping_type(mapping, python_primitive_type, python_type_str): +@pytest.mark.parametrize("mapping,primitive_type,python_type_str", consts.MAPPING_AND_TYPES) +def test_mapping_type(mapping, primitive_type, python_type_str): """ When the type is Dict, the Avro field type should be map with the values attribute present. The keys are always string type. """ name = "a_map_field" - python_type = mapping[str, python_primitive_type] + python_type = mapping[str, primitive_type] field = fields.AvroField(name, python_type, default=dataclasses.MISSING) expected = { @@ -178,7 +188,9 @@ def test_mapping_type(mapping, python_primitive_type, python_type_str): value = {"hola": b"hi"} default = {"hola": "hi"} else: - value = default = faker.pydict(2, True, python_primitive_type) + if utils.is_annotated(primitive_type): + primitive_type, _ = get_args(primitive_type) + value = default = faker.pydict(2, True, primitive_type) field = fields.AvroField(name, python_type, default=default, default_factory=lambda: value) @@ -202,14 +214,14 @@ def test_invalid_map(): assert msg == str(excinfo.value) -@pytest.mark.parametrize("mapping,python_primitive_type,python_type_str,value", consts.MAPPING_LOGICAL_TYPES) -def test_mapping_logical_type(mapping, python_primitive_type, python_type_str, value): +@pytest.mark.parametrize("mapping,primitive_type,python_type_str,value", consts.MAPPING_LOGICAL_TYPES) +def test_mapping_logical_type(mapping, primitive_type, python_type_str, value): """ When the type is Dict, the Avro field type should be map with the values attribute present. The keys are always string type. """ name = "a_map_field" - python_type = mapping[str, python_primitive_type] + python_type = mapping[str, primitive_type] field = fields.AvroField(name, python_type, default=dataclasses.MISSING) expected = { @@ -234,10 +246,7 @@ def test_mapping_logical_type(mapping, python_primitive_type, python_type_str, v expected = { "name": name, "type": {"type": "map", "name": name, "values": python_type_str}, - "default": { - key: fields.LOGICAL_TYPES_FIELDS_CLASSES[python_primitive_type].to_avro(value) - for key, value in values.items() - }, + "default": field.get_default_value(), } assert expected == field.to_dict() diff --git a/tests/fields/test_primitive_types.py b/tests/fields/test_primitive_types.py index 9d829c6e..0c0d2b09 100644 --- a/tests/fields/test_primitive_types.py +++ b/tests/fields/test_primitive_types.py @@ -12,16 +12,15 @@ def test_primitive_types(primitive_type): name = "a_field" field = fields.AvroField(name, primitive_type, default=dataclasses.MISSING) - avro_type = field_utils.PYTHON_TYPE_TO_AVRO[primitive_type] - assert {"name": name, "type": avro_type} == field.to_dict() + assert {"name": name, "type": field.avro_type} == field.to_dict() @pytest.mark.parametrize("primitive_type", field_utils.PYTHON_INMUTABLE_TYPES) def test_primitive_types_with_default_value_none(primitive_type): name = "a_field" field = fields.AvroField(name, primitive_type, default=None) - avro_type = [field_utils.NULL, field_utils.PYTHON_TYPE_TO_AVRO[primitive_type]] + avro_type = [field_utils.NULL, field.avro_type] assert {"name": name, "type": avro_type, "default": None} == field.to_dict() @@ -30,12 +29,11 @@ def test_primitive_types_with_default_value_none(primitive_type): def test_primitive_types_with_default_value(primitive_type, default): name = "a_field" field = fields.AvroField(name, primitive_type, default=default) - avro_type = field_utils.PYTHON_TYPE_TO_AVRO[primitive_type] - if primitive_type is bytes: + if field.avro_type == field_utils.BYTES: default = default.decode() - assert {"name": name, "type": avro_type, "default": default} == field.to_dict() + assert {"name": name, "type": field.avro_type, "default": default} == field.to_dict() @pytest.mark.parametrize("primitive_type,invalid_default", consts.PRIMITIVE_TYPES_AND_INVALID_DEFAULTS) diff --git a/tests/schemas/test_schema.py b/tests/schemas/test_schema.py index 9582de61..e4daf882 100644 --- a/tests/schemas/test_schema.py +++ b/tests/schemas/test_schema.py @@ -219,6 +219,15 @@ def test_get_enum_type_map_with_unions(user_advance_dataclass_with_union_enum, c } +def test_get_enum_type_map_with_unions_with_annotated( + user_advance_dataclass_with_union_enum_with_annotated, color_enum, user_type_enum +): + assert user_advance_dataclass_with_union_enum_with_annotated._get_enum_type_map() == { + "favorite_colors": color_enum, + "user_type": user_type_enum, + } + + def test_get_enum_type_map_with_sub_record(user_advance_dataclass_with_sub_record_and_enum, color_enum, user_type_enum): assert user_advance_dataclass_with_sub_record_and_enum._get_enum_type_map() == { "favorite_colors": color_enum,