Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Fixed #7270 -- Added the ability to follow reverse OneToOneFields in …

…select_related(). Thanks to George Vilches, Ben Davis, and Alex Gaynor for their work on various stages of this patch.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@12307 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 58cd220f51d5e294cb9e67c12a6e9d08523e282f 1 parent 8e8d4b5
@freakboy3742 freakboy3742 authored
View
4 django/db/models/fields/related.py
@@ -189,7 +189,7 @@ class SingleRelatedObjectDescriptor(object):
# SingleRelatedObjectDescriptor instance.
def __init__(self, related):
self.related = related
- self.cache_name = '_%s_cache' % related.get_accessor_name()
+ self.cache_name = related.get_cache_name()
def __get__(self, instance, instance_type=None):
if instance is None:
@@ -319,7 +319,7 @@ def __set__(self, instance, value):
# cache. This cache also might not exist if the related object
# hasn't been accessed yet.
if related:
- cache_name = '_%s_cache' % self.field.related.get_accessor_name()
+ cache_name = self.field.related.get_cache_name()
try:
delattr(related, cache_name)
except AttributeError:
View
74 django/db/models/query.py
@@ -1116,6 +1116,29 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
"""
Helper function that recursively returns an object with the specified
related attributes already populated.
+
+ This method may be called recursively to populate deep select_related()
+ clauses.
+
+ Arguments:
+ * klass - the class to retrieve (and instantiate)
+ * row - the row of data returned by the database cursor
+ * index_start - the index of the row at which data for this
+ object is known to start
+ * max_depth - the maximum depth to which a select_related()
+ relationship should be explored.
+ * cur_depth - the current depth in the select_related() tree.
+ Used in recursive calls to determin if we should dig deeper.
+ * requested - A dictionary describing the select_related() tree
+ that is to be retrieved. keys are field names; values are
+ dictionaries describing the keys on that related object that
+ are themselves to be select_related().
+ * offset - the number of additional fields that are known to
+ exist in `row` for `klass`. This usually means the number of
+ annotated results on `klass`.
+ * only_load - if the query has had only() or defer() applied,
+ this is the list of field names that will be returned. If None,
+ the full field list for `klass` can be assumed.
"""
if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now.
@@ -1127,14 +1150,18 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
# Handle deferred fields.
skip = set()
init_list = []
- pk_val = row[index_start + klass._meta.pk_index()]
+ # Build the list of fields that *haven't* been requested
for field in klass._meta.fields:
if field.name not in load_fields:
skip.add(field.name)
else:
init_list.append(field.attname)
+ # Retrieve all the requested fields
field_count = len(init_list)
fields = row[index_start : index_start + field_count]
+ # If all the select_related columns are None, then the related
+ # object must be non-existent - set the relation to None.
+ # Otherwise, construct the related object.
if fields == (None,) * field_count:
obj = None
elif skip:
@@ -1143,14 +1170,20 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
else:
obj = klass(*fields)
else:
+ # Load all fields on klass
field_count = len(klass._meta.fields)
fields = row[index_start : index_start + field_count]
+ # If all the select_related columns are None, then the related
+ # object must be non-existent - set the relation to None.
+ # Otherwise, construct the related object.
if fields == (None,) * field_count:
obj = None
else:
obj = klass(*fields)
index_end = index_start + field_count + offset
+ # Iterate over each related object, populating any
+ # select_related() fields
for f in klass._meta.fields:
if not select_related_descend(f, restricted, requested):
continue
@@ -1158,12 +1191,51 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
next = requested[f.name]
else:
next = None
+ # Recursively retrieve the data for the related object
cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
cur_depth+1, next)
+ # If the recursive descent found an object, populate the
+ # descriptor caches relevant to the object
if cached_row:
rel_obj, index_end = cached_row
if obj is not None:
+ # If the base object exists, populate the
+ # descriptor cache
setattr(obj, f.get_cache_name(), rel_obj)
+ if f.unique:
+ # If the field is unique, populate the
+ # reverse descriptor cache on the related object
+ setattr(rel_obj, f.related.get_cache_name(), obj)
+
+ # Now do the same, but for reverse related objects.
+ # Only handle the restricted case - i.e., don't do a depth
+ # descent into reverse relations unless explicitly requested
+ if restricted:
+ related_fields = [
+ (o.field, o.model)
+ for o in klass._meta.get_all_related_objects()
+ if o.field.unique
+ ]
+ for f, model in related_fields:
+ if not select_related_descend(f, restricted, requested, reverse=True):
+ continue
+ next = requested[f.related_query_name()]
+ # Recursively retrieve the data for the related object
+ cached_row = get_cached_row(model, row, index_end, max_depth,
+ cur_depth+1, next)
+ # If the recursive descent found an object, populate the
+ # descriptor caches relevant to the object
+ if cached_row:
+ rel_obj, index_end = cached_row
+ if obj is not None:
+ # If the field is unique, populate the
+ # reverse descriptor cache
+ setattr(obj, f.related.get_cache_name(), rel_obj)
+ if rel_obj is not None:
+ # If the related object exists, populate
+ # the descriptor cache.
+ setattr(rel_obj, f.get_cache_name(), obj)
+
return obj, index_end
def delete_objects(seen_objs, using):
View
18 django/db/models/query_utils.py
@@ -197,19 +197,29 @@ def __set__(self, instance, value):
"""
instance.__dict__[self.field_name] = value
-def select_related_descend(field, restricted, requested):
+def select_related_descend(field, restricted, requested, reverse=False):
"""
Returns True if this field should be used to descend deeper for
select_related() purposes. Used by both the query construction code
(sql.query.fill_related_selections()) and the model instance creation code
(query.get_cached_row()).
+
+ Arguments:
+ * field - the field to be checked
+ * restricted - a boolean field, indicating if the field list has been
+ manually restricted using a requested clause)
+ * requested - The select_related() dictionary.
+ * reverse - boolean, True if we are checking a reverse select related
"""
if not field.rel:
return False
- if field.rel.parent_link:
- return False
- if restricted and field.name not in requested:
+ if field.rel.parent_link and not reverse:
return False
+ if restricted:
+ if reverse and field.related_query_name() not in requested:
+ return False
+ if not reverse and field.name not in requested:
+ return False
if not restricted and field.null:
return False
return True
View
3  django/db/models/related.py
@@ -45,3 +45,6 @@ def get_accessor_name(self):
return self.field.rel.related_name or (self.opts.object_name.lower() + '_set')
else:
return self.field.rel.related_name or (self.opts.object_name.lower())
+
+ def get_cache_name(self):
+ return "_%s_cache" % self.get_accessor_name()
View
68 django/db/models/sql/compiler.py
@@ -520,7 +520,7 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
# Setup for the case when only particular related fields should be
# included in the related selection.
- if requested is None and restricted is not False:
+ if requested is None:
if isinstance(self.query.select_related, dict):
requested = self.query.select_related
restricted = True
@@ -600,6 +600,72 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
used, next, restricted, new_nullable, dupe_set, avoid)
+ if restricted:
+ related_fields = [
+ (o.field, o.model)
+ for o in opts.get_all_related_objects()
+ if o.field.unique
+ ]
+ for f, model in related_fields:
+ if not select_related_descend(f, restricted, requested, reverse=True):
+ continue
+ # The "avoid" set is aliases we want to avoid just for this
+ # particular branch of the recursion. They aren't permanently
+ # forbidden from reuse in the related selection tables (which is
+ # what "used" specifies).
+ avoid = avoid_set.copy()
+ dupe_set = orig_dupe_set.copy()
+ table = model._meta.db_table
+
+ int_opts = opts
+ alias = root_alias
+ alias_chain = []
+ chain = opts.get_base_chain(f.rel.to)
+ if chain is not None:
+ for int_model in chain:
+ # Proxy model have elements in base chain
+ # with no parents, assign the new options
+ # object and skip to the next base in that
+ # case
+ if not int_opts.parents[int_model]:
+ int_opts = int_model._meta
+ continue
+ lhs_col = int_opts.parents[int_model].column
+ dedupe = lhs_col in opts.duplicate_targets
+ if dedupe:
+ avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col),
+ ())
+ dupe_set.add((opts, lhs_col))
+ int_opts = int_model._meta
+ alias = self.query.join(
+ (alias, int_opts.db_table, lhs_col, int_opts.pk.column),
+ exclusions=used, promote=True, reuse=used
+ )
+ alias_chain.append(alias)
+ for dupe_opts, dupe_col in dupe_set:
+ self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias)
+ dedupe = f.column in opts.duplicate_targets
+ if dupe_set or dedupe:
+ avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ()))
+ if dedupe:
+ dupe_set.add((opts, f.column))
+ alias = self.query.join(
+ (alias, table, f.rel.get_related_field().column, f.column),
+ exclusions=used.union(avoid),
+ promote=True
+ )
+ used.add(alias)
+ columns, aliases = self.get_default_columns(start_alias=alias,
+ opts=model._meta, as_pairs=True)
+ self.query.related_select_cols.extend(columns)
+ self.query.related_select_fields.extend(model._meta.fields)
+
+ next = requested.get(f.related_query_name(), {})
+ new_nullable = f.null or None
+
+ self.fill_related_selections(model._meta, table, cur_depth+1,
+ used, next, restricted, new_nullable)
+
def deferred_to_columns(self):
"""
Converts the self.deferred_loading data structure to mapping of table
View
24 docs/ref/models/querysets.txt
@@ -619,17 +619,29 @@ This is also valid::
...and would also pull in the ``building`` relation.
-You can only refer to ``ForeignKey`` relations in the list of fields passed to
-``select_related``. You *can* refer to foreign keys that have ``null=True``
-(unlike the default ``select_related()`` call). It's an error to use both a
-list of fields and the ``depth`` parameter in the same ``select_related()``
-call, since they are conflicting options.
+You can refer to any ``ForeignKey`` or ``OneToOneField`` relation in
+the list of fields passed to ``select_related``. Ths includes foreign
+keys that have ``null=True`` (unlike the default ``select_related()``
+call). It's an error to use both a list of fields and the ``depth``
+parameter in the same ``select_related()`` call, since they are
+conflicting options.
.. versionadded:: 1.0
Both the ``depth`` argument and the ability to specify field names in the call
to ``select_related()`` are new in Django version 1.0.
+.. versionchanged:: 1.2
+
+You can also refer to the reverse direction of a ``OneToOneFields`` in
+the list of fields passed to ``select_related`` -- that is, you can traverse
+a ``OneToOneField`` back to the object on which the field is defined. Instead
+of specifying the field name, use the ``related_name`` for the field on the
+related object.
+
+``OneToOneFields`` will not be traversed in the reverse direction if you
+are performing a depth-based ``select_related``.
+
.. _queryset-extra:
``extra(select=None, where=None, params=None, tables=None, order_by=None, select_params=None)``
@@ -1335,7 +1347,7 @@ extract two field values, where only one is expected::
entries = Entry.objects.filter(blog__in=list(values))
Note the ``list()`` call around the Blog ``QuerySet`` to force execution of
- the first query. Without it, a nested query would be executed, because
+ the first query. Without it, a nested query would be executed, because
:ref:`querysets-are-lazy`.
gt
View
0  tests/regressiontests/select_related_onetoone/__init__.py
No changes.
View
46 tests/regressiontests/select_related_onetoone/models.py
@@ -0,0 +1,46 @@
+from django.db import models
+
+
+class User(models.Model):
+ username = models.CharField(max_length=100)
+ email = models.EmailField()
+
+ def __unicode__(self):
+ return self.username
+
+
+class UserProfile(models.Model):
+ user = models.OneToOneField(User)
+ city = models.CharField(max_length=100)
+ state = models.CharField(max_length=2)
+
+ def __unicode__(self):
+ return "%s, %s" % (self.city, self.state)
+
+
+class UserStatResult(models.Model):
+ results = models.CharField(max_length=50)
+
+ def __unicode__(self):
+ return 'UserStatResults, results = %s' % (self.results,)
+
+
+class UserStat(models.Model):
+ user = models.OneToOneField(User, primary_key=True)
+ posts = models.IntegerField()
+ results = models.ForeignKey(UserStatResult)
+
+ def __unicode__(self):
+ return 'UserStat, posts = %s' % (self.posts,)
+
+
+class StatDetails(models.Model):
+ base_stats = models.OneToOneField(UserStat)
+ comments = models.IntegerField()
+
+ def __unicode__(self):
+ return 'StatDetails, comments = %s' % (self.comments,)
+
+
+class AdvancedUserStat(UserStat):
+ pass
View
83 tests/regressiontests/select_related_onetoone/tests.py
@@ -0,0 +1,83 @@
+from django import db
+from django.conf import settings
+from django.test import TestCase
+
+from models import User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat
+
+class ReverseSelectRelatedTestCase(TestCase):
+ def setUp(self):
+ # Explicitly enable debug for these tests - we need to count
+ # the queries that have been issued.
+ self.old_debug = settings.DEBUG
+ settings.DEBUG = True
+
+ user = User.objects.create(username="test")
+ userprofile = UserProfile.objects.create(user=user, state="KS",
+ city="Lawrence")
+ results = UserStatResult.objects.create(results='first results')
+ userstat = UserStat.objects.create(user=user, posts=150,
+ results=results)
+ details = StatDetails.objects.create(base_stats=userstat, comments=259)
+
+ user2 = User.objects.create(username="bob")
+ results2 = UserStatResult.objects.create(results='moar results')
+ advstat = AdvancedUserStat.objects.create(user=user2, posts=200,
+ results=results2)
+ StatDetails.objects.create(base_stats=advstat, comments=250)
+
+ db.reset_queries()
+
+ def assertQueries(self, queries):
+ self.assertEqual(len(db.connection.queries), queries)
+
+ def tearDown(self):
+ settings.DEBUG = self.old_debug
+
+ def test_basic(self):
+ u = User.objects.select_related("userprofile").get(username="test")
+ self.assertEqual(u.userprofile.state, "KS")
+ self.assertQueries(1)
+
+ def test_follow_next_level(self):
+ u = User.objects.select_related("userstat__results").get(username="test")
+ self.assertEqual(u.userstat.posts, 150)
+ self.assertEqual(u.userstat.results.results, 'first results')
+ self.assertQueries(1)
+
+ def test_follow_two(self):
+ u = User.objects.select_related("userprofile", "userstat").get(username="test")
+ self.assertEqual(u.userprofile.state, "KS")
+ self.assertEqual(u.userstat.posts, 150)
+ self.assertQueries(1)
+
+ def test_follow_two_next_level(self):
+ u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
+ self.assertEqual(u.userstat.results.results, 'first results')
+ self.assertEqual(u.userstat.statdetails.comments, 259)
+ self.assertQueries(1)
+
+ def test_forward_and_back(self):
+ stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
+ self.assertEqual(stat.user.userprofile.state, 'KS')
+ self.assertEqual(stat.user.userstat.posts, 150)
+ self.assertQueries(1)
+
+ def test_back_and_forward(self):
+ u = User.objects.select_related("userstat").get(username="test")
+ self.assertEqual(u.userstat.user.username, 'test')
+ self.assertQueries(1)
+
+ def test_not_followed_by_default(self):
+ u = User.objects.select_related().get(username="test")
+ self.assertEqual(u.userstat.posts, 150)
+ self.assertQueries(2)
+
+ def test_follow_from_child_class(self):
+ stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200)
+ self.assertEqual(stat.statdetails.comments, 250)
+ self.assertQueries(1)
+
+ def test_follow_inheritance(self):
+ stat = UserStat.objects.select_related('advanceduserstat').get(posts=200)
+ self.assertEqual(stat.advanceduserstat.posts, 200)
+ self.assertQueries(1)
Please sign in to comment.
Something went wrong with that request. Please try again.