diff --git a/seal/managers.py b/seal/managers.py index 50dba70..aaba021 100644 --- a/seal/managers.py +++ b/seal/managers.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals from django.db import models +from django.db.models.constants import LOOKUP_SEP from django.utils.six import string_types @@ -23,14 +24,34 @@ def _clone(self, **kwargs): clone._sealed = sealed return clone + def _unsealed_prefetch_lookup(self, prefetch_lookup, to_attr=None): + """ + Turn a string prefetch lookup or a Prefetch instance without an + explicit queryset into a Prefetch object with an explicit queryset + to prevent the prefetching logic from accessing sealed related + managers and triggering a SealedObject exception. + """ + if isinstance(prefetch_lookup, string_types): + parts = prefetch_lookup.split(LOOKUP_SEP, 1) + if len(parts) > 1: + head, tail = parts + else: + head, tail = parts[0], None + queryset = self.model._meta.get_field(head).remote_field.model._default_manager.all() + if tail: + queryset = queryset.prefetch_related(tail) + return models.Prefetch(head, queryset, to_attr=to_attr) + elif isinstance(prefetch_lookup, models.Prefetch) and prefetch_lookup.queryset is None: + return self._unsealed_prefetch_lookup( + prefetch_lookup.prefetch_through, + to_attr=prefetch_lookup.to_attr, + ) + return prefetch_lookup + def seal(self): clone = self._clone(_sealed=True) clone._prefetch_related_lookups = tuple( - models.Prefetch( - lookup, - self.model._meta.get_field(lookup).remote_field.model._default_manager.all(), - ) if isinstance(lookup, string_types) else lookup - for lookup in clone._prefetch_related_lookups + self._unsealed_prefetch_lookup(looukp) for looukp in clone._prefetch_related_lookups ) if issubclass(clone._iterable_class, models.query.ModelIterable): clone._iterable_class = SealedModelIterable diff --git a/tests/test_managers.py b/tests/test_managers.py index d200f1a..b76e57b 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1,3 +1,4 @@ +from django.db.models import Prefetch from django.test import TestCase from seal.exceptions import SealedObject @@ -48,9 +49,54 @@ def test_not_sealed_many_to_many(self): instance = SeaLion.objects.get() self.assertSequenceEqual(instance.previous_locations.all(), [self.location]) - def test_sealed_prefetched_many_to_many(self): + def test_sealed_string_prefetched_many_to_many(self): instance = SeaLion.objects.prefetch_related('previous_locations').seal().get() - self.assertSequenceEqual(instance.previous_locations.all(), [self.location]) + with self.assertNumQueries(0): + self.assertSequenceEqual(instance.previous_locations.all(), [self.location]) + + def test_sealed_prefetch_prefetched_many_to_many(self): + instance = SeaLion.objects.prefetch_related( + Prefetch('previous_locations'), + ).seal().get() + with self.assertNumQueries(0): + self.assertSequenceEqual(instance.previous_locations.all(), [self.location]) + + def test_sealed_prefetch_queryset_prefetched_many_to_many(self): + instance = SeaLion.objects.prefetch_related( + Prefetch('previous_locations', Location.objects.all()), + ).seal().get() + with self.assertNumQueries(0): + self.assertSequenceEqual(instance.previous_locations.all(), [self.location]) + + def test_sealed_string_prefetched_nested_many_to_many(self): + instance = SeaLion.objects.prefetch_related('previous_locations__previous_visitors').seal().get() + with self.assertNumQueries(0): + self.assertSequenceEqual(instance.previous_locations.all(), [self.location]) + self.assertSequenceEqual( + instance.previous_locations.all()[0].previous_visitors.all(), [self.sealion] + ) + + def test_sealed_prefetch_prefetched_nested_many_to_many(self): + instance = SeaLion.objects.prefetch_related( + Prefetch('previous_locations__previous_visitors'), + ).seal().get() + with self.assertNumQueries(0): + self.assertSequenceEqual(instance.previous_locations.all(), [self.location]) + self.assertSequenceEqual( + instance.previous_locations.all()[0].previous_visitors.all(), [self.sealion] + ) + + def test_prefetched_sealed_many_to_many(self): + instance = SeaLion.objects.prefetch_related( + Prefetch('previous_locations', Location.objects.seal()), + ).get() + with self.assertNumQueries(0): + self.assertSequenceEqual(instance.previous_locations.all(), [self.location]) + message = 'Cannot fetch many-to-many field previous_visitors on a sealed object.' + with self.assertRaisesMessage(SealedObject, message): + self.assertSequenceEqual( + instance.previous_locations.all()[0].previous_visitors.all(), [self.sealion] + ) def test_sealed_deferred_parent_link(self): instance = GreatSeaLion.objects.only('pk').seal().get()