Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Fixed #19547 -- Caching of related instances.

When &'ing or |'ing querysets, wrong values could be cached, and crashes
could happen.

Thanks Marc Tamlyn for figuring out the problem and writing the patch.
  • Loading branch information...
commit 07fbc6ae0e3b7742915b785c737b7e6e8a0e3503 1 parent 695b208
@aaugustin aaugustin authored
View
2  django/db/models/fields/related.py
@@ -496,7 +496,7 @@ def get_query_set(self):
except (AttributeError, KeyError):
db = self._db or router.db_for_read(self.model, instance=self.instance)
qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
- qs._known_related_object = (rel_field.name, self.instance)
+ qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
return qs
def get_prefetch_query_set(self, instances):
View
29 django/db/models/query.py
@@ -44,7 +44,7 @@ def __init__(self, model=None, query=None, using=None):
self._for_write = False
self._prefetch_related_lookups = []
self._prefetch_done = False
- self._known_related_object = None # (attname, rel_obj)
+ self._known_related_objects = {} # {rel_field, {pk: rel_obj}}
########################
# PYTHON MAGIC METHODS #
@@ -221,6 +221,7 @@ def __and__(self, other):
if isinstance(other, EmptyQuerySet):
return other._clone()
combined = self._clone()
+ combined._merge_known_related_objects(other)
combined.query.combine(other.query, sql.AND)
return combined
@@ -229,6 +230,7 @@ def __or__(self, other):
combined = self._clone()
if isinstance(other, EmptyQuerySet):
return combined
+ combined._merge_known_related_objects(other)
combined.query.combine(other.query, sql.OR)
return combined
@@ -289,10 +291,9 @@ def iterator(self):
init_list.append(field.attname)
model_cls = deferred_class_factory(self.model, skip)
- # Cache db, model and known_related_object outside the loop
+ # Cache db and model outside the loop
db = self.db
model = self.model
- kro_attname, kro_instance = self._known_related_object or (None, None)
compiler = self.query.get_compiler(using=db)
if fill_cache:
klass_info = get_klass_info(model, max_depth=max_depth,
@@ -323,9 +324,16 @@ def iterator(self):
for i, aggregate in enumerate(aggregate_select):
setattr(obj, aggregate, row[i + aggregate_start])
- # Add the known related object to the model, if there is one
- if kro_instance:
- setattr(obj, kro_attname, kro_instance)
+ # Add the known related objects to the model, if there are any
+ if self._known_related_objects:
+ for field, rel_objs in self._known_related_objects.items():
+ pk = getattr(obj, field.get_attname())
+ try:
+ rel_obj = rel_objs[pk]
+ except KeyError:
+ pass # may happen in qs1 | qs2 scenarios
+ else:
+ setattr(obj, field.name, rel_obj)
yield obj
@@ -902,7 +910,7 @@ def _clone(self, klass=None, setup=False, **kwargs):
c = klass(model=self.model, query=query, using=self._db)
c._for_write = self._for_write
c._prefetch_related_lookups = self._prefetch_related_lookups[:]
- c._known_related_object = self._known_related_object
+ c._known_related_objects = self._known_related_objects
c.__dict__.update(kwargs)
if setup and hasattr(c, '_setup_query'):
c._setup_query()
@@ -942,6 +950,13 @@ def _merge_sanity_check(self, other):
"""
pass
+ def _merge_known_related_objects(self, other):
+ """
+ Keep track of all known related objects from either QuerySet instance.
+ """
+ for field, objects in other._known_related_objects.items():
+ self._known_related_objects.setdefault(field, {}).update(objects)
+
def _setup_aggregate_query(self, aggregates):
"""
Prepare the query for computing a result that contains aggregate annotations.
View
11 tests/modeltests/known_related_objects/fixtures/tournament.json
@@ -15,9 +15,17 @@
},
{
"pk": 1,
+ "model": "known_related_objects.organiser",
+ "fields": {
+ "name": "Organiser 1"
+ }
+ },
+ {
+ "pk": 1,
"model": "known_related_objects.pool",
"fields": {
"tournament": 1,
+ "organiser": 1,
"name": "T1 Pool 1"
}
},
@@ -26,6 +34,7 @@
"model": "known_related_objects.pool",
"fields": {
"tournament": 1,
+ "organiser": 1,
"name": "T1 Pool 2"
}
},
@@ -34,6 +43,7 @@
"model": "known_related_objects.pool",
"fields": {
"tournament": 2,
+ "organiser": 1,
"name": "T2 Pool 1"
}
},
@@ -42,6 +52,7 @@
"model": "known_related_objects.pool",
"fields": {
"tournament": 2,
+ "organiser": 1,
"name": "T2 Pool 2"
}
},
View
4 tests/modeltests/known_related_objects/models.py
@@ -9,9 +9,13 @@
class Tournament(models.Model):
name = models.CharField(max_length=30)
+class Organiser(models.Model):
+ name = models.CharField(max_length=30)
+
class Pool(models.Model):
name = models.CharField(max_length=30)
tournament = models.ForeignKey(Tournament)
+ organiser = models.ForeignKey(Organiser)
class PoolStyle(models.Model):
name = models.CharField(max_length=30)
View
42 tests/modeltests/known_related_objects/tests.py
@@ -2,7 +2,7 @@
from django.test import TestCase
-from .models import Tournament, Pool, PoolStyle
+from .models import Tournament, Organiser, Pool, PoolStyle
class ExistingRelatedInstancesTests(TestCase):
fixtures = ['tournament.json']
@@ -27,6 +27,46 @@ def test_foreign_key_multiple_prefetch(self):
pool2 = tournaments[1].pool_set.all()[0]
self.assertIs(tournaments[1], pool2.tournament)
+ def test_queryset_or(self):
+ tournament_1 = Tournament.objects.get(pk=1)
+ tournament_2 = Tournament.objects.get(pk=2)
+ with self.assertNumQueries(1):
+ pools = tournament_1.pool_set.all() | tournament_2.pool_set.all()
+ related_objects = set(pool.tournament for pool in pools)
+ self.assertEqual(related_objects, set((tournament_1, tournament_2)))
+
+ def test_queryset_or_different_cached_items(self):
+ tournament = Tournament.objects.get(pk=1)
+ organiser = Organiser.objects.get(pk=1)
+ with self.assertNumQueries(1):
+ pools = tournament.pool_set.all() | organiser.pool_set.all()
+ first = pools.filter(pk=1)[0]
+ self.assertIs(first.tournament, tournament)
+ self.assertIs(first.organiser, organiser)
+
+ def test_queryset_or_only_one_with_precache(self):
+ tournament_1 = Tournament.objects.get(pk=1)
+ tournament_2 = Tournament.objects.get(pk=2)
+ # 2 queries here as pool id 3 has tournament 2, which is not cached
+ with self.assertNumQueries(2):
+ pools = tournament_1.pool_set.all() | Pool.objects.filter(pk=3)
+ related_objects = set(pool.tournament for pool in pools)
+ self.assertEqual(related_objects, set((tournament_1, tournament_2)))
+ # and the other direction
+ with self.assertNumQueries(2):
+ pools = Pool.objects.filter(pk=3) | tournament_1.pool_set.all()
+ related_objects = set(pool.tournament for pool in pools)
+ self.assertEqual(related_objects, set((tournament_1, tournament_2)))
+
+ def test_queryset_and(self):
+ tournament = Tournament.objects.get(pk=1)
+ organiser = Organiser.objects.get(pk=1)
+ with self.assertNumQueries(1):
+ pools = tournament.pool_set.all() & organiser.pool_set.all()
+ first = pools.filter(pk=1)[0]
+ self.assertIs(first.tournament, tournament)
+ self.assertIs(first.organiser, organiser)
+
def test_one_to_one(self):
with self.assertNumQueries(2):
style = PoolStyle.objects.get(pk=1)

0 comments on commit 07fbc6a

Please sign in to comment.
Something went wrong with that request. Please try again.