Skip to content

Commit

Permalink
fields.Enum: merge by_value and field arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Sep 15, 2022
1 parent d013455 commit db37b73
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 58 deletions.
57 changes: 29 additions & 28 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1856,13 +1856,15 @@ class IPv6Interface(IPInterface):


class Enum(Field):
"""An Enum field (de)serializing enum members by symbol (name) as string or by value.
"""An Enum field (de)serializing enum members by symbol (name) or by value.
:param enum Enum: Enum class
:param boolean by_value: Whether to (de)serialize by value or by name. Defaults to False.
:param field: Field class or instance to use if (de)serializing by value. Defaults to Field.
:param boolean|Schema|Field by_value: Whether to (de)serialize by value or by name,
or Field class or instance to use to (de)serialize by value. Defaults to False.
``field`` argument may only be passed if (de)serializing by value.
If `by_value` is `False` (default), enum members are (de)serialized by symbol (name).
If it is `True`, they are (de)serialized by value using :class:`Field`.
If it is a field instance or class, they are (de)serialized by value using this field.
.. versionadded:: 3.18.0
"""
Expand All @@ -1874,57 +1876,56 @@ class Enum(Field):
def __init__(
self,
enum: type[EnumType],
by_value: bool = False,
field: Field | type | None = None,
*,
by_value: bool | Field | type = False,
**kwargs,
):
super().__init__(**kwargs)
self.enum = enum
self.by_value = by_value

# Serialization by name
if self.by_value is False:
if field is not None:
raise ValueError('"field" can not be passed when serializing by name.')
if by_value is False:
self.field: Field = String()
self.choices = ", ".join(
[str(self.field._serialize(m, None, None)) for m in enum.__members__]
self.choices_text = ", ".join(
str(self.field._serialize(m, None, None)) for m in enum.__members__
)
# Serialization by value
else:
if field is not None:
if by_value is True:
self.field = Field()
else:
try:
self.field = resolve_field_instance(field)
self.field = resolve_field_instance(by_value)
except FieldInstanceResolutionError as error:
raise ValueError(
'"field" must be a subclass or instance of '
'"by_value" must be either a bool or a subclass or instance of '
"marshmallow.base.FieldABC."
) from error
else:
self.field = Field()
self.choices = ", ".join(
[str(self.field._serialize(m.value, None, None)) for m in enum]
self.choices_text = ", ".join(
str(self.field._serialize(m.value, None, None)) for m in enum
)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
if self.by_value:
return self.field._serialize(value.value, attr, obj, **kwargs)
return value.name
val = value.value
else:
val = value.name
return self.field._serialize(val, attr, obj, **kwargs)

def _deserialize(self, value, attr, data, **kwargs):
val = self.field._deserialize(value, attr, data, **kwargs)
if self.by_value:
value = self.field._deserialize(value, attr, data, **kwargs)
try:
return self.enum(value)
except ValueError as exc:
raise self.make_error("unknown", choices=self.choices) from exc
value = self.field._deserialize(value, attr, data, **kwargs)
return self.enum(val)
except ValueError as error:
raise self.make_error("unknown", choices=self.choices_text) from error
try:
return getattr(self.enum, value)
except AttributeError as exc:
raise self.make_error("unknown", choices=self.choices) from exc
return getattr(self.enum, val)
except AttributeError as error:
raise self.make_error("unknown", choices=self.choices_text) from error


class Method(Field):
Expand Down
4 changes: 2 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class DateEnum(Enum):
fields.IPv4Interface,
fields.IPv6Interface,
functools.partial(fields.Enum, GenderEnum),
functools.partial(fields.Enum, HairColorEnum, fields.String),
functools.partial(fields.Enum, GenderEnum, fields.Integer),
functools.partial(fields.Enum, HairColorEnum, by_value=fields.String),
functools.partial(fields.Enum, GenderEnum, by_value=fields.Integer),
]


Expand Down
64 changes: 43 additions & 21 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,60 +1097,82 @@ def test_invalid_ipv6interface_deserialization(self, in_value):

assert excinfo.value.args[0] == "Not a valid IPv6 interface."

def test_enum_by_symbol_field_deserialization(self):
def test_enum_field_by_symbol_deserialization(self):
field = fields.Enum(GenderEnum)
assert field.deserialize("male") == GenderEnum.male

def test_enum_by_symbol_field_invalid_value(self):
def test_enum_field_by_symbol_invalid_value(self):
field = fields.Enum(GenderEnum)
with pytest.raises(
ValidationError, match="Must be one of: male, female, non_binary."
):
field.deserialize("dummy")

def test_enum_by_symbol_field_not_string(self):
def test_enum_field_by_symbol_not_string(self):
field = fields.Enum(GenderEnum)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)

def test_enum_by_value_field_deserialization(self):
field = fields.Enum(HairColorEnum, by_value=True, field=fields.String)
def test_enum_field_by_value_true_deserialization(self):
field = fields.Enum(HairColorEnum, by_value=True)
assert field.deserialize("black hair") == HairColorEnum.black
field = fields.Enum(GenderEnum, by_value=True, field=fields.Integer)
field = fields.Enum(GenderEnum, by_value=True)
assert field.deserialize(1) == GenderEnum.male
field = fields.Enum(
DateEnum, by_value=True, field=fields.Date(format="%d/%m/%Y")
)

def test_enum_field_by_value_field_deserialization(self):
field = fields.Enum(HairColorEnum, by_value=fields.String)
assert field.deserialize("black hair") == HairColorEnum.black
field = fields.Enum(GenderEnum, by_value=fields.Integer)
assert field.deserialize(1) == GenderEnum.male
field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y"))
assert field.deserialize("29/02/2004") == DateEnum.date_1

def test_enum_by_value_field_invalid_value(self):
field = fields.Enum(HairColorEnum, by_value=True, field=fields.String)
def test_enum_field_by_value_true_invalid_value(self):
field = fields.Enum(HairColorEnum, by_value=True)
with pytest.raises(
ValidationError,
match="Must be one of: black hair, brown hair, blond hair, red hair.",
):
field.deserialize("dummy")
field = fields.Enum(GenderEnum, by_value=True, field=fields.Integer)
field = fields.Enum(GenderEnum, by_value=True)
with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."):
field.deserialize(12)
field = fields.Enum(
DateEnum, by_value=True, field=fields.Date(format="%d/%m/%Y")
)

def test_enum_field_by_value_field_invalid_value(self):
field = fields.Enum(HairColorEnum, by_value=fields.String)
with pytest.raises(
ValidationError,
match="Must be one of: black hair, brown hair, blond hair, red hair.",
):
field.deserialize("dummy")
field = fields.Enum(GenderEnum, by_value=fields.Integer)
with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."):
field.deserialize(12)
field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y"))
with pytest.raises(
ValidationError, match="Must be one of: 29/02/2004, 29/02/2008, 29/02/2012."
):
field.deserialize("28/02/2004")

def test_enum_by_value_field_wrong_type(self):
field = fields.Enum(HairColorEnum, by_value=True, field=fields.String)
def test_enum_field_by_value_true_wrong_type(self):
field = fields.Enum(HairColorEnum, by_value=True)
with pytest.raises(
ValidationError,
match="Must be one of: black hair, brown hair, blond hair, red hair.",
):
field.deserialize("dummy")
field = fields.Enum(GenderEnum, by_value=True)
with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."):
field.deserialize(12)

def test_enum_field_by_value_field_wrong_type(self):
field = fields.Enum(HairColorEnum, by_value=fields.String)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)
field = fields.Enum(GenderEnum, by_value=True, field=fields.Integer)
field = fields.Enum(GenderEnum, by_value=fields.Integer)
with pytest.raises(ValidationError, match="Not a valid integer."):
field.deserialize("dummy")
field = fields.Enum(
DateEnum, by_value=True, field=fields.Date(format="%d/%m/%Y")
)
field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y"))
with pytest.raises(ValidationError, match="Not a valid date."):
field.deserialize("30/02/2004")

Expand Down
21 changes: 14 additions & 7 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,22 +255,29 @@ def test_ipv6_interface_field(self, user):
== ipv6interface_exploded_string
)

def test_enum_by_symbol_field_serialization(self, user):
def test_enum_field_by_symbol_serialization(self, user):
user.sex = GenderEnum.male
field = fields.Enum(GenderEnum)
assert field.serialize("sex", user) == "male"

def test_enum_by_value_field_serialization(self, user):
def test_enum_field_by_value_true_serialization(self, user):
user.hair_color = HairColorEnum.black
field = fields.Enum(HairColorEnum, by_value=True, field=fields.String)
field = fields.Enum(HairColorEnum, by_value=True)
assert field.serialize("hair_color", user) == "black hair"
user.sex = GenderEnum.male
field = fields.Enum(GenderEnum, by_value=True, field=fields.Integer)
field = fields.Enum(GenderEnum, by_value=True)
assert field.serialize("sex", user) == 1
user.some_date = DateEnum.date_1
field = fields.Enum(
DateEnum, by_value=True, field=fields.Date(format="%d/%m/%Y")
)

def test_enum_field_by_value_field_serialization(self, user):
user.hair_color = HairColorEnum.black
field = fields.Enum(HairColorEnum, by_value=fields.String)
assert field.serialize("hair_color", user) == "black hair"
user.sex = GenderEnum.male
field = fields.Enum(GenderEnum, by_value=fields.Integer)
assert field.serialize("sex", user) == 1
user.some_date = DateEnum.date_1
field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y"))
assert field.serialize("some_date", user) == "29/02/2004"

def test_decimal_field(self, user):
Expand Down

0 comments on commit db37b73

Please sign in to comment.