Permalink
Browse files

Fixed #17001 -- Custom querysets for prefetch_related.

This patch introduces the Prefetch object which allows customizing prefetch
operations.

This enables things like filtering prefetched relations, calling select_related
from a prefetched relation, or prefetching the same relation multiple times
with different querysets.

When a Prefetch instance specifies a to_attr argument, the result is stored
in a list rather than a QuerySet. This has the fortunate consequence of being
significantly faster. The preformance improvement is due to the fact that we
save the costly creation of a QuerySet instance.

Thanks @akaariai for the original patch and @bmispelon and @timgraham
for the reviews.
  • Loading branch information...
loic authored and akaariai committed Nov 6, 2013
1 parent b1b04df commit f51c1f590085556abca44fd2a49618162203b2ec
@@ -76,7 +76,10 @@ def get_content_type(self, obj=None, id=None, using=None):
# This should never happen. I love comments like this, don't you?
raise Exception("Impossible arguments to GFK.get_content_type!")
- def get_prefetch_queryset(self, instances):
+ def get_prefetch_queryset(self, instances, queryset=None):
+ if queryset is not None:
+ raise ValueError("Custom queryset can't be used for this lookup.")
+
# For efficiency, group the instances by content type and then do one
# query per model
fk_dict = defaultdict(set)
@@ -348,17 +351,22 @@ def get_queryset(self):
db = self._db or router.db_for_read(self.model, instance=self.instance)
return super(GenericRelatedObjectManager, self).get_queryset().using(db).filter(**self.core_filters)
- def get_prefetch_queryset(self, instances):
- db = self._db or router.db_for_read(self.model, instance=instances[0])
+ def get_prefetch_queryset(self, instances, queryset=None):
+ if queryset is None:
+ queryset = super(GenericRelatedObjectManager, self).get_queryset()
+
+ queryset._add_hints(instance=instances[0])
+ queryset = queryset.using(queryset._db or self._db)
+
query = {
'%s__pk' % self.content_type_field_name: self.content_type.id,
'%s__in' % self.object_id_field_name: set(obj._get_pk_val() for obj in instances)
}
- qs = super(GenericRelatedObjectManager, self).get_queryset().using(db).filter(**query)
+
# We (possibly) need to convert object IDs to the type of the
# instances' PK in order to match up instances:
object_id_converter = instances[0]._meta.pk.to_python
- return (qs,
+ return (queryset.filter(**query),
lambda relobj: object_id_converter(getattr(relobj, self.object_id_field_name)),
lambda obj: obj._get_pk_val(),
False,
@@ -4,7 +4,7 @@
from django.db.models.loading import ( # NOQA
get_apps, get_app_path, get_app_paths, get_app, get_models, get_model,
register_models, UnavailableApp)
-from django.db.models.query import Q, QuerySet # NOQA
+from django.db.models.query import Q, QuerySet, Prefetch # NOQA
from django.db.models.expressions import F # NOQA
from django.db.models.manager import Manager # NOQA
from django.db.models.base import Model # NOQA
@@ -162,7 +162,10 @@ def is_cached(self, instance):
def get_queryset(self, **hints):
return self.related.model._base_manager.db_manager(hints=hints)
- def get_prefetch_queryset(self, instances):
+ def get_prefetch_queryset(self, instances, queryset=None):
+ if queryset is not None:
+ raise ValueError("Custom queryset can't be used for this lookup.")
+
rel_obj_attr = attrgetter(self.related.field.attname)
instance_attr = lambda obj: obj._get_pk_val()
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
@@ -264,7 +267,10 @@ def get_queryset(self, **hints):
else:
return QuerySet(self.field.rel.to, hints=hints)
- def get_prefetch_queryset(self, instances):
+ def get_prefetch_queryset(self, instances, queryset=None):
+ if queryset is not None:
+ raise ValueError("Custom queryset can't be used for this lookup.")
+
rel_obj_attr = self.field.get_foreign_related_value
instance_attr = self.field.get_local_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
@@ -397,23 +403,26 @@ def get_queryset(self):
qs._known_related_objects = {rel_field: {self.instance.pk: self.instance}}
return qs
- def get_prefetch_queryset(self, instances):
+ def get_prefetch_queryset(self, instances, queryset=None):
+ if queryset is None:
+ queryset = super(RelatedManager, self).get_queryset()
+
+ queryset._add_hints(instance=instances[0])
+ queryset = queryset.using(queryset._db or self._db)
+
rel_obj_attr = rel_field.get_local_related_value
instance_attr = rel_field.get_foreign_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
query = {'%s__in' % rel_field.name: instances}
- qs = super(RelatedManager, self).get_queryset()
- qs._add_hints(instance=instances[0])
- if self._db:
- qs = qs.using(self._db)
- qs = qs.filter(**query)
+ queryset = queryset.filter(**query)
+
# Since we just bypassed this class' get_queryset(), we must manage
# the reverse relation manually.
- for rel_obj in qs:
+ for rel_obj in queryset:
instance = instances_dict[rel_obj_attr(rel_obj)]
setattr(rel_obj, rel_field.name, instance)
cache_name = rel_field.related_query_name()
- return qs, rel_obj_attr, instance_attr, False, cache_name
+ return queryset, rel_obj_attr, instance_attr, False, cache_name
def add(self, *objs):
objs = list(objs)
@@ -563,15 +572,15 @@ def get_queryset(self):
qs = qs.using(self._db)
return qs._next_is_sticky().filter(**self.core_filters)
- def get_prefetch_queryset(self, instances):
- instance = instances[0]
- db = self._db or router.db_for_read(instance.__class__, instance=instance)
+ def get_prefetch_queryset(self, instances, queryset=None):
+ if queryset is None:
+ queryset = super(ManyRelatedManager, self).get_queryset()
+
+ queryset._add_hints(instance=instances[0])
+ queryset = queryset.using(queryset._db or self._db)
+
query = {'%s__in' % self.query_field_name: instances}
- qs = super(ManyRelatedManager, self).get_queryset()
- qs._add_hints(instance=instance)
- if self._db:
- qs = qs.using(db)
- qs = qs._next_is_sticky().filter(**query)
+ queryset = queryset._next_is_sticky().filter(**query)
# M2M: need to annotate the query in order to get the primary model
# that the secondary model was actually related to. We know that
@@ -582,12 +591,12 @@ def get_prefetch_queryset(self, instances):
# dealing with PK values.
fk = self.through._meta.get_field(self.source_field_name)
join_table = self.through._meta.db_table
- connection = connections[db]
+ connection = connections[queryset.db]
qn = connection.ops.quote_name
- qs = qs.extra(select=dict(
+ queryset = queryset.extra(select=dict(
('_prefetch_related_val_%s' % f.attname,
'%s.%s' % (qn(join_table), qn(f.column))) for f in fk.local_related_fields))
- return (qs,
+ return (queryset,
lambda result: tuple(getattr(result, '_prefetch_related_val_%s' % f.attname) for f in fk.local_related_fields),
lambda inst: tuple(getattr(inst, f.attname) for f in fk.foreign_related_fields),
False,
View
@@ -1619,46 +1619,103 @@ def model_fields(self):
return self._model_fields
+class Prefetch(object):
+ def __init__(self, lookup, queryset=None, to_attr=None):
+ # `prefetch_through` is the path we traverse to perform the prefetch.
+ self.prefetch_through = lookup
+ # `prefetch_to` is the path to the attribute that stores the result.
+ self.prefetch_to = lookup
+ if to_attr:
+ self.prefetch_to = LOOKUP_SEP.join(lookup.split(LOOKUP_SEP)[:-1] + [to_attr])
+
+ self.queryset = queryset
+ self.to_attr = to_attr
+
+ def add_prefix(self, prefix):
+ self.prefetch_through = LOOKUP_SEP.join([prefix, self.prefetch_through])
+ self.prefetch_to = LOOKUP_SEP.join([prefix, self.prefetch_to])
+
+ def get_current_prefetch_through(self, level):
+ return LOOKUP_SEP.join(self.prefetch_through.split(LOOKUP_SEP)[:level + 1])
+
+ def get_current_prefetch_to(self, level):
+ return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[:level + 1])
+
+ def get_current_to_attr(self, level):
+ parts = self.prefetch_to.split(LOOKUP_SEP)
+ to_attr = parts[level]
+ to_list = self.to_attr and level == len(parts) - 1
+ return to_attr, to_list
+
+ def get_current_queryset(self, level):
+ if self.get_current_prefetch_to(level) == self.prefetch_to:
+ return self.queryset
+ return None
+
+ def __eq__(self, other):
+ if isinstance(other, Prefetch):
+ return self.prefetch_to == other.prefetch_to
+ return False
+
+
+def normalize_prefetch_lookups(lookups, prefix=None):
+ """
+ Helper function that normalize lookups into Prefetch objects.
+ """
+ ret = []
+ for lookup in lookups:
+ if not isinstance(lookup, Prefetch):
+ lookup = Prefetch(lookup)
+ if prefix:
+ lookup.add_prefix(prefix)
+ ret.append(lookup)
+ return ret
+
+
def prefetch_related_objects(result_cache, related_lookups):
"""
Helper function for prefetch_related functionality
Populates prefetched objects caches for a list of results
from a QuerySet
"""
+
if len(result_cache) == 0:
return # nothing to do
+ related_lookups = normalize_prefetch_lookups(related_lookups)
+
# We need to be able to dynamically add to the list of prefetch_related
# lookups that we look up (see below). So we need some book keeping to
# ensure we don't do duplicate work.
- done_lookups = set() # list of lookups like foo__bar__baz
done_queries = {} # dictionary of things like 'foo__bar': [results]
auto_lookups = [] # we add to this as we go through.
followed_descriptors = set() # recursion protection
all_lookups = itertools.chain(related_lookups, auto_lookups)
for lookup in all_lookups:
- if lookup in done_lookups:
- # We've done exactly this already, skip the whole thing
+ if lookup.prefetch_to in done_queries:
+ if lookup.queryset:
+ raise ValueError("'%s' lookup was already seen with a different queryset. "
+ "You may need to adjust the ordering of your lookups." % lookup.prefetch_to)
+
continue
- done_lookups.add(lookup)
# Top level, the list of objects to decorate is the result cache
# from the primary QuerySet. It won't be for deeper levels.
obj_list = result_cache
- attrs = lookup.split(LOOKUP_SEP)
- for level, attr in enumerate(attrs):
+ through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
+ for level, through_attr in enumerate(through_attrs):
# Prepare main instances
if len(obj_list) == 0:
break
- current_lookup = LOOKUP_SEP.join(attrs[:level + 1])
- if current_lookup in done_queries:
+ prefetch_to = lookup.get_current_prefetch_to(level)
+ if prefetch_to in done_queries:
# Skip any prefetching, and any object preparation
- obj_list = done_queries[current_lookup]
+ obj_list = done_queries[prefetch_to]
continue
# Prepare objects:
@@ -1685,34 +1742,40 @@ def prefetch_related_objects(result_cache, related_lookups):
# We assume that objects retrieved are homogenous (which is the premise
# of prefetch_related), so what applies to first object applies to all.
first_obj = obj_list[0]
- prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(first_obj, attr)
+ prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(first_obj, through_attr)
if not attr_found:
raise AttributeError("Cannot find '%s' on %s object, '%s' is an invalid "
"parameter to prefetch_related()" %
- (attr, first_obj.__class__.__name__, lookup))
+ (through_attr, first_obj.__class__.__name__, lookup.prefetch_through))
- if level == len(attrs) - 1 and prefetcher is None:
+ if level == len(through_attrs) - 1 and prefetcher is None:
# Last one, this *must* resolve to something that supports
# prefetching, otherwise there is no point adding it and the
# developer asking for it has made a mistake.
raise ValueError("'%s' does not resolve to a item that supports "
"prefetching - this is an invalid parameter to "
- "prefetch_related()." % lookup)
+ "prefetch_related()." % lookup.prefetch_through)
if prefetcher is not None and not is_fetched:
- obj_list, additional_prl = prefetch_one_level(obj_list, prefetcher, attr)
+ obj_list, additional_lookups = prefetch_one_level(obj_list, prefetcher, lookup, level)
# We need to ensure we don't keep adding lookups from the
# same relationships to stop infinite recursion. So, if we
# are already on an automatically added lookup, don't add
# the new lookups from relationships we've seen already.
- if not (lookup in auto_lookups and
- descriptor in followed_descriptors):
- for f in additional_prl:
- new_prl = LOOKUP_SEP.join([current_lookup, f])
- auto_lookups.append(new_prl)
- done_queries[current_lookup] = obj_list
+ if not (lookup in auto_lookups and descriptor in followed_descriptors):
+ done_queries[prefetch_to] = obj_list
+ auto_lookups.extend(normalize_prefetch_lookups(additional_lookups, prefetch_to))
followed_descriptors.add(descriptor)
+ elif isinstance(getattr(first_obj, through_attr), list):
+ # The current part of the lookup relates to a custom Prefetch.
+ # This means that obj.attr is a list of related objects, and
+ # thus we must turn the obj.attr lists into a single related
+ # object list.
+ new_list = []
+ for obj in obj_list:
+ new_list.extend(getattr(obj, through_attr))
+ obj_list = new_list
else:
# Either a singly related object that has already been fetched
# (e.g. via select_related), or hopefully some other property
@@ -1724,7 +1787,7 @@ def prefetch_related_objects(result_cache, related_lookups):
new_obj_list = []
for obj in obj_list:
try:
- new_obj = getattr(obj, attr)
+ new_obj = getattr(obj, through_attr)
except exceptions.ObjectDoesNotExist:
continue
if new_obj is None:
@@ -1755,6 +1818,11 @@ def get_prefetcher(instance, attr):
try:
rel_obj = getattr(instance, attr)
attr_found = True
+ # If we are following a lookup path which leads us through a previous
+ # fetch from a custom Prefetch then we might end up into a list
+ # instead of related qs. This means the objects are already fetched.
+ if isinstance(rel_obj, list):
+ is_fetched = True
except AttributeError:
pass
else:
@@ -1776,7 +1844,7 @@ def get_prefetcher(instance, attr):
return prefetcher, rel_obj_descriptor, attr_found, is_fetched
-def prefetch_one_level(instances, prefetcher, attname):
+def prefetch_one_level(instances, prefetcher, lookup, level):
"""
Helper function for prefetch_related_objects
@@ -1799,14 +1867,14 @@ def prefetch_one_level(instances, prefetcher, attname):
# The 'values to be matched' must be hashable as they will be used
# in a dictionary.
- rel_qs, rel_obj_attr, instance_attr, single, cache_name =\
- prefetcher.get_prefetch_queryset(instances)
+ rel_qs, rel_obj_attr, instance_attr, single, cache_name = (
+ prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level)))
# We have to handle the possibility that the default manager itself added
# prefetch_related lookups to the QuerySet we just got back. We don't want to
# trigger the prefetch_related functionality by evaluating the query.
# Rather, we need to merge in the prefetch_related lookups.
- additional_prl = getattr(rel_qs, '_prefetch_related_lookups', [])
- if additional_prl:
+ additional_lookups = getattr(rel_qs, '_prefetch_related_lookups', [])
+ if additional_lookups:
# Don't need to clone because the manager should have given us a fresh
# instance, so we access an internal instead of using public interface
# for performance reasons.
@@ -1826,12 +1894,15 @@ def prefetch_one_level(instances, prefetcher, attname):
# Need to assign to single cache on instance
setattr(obj, cache_name, vals[0] if vals else None)
else:
- # Multi, attribute represents a manager with an .all() method that
- # returns a QuerySet
- qs = getattr(obj, attname).all()
- qs._result_cache = vals
- # We don't want the individual qs doing prefetch_related now, since we
- # have merged this into the current work.
- qs._prefetch_done = True
- obj._prefetched_objects_cache[cache_name] = qs
- return all_related_objects, additional_prl
+ to_attr, to_list = lookup.get_current_to_attr(level)
+ if to_list:
+ setattr(obj, to_attr, vals)
+ else:
+ # Cache in the QuerySet.all().
+ qs = getattr(obj, to_attr).all()
+ qs._result_cache = vals
+ # We don't want the individual qs doing prefetch_related now,
+ # since we have merged this into the current work.
+ qs._prefetch_done = True
+ obj._prefetched_objects_cache[cache_name] = qs
+ return all_related_objects, additional_lookups
Oops, something went wrong.

0 comments on commit f51c1f5

Please sign in to comment.