Skip to content

Commit

Permalink
Fixed #17003 - prefetch_related should support foreign keys/one-to-one
Browse files Browse the repository at this point in the history
Support for `GenericForeignKey` is also included.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16939 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information
spookylukey committed Oct 7, 2011
1 parent 672f2db commit 052a011
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 120 deletions.
52 changes: 51 additions & 1 deletion django/contrib/contenttypes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
Classes allowing "generic" relations through ContentType and object-id fields.
"""

from collections import defaultdict
from functools import partial
from operator import attrgetter

from django.core.exceptions import ObjectDoesNotExist
from django.db import connection
from django.db.models import signals
Expand Down Expand Up @@ -59,6 +62,49 @@ def get_content_type(self, obj=None, id=None, using=None):
# This should never happen. I love comments like this, don't you?
raise Exception("Impossible arguments to GFK.get_content_type!")

def get_prefetch_query_set(self, instances):
# For efficiency, group the instances by content type and then do one
# query per model
fk_dict = defaultdict(list)
# We need one instance for each group in order to get the right db:
instance_dict = {}
ct_attname = self.model._meta.get_field(self.ct_field).get_attname()
for instance in instances:
# We avoid looking for values if either ct_id or fkey value is None
ct_id = getattr(instance, ct_attname)
if ct_id is not None:
fk_val = getattr(instance, self.fk_field)
if fk_val is not None:
fk_dict[ct_id].append(fk_val)
instance_dict[ct_id] = instance

ret_val = []
for ct_id, fkeys in fk_dict.items():
instance = instance_dict[ct_id]
ct = self.get_content_type(id=ct_id, using=instance._state.db)
ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))

# For doing the join in Python, we have to match both the FK val and the
# content type, so the 'attr' vals we return need to be callables that
# will return a (fk, class) pair.
def gfk_key(obj):
ct_id = getattr(obj, ct_attname)
if ct_id is None:
return None
else:
return (getattr(obj, self.fk_field),
self.get_content_type(id=ct_id,
using=obj._state.db).model_class())

return (ret_val,
lambda obj: (obj._get_pk_val(), obj.__class__),
gfk_key,
True,
self.cache_attr)

def is_cached(self, instance):
return hasattr(instance, self.cache_attr)

def __get__(self, instance, instance_type=None):
if instance is None:
return self
Expand Down Expand Up @@ -282,7 +328,11 @@ def get_prefetch_query_set(self, instances):
[obj._get_pk_val() for obj in instances]
}
qs = super(GenericRelatedObjectManager, self).get_query_set().using(db).filter(**query)
return (qs, self.object_id_field_name, 'pk')
return (qs,
attrgetter(self.object_id_field_name),
lambda obj: obj._get_pk_val(),
False,
self.prefetch_cache_name)

def add(self, *objs):
for obj in objs:
Expand Down
6 changes: 6 additions & 0 deletions django/contrib/contenttypes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,11 @@ def get_object_for_this_type(self, **kwargs):
"""
return self.model_class()._base_manager.using(self._state.db).get(**kwargs)

def get_all_objects_for_this_type(self, **kwargs):
"""
Returns all objects of this type for the keyword arguments given.
"""
return self.model_class()._base_manager.using(self._state.db).filter(**kwargs)

def natural_key(self):
return (self.app_label, self.model)
86 changes: 60 additions & 26 deletions django/db/models/fields/related.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from operator import attrgetter

from django.db import connection, router
from django.db.backends import util
from django.db.models import signals, get_model
Expand Down Expand Up @@ -227,15 +229,30 @@ def __init__(self, related):
self.related = related
self.cache_name = related.get_cache_name()

def is_cached(self, instance):
return hasattr(instance, self.cache_name)

def get_query_set(self, **db_hints):
db = router.db_for_read(self.related.model, **db_hints)
return self.related.model._base_manager.using(db)

def get_prefetch_query_set(self, instances):
vals = [instance._get_pk_val() for instance in instances]
params = {'%s__pk__in' % self.related.field.name: vals}
return (self.get_query_set(),
attrgetter(self.related.field.attname),
lambda obj: obj._get_pk_val(),
True,
self.cache_name)

def __get__(self, instance, instance_type=None):
if instance is None:
return self
try:
return getattr(instance, self.cache_name)
except AttributeError:
params = {'%s__pk' % self.related.field.name: instance._get_pk_val()}
db = router.db_for_read(self.related.model, instance=instance)
rel_obj = self.related.model._base_manager.using(db).get(**params)
rel_obj = self.get_query_set(instance=instance).get(**params)
setattr(instance, self.cache_name, rel_obj)
return rel_obj

Expand Down Expand Up @@ -283,14 +300,40 @@ class ReverseSingleRelatedObjectDescriptor(object):
# ReverseSingleRelatedObjectDescriptor instance.
def __init__(self, field_with_rel):
self.field = field_with_rel
self.cache_name = self.field.get_cache_name()

def is_cached(self, instance):
return hasattr(instance, self.cache_name)

def get_query_set(self, **db_hints):
db = router.db_for_read(self.field.rel.to, **db_hints)
rel_mgr = self.field.rel.to._default_manager
# If the related manager indicates that it should be used for
# related fields, respect that.
if getattr(rel_mgr, 'use_for_related_fields', False):
return rel_mgr.using(db)
else:
return QuerySet(self.field.rel.to).using(db)

def get_prefetch_query_set(self, instances):
vals = [getattr(instance, self.field.attname) for instance in instances]
other_field = self.field.rel.get_related_field()
if other_field.rel:
params = {'%s__pk__in' % self.field.rel.field_name: vals}
else:
params = {'%s__in' % self.field.rel.field_name: vals}
return (self.get_query_set().filter(**params),
attrgetter(self.field.rel.field_name),
attrgetter(self.field.attname),
True,
self.cache_name)

def __get__(self, instance, instance_type=None):
if instance is None:
return self

cache_name = self.field.get_cache_name()
try:
return getattr(instance, cache_name)
return getattr(instance, self.cache_name)
except AttributeError:
val = getattr(instance, self.field.attname)
if val is None:
Expand All @@ -303,16 +346,9 @@ def __get__(self, instance, instance_type=None):
params = {'%s__pk' % self.field.rel.field_name: val}
else:
params = {'%s__exact' % self.field.rel.field_name: val}

# If the related manager indicates that it should be used for
# related fields, respect that.
rel_mgr = self.field.rel.to._default_manager
db = router.db_for_read(self.field.rel.to, instance=instance)
if getattr(rel_mgr, 'use_for_related_fields', False):
rel_obj = rel_mgr.using(db).get(**params)
else:
rel_obj = QuerySet(self.field.rel.to).using(db).get(**params)
setattr(instance, cache_name, rel_obj)
qs = self.get_query_set(instance=instance)
rel_obj = qs.get(**params)
setattr(instance, self.cache_name, rel_obj)
return rel_obj

def __set__(self, instance, value):
Expand Down Expand Up @@ -425,15 +461,15 @@ def get_query_set(self):
return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)

def get_prefetch_query_set(self, instances):
"""
Return a queryset that does the bulk lookup needed
by prefetch_related functionality.
"""
db = self._db or router.db_for_read(self.model)
query = {'%s__%s__in' % (rel_field.name, attname):
[getattr(obj, attname) for obj in instances]}
qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
return (qs, rel_field.get_attname(), attname)
return (qs,
attrgetter(rel_field.get_attname()),
attrgetter(attname),
False,
rel_field.related_query_name())

def add(self, *objs):
for obj in objs:
Expand Down Expand Up @@ -507,12 +543,6 @@ def get_query_set(self):
return super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**self.core_filters)

def get_prefetch_query_set(self, instances):
"""
Returns a tuple:
(queryset of instances of self.model that are related to passed in instances
attr of returned instances needed for matching
attr of passed in instances needed for matching)
"""
from django.db import connections
db = self._db or router.db_for_read(self.model)
query = {'%s__pk__in' % self.query_field_name:
Expand All @@ -534,7 +564,11 @@ def get_prefetch_query_set(self, instances):
qs = qs.extra(select={'_prefetch_related_val':
'%s.%s' % (qn(join_table), qn(source_col))})
select_attname = fk.rel.get_related_field().get_attname()
return (qs, '_prefetch_related_val', select_attname)
return (qs,
attrgetter('_prefetch_related_val'),
attrgetter(select_attname),
False,
self.prefetch_cache_name)

# If the ManyToMany relation has an intermediary model,
# the add and remove methods do not exist.
Expand Down
Loading

0 comments on commit 052a011

Please sign in to comment.