Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions seal/managers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import unicode_literals

from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.utils.six import string_types


Expand All @@ -23,14 +24,34 @@ def _clone(self, **kwargs):
clone._sealed = sealed
return clone

def _unsealed_prefetch_lookup(self, prefetch_lookup, to_attr=None):
"""
Turn a string prefetch lookup or a Prefetch instance without an
explicit queryset into a Prefetch object with an explicit queryset
to prevent the prefetching logic from accessing sealed related
managers and triggering a SealedObject exception.
"""
if isinstance(prefetch_lookup, string_types):
parts = prefetch_lookup.split(LOOKUP_SEP, 1)
if len(parts) > 1:
head, tail = parts
else:
head, tail = parts[0], None
queryset = self.model._meta.get_field(head).remote_field.model._default_manager.all()
if tail:
queryset = queryset.prefetch_related(tail)
return models.Prefetch(head, queryset, to_attr=to_attr)
elif isinstance(prefetch_lookup, models.Prefetch) and prefetch_lookup.queryset is None:
return self._unsealed_prefetch_lookup(
prefetch_lookup.prefetch_through,
to_attr=prefetch_lookup.to_attr,
)
return prefetch_lookup

def seal(self):
clone = self._clone(_sealed=True)
clone._prefetch_related_lookups = tuple(
models.Prefetch(
lookup,
self.model._meta.get_field(lookup).remote_field.model._default_manager.all(),
) if isinstance(lookup, string_types) else lookup
for lookup in clone._prefetch_related_lookups
self._unsealed_prefetch_lookup(looukp) for looukp in clone._prefetch_related_lookups
)
if issubclass(clone._iterable_class, models.query.ModelIterable):
clone._iterable_class = SealedModelIterable
Expand Down
50 changes: 48 additions & 2 deletions tests/test_managers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.db.models import Prefetch
from django.test import TestCase
from seal.exceptions import SealedObject

Expand Down Expand Up @@ -48,9 +49,54 @@ def test_not_sealed_many_to_many(self):
instance = SeaLion.objects.get()
self.assertSequenceEqual(instance.previous_locations.all(), [self.location])

def test_sealed_prefetched_many_to_many(self):
def test_sealed_string_prefetched_many_to_many(self):
instance = SeaLion.objects.prefetch_related('previous_locations').seal().get()
self.assertSequenceEqual(instance.previous_locations.all(), [self.location])
with self.assertNumQueries(0):
self.assertSequenceEqual(instance.previous_locations.all(), [self.location])

def test_sealed_prefetch_prefetched_many_to_many(self):
instance = SeaLion.objects.prefetch_related(
Prefetch('previous_locations'),
).seal().get()
with self.assertNumQueries(0):
self.assertSequenceEqual(instance.previous_locations.all(), [self.location])

def test_sealed_prefetch_queryset_prefetched_many_to_many(self):
instance = SeaLion.objects.prefetch_related(
Prefetch('previous_locations', Location.objects.all()),
).seal().get()
with self.assertNumQueries(0):
self.assertSequenceEqual(instance.previous_locations.all(), [self.location])

def test_sealed_string_prefetched_nested_many_to_many(self):
instance = SeaLion.objects.prefetch_related('previous_locations__previous_visitors').seal().get()
with self.assertNumQueries(0):
self.assertSequenceEqual(instance.previous_locations.all(), [self.location])
self.assertSequenceEqual(
instance.previous_locations.all()[0].previous_visitors.all(), [self.sealion]
)

def test_sealed_prefetch_prefetched_nested_many_to_many(self):
instance = SeaLion.objects.prefetch_related(
Prefetch('previous_locations__previous_visitors'),
).seal().get()
with self.assertNumQueries(0):
self.assertSequenceEqual(instance.previous_locations.all(), [self.location])
self.assertSequenceEqual(
instance.previous_locations.all()[0].previous_visitors.all(), [self.sealion]
)

def test_prefetched_sealed_many_to_many(self):
instance = SeaLion.objects.prefetch_related(
Prefetch('previous_locations', Location.objects.seal()),
).get()
with self.assertNumQueries(0):
self.assertSequenceEqual(instance.previous_locations.all(), [self.location])
message = 'Cannot fetch many-to-many field previous_visitors on a sealed object.'
with self.assertRaisesMessage(SealedObject, message):
self.assertSequenceEqual(
instance.previous_locations.all()[0].previous_visitors.all(), [self.sealion]
)

def test_sealed_deferred_parent_link(self):
instance = GreatSeaLion.objects.only('pk').seal().get()
Expand Down