diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index 6c371557c58a5..85b39221288ad 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -1,3 +1,4 @@ +import itertools from collections import defaultdict from django.contrib.contenttypes.models import ContentType @@ -568,19 +569,28 @@ def get_prefetch_queryset(self, instances, queryset=None): queryset._add_hints(instance=instances[0]) queryset = queryset.using(queryset._db or self._db) - - query = { - '%s__pk' % self.content_type_field_name: self.content_type.id, - '%s__in' % self.object_id_field_name: {obj.pk for obj in instances} - } - + # Group instances by content types. + content_type_queries = iter( + models.Q(**{ + '%s__pk' % self.content_type_field_name: content_type_id, + '%s__in' % self.object_id_field_name: {obj.pk for obj in objs} + }) + for content_type_id, objs in + itertools.groupby(instances, lambda obj: self.get_content_type(obj).pk) + ) + query = next(content_type_queries) + for q in content_type_queries: + query |= q # We (possibly) need to convert object IDs to the type of the # instances' PK in order to match up instances: object_id_converter = instances[0]._meta.pk.to_python return ( - queryset.filter(**query), - lambda relobj: object_id_converter(getattr(relobj, self.object_id_field_name)), - lambda obj: obj.pk, + queryset.filter(query), + lambda relobj: ( + object_id_converter(getattr(relobj, self.object_id_field_name)), + relobj.content_type_id + ), + lambda obj: (obj.pk, self.get_content_type(obj).pk), False, self.prefetch_cache_name, False, diff --git a/tests/generic_relations/tests.py b/tests/generic_relations/tests.py index 2e99c5b5cf41a..7c0db95908999 100644 --- a/tests/generic_relations/tests.py +++ b/tests/generic_relations/tests.py @@ -546,6 +546,24 @@ def test_add_then_remove_after_prefetch(self): platypus.tags.remove(weird_tag) self.assertSequenceEqual(platypus.tags.all(), [furry_tag]) + def test_prefetch_related_different_content_types(self): + TaggedItem.objects.create(content_object=self.platypus, tag='prefetch_tag_1') + TaggedItem.objects.create( + content_object=Vegetable.objects.create(name='Broccoli'), + tag='prefetch_tag_2', + ) + TaggedItem.objects.create( + content_object=Animal.objects.create(common_name='Bear'), + tag='prefetch_tag_3', + ) + qs = TaggedItem.objects.filter( + tag__startswith='prefetch_tag_', + ).prefetch_related('content_object', 'content_object__tags') + with self.assertNumQueries(4): + tags = list(qs) + for tag in tags: + self.assertSequenceEqual(tag.content_object.tags.all(), [tag]) + class ProxyRelatedModelTest(TestCase): def test_default_behavior(self):