Skip to content

Commit

Permalink
feat: add pydantic v2 support. Closes #415
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosschroh committed Nov 3, 2023
1 parent 0b88f4b commit 36c0db3
Show file tree
Hide file tree
Showing 36 changed files with 2,765 additions and 433 deletions.
7 changes: 4 additions & 3 deletions dataclasses_avroschema/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def avro_type(self) -> typing.Union[str, typing.Dict]:

@staticmethod
def _get_self_reference_type(a_type: typing.Any) -> str:
internal_type = a_type.__args__[0]

return internal_type.__forward_arg__
if getattr(a_type, "__args__", None):
internal_type = a_type.__args__[0]
return internal_type.__forward_arg__
return a_type.__name__

@staticmethod
def get_singular_name(name: str) -> str:
Expand Down
45 changes: 27 additions & 18 deletions dataclasses_avroschema/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,13 @@ def get_default_value(self) -> typing.Union[dataclasses._MISSING_TYPE, None]:
return None
return dataclasses.MISSING

def fake(self) -> typing.Any:
if getattr(self.type, "__args__", None):
# It means that self.type is `typing.Type['AType']`, and the argument is a string
# then we return None
return None
return self.type.fake()


@dataclasses.dataclass
class DateField(ImmutableField):
Expand Down Expand Up @@ -743,6 +750,7 @@ def get_avro_type(self) -> typing.Union[str, typing.List, typing.Dict]:

if self.default is None:
return [field_utils.NULL, record_type]

return record_type

def default_to_avro(self, value: "schema_generator.AvroModel") -> typing.Dict:
Expand All @@ -757,14 +765,15 @@ def fake(self) -> typing.Any:


from .mapper import (
ALL_TYPES_FIELD_CLASSES,
CONTAINER_FIELDS_CLASSES,
IMMUTABLE_FIELDS_CLASSES,
LOGICAL_TYPES_FIELDS_CLASSES,
SPECIAL_ANNOTATED_TYPES,
)

LOGICAL_CLASSES = LOGICAL_TYPES_FIELDS_CLASSES.keys()
PYDANTIC_CUSTOM_CLASS_METHOD_NAMES = {"__get_validators__", "validate"}
PYDANTIC_CUSTOM_CLASS_METHOD_NAMES = {"__get_validators__", "__get_pydantic_core_schema__"}


def field_factory(
Expand All @@ -783,18 +792,15 @@ def field_factory(
metadata = {}

field_info = None

if native_type is None:
native_type = type(None)

if native_type not in types.CUSTOM_TYPES and utils.is_annotated(native_type):
if utils.is_annotated(native_type) and native_type not in ALL_TYPES_FIELD_CLASSES:
a_type, *extra_args = get_args(native_type)
field_info = next((arg for arg in extra_args if isinstance(arg, types.FieldInfo)), None)
if field_info is None or a_type in (decimal.Decimal, types.Fixed):
# it means that it is a custom type defined by us
# `Int32`, `Float32`,`TimeMicro` or `DateTimeMicro`
# or a type Annotated with the end user
native_type = a_type
# it means that it is a custom type defined by us `Int32`, `Float32`,`TimeMicro` or `DateTimeMicro`
# or a known type Annotated with the end user
native_type = a_type

if native_type in IMMUTABLE_FIELDS_CLASSES:
klass = IMMUTABLE_FIELDS_CLASSES[native_type]
Expand Down Expand Up @@ -823,7 +829,7 @@ def field_factory(
parent=parent,
)

elif utils.is_self_referenced(native_type):
elif utils.is_self_referenced(native_type, parent):
return SelfReferenceField(
name=name,
type=native_type,
Expand Down Expand Up @@ -917,16 +923,19 @@ def field_factory(
elif (
inspect.isclass(native_type)
and not is_pydantic_model(native_type)
and all(method_name in dir(native_type) for method_name in PYDANTIC_CUSTOM_CLASS_METHOD_NAMES)
and any(method_name in dir(native_type) for method_name in PYDANTIC_CUSTOM_CLASS_METHOD_NAMES)
):
try:
# Build a field for the encoded type since that's what will be serialized
encoded_type = parent.__config__.json_encoders[native_type]
except KeyError:
raise ValueError(
f"Type {native_type} for field {name} must be listed in the pydantic 'json_encoders' config for {parent}"
" (or for one of the classes in its inheritance tree since pydantic configs are inherited)"
)
if getattr(parent, "__config__", None):
try:
# Build a field for the encoded type since that's what will be serialized
encoded_type = parent.__config__.json_encoders[native_type]
except KeyError:
raise ValueError(
f"Type {native_type} for field {name} must be listed in the pydantic 'json_encoders' config for {parent}"
" (or for one of the classes in its inheritance tree since pydantic configs are inherited)"
)
else:
encoded_type = parent.model_config["json_encoders"][native_type]

# default_factory is not schema-friendly for Custom Classes since it could be returning
# dynamically constructed values that should not be treated as defaults. For example,
Expand Down
7 changes: 7 additions & 0 deletions dataclasses_avroschema/fields/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@
decimal.Decimal: fields.DecimalField,
types.Fixed: fields.FixedField,
}

ALL_TYPES_FIELD_CLASSES = { # type: ignore
**IMMUTABLE_FIELDS_CLASSES,
**CONTAINER_FIELDS_CLASSES,
**LOGICAL_TYPES_FIELDS_CLASSES,
**SPECIAL_ANNOTATED_TYPES,
}
4 changes: 3 additions & 1 deletion dataclasses_avroschema/pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
from .mapper import PYDANTIC_INMUTABLE_FIELDS_CLASSES, PYDANTIC_LOGICAL_TYPES_FIELDS_CLASSES
from dataclasses_avroschema.fields import mapper

mapper.IMMUTABLE_FIELDS_CLASSES.update(PYDANTIC_INMUTABLE_FIELDS_CLASSES)
mapper.IMMUTABLE_FIELDS_CLASSES.update(PYDANTIC_INMUTABLE_FIELDS_CLASSES) # type: ignore
mapper.LOGICAL_TYPES_FIELDS_CLASSES.update(PYDANTIC_LOGICAL_TYPES_FIELDS_CLASSES) # type: ignore
mapper.ALL_TYPES_FIELD_CLASSES.update(PYDANTIC_INMUTABLE_FIELDS_CLASSES)
mapper.ALL_TYPES_FIELD_CLASSES.update(PYDANTIC_LOGICAL_TYPES_FIELDS_CLASSES)
2 changes: 1 addition & 1 deletion dataclasses_avroschema/pydantic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class PositiveFloatField(PydanticField):
avro_type: typing.ClassVar[typing.Dict[str, str]] = {"type": DOUBLE, "pydantic-class": "PositiveFloat"}

def fake(self) -> float:
return fake.pyfloat(positive=True)
return fake.pyfloat(positive=True, min_value=1)


class NegativeIntField(PydanticField):
Expand Down
25 changes: 12 additions & 13 deletions dataclasses_avroschema/pydantic/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Optional, Type, TypeVar
import json
from typing import Any, Callable, Dict, Optional, Type, TypeVar

from fastavro.validation import validate

Expand All @@ -24,7 +25,7 @@ def generate_dataclass(cls: Type[CT]) -> Type[CT]:

@classmethod
def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str:
return cls.schema_json(*args, **kwargs)
return json.dumps(cls.model_json_schema(*args, **kwargs))

@classmethod
def standardize_type(cls: Type[CT], data: dict) -> Any:
Expand All @@ -33,14 +34,9 @@ def standardize_type(cls: Type[CT], data: dict) -> Any:
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
encoders = cls.__config__.json_encoders
for k, v in data.items():
v_type = type(v)
if v_type in encoders:
encode_method = encoders[v_type]
data[k] = encode_method(v)
elif isinstance(v, dict):
cls.standardize_type(v)
for value in data.values():
if isinstance(value, dict):
cls.standardize_type(value)

return standardize_custom_type(data)

Expand All @@ -51,14 +47,17 @@ def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> Js
It also doesn't provide the exclude, include, by_alias, etc.
parameters that dict provides.
"""
data = dict(self)

data = self.model_dump()
standardize_method = standardize_factory or self.standardize_type

# the standardize called can be replaced if we have a custom implementation of asdict
# for now I think is better to use the native implementation
return standardize_method(data)

@classmethod
def parse_obj(cls: Type[CT], data: Dict) -> CT:
return cls.model_validate(obj=data)

def serialize(self, serialization_type: str = AVRO) -> bytes:
"""
Overrides the base AvroModel's serialize method to inject this
Expand Down Expand Up @@ -91,7 +90,7 @@ def fake(cls: Type[CT], **data: Any) -> CT:
payload = {field.name: field.fake() for field in cls.get_fields() if field.name not in data.keys()}
payload.update(data)

return cls.parse_obj(payload)
return cls.model_validate(payload)

@classmethod
def _generate_parser(cls: Type[CT]) -> PydanticParser:
Expand Down
4 changes: 3 additions & 1 deletion dataclasses_avroschema/pydantic/mapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pydantic
from pydantic.v1 import ConstrainedInt

from . import fields

Expand All @@ -25,12 +26,13 @@
pydantic.PositiveFloat: fields.PositiveFloatField,
pydantic.NegativeInt: fields.NegativeIntField,
pydantic.PositiveInt: fields.PositiveIntField,
pydantic.ConstrainedInt: fields.ConstrainedIntField,
ConstrainedInt: fields.ConstrainedIntField,
# ConstrainedIntValue is a dynamic type that needs to be referenced by qualified name
# and cannot be imported directly
"ConstrainedIntValue": fields.ConstrainedIntField,
}


PYDANTIC_LOGICAL_TYPES_FIELDS_CLASSES = {
pydantic.UUID1: fields.UUID1Field,
pydantic.UUID3: fields.UUID3Field,
Expand Down
16 changes: 8 additions & 8 deletions dataclasses_avroschema/pydantic/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ class PydanticParser(Parser):
def parse_fields(self, exclude: typing.List) -> typing.List[Field]:
return [
AvroField(
model_field.name,
model_field.annotation,
field_name,
field_info.rebuild_annotation(),
default=dataclasses.MISSING
if model_field.required or model_field.default_factory
else model_field.default,
default_factory=model_field.default_factory,
metadata=model_field.field_info.extra.get("metadata", {}),
if field_info.is_required() or field_info.default_factory
else field_info.default,
default_factory=field_info.default_factory,
metadata=field_info.json_schema_extra.get("metadata", {}) if field_info.json_schema_extra else {},
model_metadata=self.metadata,
parent=self.parent,
)
for model_field in self.type.__fields__.values()
if model_field.name not in exclude
for field_name, field_info in self.type.model_fields.items()
if field_name not in exclude and field_name != "model_config"
]
6 changes: 6 additions & 0 deletions dataclasses_avroschema/pydantic/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .main import AvroBaseModel # noqa: F401 I001
from .mapper import PYDANTIC_INMUTABLE_FIELDS_CLASSES, PYDANTIC_LOGICAL_TYPES_FIELDS_CLASSES
from dataclasses_avroschema.fields import mapper

mapper.IMMUTABLE_FIELDS_CLASSES.update(PYDANTIC_INMUTABLE_FIELDS_CLASSES)
mapper.LOGICAL_TYPES_FIELDS_CLASSES.update(PYDANTIC_LOGICAL_TYPES_FIELDS_CLASSES) # type: ignore
99 changes: 99 additions & 0 deletions dataclasses_avroschema/pydantic/v1/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Any, Callable, Optional, Type, TypeVar

from fastavro.validation import validate

from dataclasses_avroschema import serialization
from dataclasses_avroschema.schema_generator import AVRO, AvroModel
from dataclasses_avroschema.types import JsonDict
from dataclasses_avroschema.utils import standardize_custom_type

from .parser import PydanticV1Parser

try:
from pydantic.v1 import BaseModel # pragma: no cover
except ImportError as ex: # pragma: no cover
raise Exception("pydantic must be installed in order to use AvroBaseModel") from ex # pragma: no cover

CT = TypeVar("CT", bound="AvroBaseModel")


class AvroBaseModel(BaseModel, AvroModel): # type: ignore
@classmethod
def generate_dataclass(cls: Type[CT]) -> Type[CT]:
return cls

@classmethod
def json_schema(cls: Type[CT], *args: Any, **kwargs: Any) -> str:
return cls.schema_json(*args, **kwargs)

@classmethod
def standardize_type(cls: Type[CT], data: dict) -> Any:
"""
Standardization factory that converts data according to the
user-defined pydantic json_encoders prior to passing values
to the standard type conversion factory
"""
encoders = cls.__config__.json_encoders
for k, v in data.items():
v_type = type(v)
if v_type in encoders:
encode_method = encoders[v_type]
data[k] = encode_method(v)
elif isinstance(v, dict):
cls.standardize_type(v)

return standardize_custom_type(data)

def asdict(self, standardize_factory: Optional[Callable[..., Any]] = None) -> JsonDict:
"""
Returns this model in dictionary form. This method differs from
pydantic's dict by converting all values to their Avro representation.
It also doesn't provide the exclude, include, by_alias, etc.
parameters that dict provides.
"""
data = dict(self)

standardize_method = standardize_factory or self.standardize_type

# the standardize called can be replaced if we have a custom implementation of asdict
# for now I think is better to use the native implementation
return standardize_method(data)

def serialize(self, serialization_type: str = AVRO) -> bytes:
"""
Overrides the base AvroModel's serialize method to inject this
class's standardization factory method
"""
schema = self.avro_schema_to_python()

return serialization.serialize(
self.asdict(standardize_factory=self.standardize_type),
schema,
serialization_type=serialization_type,
)

def validate_avro(self) -> bool:
"""
Validate that instance matches the avro schema
"""
schema = self.avro_schema_to_python()
return validate(self.asdict(), schema)

@classmethod
def fake(cls: Type[CT], **data: Any) -> CT:
"""
Creates a fake instance of the model.
Attributes:
data: Dict[str, Any] represent the user values to use in the instance
"""
# only generate fakes for fields that were not provided in data
payload = {field.name: field.fake() for field in cls.get_fields() if field.name not in data.keys()}
payload.update(data)

return cls.parse_obj(payload)

@classmethod
def _generate_parser(cls: Type[CT]) -> PydanticV1Parser:
cls._metadata = cls.generate_metadata()
return PydanticV1Parser(type=cls._klass, metadata=cls._metadata, parent=cls._parent or cls)
40 changes: 40 additions & 0 deletions dataclasses_avroschema/pydantic/v1/mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pydantic import v1

from dataclasses_avroschema.pydantic import fields

PYDANTIC_INMUTABLE_FIELDS_CLASSES = {
v1.FilePath: fields.FilePathField,
v1.DirectoryPath: fields.DirectoryPathField,
v1.EmailStr: fields.EmailStrField,
v1.NameEmail: fields.NameEmailField,
v1.AnyUrl: fields.AnyUrlField,
v1.AnyHttpUrl: fields.AnyHttpUrlField,
v1.HttpUrl: fields.HttpUrlField,
v1.FileUrl: fields.FileUrlField,
v1.PostgresDsn: fields.PostgresDsnField,
v1.CockroachDsn: fields.CockroachDsnField,
v1.AmqpDsn: fields.AmqpDsnField,
v1.RedisDsn: fields.RedisDsnField,
v1.MongoDsn: fields.MongoDsnField,
v1.KafkaDsn: fields.KafkaDsnField,
v1.SecretStr: fields.SecretStrField,
v1.IPvAnyAddress: fields.IPvAnyAddressField,
v1.IPvAnyInterface: fields.IPvAnyInterfaceField,
v1.IPvAnyNetwork: fields.IPvAnyNetworkField,
v1.NegativeFloat: fields.NegativeFloatField,
v1.PositiveFloat: fields.PositiveFloatField,
v1.NegativeInt: fields.NegativeIntField,
v1.PositiveInt: fields.PositiveIntField,
v1.ConstrainedInt: fields.ConstrainedIntField,
# ConstrainedIntValue is a dynamic type that needs to be referenced by qualified name
# and cannot be imported directly
"ConstrainedIntValue": fields.ConstrainedIntField,
}


PYDANTIC_LOGICAL_TYPES_FIELDS_CLASSES = {
v1.UUID1: fields.UUID1Field,
v1.UUID3: fields.UUID3Field,
v1.UUID4: fields.UUID4Field,
v1.UUID5: fields.UUID5Field,
}
Loading

0 comments on commit 36c0db3

Please sign in to comment.