Skip to content

Commit

Permalink
Fixed #31910 -- Fixed crash of GIS aggregations over subqueries.
Browse files Browse the repository at this point in the history
Regression was introduced by fff5186 but was due a long standing issue.

AggregateQuery was abusing Query.subquery: bool by stashing its
compiled inner query's SQL for later use in its compiler which made
select_format checks for Query.subquery wrongly assume the provide
query was a subquery.

This patch prevents that from happening by using a dedicated
inner_query attribute which is compiled at a later time by
SQLAggregateCompiler.

Moving the inner query's compilation to SQLAggregateCompiler.compile
had the side effect of addressing a long standing issue with
aggregation subquery pushdown which prevented converters from being
run. This is now fixed as the aggregation_regress adjustments
demonstrate.

Refs #25367.

Thanks Eran Keydar for the report.
  • Loading branch information
charettes authored and felixxm committed Nov 4, 2020
1 parent 789c47e commit c2d4926
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 17 deletions.
7 changes: 5 additions & 2 deletions django/db/models/sql/compiler.py
Expand Up @@ -1596,8 +1596,11 @@ def as_sql(self):
sql = ', '.join(sql)
params = tuple(params)

sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery)
params = params + self.query.sub_params
inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
self.using
).as_sql(with_col_aliases=True)
sql = 'SELECT %s FROM (%s) subquery' % (sql, inner_query_sql)
params = params + inner_query_params
return sql, params


Expand Down
14 changes: 3 additions & 11 deletions django/db/models/sql/query.py
Expand Up @@ -17,9 +17,7 @@
from itertools import chain, count, product
from string import ascii_uppercase

from django.core.exceptions import (
EmptyResultSet, FieldDoesNotExist, FieldError,
)
from django.core.exceptions import FieldDoesNotExist, FieldError
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections
from django.db.models.aggregates import Count
from django.db.models.constants import LOOKUP_SEP
Expand Down Expand Up @@ -449,8 +447,9 @@ def get_aggregation(self, using, added_aggregate_names):
if (isinstance(self.group_by, tuple) or self.is_sliced or existing_annotations or
self.distinct or self.combinator):
from django.db.models.sql.subqueries import AggregateQuery
outer_query = AggregateQuery(self.model)
inner_query = self.clone()
inner_query.subquery = True
outer_query = AggregateQuery(self.model, inner_query)
inner_query.select_for_update = False
inner_query.select_related = False
inner_query.set_annotation_mask(self.annotation_select)
Expand Down Expand Up @@ -492,13 +491,6 @@ def get_aggregation(self, using, added_aggregate_names):
# field selected in the inner query, yet we must use a subquery.
# So, make sure at least one field is selected.
inner_query.select = (self.model._meta.pk.get_col(inner_query.get_initial_alias()),)
try:
outer_query.add_subquery(inner_query, using)
except EmptyResultSet:
return {
alias: None
for alias in outer_query.annotation_select
}
else:
outer_query = self
self.select = ()
Expand Down
6 changes: 3 additions & 3 deletions django/db/models/sql/subqueries.py
Expand Up @@ -157,6 +157,6 @@ class AggregateQuery(Query):

compiler = 'SQLAggregateCompiler'

def add_subquery(self, query, using):
query.subquery = True
self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True)
def __init__(self, model, inner_query):
self.inner_query = inner_query
super().__init__(model)
2 changes: 1 addition & 1 deletion tests/aggregation_regress/tests.py
Expand Up @@ -974,7 +974,7 @@ def test_empty_filter_count(self):
def test_empty_filter_aggregate(self):
self.assertEqual(
Author.objects.filter(id__in=[]).annotate(Count("friends")).aggregate(Count("pk")),
{"pk__count": None}
{"pk__count": 0}
)

def test_none_call_before_aggregate(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/gis_tests/geoapp/tests.py
Expand Up @@ -12,6 +12,7 @@
from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import F, OuterRef, Subquery
from django.test import TestCase, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext

from ..utils import (
mariadb, mysql, oracle, postgis, skipUnlessGISLookup, spatialite,
Expand Down Expand Up @@ -593,6 +594,19 @@ def test_unionagg(self):
qs = City.objects.filter(name='NotACity')
self.assertIsNone(qs.aggregate(Union('point'))['point__union'])

@skipUnlessDBFeature('supports_union_aggr')
def test_geoagg_subquery(self):
ks = State.objects.get(name='Kansas')
union = GEOSGeometry('MULTIPOINT(-95.235060 38.971823)')
# Use distinct() to force the usage of a subquery for aggregation.
with CaptureQueriesContext(connection) as ctx:
self.assertIs(union.equals(
City.objects.filter(point__within=ks.poly).distinct().aggregate(
Union('point'),
)['point__union'],
), True)
self.assertIn('subquery', ctx.captured_queries[0]['sql'])

@unittest.skipUnless(
connection.vendor == 'oracle',
'Oracle supports tolerance parameter.',
Expand Down

0 comments on commit c2d4926

Please sign in to comment.