Skip to content

Commit

Permalink
Add set and OrderedSet serializers (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
renan-r-santos committed Feb 25, 2024
1 parent 06d66a9 commit 99190aa
Show file tree
Hide file tree
Showing 10 changed files with 506 additions and 236 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
test-${{ matrix.python-version }}
test_fastapi-${{ matrix.python-version }}
test_numpy-${{ matrix.python-version }}
test_ordered_set-${{ matrix.python-version }}
- name: Store coverage
uses: actions/upload-artifact@v3
with:
Expand Down
9 changes: 9 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def test_numpy(s: Session):
s.run("python", "-m", "pytest", "--cov", "serialite", "tests/test_numpy.py")


@session(python=["3.10", "3.11", "3.12"])
def test_ordered_set(s: Session):
s.install(".[ordered-set]", "pytest", "pytest-cov")
s.env["COVERAGE_FILE"] = f".coverage.ordered_set.{s.python}"
s.run(
"python", "-m", "pytest", "--cov", "serialite", "tests/implementations/test_ordered_set.py"
)


@session(venv_backend="none")
def coverage(s: Session):
s.run("coverage", "combine")
Expand Down
493 changes: 257 additions & 236 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ python = "^3.10"
typing_extensions = "^4.3"
fastapi = { version = "^0.100", optional = true }
pydantic = { version = "^1.10", optional = true }
ordered-set = { version = "^4.1", optional = true }
# Lie about numpy only being needed below 3.13 in order to satisfy its Python ceiling
numpy = { version = "^1.25", optional = true, python = "<3.13" }

Expand All @@ -38,6 +39,7 @@ ruff = ">=0.0.275"
[tool.poetry.extras]
fastapi = ["fastapi", "pydantic"]
numpy = ["numpy"]
ordered-set = ["ordered-set"]


[tool.coverage.run]
Expand Down
20 changes: 20 additions & 0 deletions src/serialite/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ def list_to_data(cls):
return ListSerializer(serializer(cls.__args__[0]))


@serializer.register(set)
def set_to_data(cls):
from ._implementations import SetSerializer

return SetSerializer(serializer(cls.__args__[0]))


@serializer.register(tuple)
def tuple_to_data(cls):
from ._implementations import TupleSerializer
Expand Down Expand Up @@ -261,3 +268,16 @@ def array_to_data(cls):
from ._implementations import ArraySerializer

return ArraySerializer(dtype=float)


try:
from ordered_set import OrderedSet
except ImportError:
pass
else:

@serializer.register(OrderedSet)
def ordered_set_to_data(cls):
from ._implementations import OrderedSetSerializer

return OrderedSetSerializer(serializer(cls.__args__[0]))
6 changes: 6 additions & 0 deletions src/serialite/_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ._none import NoneSerializer
from ._path import PathSerializer
from ._reserved import ReservedSerializer
from ._set import SetSerializer
from ._string import StringSerializer
from ._tuple import TupleSerializer
from ._union import OptionalSerializer, TryUnionSerializer
Expand All @@ -22,3 +23,8 @@
from ._array import ArraySerializer
except ImportError:
pass

try:
from ._ordered_set import OrderedSetSerializer
except ImportError:
pass
56 changes: 56 additions & 0 deletions src/serialite/_implementations/_ordered_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
__all__ = ["OrderedSetSerializer"]


from typing import Generic, TypeVar

from ordered_set import OrderedSet

from .._base import Serializer
from .._result import DeserializationFailure, DeserializationResult, DeserializationSuccess
from .._stable_set import StableSet

Element = TypeVar("Element")


class OrderedSetSerializer(Generic[Element], Serializer[OrderedSet[Element]]):
def __init__(self, element_serializer: Serializer[Element]):
self.element_serializer = element_serializer

def from_data(self, data) -> DeserializationResult[OrderedSet[Element]]:
# Return early if the data isn't even a list
if not isinstance(data, list):
return DeserializationFailure(f"Not a valid list: {data!r}")

# Validate values
errors = {}
values = OrderedSet()
for i, value in enumerate(data):
value_or_error = self.element_serializer.from_data(value)
if isinstance(value_or_error, DeserializationFailure):
errors[str(i)] = value_or_error.error
elif value_or_error.value in values:
errors[str(i)] = (
f"Duplicated value found: {value_or_error.value!r}. "
"Expected a list of unique values."
)
else:
values.add(value_or_error.value)

if len(errors) > 0:
return DeserializationFailure(errors)
else:
return DeserializationSuccess(values)

def to_data(self, value: OrderedSet[Element]):
if not isinstance(value, OrderedSet):
raise ValueError(f"Not an OrderedSet: {value!r}")

return [self.element_serializer.to_data(item) for item in value]

def collect_openapi_models(
self, parent_models: StableSet[Serializer]
) -> StableSet[Serializer]:
return self.element_serializer.collect_openapi_models(parent_models)

def to_openapi_schema(self, refs: dict[Serializer, str], force: bool = False):
return {"type": "array", "items": self.element_serializer.to_openapi_schema(refs)}
54 changes: 54 additions & 0 deletions src/serialite/_implementations/_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
__all__ = ["SetSerializer"]


from typing import Generic, TypeVar

from .._base import Serializer
from .._result import DeserializationFailure, DeserializationResult, DeserializationSuccess
from .._stable_set import StableSet

Element = TypeVar("Element")


class SetSerializer(Generic[Element], Serializer[set[Element]]):
def __init__(self, element_serializer: Serializer[Element]):
self.element_serializer = element_serializer

def from_data(self, data) -> DeserializationResult[set[Element]]:
# Return early if the data isn't even a list
if not isinstance(data, list):
return DeserializationFailure(f"Not a valid list: {data!r}")

# Validate values
errors = {}
values = set()
for i, value in enumerate(data):
value_or_error = self.element_serializer.from_data(value)
if isinstance(value_or_error, DeserializationFailure):
errors[str(i)] = value_or_error.error
elif value_or_error.value in values:
errors[str(i)] = (
f"Duplicated value found: {value_or_error.value!r}. "
"Expected a list of unique values."
)
else:
values.add(value_or_error.value)

if len(errors) > 0:
return DeserializationFailure(errors)
else:
return DeserializationSuccess(values)

def to_data(self, value: set[Element]):
if not isinstance(value, set):
raise ValueError(f"Not a set: {value!r}")

return [self.element_serializer.to_data(item) for item in value]

def collect_openapi_models(
self, parent_models: StableSet[Serializer]
) -> StableSet[Serializer]:
return self.element_serializer.collect_openapi_models(parent_models)

def to_openapi_schema(self, refs: dict[Serializer, str], force: bool = False):
return {"type": "array", "items": self.element_serializer.to_openapi_schema(refs)}
54 changes: 54 additions & 0 deletions tests/implementations/test_ordered_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest

try:
from ordered_set import OrderedSet
except ImportError:
pytest.skip("ordered-set not available", allow_module_level=True)

from serialite import (
DeserializationFailure,
DeserializationSuccess,
FloatSerializer,
OrderedSetSerializer,
)

ordered_set_serializer = OrderedSetSerializer(FloatSerializer())


def test_valid_inputs():
data = [12.3, 15.5, 16.0]
value = OrderedSet([12.3, 15.5, 16.0])

assert ordered_set_serializer.from_data(data) == DeserializationSuccess(value)
assert ordered_set_serializer.to_data(value) == data


def test_from_data_failure_top_level():
data = 12.5
assert ordered_set_serializer.from_data(data) == DeserializationFailure(
"Not a valid list: 12.5"
)


def test_from_data_failure_element():
data = ["str1", 15.5, "str2"]
actual = ordered_set_serializer.from_data(data)
expected_msg = {"0": "Not a valid float: 'str1'", "2": "Not a valid float: 'str2'"}
assert actual == DeserializationFailure(expected_msg)


def test_from_data_failure_uniqueness():
data = [12.3, 15.5, 16.0, 12.3]
actual = ordered_set_serializer.from_data(data)
expected_msg = {"3": "Duplicated value found: 12.3. Expected a list of unique values."}
assert actual == DeserializationFailure(expected_msg)


def test_to_data_failure_top_level():
with pytest.raises(ValueError):
_ = ordered_set_serializer.to_data(12.5)


def test_to_data_failure_element():
with pytest.raises(ValueError):
_ = ordered_set_serializer.to_data(OrderedSet([12.5, "a"]))
47 changes: 47 additions & 0 deletions tests/implementations/test_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from serialite import (
DeserializationFailure,
DeserializationSuccess,
FloatSerializer,
SetSerializer,
)

set_serializer = SetSerializer(FloatSerializer())


def test_valid_inputs():
data = [12.3, 15.5, 16.0]
value = {12.3, 15.5, 16.0}

assert set_serializer.from_data(data) == DeserializationSuccess(value)
assert sorted(set_serializer.to_data(value)) == sorted(data)


def test_from_data_failure_top_level():
data = 12.5
assert set_serializer.from_data(data) == DeserializationFailure("Not a valid list: 12.5")


def test_from_data_failure_element():
data = ["str1", 15.5, "str2"]
actual = set_serializer.from_data(data)
expected_msg = {"0": "Not a valid float: 'str1'", "2": "Not a valid float: 'str2'"}
assert actual == DeserializationFailure(expected_msg)


def test_from_data_failure_uniqueness():
data = [12.3, 15.5, 16.0, 12.3]
actual = set_serializer.from_data(data)
expected_msg = {"3": "Duplicated value found: 12.3. Expected a list of unique values."}
assert actual == DeserializationFailure(expected_msg)


def test_to_data_failure_top_level():
with pytest.raises(ValueError):
_ = set_serializer.to_data(12.5)


def test_to_data_failure_element():
with pytest.raises(ValueError):
_ = set_serializer.to_data({12.5, "a"})

0 comments on commit 99190aa

Please sign in to comment.