Skip to content

Commit

Permalink
Merge pull request #1896 from marshmallow-code/ordered_set_default
Browse files Browse the repository at this point in the history
Use OrderedSet as default set_class
  • Loading branch information
lafrech committed Jul 20, 2023
2 parents 91147b2 + 6abbfca commit 03889c6
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 70 deletions.
31 changes: 0 additions & 31 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -524,37 +524,6 @@ Note that ``name`` will be automatically formatted as a :class:`String <marshmal
# No need to include 'uppername'
additional = ("name", "email", "created_at")
Ordering Output
---------------

To maintain field ordering, set the ``ordered`` option to `True`. This will instruct marshmallow to serialize data to a `collections.OrderedDict`.

.. code-block:: python
from collections import OrderedDict
from pprint import pprint
from marshmallow import Schema, fields
class UserSchema(Schema):
first_name = fields.String()
last_name = fields.String()
email = fields.Email()
class Meta:
ordered = True
u = User("Charlie", "Stones", "charlie@stones.com")
schema = UserSchema()
result = schema.dump(u)
assert isinstance(result, OrderedDict)
pprint(result, indent=2)
#  OrderedDict([('first_name', 'Charlie'),
# ('last_name', 'Stones'),
# ('email', 'charlie@stones.com')])
Next Steps
----------

Expand Down
4 changes: 0 additions & 4 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ class Field(FieldABC):
# to exist as attributes on the objects to serialize. Set this to False
# for those fields
_CHECK_ATTRIBUTE = True
_creation_index = 0 # Used for sorting

#: Default error messages for various kinds of errors. The keys in this dictionary
#: are passed to `Field.make_error`. The values are error messages passed to
Expand Down Expand Up @@ -227,9 +226,6 @@ def __init__(
stacklevel=2,
)

self._creation_index = Field._creation_index
Field._creation_index += 1

# Collect default error message from self and parent classes
messages = {} # type: dict[str, str]
for cls in reversed(self.__class__.__mro__):
Expand Down
40 changes: 15 additions & 25 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,21 @@
_T = typing.TypeVar("_T")


def _get_fields(attrs, ordered=False):
"""Get fields from a class. If ordered=True, fields will sorted by creation index.
def _get_fields(attrs):
"""Get fields from a class
:param attrs: Mapping of class attributes
:param bool ordered: Sort fields by creation index
"""
fields = [
return [
(field_name, field_value)
for field_name, field_value in attrs.items()
if is_instance_or_subclass(field_value, base.FieldABC)
]
if ordered:
fields.sort(key=lambda pair: pair[1]._creation_index)
return fields


# This function allows Schemas to inherit from non-Schema classes and ensures
# inheritance according to the MRO
def _get_fields_by_mro(klass, ordered=False):
def _get_fields_by_mro(klass):
"""Collect fields from a class, following its method resolution order. The
class itself is excluded from the search; only its parents are checked. Get
fields from ``_declared_fields`` if available, else use ``__dict__``.
Expand All @@ -73,7 +69,6 @@ class itself is excluded from the search; only its parents are checked. Get
(
_get_fields(
getattr(base, "_declared_fields", base.__dict__),
ordered=ordered,
)
for base in mro[:0:-1]
),
Expand Down Expand Up @@ -102,13 +97,13 @@ def __new__(mcs, name, bases, attrs):
break
else:
ordered = False
cls_fields = _get_fields(attrs, ordered=ordered)
cls_fields = _get_fields(attrs)
# Remove fields from list of class attributes to avoid shadowing
# Schema attributes/methods in case of name conflict
for field_name, _ in cls_fields:
del attrs[field_name]
klass = super().__new__(mcs, name, bases, attrs)
inherited_fields = _get_fields_by_mro(klass, ordered=ordered)
inherited_fields = _get_fields_by_mro(klass)

meta = klass.Meta
# Set klass.opts in __new__ rather than __init__ so that it is accessible in
Expand All @@ -117,13 +112,11 @@ def __new__(mcs, name, bases, attrs):
# Add fields specified in the `include` class Meta option
cls_fields += list(klass.opts.include.items())

dict_cls = OrderedDict if ordered else dict
# Assign _declared_fields on class
klass._declared_fields = mcs.get_declared_fields(
klass=klass,
cls_fields=cls_fields,
inherited_fields=inherited_fields,
dict_cls=dict_cls,
)
return klass

Expand All @@ -133,7 +126,7 @@ def get_declared_fields(
klass: type,
cls_fields: list,
inherited_fields: list,
dict_cls: type,
dict_cls: type = dict,
):
"""Returns a dictionary of field_name => `Field` pairs declared on the class.
This is exposed mainly so that plugins can add additional fields, e.g. fields
Expand All @@ -143,8 +136,7 @@ def get_declared_fields(
:param cls_fields: The fields declared on the class, including those added
by the ``include`` class Meta option.
:param inherited_fields: Inherited fields.
:param dict_cls: Either `dict` or `OrderedDict`, depending on whether
the user specified `ordered=True`.
:param dict_cls: dict-like class to use for dict output Default to ``dict``.
"""
return dict_cls(inherited_fields + cls_fields)

Expand Down Expand Up @@ -319,6 +311,8 @@ class AlbumSchema(Schema):

OPTIONS_CLASS = SchemaOpts # type: type

set_class = OrderedSet

# These get set by SchemaMeta
opts = None # type: SchemaOpts
_declared_fields = {} # type: typing.Dict[str, ma_fields.Field]
Expand Down Expand Up @@ -350,9 +344,7 @@ class Meta:
- ``timeformat``: Default format for `Time <fields.Time>` fields.
- ``render_module``: Module to use for `loads <Schema.loads>` and `dumps <Schema.dumps>`.
Defaults to `json` from the standard library.
- ``ordered``: If `True`, order serialization output according to the
order in which fields were declared. Output of `Schema.dump` will be a
`collections.OrderedDict`.
- ``ordered``: If `True`, output of `Schema.dump` will be a `collections.OrderedDict`.
- ``index_errors``: If `True`, errors dictionaries will include the index
of invalid items in a collection.
- ``load_only``: Tuple or list of fields to exclude from serialized results.
Expand Down Expand Up @@ -386,7 +378,9 @@ def __init__(
self.declared_fields = copy.deepcopy(self._declared_fields)
self.many = many
self.only = only
self.exclude = set(self.opts.exclude) | set(exclude)
self.exclude: set[typing.Any] | typing.MutableSet[typing.Any] = set(
self.opts.exclude
) | set(exclude)
self.ordered = self.opts.ordered
self.load_only = set(load_only) or set(self.opts.load_only)
self.dump_only = set(dump_only) or set(self.opts.dump_only)
Expand Down Expand Up @@ -419,10 +413,6 @@ def __repr__(self) -> str:
def dict_class(self) -> type:
return OrderedDict if self.ordered else dict

@property
def set_class(self) -> type:
return OrderedSet if self.ordered else set

@classmethod
def from_dict(
cls,
Expand Down Expand Up @@ -970,7 +960,7 @@ def _init_fields(self) -> None:

if self.only is not None:
# Return only fields specified in only option
field_names = self.set_class(self.only)
field_names: typing.AbstractSet[typing.Any] = self.set_class(self.only)

invalid_fields |= field_names - available_field_names
else:
Expand Down
2 changes: 1 addition & 1 deletion src/marshmallow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
"""
import typing

StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.Set[str]]
StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]
Tag = typing.Union[str, typing.Tuple[str, bool]]
Validator = typing.Callable[[typing.Any], typing.Any]
9 changes: 5 additions & 4 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
RAISE,
missing,
)
from marshmallow.orderedset import OrderedSet
from marshmallow.exceptions import StringNotCollectionError

from tests.base import ALL_FIELDS
Expand Down Expand Up @@ -380,14 +381,14 @@ class MySchema(Schema):
@pytest.mark.parametrize(
("param", "fields_list"), [("only", ["foo"]), ("exclude", ["bar"])]
)
def test_ordered_instanced_nested_schema_only_and_exclude(self, param, fields_list):
def test_nested_schema_only_and_exclude(self, param, fields_list):
class NestedSchema(Schema):
# We mean to test the use of OrderedSet to specify it explicitly
# even if it is default
set_class = OrderedSet
foo = fields.String()
bar = fields.String()

class Meta:
ordered = True

class MySchema(Schema):
nested = fields.Nested(NestedSchema(), **{param: fields_list})

Expand Down
6 changes: 1 addition & 5 deletions tests/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def test_nested_field_order_with_only_arg_is_maintained_on_load(self):

def test_nested_field_order_with_exclude_arg_is_maintained(self, user):
class HasNestedExclude(Schema):
class Meta:
ordered = True

user = fields.Nested(KeepOrder, exclude=("birthdate",))

ser = HasNestedExclude()
Expand Down Expand Up @@ -231,7 +228,7 @@ def test_fields_are_added(self):
result = s.load({"name": "Steve", "from": "Oskosh"})
assert result == in_data

def test_ordered_included(self):
def test_included_fields_ordered_after_declared_fields(self):
class AddFieldsOrdered(Schema):
name = fields.Str()
email = fields.Str()
Expand All @@ -242,7 +239,6 @@ class Meta:
"in": fields.Str(),
"@at": fields.Str(),
}
ordered = True

s = AddFieldsOrdered()
in_data = {
Expand Down
13 changes: 13 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2933,3 +2933,16 @@ class Meta:
MySchema(unknown="badval")
else:
MySchema().load({"foo": "bar"}, unknown="badval")


@pytest.mark.parametrize("dict_cls", (dict, OrderedDict))
def test_set_dict_class(dict_cls):
"""Demonstrate how to specify dict_class as class attribute"""

class MySchema(Schema):
dict_class = dict_cls
foo = fields.String()

result = MySchema().dump({"foo": "bar"})
assert result == {"foo": "bar"}
assert isinstance(result, dict_cls)

0 comments on commit 03889c6

Please sign in to comment.