diff --git a/polymodels/managers.py b/polymodels/managers.py index 9fbd9b6..ed7c91f 100644 --- a/polymodels/managers.py +++ b/polymodels/managers.py @@ -7,11 +7,28 @@ from .compat import is_model_iterable +try: + from django.db.models.query import ModelIterable +except ImportError: + # TODO: Remove when dropping support for Django 1.8. + class PolymorphicModelIterable(object): + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + for instance in self.iterable: + yield instance.type_cast() +else: + class PolymorphicModelIterable(ModelIterable): + def __iter__(self): + for instance in super(PolymorphicModelIterable, self).__iter__(): + yield instance.type_cast() + + class PolymorphicQuerySet(models.query.QuerySet): def select_subclasses(self, *models): - # TODO: Set a different _iterable_class instead of the type_cast flag - # when dropping support for Django 1.8. - self.type_cast = True + if is_model_iterable(self): + self._iterable_class = PolymorphicModelIterable relateds = set() accessors = self.model.subclass_accessors if models: @@ -46,27 +63,17 @@ def select_subclasses(self, *models): def exclude_subclasses(self): return self.filter(**self.model.content_type_lookup()) - def _clone(self, *args, **kwargs): - kwargs.update(type_cast=getattr(self, 'type_cast', False)) - return super(PolymorphicQuerySet, self)._clone(*args, **kwargs) - - # TODO: Remove all this support code an use _iterable_class when dropping - # support for Django 1.8. - def _type_cast_iterator(self, iterator): - if is_model_iterable(self) and getattr(self, 'type_cast', False): - iterator = (obj.type_cast() for obj in iterator) - # yield from iterator - for obj in iterator: - yield obj + # TODO: Remove when dropping support for Django 1.8. + if django.VERSION < (1, 9): + def _clone(self, *args, **kwargs): + kwargs.update(_iterable_class=getattr(self, '_iterable_class', None)) + return super(PolymorphicQuerySet, self)._clone(*args, **kwargs) - def iterator(self, *args, **kwargs): - iterator = super(PolymorphicQuerySet, self).iterator(*args, **kwargs) - return self._type_cast_iterator(iterator) - - if django.VERSION >= (1, 9): - def __iter__(self): - iterator = super(PolymorphicQuerySet, self).__iter__() - return self._type_cast_iterator(iterator) + def iterator(self): + iterator = super(PolymorphicQuerySet, self).iterator() + if getattr(self, '_iterable_class', None) is PolymorphicModelIterable: + return self._iterable_class(iterator) + return iterator class PolymorphicManager(models.Manager.from_queryset(PolymorphicQuerySet)): diff --git a/tests/test_managers.py b/tests/test_managers.py index bb3ea73..e4b0745 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -77,6 +77,10 @@ def test_select_subclasses(self): ['', '']) + def test_select_subclasses_get(self): + snake = Snake.objects.create(name='snake', length=10) + self.assertEqual(Animal.objects.select_subclasses().get(), snake) + def test_select_subclasses_values(self): Animal.objects.create(name='animal') self.assertQuerysetEqual(