Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fixed #18177 -- Cached known related instances.

This was recently fixed for one-to-one relations; this patch adds
support for foreign keys. Thanks kaiser.yann for the report and
the initial version of the patch.
  • Loading branch information...
commit 1e6c3368f2517f32b0651d68277ea8c9ef81d9b2 1 parent 3b2993e
@aaugustin aaugustin authored
View
63 django/db/models/fields/related.py
@@ -237,13 +237,18 @@ def get_query_set(self, **db_hints):
return self.related.model._base_manager.using(db)
def get_prefetch_query_set(self, instances):
- vals = set(instance._get_pk_val() for instance in instances)
- params = {'%s__pk__in' % self.related.field.name: vals}
- return (self.get_query_set(instance=instances[0]).filter(**params),
- attrgetter(self.related.field.attname),
- lambda obj: obj._get_pk_val(),
- True,
- self.cache_name)
+ rel_obj_attr = attrgetter(self.related.field.attname)
+ instance_attr = lambda obj: obj._get_pk_val()
+ instances_dict = dict((instance_attr(inst), inst) for inst in instances)
+ params = {'%s__pk__in' % self.related.field.name: instances_dict.keys()}
+ qs = self.get_query_set(instance=instances[0]).filter(**params)
+ # Since we're going to assign directly in the cache,
+ # we must manage the reverse relation cache manually.
+ rel_obj_cache_name = self.related.field.get_cache_name()
+ for rel_obj in qs:
+ instance = instances_dict[rel_obj_attr(rel_obj)]
+ setattr(rel_obj, rel_obj_cache_name, instance)
+ return qs, rel_obj_attr, instance_attr, True, self.cache_name
def __get__(self, instance, instance_type=None):
if instance is None:
@@ -324,17 +329,23 @@ def get_query_set(self, **db_hints):
return QuerySet(self.field.rel.to).using(db)
def get_prefetch_query_set(self, instances):
- vals = set(getattr(instance, self.field.attname) for instance in instances)
+ rel_obj_attr = attrgetter(self.field.rel.field_name)
+ instance_attr = attrgetter(self.field.attname)
+ instances_dict = dict((instance_attr(inst), inst) for inst in instances)
other_field = self.field.rel.get_related_field()
if other_field.rel:
- params = {'%s__pk__in' % self.field.rel.field_name: vals}
+ params = {'%s__pk__in' % self.field.rel.field_name: instances_dict.keys()}
else:
- params = {'%s__in' % self.field.rel.field_name: vals}
- return (self.get_query_set(instance=instances[0]).filter(**params),
- attrgetter(self.field.rel.field_name),
- attrgetter(self.field.attname),
- True,
- self.cache_name)
+ params = {'%s__in' % self.field.rel.field_name: instances_dict.keys()}
+ qs = self.get_query_set(instance=instances[0]).filter(**params)
+ # Since we're going to assign directly in the cache,
+ # we must manage the reverse relation cache manually.
+ if not self.field.rel.multiple:
+ rel_obj_cache_name = self.field.related.get_cache_name()
+ for rel_obj in qs:
+ instance = instances_dict[rel_obj_attr(rel_obj)]
+ setattr(rel_obj, rel_obj_cache_name, instance)
+ return qs, rel_obj_attr, instance_attr, True, self.cache_name
def __get__(self, instance, instance_type=None):
if instance is None:
@@ -467,18 +478,24 @@ def get_query_set(self):
return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
except (AttributeError, KeyError):
db = self._db or router.db_for_read(self.model, instance=self.instance)
- return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
+ qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
+ qs._known_related_object = (rel_field.name, self.instance)
+ return qs
def get_prefetch_query_set(self, instances):
+ rel_obj_attr = attrgetter(rel_field.get_attname())
+ instance_attr = attrgetter(attname)
+ instances_dict = dict((instance_attr(inst), inst) for inst in instances)
db = self._db or router.db_for_read(self.model, instance=instances[0])
- query = {'%s__%s__in' % (rel_field.name, attname):
- set(getattr(obj, attname) for obj in instances)}
+ query = {'%s__%s__in' % (rel_field.name, attname): instances_dict.keys()}
qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
- return (qs,
- attrgetter(rel_field.get_attname()),
- attrgetter(attname),
- False,
- rel_field.related_query_name())
+ # Since we just bypassed this class' get_query_set(), we must manage
+ # the reverse relation manually.
+ for rel_obj in qs:
+ instance = instances_dict[rel_obj_attr(rel_obj)]
+ setattr(rel_obj, rel_field.name, instance)
+ cache_name = rel_field.related_query_name()
+ return qs, rel_obj_attr, instance_attr, False, cache_name
def add(self, *objs):
for obj in objs:
View
21 django/db/models/query.py
@@ -41,6 +41,7 @@ def __init__(self, model=None, query=None, using=None):
self._for_write = False
self._prefetch_related_lookups = []
self._prefetch_done = False
+ self._known_related_object = None # (attname, rel_obj)
########################
# PYTHON MAGIC METHODS #
@@ -282,9 +283,10 @@ def iterator(self):
init_list.append(field.attname)
model_cls = deferred_class_factory(self.model, skip)
- # Cache db and model outside the loop
+ # Cache db, model and known_related_object outside the loop
db = self.db
model = self.model
+ kro_attname, kro_instance = self._known_related_object or (None, None)
compiler = self.query.get_compiler(using=db)
if fill_cache:
klass_info = get_klass_info(model, max_depth=max_depth,
@@ -294,12 +296,12 @@ def iterator(self):
obj, _ = get_cached_row(row, index_start, db, klass_info,
offset=len(aggregate_select))
else:
+ # Omit aggregates in object creation.
+ row_data = row[index_start:aggregate_start]
if skip:
- row_data = row[index_start:aggregate_start]
obj = model_cls(**dict(zip(init_list, row_data)))
else:
- # Omit aggregates in object creation.
- obj = model(*row[index_start:aggregate_start])
+ obj = model(*row_data)
# Store the source database of the object
obj._state.db = db
@@ -313,7 +315,11 @@ def iterator(self):
# Add the aggregates to the model
if aggregate_select:
for i, aggregate in enumerate(aggregate_select):
- setattr(obj, aggregate, row[i+aggregate_start])
+ setattr(obj, aggregate, row[i + aggregate_start])
+
+ # Add the known related object to the model, if there is one
+ if kro_instance:
+ setattr(obj, kro_attname, kro_instance)
yield obj
@@ -864,6 +870,7 @@ def _clone(self, klass=None, setup=False, **kwargs):
c = klass(model=self.model, query=query, using=self._db)
c._for_write = self._for_write
c._prefetch_related_lookups = self._prefetch_related_lookups[:]
+ c._known_related_object = self._known_related_object
c.__dict__.update(kwargs)
if setup and hasattr(c, '_setup_query'):
c._setup_query()
@@ -1781,9 +1788,7 @@ def prefetch_one_level(instances, prefetcher, attname):
rel_obj_cache = {}
for rel_obj in all_related_objects:
rel_attr_val = rel_obj_attr(rel_obj)
- if rel_attr_val not in rel_obj_cache:
- rel_obj_cache[rel_attr_val] = []
- rel_obj_cache[rel_attr_val].append(rel_obj)
+ rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
for obj in instances:
instance_attr_val = instance_attr(obj)
View
18 docs/releases/1.5.txt
@@ -44,6 +44,24 @@ reasons or when trying to avoid overwriting concurrent changes.
See the :meth:`Model.save() <django.db.models.Model.save()>` documentation for
more details.
+Caching of related model instances
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+When traversing relations, the ORM will avoid re-fetching objects that were
+previously loaded. For example, with the tutorial's models::
+
+ >>> first_poll = Poll.objects.all()[0]
+ >>> first_choice = first_poll.choice_set.all()[0]
+ >>> first_choice.poll is first_poll
+ True
+
+In Django 1.5, the third line no longer triggers a new SQL query to fetch
+``first_choice.poll``; it was set when by the second line.
+
+For one-to-one relationships, both sides can be cached. For many-to-one
+relationships, only the single side of the relationship can be cached. This
+is particularly helpful in combination with ``prefetch_related``.
+
Minor features
~~~~~~~~~~~~~~
View
0  tests/modeltests/known_related_objects/__init__.py
No changes.
View
65 tests/modeltests/known_related_objects/fixtures/tournament.json
@@ -0,0 +1,65 @@
+[
+ {
+ "pk": 1,
+ "model": "known_related_objects.tournament",
+ "fields": {
+ "name": "Tourney 1"
+ }
+ },
+ {
+ "pk": 2,
+ "model": "known_related_objects.tournament",
+ "fields": {
+ "name": "Tourney 2"
+ }
+ },
+ {
+ "pk": 1,
+ "model": "known_related_objects.pool",
+ "fields": {
+ "tournament": 1,
+ "name": "T1 Pool 1"
+ }
+ },
+ {
+ "pk": 2,
+ "model": "known_related_objects.pool",
+ "fields": {
+ "tournament": 1,
+ "name": "T1 Pool 2"
+ }
+ },
+ {
+ "pk": 3,
+ "model": "known_related_objects.pool",
+ "fields": {
+ "tournament": 2,
+ "name": "T2 Pool 1"
+ }
+ },
+ {
+ "pk": 4,
+ "model": "known_related_objects.pool",
+ "fields": {
+ "tournament": 2,
+ "name": "T2 Pool 2"
+ }
+ },
+ {
+ "pk": 1,
+ "model": "known_related_objects.poolstyle",
+ "fields": {
+ "name": "T1 Pool 2 Style",
+ "pool": 2
+ }
+ },
+ {
+ "pk": 2,
+ "model": "known_related_objects.poolstyle",
+ "fields": {
+ "name": "T2 Pool 1 Style",
+ "pool": 3
+ }
+ }
+]
+
View
19 tests/modeltests/known_related_objects/models.py
@@ -0,0 +1,19 @@
+"""
+Existing related object instance caching.
+
+Test that queries are not redone when going back through known relations.
+"""
+
+from django.db import models
+
+class Tournament(models.Model):
+ name = models.CharField(max_length=30)
+
+class Pool(models.Model):
+ name = models.CharField(max_length=30)
+ tournament = models.ForeignKey(Tournament)
+
+class PoolStyle(models.Model):
+ name = models.CharField(max_length=30)
+ pool = models.OneToOneField(Pool)
+
View
88 tests/modeltests/known_related_objects/tests.py
@@ -0,0 +1,88 @@
+from __future__ import absolute_import
+
+from django.test import TestCase
+
+from .models import Tournament, Pool, PoolStyle
+
+class ExistingRelatedInstancesTests(TestCase):
+ fixtures = ['tournament.json']
+
+ def test_foreign_key(self):
+ with self.assertNumQueries(2):
+ tournament = Tournament.objects.get(pk=1)
+ pool = tournament.pool_set.all()[0]
+ self.assertIs(tournament, pool.tournament)
+
+ def test_foreign_key_prefetch_related(self):
+ with self.assertNumQueries(2):
+ tournament = (Tournament.objects.prefetch_related('pool_set').get(pk=1))
+ pool = tournament.pool_set.all()[0]
+ self.assertIs(tournament, pool.tournament)
+
+ def test_foreign_key_multiple_prefetch(self):
+ with self.assertNumQueries(2):
+ tournaments = list(Tournament.objects.prefetch_related('pool_set'))
+ pool1 = tournaments[0].pool_set.all()[0]
+ self.assertIs(tournaments[0], pool1.tournament)
+ pool2 = tournaments[1].pool_set.all()[0]
+ self.assertIs(tournaments[1], pool2.tournament)
+
+ def test_one_to_one(self):
+ with self.assertNumQueries(2):
+ style = PoolStyle.objects.get(pk=1)
+ pool = style.pool
+ self.assertIs(style, pool.poolstyle)
+
+ def test_one_to_one_select_related(self):
+ with self.assertNumQueries(1):
+ style = PoolStyle.objects.select_related('pool').get(pk=1)
+ pool = style.pool
+ self.assertIs(style, pool.poolstyle)
+
+ def test_one_to_one_multi_select_related(self):
+ with self.assertNumQueries(1):
+ poolstyles = list(PoolStyle.objects.select_related('pool'))
+ self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle)
+ self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle)
+
+ def test_one_to_one_prefetch_related(self):
+ with self.assertNumQueries(2):
+ style = PoolStyle.objects.prefetch_related('pool').get(pk=1)
+ pool = style.pool
+ self.assertIs(style, pool.poolstyle)
+
+ def test_one_to_one_multi_prefetch_related(self):
+ with self.assertNumQueries(2):
+ poolstyles = list(PoolStyle.objects.prefetch_related('pool'))
+ self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle)
+ self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle)
+
+ def test_reverse_one_to_one(self):
+ with self.assertNumQueries(2):
+ pool = Pool.objects.get(pk=2)
+ style = pool.poolstyle
+ self.assertIs(pool, style.pool)
+
+ def test_reverse_one_to_one_select_related(self):
+ with self.assertNumQueries(1):
+ pool = Pool.objects.select_related('poolstyle').get(pk=2)
+ style = pool.poolstyle
+ self.assertIs(pool, style.pool)
+
+ def test_reverse_one_to_one_prefetch_related(self):
+ with self.assertNumQueries(2):
+ pool = Pool.objects.prefetch_related('poolstyle').get(pk=2)
+ style = pool.poolstyle
+ self.assertIs(pool, style.pool)
+
+ def test_reverse_one_to_one_multi_select_related(self):
+ with self.assertNumQueries(1):
+ pools = list(Pool.objects.select_related('poolstyle'))
+ self.assertIs(pools[1], pools[1].poolstyle.pool)
+ self.assertIs(pools[2], pools[2].poolstyle.pool)
+
+ def test_reverse_one_to_one_multi_prefetch_related(self):
+ with self.assertNumQueries(2):
+ pools = list(Pool.objects.prefetch_related('poolstyle'))
+ self.assertIs(pools[1], pools[1].poolstyle.pool)
+ self.assertIs(pools[2], pools[2].poolstyle.pool)
Please sign in to comment.
Something went wrong with that request. Please try again.