Skip to content

Commit

Permalink
Used Queryset._iterable_class to type cast model on iteration.
Browse files Browse the repository at this point in the history
  • Loading branch information
charettes committed Jan 21, 2017
1 parent 8f22a37 commit 99451d4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
53 changes: 30 additions & 23 deletions polymodels/managers.py
Expand Up @@ -7,11 +7,28 @@
from .compat import is_model_iterable


try:
from django.db.models.query import ModelIterable
except ImportError:
# TODO: Remove when dropping support for Django 1.8.
class PolymorphicModelIterable(object):
def __init__(self, iterable):
self.iterable = iterable

def __iter__(self):
for instance in self.iterable:
yield instance.type_cast()
else:
class PolymorphicModelIterable(ModelIterable):
def __iter__(self):
for instance in super(PolymorphicModelIterable, self).__iter__():
yield instance.type_cast()


class PolymorphicQuerySet(models.query.QuerySet):
def select_subclasses(self, *models):
# TODO: Set a different _iterable_class instead of the type_cast flag
# when dropping support for Django 1.8.
self.type_cast = True
if is_model_iterable(self):
self._iterable_class = PolymorphicModelIterable
relateds = set()
accessors = self.model.subclass_accessors
if models:
Expand Down Expand Up @@ -46,27 +63,17 @@ def select_subclasses(self, *models):
def exclude_subclasses(self):
return self.filter(**self.model.content_type_lookup())

def _clone(self, *args, **kwargs):
kwargs.update(type_cast=getattr(self, 'type_cast', False))
return super(PolymorphicQuerySet, self)._clone(*args, **kwargs)

# TODO: Remove all this support code an use _iterable_class when dropping
# support for Django 1.8.
def _type_cast_iterator(self, iterator):
if is_model_iterable(self) and getattr(self, 'type_cast', False):
iterator = (obj.type_cast() for obj in iterator)
# yield from iterator
for obj in iterator:
yield obj
# TODO: Remove when dropping support for Django 1.8.
if django.VERSION < (1, 9):
def _clone(self, *args, **kwargs):
kwargs.update(_iterable_class=getattr(self, '_iterable_class', None))
return super(PolymorphicQuerySet, self)._clone(*args, **kwargs)

def iterator(self, *args, **kwargs):
iterator = super(PolymorphicQuerySet, self).iterator(*args, **kwargs)
return self._type_cast_iterator(iterator)

if django.VERSION >= (1, 9):
def __iter__(self):
iterator = super(PolymorphicQuerySet, self).__iter__()
return self._type_cast_iterator(iterator)
def iterator(self):
iterator = super(PolymorphicQuerySet, self).iterator()
if getattr(self, '_iterable_class', None) is PolymorphicModelIterable:
return self._iterable_class(iterator)
return iterator


class PolymorphicManager(models.Manager.from_queryset(PolymorphicQuerySet)):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_managers.py
Expand Up @@ -77,6 +77,10 @@ def test_select_subclasses(self):
['<BigSnake: big snake>',
'<HugeSnake: huge snake>'])

def test_select_subclasses_get(self):
snake = Snake.objects.create(name='snake', length=10)
self.assertEqual(Animal.objects.select_subclasses().get(), snake)

def test_select_subclasses_values(self):
Animal.objects.create(name='animal')
self.assertQuerysetEqual(
Expand Down

0 comments on commit 99451d4

Please sign in to comment.