Permalink
Browse files

adding support for prefetch_related

  • Loading branch information...
1 parent 673286e commit fb6501f1a7e88b48117ee1966d3536055569fe6a @marcinossowski marcinossowski committed Feb 18, 2014
Showing with 66 additions and 7 deletions.
  1. +44 −6 sortedm2m/fields.py
  2. +22 −1 sortedm2m_tests/sortedm2m_field/tests.py
View
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
+from operator import attrgetter
import sys
+from django.db import connections
from django.db import router
from django.db.models import signals
from django.db.models.fields.related import add_lazy_relation, create_many_related_manager
@@ -81,18 +83,54 @@ def get_query_set(self):
# the extra sorting field of the intermediary model. The fields
# are hidden for joins because we set ``auto_created`` on the
# intermediary's meta options.
- return super(SortedRelatedManager, self).\
- get_query_set().\
- extra(order_by=['%s.%s' % (
- rel.through._meta.db_table,
- rel.through._sort_field_name,
- )])
+ try:
+ return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
+ except (AttributeError, KeyError):
+ return super(SortedRelatedManager, self).\
+ get_query_set().\
+ extra(order_by=['%s.%s' % (
+ rel.through._meta.db_table,
+ rel.through._sort_field_name,
+ )])
if not hasattr(RelatedManager, '_get_fk_val'):
@property
def _fk_val(self):
return self._pk_val
+ def get_prefetch_query_set(self, instances):
+ # mostly a copy of get_prefetch_query_set from ManyRelatedManager
+ # but with addition of proper ordering
+ db = self._db or router.db_for_read(instances[0].__class__, instance=instances[0])
+ query = {'%s__pk__in' % self.query_field_name:
+ set(obj._get_pk_val() for obj in instances)}
+ qs = super(RelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**query)
+
+ # M2M: need to annotate the query in order to get the primary model
+ # that the secondary model was actually related to. We know that
+ # there will already be a join on the join table, so we can just add
+ # the select.
+
+ # For non-autocreated 'through' models, can't assume we are
+ # dealing with PK values.
+ fk = self.through._meta.get_field(self.source_field_name)
+ source_col = fk.column
+ join_table = self.through._meta.db_table
+ connection = connections[db]
+ qn = connection.ops.quote_name
+ qs = qs.extra(select={'_prefetch_related_val':
+ '%s.%s' % (qn(join_table), qn(source_col))},
+ order_by=['%s.%s' % (
+ rel.through._meta.db_table,
+ rel.through._sort_field_name,
+ )])
+ select_attname = fk.rel.get_related_field().get_attname()
+ return (qs,
+ attrgetter('_prefetch_related_val'),
+ attrgetter(select_attname),
+ False,
+ self.prefetch_cache_name)
+
def _add_items(self, source_field_name, target_field_name, *objs):
# source_field_name: the PK fieldname in join_table for the source object
# target_field_name: the PK fieldname in join_table for the target object
@@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
+from django.db import connection
from django.db.models.fields import FieldDoesNotExist
from django.test import TestCase
+from django.test.utils import override_settings
from django.utils import six
from sortedm2m_tests.models import Book, Shelf, DoItYourselfShelf, Store, \
MessyStore, SelfReference
-
str_ = six.text_type
@@ -157,6 +158,26 @@ def test_remove_items_by_pk(self):
# self.books[3],
# self.books[4]])
+ # to enable population of connection.queries
+ @override_settings(DEBUG=True)
+ def test_prefetch_related_queries_num(self):
+ shelf = self.model.objects.create()
+ shelf.books.add(self.books[0])
+
+ shelf = self.model.objects.filter(pk=shelf.pk).prefetch_related('books')[0]
+ queries_num = len(connection.queries)
+ name = shelf.books.all()[0].name
+ self.assertEqual(queries_num, len(connection.queries))
+
+ def test_prefetch_related_sorting(self):
+ shelf = self.model.objects.create()
+ books = [self.books[0], self.books[2], self.books[1]]
+ shelf.books = books
+
+ shelf = self.model.objects.filter(pk=shelf.pk).prefetch_related('books')[0]
+ def get_ids(queryset):
+ return [obj.id for obj in queryset]
+ self.assertEqual(get_ids(shelf.books.all()), get_ids(books))
class TestStringReference(TestSortedManyToManyField):
'''

0 comments on commit fb6501f

Please sign in to comment.