Permalink
Browse files

Fixed #17003 - prefetch_related should support foreign keys/one-to-one

Support for `GenericForeignKey` is also included.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16939 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
1 parent 672f2db commit 052a011ee6122482a471795c1994bbcfdb069611 @spookylukey spookylukey committed Oct 7, 2011
@@ -2,7 +2,10 @@
Classes allowing "generic" relations through ContentType and object-id fields.
"""
+from collections import defaultdict
from functools import partial
+from operator import attrgetter
+
from django.core.exceptions import ObjectDoesNotExist
from django.db import connection
from django.db.models import signals
@@ -59,6 +62,49 @@ def get_content_type(self, obj=None, id=None, using=None):
# This should never happen. I love comments like this, don't you?
raise Exception("Impossible arguments to GFK.get_content_type!")
+ def get_prefetch_query_set(self, instances):
+ # For efficiency, group the instances by content type and then do one
+ # query per model
+ fk_dict = defaultdict(list)
+ # We need one instance for each group in order to get the right db:
+ instance_dict = {}
+ ct_attname = self.model._meta.get_field(self.ct_field).get_attname()
+ for instance in instances:
+ # We avoid looking for values if either ct_id or fkey value is None
+ ct_id = getattr(instance, ct_attname)
+ if ct_id is not None:
+ fk_val = getattr(instance, self.fk_field)
+ if fk_val is not None:
+ fk_dict[ct_id].append(fk_val)
+ instance_dict[ct_id] = instance
+
+ ret_val = []
+ for ct_id, fkeys in fk_dict.items():
+ instance = instance_dict[ct_id]
+ ct = self.get_content_type(id=ct_id, using=instance._state.db)
+ ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
+
+ # For doing the join in Python, we have to match both the FK val and the
+ # content type, so the 'attr' vals we return need to be callables that
+ # will return a (fk, class) pair.
+ def gfk_key(obj):
+ ct_id = getattr(obj, ct_attname)
+ if ct_id is None:
+ return None
+ else:
+ return (getattr(obj, self.fk_field),
+ self.get_content_type(id=ct_id,
+ using=obj._state.db).model_class())
+
+ return (ret_val,
+ lambda obj: (obj._get_pk_val(), obj.__class__),
+ gfk_key,
+ True,
+ self.cache_attr)
+
+ def is_cached(self, instance):
+ return hasattr(instance, self.cache_attr)
+
def __get__(self, instance, instance_type=None):
if instance is None:
return self
@@ -282,7 +328,11 @@ def get_prefetch_query_set(self, instances):
[obj._get_pk_val() for obj in instances]
}
qs = super(GenericRelatedObjectManager, self).get_query_set().using(db).filter(**query)
- return (qs, self.object_id_field_name, 'pk')
+ return (qs,
+ attrgetter(self.object_id_field_name),
+ lambda obj: obj._get_pk_val(),
+ False,
+ self.prefetch_cache_name)
def add(self, *objs):
for obj in objs:
@@ -113,5 +113,11 @@ def get_object_for_this_type(self, **kwargs):
"""
return self.model_class()._base_manager.using(self._state.db).get(**kwargs)
+ def get_all_objects_for_this_type(self, **kwargs):
+ """
+ Returns all objects of this type for the keyword arguments given.
+ """
+ return self.model_class()._base_manager.using(self._state.db).filter(**kwargs)
+
def natural_key(self):
return (self.app_label, self.model)
@@ -1,3 +1,5 @@
+from operator import attrgetter
+
from django.db import connection, router
from django.db.backends import util
from django.db.models import signals, get_model
@@ -227,15 +229,30 @@ def __init__(self, related):
self.related = related
self.cache_name = related.get_cache_name()
+ def is_cached(self, instance):
+ return hasattr(instance, self.cache_name)
+
+ def get_query_set(self, **db_hints):
+ db = router.db_for_read(self.related.model, **db_hints)
+ return self.related.model._base_manager.using(db)
+
+ def get_prefetch_query_set(self, instances):
+ vals = [instance._get_pk_val() for instance in instances]
+ params = {'%s__pk__in' % self.related.field.name: vals}
+ return (self.get_query_set(),
+ attrgetter(self.related.field.attname),
+ lambda obj: obj._get_pk_val(),
+ True,
+ self.cache_name)
+
def __get__(self, instance, instance_type=None):
if instance is None:
return self
try:
return getattr(instance, self.cache_name)
except AttributeError:
params = {'%s__pk' % self.related.field.name: instance._get_pk_val()}
- db = router.db_for_read(self.related.model, instance=instance)
- rel_obj = self.related.model._base_manager.using(db).get(**params)
+ rel_obj = self.get_query_set(instance=instance).get(**params)
setattr(instance, self.cache_name, rel_obj)
return rel_obj
@@ -283,14 +300,40 @@ class ReverseSingleRelatedObjectDescriptor(object):
# ReverseSingleRelatedObjectDescriptor instance.
def __init__(self, field_with_rel):
self.field = field_with_rel
+ self.cache_name = self.field.get_cache_name()
+
+ def is_cached(self, instance):
+ return hasattr(instance, self.cache_name)
+
+ def get_query_set(self, **db_hints):
+ db = router.db_for_read(self.field.rel.to, **db_hints)
+ rel_mgr = self.field.rel.to._default_manager
+ # If the related manager indicates that it should be used for
+ # related fields, respect that.
+ if getattr(rel_mgr, 'use_for_related_fields', False):
+ return rel_mgr.using(db)
+ else:
+ return QuerySet(self.field.rel.to).using(db)
+
+ def get_prefetch_query_set(self, instances):
+ vals = [getattr(instance, self.field.attname) for instance in instances]
+ other_field = self.field.rel.get_related_field()
+ if other_field.rel:
+ params = {'%s__pk__in' % self.field.rel.field_name: vals}
+ else:
+ params = {'%s__in' % self.field.rel.field_name: vals}
+ return (self.get_query_set().filter(**params),
+ attrgetter(self.field.rel.field_name),
+ attrgetter(self.field.attname),
+ True,
+ self.cache_name)
def __get__(self, instance, instance_type=None):
if instance is None:
return self
- cache_name = self.field.get_cache_name()
try:
- return getattr(instance, cache_name)
+ return getattr(instance, self.cache_name)
except AttributeError:
val = getattr(instance, self.field.attname)
if val is None:
@@ -303,16 +346,9 @@ def __get__(self, instance, instance_type=None):
params = {'%s__pk' % self.field.rel.field_name: val}
else:
params = {'%s__exact' % self.field.rel.field_name: val}
-
- # If the related manager indicates that it should be used for
- # related fields, respect that.
- rel_mgr = self.field.rel.to._default_manager
- db = router.db_for_read(self.field.rel.to, instance=instance)
- if getattr(rel_mgr, 'use_for_related_fields', False):
- rel_obj = rel_mgr.using(db).get(**params)
- else:
- rel_obj = QuerySet(self.field.rel.to).using(db).get(**params)
- setattr(instance, cache_name, rel_obj)
+ qs = self.get_query_set(instance=instance)
+ rel_obj = qs.get(**params)
+ setattr(instance, self.cache_name, rel_obj)
return rel_obj
def __set__(self, instance, value):
@@ -425,15 +461,15 @@ def get_query_set(self):
return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
def get_prefetch_query_set(self, instances):
- """
- Return a queryset that does the bulk lookup needed
- by prefetch_related functionality.
- """
db = self._db or router.db_for_read(self.model)
query = {'%s__%s__in' % (rel_field.name, attname):
[getattr(obj, attname) for obj in instances]}
qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
- return (qs, rel_field.get_attname(), attname)
+ return (qs,
+ attrgetter(rel_field.get_attname()),
+ attrgetter(attname),
+ False,
+ rel_field.related_query_name())
def add(self, *objs):
for obj in objs:
@@ -507,12 +543,6 @@ def get_query_set(self):
return super(ManyRelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**self.core_filters)
def get_prefetch_query_set(self, instances):
- """
- Returns a tuple:
- (queryset of instances of self.model that are related to passed in instances
- attr of returned instances needed for matching
- attr of passed in instances needed for matching)
- """
from django.db import connections
db = self._db or router.db_for_read(self.model)
query = {'%s__pk__in' % self.query_field_name:
@@ -534,7 +564,11 @@ def get_prefetch_query_set(self, instances):
qs = qs.extra(select={'_prefetch_related_val':
'%s.%s' % (qn(join_table), qn(source_col))})
select_attname = fk.rel.get_related_field().get_attname()
- return (qs, '_prefetch_related_val', select_attname)
+ return (qs,
+ attrgetter('_prefetch_related_val'),
+ attrgetter(select_attname),
+ False,
+ self.prefetch_cache_name)
# If the ManyToMany relation has an intermediary model,
# the add and remove methods do not exist.
Oops, something went wrong.

0 comments on commit 052a011

Please sign in to comment.