Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Fixed #13781 -- Improved select_related in inheritance situations

The select_related code got confused when it needed to travel a
reverse relation to a model which had different parent than the
originally travelled relation.

Thanks to Trac aliases shauncutts for report and ungenio for original
patch (committed patch is somewhat modified version of that).
  • Loading branch information...
commit f51e409a5fb34020e170494320a421503689aea0 1 parent 92d7f54
@akaariai akaariai authored
View
6 django/db/models/options.py
@@ -75,6 +75,7 @@ def contribute_to_class(self, cls, name):
from django.db.backends.util import truncate_name
cls._meta = self
+ self.model = cls
self.installed = re.sub('\.models$', '', cls.__module__) in settings.INSTALLED_APPS
# First, construct the default values for these options.
self.object_name = cls.__name__
@@ -464,7 +465,7 @@ def get_base_chain(self, model):
a granparent or even more distant relation.
"""
if not self.parents:
- return
+ return None
if model in self.parents:
return [model]
for parent in self.parents:
@@ -472,8 +473,7 @@ def get_base_chain(self, model):
if res:
res.insert(0, parent)
return res
- raise TypeError('%r is not an ancestor of this model'
- % model._meta.module_name)
+ return None
def get_parent_list(self):
"""
View
83 django/db/models/query.py
@@ -1300,7 +1300,7 @@ def values_list(self, *fields, **kwargs):
value_annotation = False
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
- only_load=None, local_only=False):
+ only_load=None, from_parent=None):
"""
Helper function that recursively returns an information for a klass, to be
used in get_cached_row. It exists just to compute this information only
@@ -1320,8 +1320,10 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
* only_load - if the query has had only() or defer() applied,
this is the list of field names that will be returned. If None,
the full field list for `klass` can be assumed.
- * local_only - Only populate local fields. This is used when
- following reverse select-related relations
+ * from_parent - the parent model used to get to this model
+
+ Note that when travelling from parent to child, we will only load child
+ fields which aren't in the parent.
"""
if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now.
@@ -1347,7 +1349,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
for field, model in klass._meta.get_fields_with_model():
if field.name not in load_fields:
skip.add(field.attname)
- elif local_only and model is not None:
+ elif from_parent and issubclass(from_parent, model.__class__):
+ # Avoid loading fields already loaded for parent model for
+ # child models.
continue
else:
init_list.append(field.attname)
@@ -1361,16 +1365,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
else:
# Load all fields on klass
- # We trying to not populate field_names variable for perfomance reason.
- # If field_names variable is set, it is used to instantiate desired fields,
- # by passing **dict(zip(field_names, fields)) as kwargs to Model.__init__ method.
- # But kwargs version of Model.__init__ is slower, so we should avoid using
- # it when it is not really neccesary.
- if local_only and len(klass._meta.local_fields) != len(klass._meta.fields):
- field_count = len(klass._meta.local_fields)
- field_names = [f.attname for f in klass._meta.local_fields]
- else:
- field_count = len(klass._meta.fields)
+ field_count = len(klass._meta.fields)
+ # Check if we need to skip some parent fields.
+ if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
+ # Only load those fields which haven't been already loaded into
+ # 'from_parent'.
+ non_seen_models = [p for p in klass._meta.get_parent_list()
+ if not issubclass(from_parent, p)]
+ # Load local fields, too...
+ non_seen_models.append(klass)
+ field_names = [f.attname for f in klass._meta.fields
+ if f.model in non_seen_models]
+ field_count = len(field_names)
+ # Try to avoid populating field_names variable for perfomance reasons.
+ # If field_names variable is set, we use **kwargs based model init
+ # which is slower than normal init.
+ if field_count == len(klass._meta.fields):
field_names = ()
restricted = requested is not None
@@ -1392,8 +1402,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
if o.field.unique and select_related_descend(o.field, restricted, requested,
only_load.get(o.model), reverse=True):
next = requested[o.field.related_query_name()]
+ parent = klass if issubclass(o.model, klass) else None
klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
- requested=next, only_load=only_load, local_only=True)
+ requested=next, only_load=only_load, from_parent=parent)
reverse_related_fields.append((o.field, klass_info))
if field_names:
pk_idx = field_names.index(klass._meta.pk.attname)
@@ -1403,7 +1414,8 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
-def get_cached_row(row, index_start, using, klass_info, offset=0):
+def get_cached_row(row, index_start, using, klass_info, offset=0,
+ parent_data=()):
"""
Helper function that recursively returns an object with the specified
related attributes already populated.
@@ -1418,13 +1430,16 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
* offset - the number of additional fields that are known to
exist in row for `klass`. This usually means the number of
annotated results on `klass`.
- * using - the database alias on which the query is being executed.
+ * using - the database alias on which the query is being executed.
* klass_info - result of the get_klass_info function
+ * parent_data - parent model data in format (field, value). Used
+ to populate the non-local fields of child models.
"""
if klass_info is None:
return None
klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
+
fields = row[index_start : index_start + field_count]
# If the pk column is None (or the Oracle equivalent ''), then the related
# object must be non-existent - set the relation to None.
@@ -1434,7 +1449,6 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
obj = klass(**dict(zip(field_names, fields)))
else:
obj = klass(*fields)
-
# If an object was retrieved, set the database state.
if obj:
obj._state.db = using
@@ -1464,34 +1478,35 @@ def get_cached_row(row, index_start, using, klass_info, offset=0):
# Only handle the restricted case - i.e., don't do a depth
# descent into reverse relations unless explicitly requested
for f, klass_info in reverse_related_fields:
+ # Transfer data from this object to childs.
+ parent_data = []
+ for rel_field, rel_model in klass_info[0]._meta.get_fields_with_model():
+ if rel_model is not None and isinstance(obj, rel_model):
+ parent_data.append((rel_field, getattr(obj, rel_field.attname)))
# Recursively retrieve the data for the related object
- cached_row = get_cached_row(row, index_end, using, klass_info)
+ cached_row = get_cached_row(row, index_end, using, klass_info,
+ parent_data=parent_data)
# If the recursive descent found an object, populate the
# descriptor caches relevant to the object
if cached_row:
rel_obj, index_end = cached_row
if obj is not None:
- # If the field is unique, populate the
- # reverse descriptor cache
+ # populate the reverse descriptor cache
setattr(obj, f.related.get_cache_name(), rel_obj)
if rel_obj is not None:
# If the related object exists, populate
# the descriptor cache.
setattr(rel_obj, f.get_cache_name(), obj)
- # Now populate all the non-local field values
- # on the related object
- for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
- if rel_model is not None:
+ # Populate related object caches using parent data.
+ for rel_field, _ in parent_data:
+ if rel_field.rel:
setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
- # populate the field cache for any related object
- # that has already been retrieved
- if rel_field.rel:
- try:
- cached_obj = getattr(obj, rel_field.get_cache_name())
- setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
- except AttributeError:
- # Related object hasn't been cached yet
- pass
+ try:
+ cached_obj = getattr(obj, rel_field.get_cache_name())
+ setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
+ except AttributeError:
+ # Related object hasn't been cached yet
+ pass
return obj, index_end
View
13 django/db/models/sql/compiler.py
@@ -240,7 +240,7 @@ def get_columns(self, with_aliases=False):
return result
def get_default_columns(self, with_aliases=False, col_aliases=None,
- start_alias=None, opts=None, as_pairs=False, local_only=False):
+ start_alias=None, opts=None, as_pairs=False, from_parent=None):
"""
Computes the default columns for selecting every field in the base
model. Will sometimes be called to pull in related models (e.g. via
@@ -265,7 +265,8 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
if start_alias:
seen = {None: start_alias}
for field, model in opts.get_fields_with_model():
- if local_only and model is not None:
+ if from_parent and model is not None and issubclass(from_parent, model):
+ # Avoid loading data for already loaded parents.
continue
if start_alias:
try:
@@ -686,11 +687,13 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
(alias, table, f.rel.get_related_field().column, f.column),
promote=True
)
+ from_parent = (opts.model if issubclass(model, opts.model)
+ else None)
columns, aliases = self.get_default_columns(start_alias=alias,
- opts=model._meta, as_pairs=True, local_only=True)
+ opts=model._meta, as_pairs=True, from_parent=from_parent)
self.query.related_select_cols.extend(
- SelectInfo(col, field) for col, field in zip(columns, model._meta.fields))
-
+ SelectInfo(col, field) for col, field
+ in zip(columns, model._meta.fields))
next = requested.get(f.related_query_name(), {})
# Use True here because we are looking at the _reverse_ side of
# the relation, which is always nullable.
View
42 tests/regressiontests/select_related_onetoone/models.py
@@ -51,6 +51,7 @@ def __str__(self):
class AdvancedUserStat(UserStat):
karma = models.IntegerField()
+
class Image(models.Model):
name = models.CharField(max_length=100)
@@ -58,3 +59,44 @@ class Image(models.Model):
class Product(models.Model):
name = models.CharField(max_length=100)
image = models.OneToOneField(Image, null=True)
+
+
+@python_2_unicode_compatible
+class Parent1(models.Model):
+ name1 = models.CharField(max_length=50)
+
+ def __str__(self):
+ return self.name1
+
+
+@python_2_unicode_compatible
+class Parent2(models.Model):
+ # Avoid having two "id" fields in the Child1 subclass
+ id2 = models.AutoField(primary_key=True)
+ name2 = models.CharField(max_length=50)
+
+ def __str__(self):
+ return self.name2
+
+
+@python_2_unicode_compatible
+class Child1(Parent1, Parent2):
+ value = models.IntegerField()
+
+ def __str__(self):
+ return self.name1
+
+
+@python_2_unicode_compatible
+class Child2(Parent1):
+ parent2 = models.OneToOneField(Parent2)
+ value = models.IntegerField()
+
+ def __str__(self):
+ return self.name1
+
+class Child3(Child2):
+ value3 = models.IntegerField()
+
+class Child4(Child1):
+ value4 = models.IntegerField()
View
102 tests/regressiontests/select_related_onetoone/tests.py
@@ -1,9 +1,11 @@
from __future__ import absolute_import
from django.test import TestCase
+from django.utils import unittest
from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
- AdvancedUserStat, Image, Product)
+ AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2, Child3,
+ Child4)
class ReverseSelectRelatedTestCase(TestCase):
@@ -21,6 +23,14 @@ def setUp(self):
advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
results=results2)
StatDetails.objects.create(base_stats=advstat, comments=250)
+ p1 = Parent1(name1="Only Parent1")
+ p1.save()
+ c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2", value=1)
+ c1.save()
+ p2 = Parent2(name2="Child2 Parent2")
+ p2.save()
+ c2 = Child2(name1="Child2 Parent1", parent2=p2, value=2)
+ c2.save()
def test_basic(self):
with self.assertNumQueries(1):
@@ -108,3 +118,93 @@ def test_nullable_missing_reverse(self):
image = Image.objects.select_related('product').get()
with self.assertRaises(Product.DoesNotExist):
image.product
+
+ def test_parent_only(self):
+ with self.assertNumQueries(1):
+ p = Parent1.objects.select_related('child1').get(name1="Only Parent1")
+ with self.assertNumQueries(0):
+ with self.assertRaises(Child1.DoesNotExist):
+ p.child1
+
+ def test_multiple_subclass(self):
+ with self.assertNumQueries(1):
+ p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
+ self.assertEqual(p.child1.name2, 'Child1 Parent2')
+
+ def test_onetoone_with_subclass(self):
+ with self.assertNumQueries(1):
+ p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
+ self.assertEqual(p.child2.name1, 'Child2 Parent1')
+
+ def test_onetoone_with_two_subclasses(self):
+ with self.assertNumQueries(1):
+ p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child2 Parent2")
+ self.assertEqual(p.child2.name1, 'Child2 Parent1')
+ with self.assertRaises(Child3.DoesNotExist):
+ p.child2.child3
+ p3 = Parent2(name2="Child3 Parent2")
+ p3.save()
+ c2 = Child3(name1="Child3 Parent1", parent2=p3, value=2, value3=3)
+ c2.save()
+ with self.assertNumQueries(1):
+ p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child3 Parent2")
+ self.assertEqual(p.child2.name1, 'Child3 Parent1')
+ self.assertEqual(p.child2.child3.value3, 3)
+ self.assertEqual(p.child2.child3.value, p.child2.value)
+ self.assertEqual(p.child2.name1, p.child2.child3.name1)
+
+ def test_multiinheritance_two_subclasses(self):
+ with self.assertNumQueries(1):
+ p = Parent1.objects.select_related('child1', 'child1__child4').get(name1="Child1 Parent1")
+ self.assertEqual(p.child1.name2, 'Child1 Parent2')
+ self.assertEqual(p.child1.name1, p.name1)
+ with self.assertRaises(Child4.DoesNotExist):
+ p.child1.child4
+ Child4(name1='n1', name2='n2', value=1, value4=4).save()
+ with self.assertNumQueries(1):
+ p = Parent2.objects.select_related('child1', 'child1__child4').get(name2="n2")
+ self.assertEqual(p.name2, 'n2')
+ self.assertEqual(p.child1.name1, 'n1')
+ self.assertEqual(p.child1.name2, p.name2)
+ self.assertEqual(p.child1.value, 1)
+ self.assertEqual(p.child1.child4.name1, p.child1.name1)
+ self.assertEqual(p.child1.child4.name2, p.child1.name2)
+ self.assertEqual(p.child1.child4.value, p.child1.value)
+ self.assertEqual(p.child1.child4.value4, 4)
+
+ @unittest.expectedFailure
+ def test_inheritance_deferred(self):
+ c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
+ with self.assertNumQueries(1):
+ p = Parent2.objects.select_related('child1').only(
+ 'id2', 'child1__value').get(name2="n2")
+ self.assertEqual(p.id2, c.id2)
+ self.assertEqual(p.child1.value, 1)
+ p = Parent2.objects.select_related('child1').only(
+ 'id2', 'child1__value').get(name2="n2")
+ with self.assertNumQueries(1):
+ self.assertEquals(p.name2, 'n2')
+ p = Parent2.objects.select_related('child1').only(
+ 'id2', 'child1__value').get(name2="n2")
+ with self.assertNumQueries(1):
+ self.assertEquals(p.child1.name2, 'n2')
+
+ @unittest.expectedFailure
+ def test_inheritance_deferred2(self):
+ c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
+ qs = Parent2.objects.select_related('child1', 'child4').only(
+ 'id2', 'child1__value', 'child1__child4__value4')
+ with self.assertNumQueries(1):
+ p = qs.get(name2="n2")
+ self.assertEqual(p.id2, c.id2)
+ self.assertEqual(p.child1.value, 1)
+ self.assertEqual(p.child1.child4.value4, 4)
+ self.assertEqual(p.child1.child4.id2, c.id2)
+ p = qs.get(name2="n2")
+ with self.assertNumQueries(1):
+ self.assertEquals(p.child1.name2, 'n2')
+ p = qs.get(name2="n2")
+ with self.assertNumQueries(1):
+ self.assertEquals(p.child1.name1, 'n1')
+ with self.assertNumQueries(1):
+ self.assertEquals(p.child1.child4.name1, 'n1')
Please sign in to comment.
Something went wrong with that request. Please try again.