Skip to content

Commit

Permalink
Merge EnumSymbol and EnumValue into Enum field
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Sep 8, 2022
1 parent 8ffc0ea commit fa0ee03
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 78 deletions.
91 changes: 40 additions & 51 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import math
import typing
import warnings
from enum import Enum
from enum import Enum as EnumType
from collections.abc import Mapping as _Mapping

from marshmallow import validate, utils, class_registry, types
Expand Down Expand Up @@ -60,8 +60,7 @@
"IPInterface",
"IPv4Interface",
"IPv6Interface",
"EnumSymbol",
"EnumValue",
"Enum",
"Method",
"Function",
"Str",
Expand Down Expand Up @@ -1856,43 +1855,14 @@ class IPv6Interface(IPInterface):
DESERIALIZATION_CLASS = ipaddress.IPv6Interface


class EnumSymbol(String):
"""An Enum field (de)serializing enum members by symbol (name) as string.
class Enum(Field):
"""An Enum field (de)serializing enum members by symbol (name) as string or by value.
:param enum Enum: Enum class
.. versionadded:: 3.18.0
"""

default_error_messages = {
"unknown": "Must be one of: {choices}.",
}

def __init__(self, enum: type[Enum], **kwargs):
self.enum = enum
self.choices = ", ".join(enum.__members__)
super().__init__(**kwargs)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
return value.name

def _deserialize(self, value, attr, data, **kwargs):
value = super()._deserialize(value, attr, data, **kwargs)
try:
return getattr(self.enum, value)
except AttributeError as exc:
raise self.make_error("unknown", choices=self.choices) from exc


class EnumValue(Field):
"""An Enum field (de)serializing enum members by value.
A Field must be provided to (de)serialize the value.
:param cls_or_instance: Field class or instance.
:param enum Enum: Enum class
If a field is provided as ``cls_or_instance`` argument, the Enum is (de)serialized by
value using this field. Otherwise, it is (de)serialized by symbol (name) as string.
.. versionadded:: 3.18.0
"""
Expand All @@ -1901,30 +1871,49 @@ class EnumValue(Field):
"unknown": "Must be one of: {choices}.",
}

def __init__(self, cls_or_instance: Field | type, enum: type[Enum], **kwargs):
def __init__(
self,
enum: type[EnumType],
cls_or_instance: Field | type | None = None,
**kwargs,
):
super().__init__(**kwargs)
try:
self.field = resolve_field_instance(cls_or_instance)
except FieldInstanceResolutionError as error:
raise ValueError(
"The enum field must be a subclass or instance of "
"marshmallow.base.FieldABC."
) from error
self.enum = enum
self.choices = ", ".join(
[str(self.field._serialize(m.value, None, None)) for m in enum]
)
if cls_or_instance is not None:
try:
self.field = resolve_field_instance(cls_or_instance)
except FieldInstanceResolutionError as error:
raise ValueError(
"The enum field must be a subclass or instance of "
"marshmallow.base.FieldABC."
) from error
self.by_symbol_or_value = "value"
self.choices = ", ".join(
[str(self.field._serialize(m.value, None, None)) for m in enum]
)
else:
self.field = String()
self.by_symbol_or_value = "symbol"
self.choices = ", ".join(enum.__members__)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
return self.field._serialize(value.value, attr, obj, **kwargs)
if self.by_symbol_or_value == "value":
return self.field._serialize(value.value, attr, obj, **kwargs)
return value.name

def _deserialize(self, value, attr, data, **kwargs):
if self.by_symbol_or_value == "value":
value = self.field._deserialize(value, attr, data, **kwargs)
try:
return self.enum(value)
except ValueError as exc:
raise self.make_error("unknown", choices=self.choices) from exc
value = self.field._deserialize(value, attr, data, **kwargs)
try:
return self.enum(value)
except ValueError as exc:
return getattr(self.enum, value)
except AttributeError as exc:
raise self.make_error("unknown", choices=self.choices) from exc


Expand Down
6 changes: 3 additions & 3 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class DateEnum(Enum):
fields.IPInterface,
fields.IPv4Interface,
fields.IPv6Interface,
functools.partial(fields.EnumSymbol, GenderEnum),
functools.partial(fields.EnumValue, fields.String, HairColorEnum),
functools.partial(fields.EnumValue, fields.Integer, GenderEnum),
functools.partial(fields.Enum, GenderEnum),
functools.partial(fields.Enum, HairColorEnum, fields.String),
functools.partial(fields.Enum, GenderEnum, fields.Integer),
]


Expand Down
36 changes: 18 additions & 18 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,54 +1097,54 @@ def test_invalid_ipv6interface_deserialization(self, in_value):

assert excinfo.value.args[0] == "Not a valid IPv6 interface."

def test_enumsymbol_field_deserialization(self):
field = fields.EnumSymbol(GenderEnum)
def test_enum_by_symbol_field_deserialization(self):
field = fields.Enum(GenderEnum)
assert field.deserialize("male") == GenderEnum.male

def test_enumsymbol_field_invalid_value(self):
field = fields.EnumSymbol(GenderEnum)
def test_enum_by_symbol_field_invalid_value(self):
field = fields.Enum(GenderEnum)
with pytest.raises(
ValidationError, match="Must be one of: male, female, non_binary."
):
field.deserialize("dummy")

def test_enumsymbol_field_not_string(self):
field = fields.EnumSymbol(GenderEnum)
def test_enum_by_symbol_field_not_string(self):
field = fields.Enum(GenderEnum)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)

def test_enumvalue_field_deserialization(self):
field = fields.EnumValue(fields.String, HairColorEnum)
def test_enum_by_value_field_deserialization(self):
field = fields.Enum(HairColorEnum, fields.String)
assert field.deserialize("black hair") == HairColorEnum.black
field = fields.EnumValue(fields.Integer, GenderEnum)
field = fields.Enum(GenderEnum, fields.Integer)
assert field.deserialize(1) == GenderEnum.male
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)
field = fields.Enum(DateEnum, fields.Date(format="%d/%m/%Y"))
assert field.deserialize("29/02/2004") == DateEnum.date_1

def test_enumvalue_field_invalid_value(self):
field = fields.EnumValue(fields.String, HairColorEnum)
def test_enum_by_value_field_invalid_value(self):
field = fields.Enum(HairColorEnum, fields.String)
with pytest.raises(
ValidationError,
match="Must be one of: black hair, brown hair, blond hair, red hair.",
):
field.deserialize("dummy")
field = fields.EnumValue(fields.Integer, GenderEnum)
field = fields.Enum(GenderEnum, fields.Integer)
with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."):
field.deserialize(12)
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)
field = fields.Enum(DateEnum, fields.Date(format="%d/%m/%Y"))
with pytest.raises(
ValidationError, match="Must be one of: 29/02/2004, 29/02/2008, 29/02/2012."
):
field.deserialize("28/02/2004")

def test_enumvalue_field_wrong_type(self):
field = fields.EnumValue(fields.String, HairColorEnum)
def test_enum_by_value_field_wrong_type(self):
field = fields.Enum(HairColorEnum, fields.String)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)
field = fields.EnumValue(fields.Integer, GenderEnum)
field = fields.Enum(GenderEnum, fields.Integer)
with pytest.raises(ValidationError, match="Not a valid integer."):
field.deserialize("dummy")
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)
field = fields.Enum(DateEnum, fields.Date(format="%d/%m/%Y"))
with pytest.raises(ValidationError, match="Not a valid date."):
field.deserialize("30/02/2004")

Expand Down
12 changes: 6 additions & 6 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,20 +255,20 @@ def test_ipv6_interface_field(self, user):
== ipv6interface_exploded_string
)

def test_enumsymbol_field_serialization(self, user):
def test_enum_by_symbol_field_serialization(self, user):
user.sex = GenderEnum.male
field = fields.EnumSymbol(GenderEnum)
field = fields.Enum(GenderEnum)
assert field.serialize("sex", user) == "male"

def test_enumvalue_field_serialization(self, user):
def test_enum_by_value_field_serialization(self, user):
user.hair_color = HairColorEnum.black
field = fields.EnumValue(fields.String, HairColorEnum)
field = fields.Enum(HairColorEnum, fields.String)
assert field.serialize("hair_color", user) == "black hair"
user.sex = GenderEnum.male
field = fields.EnumValue(fields.Integer, GenderEnum)
field = fields.Enum(GenderEnum, fields.Integer)
assert field.serialize("sex", user) == 1
user.some_date = DateEnum.date_1
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)
field = fields.Enum(DateEnum, fields.Date(format="%d/%m/%Y"))
assert field.serialize("some_date", user) == "29/02/2004"

def test_decimal_field(self, user):
Expand Down

0 comments on commit fa0ee03

Please sign in to comment.