Skip to content

Commit

Permalink
Merge pull request #148 from ermakov-oleg/support-custom-types
Browse files Browse the repository at this point in the history
Added support custom types
  • Loading branch information
ermakov-oleg committed May 6, 2024
2 parents 77504f4 + 8086a11 commit 30c2f8f
Show file tree
Hide file tree
Showing 17 changed files with 380 additions and 48 deletions.
87 changes: 87 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,90 @@ ser = Serializer(A)
print(ser.load_query_params(MultiDict(parse_qsl('foo=1&bar=2'))))
>> A(foo=1, bar='2')
```

## Custom Type Support

In `serpyco-rs`, you can add support for your own types by using the `custom_type_resolver` parameter and the `CustomType` class. This allows you to define how your custom types should be serialized and deserialized.

### CustomType

The `CustomType` class is a way to define how a custom type should be serialized and deserialized. It is a generic class that takes two type parameters: the type of the object to be serialized/deserialized and the type of the serialized/deserialized object.

Here is an example of a `CustomType` for `IPv4Address`:

```python
from serpyco_rs import CustomType
from ipaddress import IPv4Address, AddressValueError

class IPv4AddressType(CustomType[IPv4Address, str]):
def serialize(self, obj: IPv4Address) -> str:
return str(obj)

def deserialize(self, data: str) -> IPv4Address:
try:
return IPv4Address(data)
except AddressValueError:
raise ValueError(f"Invalid IPv4 address: {data}")

def get_json_schema(self) -> dict:
return {"type": "string", "format": "ipv4"}
```

In this example, `IPv4AddressType` is a `CustomType` that serializes `IPv4Address` objects to strings and deserializes strings to `IPv4Address` objects. The `get_json_schema` method returns the JSON schema for the custom type.

### custom_type_resolver

The `custom_type_resolver` is a function that takes a type as input and returns an instance of `CustomType` if the type is supported, or `None` otherwise. This function is passed to the `Serializer` constructor.

Here is an example of a `custom_type_resolver` that supports `IPv4Address`:

```python
def custom_type_resolver(t: type) -> CustomType | None
if t is IPv4Address:
return IPv4AddressType()
return None

ser = Serializer(MyDataclass, custom_type_resolver=custom_type_resolver)
```

In this example, the `custom_type_resolver` function checks if the type is `IPv4Address` and returns an instance of `IPv4AddressType` if it is. Otherwise, it returns `None`. This function is then passed to the `Serializer` constructor, which uses it to handle `IPv4Address` fields in the dataclass.

### Full Example

```python
from dataclasses import dataclass
from ipaddress import IPv4Address
from serpyco_rs import Serializer, CustomType

# Define custom type for IPv4Address
class IPv4AddressType(CustomType[IPv4Address, str]):
def serialize(self, value: IPv4Address) -> str:
return str(value)

def deserialize(self, value: str) -> IPv4Address:
return IPv4Address(value)

def get_json_schema(self):
return {
'type': 'string',
'format': 'ipv4',
}

# Defining custom_type_resolver
def custom_type_resolver(t: type) -> CustomType | None:
if t is IPv4Address:
return IPv4AddressType()
return None

@dataclass
class Data:
ip: IPv4Address

# Use custom_type_resolver in Serializer
serializer = Serializer(Data, custom_type_resolver=custom_type_resolver)

# Example usage
data = Data(ip=IPv4Address('1.1.1.1'))
serialized_data = serializer.dump(data) # {'ip': '1.1.1.1'}
deserialized_data = serializer.load(serialized_data) # Data(ip=IPv4Address('1.1.1.1'))
```
3 changes: 2 additions & 1 deletion python/serpyco_rs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._custom_types import CustomType
from ._json_schema import JsonSchemaBuilder
from ._main import Serializer
from .exceptions import ErrorItem, SchemaValidationError, ValidationError


__all__ = ['Serializer', 'ErrorItem', 'SchemaValidationError', 'ValidationError', 'JsonSchemaBuilder']
__all__ = ['Serializer', 'ErrorItem', 'SchemaValidationError', 'ValidationError', 'JsonSchemaBuilder', 'CustomType']
17 changes: 17 additions & 0 deletions python/serpyco_rs/_custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import abc
from typing import Any, Generic, TypeVar


_I = TypeVar('_I')
_O = TypeVar('_O')


class CustomType(abc.ABC, Generic[_I, _O]):
@abc.abstractmethod
def serialize(self, value: _I) -> _O: ...

@abc.abstractmethod
def deserialize(self, value: _O) -> _I: ...

@abc.abstractmethod
def get_json_schema(self) -> dict[str, Any]: ...
48 changes: 39 additions & 9 deletions python/serpyco_rs/_describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from typing_extensions import NotRequired, Required, assert_never, get_args, is_typeddict

from ._custom_types import CustomType as CustomTypeMeta
from ._impl import (
NOT_SET,
AnyType,
Expand All @@ -29,6 +30,7 @@
BooleanType,
BytesType,
CustomEncoder,
CustomType,
DateTimeType,
DateType,
DecimalType,
Expand Down Expand Up @@ -93,7 +95,11 @@
_T = TypeVar('_T')


def describe_type(t: Any, meta: Optional[Meta] = None) -> BaseType:
def describe_type(
t: Any,
meta: Optional[Meta] = None,
custom_type_resolver: Optional[Callable[[Any], Optional[CustomTypeMeta[Any, Any]]]] = None,
) -> BaseType:
args: tuple[Any, ...] = ()
metadata = _get_annotated_metadata(t)
type_repr = repr(t)
Expand Down Expand Up @@ -139,6 +145,14 @@ def describe_type(t: Any, meta: Optional[Meta] = None) -> BaseType:
custom_encoder=None,
)

if custom_type_resolver and (custom_type := custom_type_resolver(t)):
if custom_encoder is None:
custom_encoder = CustomEncoder(
serialize=custom_type.serialize,
deserialize=custom_type.deserialize,
)
return CustomType(custom_encoder=custom_encoder, json_schema=custom_type.get_json_schema())

if t is Any:
return AnyType(custom_encoder=custom_encoder)

Expand Down Expand Up @@ -189,14 +203,26 @@ def describe_type(t: Any, meta: Optional[Meta] = None) -> BaseType:

if t in {Sequence, list}:
return ArrayType(
item_type=(describe_type(annotation_wrapper(args[0]), meta) if args else AnyType(custom_encoder=None)),
item_type=(
describe_type(annotation_wrapper(args[0]), meta, custom_type_resolver)
if args
else AnyType(custom_encoder=None)
),
custom_encoder=custom_encoder,
)

if t in {Mapping, dict}:
return DictionaryType(
key_type=(describe_type(annotation_wrapper(args[0]), meta) if args else AnyType(custom_encoder=None)),
value_type=(describe_type(annotation_wrapper(args[1]), meta) if args else AnyType(custom_encoder=None)),
key_type=(
describe_type(annotation_wrapper(args[0]), meta, custom_type_resolver)
if args
else AnyType(custom_encoder=None)
),
value_type=(
describe_type(annotation_wrapper(args[1]), meta, custom_type_resolver)
if args
else AnyType(custom_encoder=None)
),
omit_none=none_format.omit,
custom_encoder=custom_encoder,
)
Expand All @@ -205,7 +231,7 @@ def describe_type(t: Any, meta: Optional[Meta] = None) -> BaseType:
if not args or Ellipsis in args:
raise RuntimeError('Variable length tuples are not supported')
return TupleType(
item_types=[describe_type(annotation_wrapper(arg), meta) for arg in args],
item_types=[describe_type(annotation_wrapper(arg), meta, custom_type_resolver) for arg in args],
custom_encoder=custom_encoder,
)

Expand All @@ -222,6 +248,7 @@ def describe_type(t: Any, meta: Optional[Meta] = None) -> BaseType:
custom_encoder=custom_encoder,
cls_none_as_default_for_optional=none_as_default_for_optional,
meta=meta,
custom_type_resolver=custom_type_resolver,
)
meta.add_to_state(meta_key, entity_type)
return entity_type
Expand All @@ -236,14 +263,14 @@ def describe_type(t: Any, meta: Optional[Meta] = None) -> BaseType:
new_args = tuple(arg for arg in args if arg is not _NoneType)
new_t = Union[new_args] if len(new_args) > 1 else new_args[0] # type: ignore[unused-ignore]
return OptionalType(
inner=describe_type(annotation_wrapper(new_t), meta),
inner=describe_type(annotation_wrapper(new_t), meta, custom_type_resolver),
custom_encoder=None,
)

discriminator = _find_metadata(metadata, Discriminator)
if not discriminator:
return UnionType(
item_types=[describe_type(annotation_wrapper(arg), meta) for arg in args],
item_types=[describe_type(annotation_wrapper(arg), meta, custom_type_resolver) for arg in args],
union_repr=type_repr.removeprefix('typing.'),
custom_encoder=custom_encoder,
)
Expand All @@ -256,7 +283,9 @@ def describe_type(t: Any, meta: Optional[Meta] = None) -> BaseType:
meta = dataclasses.replace(meta, discriminator_field=discriminator.name)
return DiscriminatedUnionType(
item_types={
_get_discriminator_value(arg, discriminator.name): describe_type(annotation_wrapper(arg), meta)
_get_discriminator_value(arg, discriminator.name): describe_type(
annotation_wrapper(arg), meta, custom_type_resolver
)
for arg in args
},
dump_discriminator=discriminator.name,
Expand Down Expand Up @@ -286,6 +315,7 @@ def _describe_entity(
cls_none_as_default_for_optional: NoneAsDefaultForOptional,
custom_encoder: Optional[CustomEncoder[Any, Any]],
meta: Meta,
custom_type_resolver: Optional[Callable[[Any], Optional[CustomTypeMeta[Any, Any]]]],
) -> Union[EntityType, TypedDictType]:
# PEP-484: Replace all unfilled type parameters with Any
if not hasattr(original_t, '__origin__') and getattr(original_t, '__parameters__', None):
Expand All @@ -306,7 +336,7 @@ def _describe_entity(
type_ = Annotated[type_, cls_filed_format, cls_none_format, cls_none_as_default_for_optional]

metadata = _get_annotated_metadata(type_)
field_type = describe_type(type_, meta)
field_type = describe_type(type_, meta, custom_type_resolver)
alias = _find_metadata(metadata, Alias)
none_as_default_for_optional = _find_metadata(metadata, NoneAsDefaultForOptional)

Expand Down
1 change: 1 addition & 0 deletions python/serpyco_rs/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
UnionType,
UUIDType,
ValidationError,
CustomType,
)


Expand Down
5 changes: 5 additions & 0 deletions python/serpyco_rs/_impl.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,8 @@ class RecursionHolder(BaseType):
self, name: str, state_key: MetaStateKey, meta: Meta, custom_encoder: CustomEncoder[Any, Any] | None = None
): ...
def get_type(self) -> BaseType: ...

class CustomType(BaseType):
json_schema: dict[str, Any]

def __init__(self, custom_encoder: CustomEncoder[Any, Any], json_schema: dict[str, Any]): ...
5 changes: 5 additions & 0 deletions python/serpyco_rs/_json_schema/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,8 @@ def _check_unions_schema_types(schema: Schema) -> TypeGuard[Union[ObjectType, Re
if isinstance(schema, (ObjectType, RefType)):
return True
raise RuntimeError(f'Unions schema items must be ObjectType or RefType. Current: {schema}')


@to_json_schema.register
def _(arg: describe.CustomType, doc: Optional[str] = None, *, config: Config) -> Schema:
return Schema(additionalArgs=arg.json_schema, description=doc, config=config)
10 changes: 8 additions & 2 deletions python/serpyco_rs/_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import abc
from typing import Annotated, Any, Generic, Protocol, TypeVar, Union, cast, overload
from collections.abc import Callable
from typing import Annotated, Any, Generic, Optional, Protocol, TypeVar, Union, cast, overload

from ._custom_types import CustomType
from ._describe import BaseType, describe_type
from ._impl import Serializer as _Serializer
from ._json_schema import get_json_schema
Expand Down Expand Up @@ -38,6 +40,7 @@ def __init__(
omit_none: bool = False,
force_default_for_optional: bool = False,
naive_datetime_to_utc: bool = False,
custom_type_resolver: Optional[Callable[[Any], Optional[CustomType[Any, Any]]]] = None,
) -> None:
"""
Create a serializer for the given type.
Expand All @@ -47,14 +50,17 @@ def __init__(
:param omit_none: If True, the serializer will omit None values from the output.
:param force_default_for_optional: If True, the serializer will force default values for optional fields.
:param naive_datetime_to_utc: If True, the serializer will convert naive datetimes to UTC.
:param custom_type_resolver: An optional callable that allows users to add support for their own types.
This parameter should be a function that takes a type as input and returns an instance of CustomType
if the user-defined type is supported, or None otherwise.
"""
if camelcase_fields:
t = cast(type[_T], Annotated[t, CamelCase])
if omit_none:
t = cast(type(_T), Annotated[t, OmitNone]) # type: ignore
if force_default_for_optional:
t = cast(type(_T), Annotated[t, ForceDefaultForOptional]) # type: ignore
self._type_info = describe_type(t)
self._type_info = describe_type(t, custom_type_resolver=custom_type_resolver)
self._schema = get_json_schema(self._type_info)
self._encoder: _Serializer[_T] = _Serializer(self._type_info, naive_datetime_to_utc)

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fn _serpyco_rs(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<types::DiscriminatedUnionType>()?;
m.add_class::<types::LiteralType>()?;
m.add_class::<types::RecursionHolder>()?;
m.add_class::<types::CustomType>()?;

// Errors
m.add(
Expand Down
10 changes: 6 additions & 4 deletions src/python/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use pyo3::Bound;
use pyo3::{PyAny, PyResult};

use crate::validator::types::{
AnyType, ArrayType, BaseType, BooleanType, BytesType, DateTimeType, DateType, DecimalType,
DictionaryType, DiscriminatedUnionType, EntityType, EnumType, FloatType, IntegerType,
LiteralType, OptionalType, RecursionHolder, StringType, TimeType, TupleType, TypedDictType,
UUIDType, UnionType,
AnyType, ArrayType, BaseType, BooleanType, BytesType, CustomType, DateTimeType, DateType,
DecimalType, DictionaryType, DiscriminatedUnionType, EntityType, EnumType, FloatType,
IntegerType, LiteralType, OptionalType, RecursionHolder, StringType, TimeType, TupleType,
TypedDictType, UUIDType, UnionType,
};

#[derive(Clone, Debug)]
Expand All @@ -33,6 +33,7 @@ pub enum Type<'a, Base = Bound<'a, BaseType>> {
Literal(Bound<'a, LiteralType>, Base),
Any(Bound<'a, AnyType>, Base),
RecursionHolder(Bound<'a, RecursionHolder>, Base),
Custom(Bound<'a, CustomType>, Base),
}

pub fn get_object_type<'a>(type_info: &Bound<'a, PyAny>) -> PyResult<Type<'a>> {
Expand Down Expand Up @@ -62,6 +63,7 @@ pub fn get_object_type<'a>(type_info: &Bound<'a, PyAny>) -> PyResult<Type<'a>> {
check_type!(type_info, base_type, Literal, LiteralType);
check_type!(type_info, base_type, Bytes, BytesType);
check_type!(type_info, base_type, RecursionHolder, RecursionHolder);
check_type!(type_info, base_type, Custom, CustomType);

if let Ok(t) = type_info.extract::<Bound<'_, EntityType>>() {
let python_object_id = type_info.as_ptr() as *const _ as usize;
Expand Down
Loading

0 comments on commit 30c2f8f

Please sign in to comment.