diff --git a/typedmodels/models.py b/typedmodels/models.py index ae31a9b..883814d 100644 --- a/typedmodels/models.py +++ b/typedmodels/models.py @@ -3,7 +3,8 @@ import django import types -from django.core.serializers.python import Serializer +from django.core.serializers.python import Serializer as _PythonSerializer +from django.core.serializers.xml_serializer import Serializer as _XmlSerializer from django.db import models from django.db.models.base import ModelBase from django.db.models.fields import Field @@ -355,13 +356,36 @@ def save(self, *args, **kwargs): return super(TypedModel, self).save(*args, **kwargs) -# Monkey patching Python serializer class in Django to use model name from base class. +# Monkey patching Python and XML serializers in Django to use model name from base class. # This should be preferably done by changing __unicode__ method for ._meta attribute in each model, # but it doesn’t work. -def get_dump_object(self, obj): - return { - "pk": smart_text(obj._get_pk_val(), strings_only=True), - "model": smart_text(getattr(obj, 'base_class', obj)._meta), - "fields": self._current - } -Serializer.get_dump_object = get_dump_object +_python_serializer_get_dump_object = _PythonSerializer.get_dump_object +def _get_dump_object(self, obj): + if isinstance(obj, TypedModel): + return { + "pk": smart_text(obj._get_pk_val(), strings_only=True), + "model": smart_text(getattr(obj, 'base_class', obj)._meta), + "fields": self._current + } + else: + return _python_serializer_get_dump_object(self, obj) +_PythonSerializer.get_dump_object = _get_dump_object + +_xml_serializer_start_object = _XmlSerializer.start_object +def _start_object(self, obj): + if isinstance(obj, TypedModel): + self.indent(1) + obj_pk = obj._get_pk_val() + modelname = smart_text(getattr(obj, 'base_class', obj)._meta) + if obj_pk is None: + attrs = {"model": modelname,} + else: + attrs = { + "pk": smart_text(obj._get_pk_val()), + "model": modelname, + } + + self.xml.startElement("object", attrs) + else: + return _xml_serializer_start_object(self, obj) +_XmlSerializer.start_object = _start_object diff --git a/typedmodels/tests.py b/typedmodels/tests.py index 3c4c9b2..38f988e 100644 --- a/typedmodels/tests.py +++ b/typedmodels/tests.py @@ -145,13 +145,11 @@ def test_related_names(self): def _check_serialization(self, serialization_format): """Helper function used to check serialization and deserialization for concrete format.""" - animals = Animal.objects.order_by('pk') serialized_animals = serializers.serialize(serialization_format, animals) deserialized_animals = [wrapper.object for wrapper in serializers.deserialize(serialization_format, serialized_animals)] self.assertEqual(set(deserialized_animals), set(animals)) - @unittest.expectedFailure def test_xml_serialization(self): self._check_serialization('xml')