Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Fixed #10182 -- Corrected realiasing and the process of evaluating va…

…lues() for queries with aggregate clauses. This means that aggregate queries can now be used as subqueries (such as in an __in clause). Thanks to omat for the report.

This involves a slight change to the interaction of annotate() and values() clauses that specify a list of columns. See the docs for details.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9888 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 542709d0d1796326dd1edacf32fc1198cfad2869 1 parent 4bd2447
Russell Keith-Magee authored February 23, 2009
2  django/db/models/aggregates.py
@@ -46,7 +46,7 @@ def add_to_query(self, query, alias, col, source, is_summary):
46 46
         # Validate that the backend has a fully supported, correct
47 47
         # implementation of this aggregate
48 48
         query.connection.ops.check_aggregate_support(aggregate)
49  
-        query.aggregate_select[alias] = aggregate
  49
+        query.aggregates[alias] = aggregate
50 50
 
51 51
 class Avg(Aggregate):
52 52
     name = 'Avg'
16  django/db/models/query.py
@@ -596,7 +596,7 @@ def annotate(self, *args, **kwargs):
596 596
 
597 597
         obj = self._clone()
598 598
 
599  
-        obj._setup_aggregate_query()
  599
+        obj._setup_aggregate_query(kwargs.keys())
600 600
 
601 601
         # Add the aggregates to the query
602 602
         for (alias, aggregate_expr) in kwargs.items():
@@ -693,7 +693,7 @@ def _merge_sanity_check(self, other):
693 693
         """
694 694
         pass
695 695
 
696  
-    def _setup_aggregate_query(self):
  696
+    def _setup_aggregate_query(self, aggregates):
697 697
         """
698 698
         Prepare the query for computing a result that contains aggregate annotations.
699 699
         """
@@ -773,6 +773,8 @@ def _setup_query(self):
773 773
 
774 774
         self.query.select = []
775 775
         self.query.add_fields(self.field_names, False)
  776
+        if self.aggregate_names is not None:
  777
+            self.query.set_aggregate_mask(self.aggregate_names)
776 778
 
777 779
     def _clone(self, klass=None, setup=False, **kwargs):
778 780
         """
@@ -798,13 +800,17 @@ def _merge_sanity_check(self, other):
798 800
             raise TypeError("Merging '%s' classes must involve the same values in each case."
799 801
                     % self.__class__.__name__)
800 802
 
801  
-    def _setup_aggregate_query(self):
  803
+    def _setup_aggregate_query(self, aggregates):
802 804
         """
803 805
         Prepare the query for computing a result that contains aggregate annotations.
804 806
         """
805 807
         self.query.set_group_by()
806 808
 
807  
-        super(ValuesQuerySet, self)._setup_aggregate_query()
  809
+        if self.aggregate_names is not None:
  810
+            self.aggregate_names.extend(aggregates)
  811
+            self.query.set_aggregate_mask(self.aggregate_names)
  812
+
  813
+        super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
808 814
 
809 815
     def as_sql(self):
810 816
         """
@@ -824,6 +830,7 @@ class ValuesListQuerySet(ValuesQuerySet):
824 830
     def iterator(self):
825 831
         if self.extra_names is not None:
826 832
             self.query.trim_extra_select(self.extra_names)
  833
+
827 834
         if self.flat and len(self._fields) == 1:
828 835
             for row in self.query.results_iter():
829 836
                 yield row[0]
@@ -837,6 +844,7 @@ def iterator(self):
837 844
             extra_names = self.query.extra_select.keys()
838 845
             field_names = self.field_names
839 846
             aggregate_names = self.query.aggregate_select.keys()
  847
+
840 848
             names = extra_names + field_names + aggregate_names
841 849
 
842 850
             # If a field list has been specified, use it. Otherwise, use the
70  django/db/models/sql/query.py
@@ -77,7 +77,9 @@ def __init__(self, model, connection, where=WhereNode):
77 77
         self.related_select_cols = []
78 78
 
79 79
         # SQL aggregate-related attributes
80  
-        self.aggregate_select = SortedDict() # Maps alias -> SQL aggregate function
  80
+        self.aggregates = SortedDict() # Maps alias -> SQL aggregate function
  81
+        self.aggregate_select_mask = None
  82
+        self._aggregate_select_cache = None
81 83
 
82 84
         # Arbitrary maximum limit for select_related. Prevents infinite
83 85
         # recursion. Can be changed by the depth parameter to select_related().
@@ -187,7 +189,15 @@ def clone(self, klass=None, **kwargs):
187 189
         obj.distinct = self.distinct
188 190
         obj.select_related = self.select_related
189 191
         obj.related_select_cols = []
190  
-        obj.aggregate_select = self.aggregate_select.copy()
  192
+        obj.aggregates = self.aggregates.copy()
  193
+        if self.aggregate_select_mask is None:
  194
+            obj.aggregate_select_mask = None
  195
+        else:
  196
+            obj.aggregate_select_mask = self.aggregate_select_mask[:]
  197
+        if self._aggregate_select_cache is None:
  198
+            obj._aggregate_select_cache = None
  199
+        else:
  200
+            obj._aggregate_select_cache = self._aggregate_select_cache.copy()
191 201
         obj.max_depth = self.max_depth
192 202
         obj.extra_select = self.extra_select.copy()
193 203
         obj.extra_tables = self.extra_tables
@@ -940,14 +950,17 @@ def change_aliases(self, change_map):
940 950
         """
941 951
         assert set(change_map.keys()).intersection(set(change_map.values())) == set()
942 952
 
943  
-        # 1. Update references in "select" and "where".
  953
+        # 1. Update references in "select" (normal columns plus aliases),
  954
+        # "group by", "where" and "having".
944 955
         self.where.relabel_aliases(change_map)
945  
-        for pos, col in enumerate(self.select):
946  
-            if isinstance(col, (list, tuple)):
947  
-                old_alias = col[0]
948  
-                self.select[pos] = (change_map.get(old_alias, old_alias), col[1])
949  
-            else:
950  
-                col.relabel_aliases(change_map)
  956
+        self.having.relabel_aliases(change_map)
  957
+        for columns in (self.select, self.aggregates.values(), self.group_by or []):
  958
+            for pos, col in enumerate(columns):
  959
+                if isinstance(col, (list, tuple)):
  960
+                    old_alias = col[0]
  961
+                    columns[pos] = (change_map.get(old_alias, old_alias), col[1])
  962
+                else:
  963
+                    col.relabel_aliases(change_map)
951 964
 
952 965
         # 2. Rename the alias in the internal table/alias datastructures.
953 966
         for old_alias, new_alias in change_map.iteritems():
@@ -1205,11 +1218,11 @@ def add_aggregate(self, aggregate, model, alias, is_summary):
1205 1218
         opts = model._meta
1206 1219
         field_list = aggregate.lookup.split(LOOKUP_SEP)
1207 1220
         if (len(field_list) == 1 and
1208  
-            aggregate.lookup in self.aggregate_select.keys()):
  1221
+            aggregate.lookup in self.aggregates.keys()):
1209 1222
             # Aggregate is over an annotation
1210 1223
             field_name = field_list[0]
1211 1224
             col = field_name
1212  
-            source = self.aggregate_select[field_name]
  1225
+            source = self.aggregates[field_name]
1213 1226
         elif (len(field_list) > 1 or
1214 1227
             field_list[0] not in [i.name for i in opts.fields]):
1215 1228
             field, source, opts, join_list, last, _ = self.setup_joins(
@@ -1299,7 +1312,7 @@ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
1299 1312
             value = SQLEvaluator(value, self)
1300 1313
             having_clause = value.contains_aggregate
1301 1314
 
1302  
-        for alias, aggregate in self.aggregate_select.items():
  1315
+        for alias, aggregate in self.aggregates.items():
1303 1316
             if alias == parts[0]:
1304 1317
                 entry = self.where_class()
1305 1318
                 entry.add((aggregate, lookup_type, value), AND)
@@ -1824,8 +1837,8 @@ def set_group_by(self):
1824 1837
         self.group_by = []
1825 1838
         if self.connection.features.allows_group_by_pk:
1826 1839
             if len(self.select) == len(self.model._meta.fields):
1827  
-                self.group_by.append('.'.join([self.model._meta.db_table,
1828  
-                                               self.model._meta.pk.column]))
  1840
+                self.group_by.append((self.model._meta.db_table,
  1841
+                                      self.model._meta.pk.column))
1829 1842
                 return
1830 1843
 
1831 1844
         for sel in self.select:
@@ -1858,7 +1871,11 @@ def add_count_column(self):
1858 1871
             # Distinct handling is done in Count(), so don't do it at this
1859 1872
             # level.
1860 1873
             self.distinct = False
1861  
-        self.aggregate_select = {None: count}
  1874
+
  1875
+        # Set only aggregate to be the count column.
  1876
+        # Clear out the select cache to reflect the new unmasked aggregates.
  1877
+        self.aggregates = {None: count}
  1878
+        self.set_aggregate_mask(None)
1862 1879
 
1863 1880
     def add_select_related(self, fields):
1864 1881
         """
@@ -1920,6 +1937,29 @@ def trim_extra_select(self, names):
1920 1937
         for key in set(self.extra_select).difference(set(names)):
1921 1938
             del self.extra_select[key]
1922 1939
 
  1940
+    def set_aggregate_mask(self, names):
  1941
+        "Set the mask of aggregates that will actually be returned by the SELECT"
  1942
+        self.aggregate_select_mask = names
  1943
+        self._aggregate_select_cache = None
  1944
+
  1945
+    def _aggregate_select(self):
  1946
+        """The SortedDict of aggregate columns that are not masked, and should
  1947
+        be used in the SELECT clause.
  1948
+
  1949
+        This result is cached for optimization purposes.
  1950
+        """
  1951
+        if self._aggregate_select_cache is not None:
  1952
+            return self._aggregate_select_cache
  1953
+        elif self.aggregate_select_mask is not None:
  1954
+            self._aggregate_select_cache = SortedDict([
  1955
+                (k,v) for k,v in self.aggregates.items()
  1956
+                if k in self.aggregate_select_mask
  1957
+            ])
  1958
+            return self._aggregate_select_cache
  1959
+        else:
  1960
+            return self.aggregates
  1961
+    aggregate_select = property(_aggregate_select)
  1962
+
1923 1963
     def set_start(self, start):
1924 1964
         """
1925 1965
         Sets the table from which to start joining. The start position is
12  django/db/models/sql/where.py
@@ -213,10 +213,14 @@ def relabel_aliases(self, change_map, node=None):
213 213
             elif isinstance(child, tree.Node):
214 214
                 self.relabel_aliases(change_map, child)
215 215
             else:
216  
-                elt = list(child[0])
217  
-                if elt[0] in change_map:
218  
-                    elt[0] = change_map[elt[0]]
219  
-                    node.children[pos] = (tuple(elt),) + child[1:]
  216
+                if isinstance(child[0], (list, tuple)):
  217
+                    elt = list(child[0])
  218
+                    if elt[0] in change_map:
  219
+                        elt[0] = change_map[elt[0]]
  220
+                        node.children[pos] = (tuple(elt),) + child[1:]
  221
+                else:
  222
+                    child[0].relabel_aliases(change_map)
  223
+
220 224
                 # Check if the query value also requires relabelling
221 225
                 if hasattr(child[3], 'relabel_aliases'):
222 226
                     child[3].relabel_aliases(change_map)
14  docs/topics/db/aggregation.txt
@@ -284,9 +284,6 @@ two authors with the same name, their results will be merged into a single
284 284
 result in the output of the query; the average will be computed as the
285 285
 average over the books written by both authors.
286 286
 
287  
-The annotation name will be added to the fields returned
288  
-as part of the ``ValuesQuerySet``.
289  
-
290 287
 Order of ``annotate()`` and ``values()`` clauses
291 288
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
292 289
 
@@ -303,12 +300,21 @@ output.
303 300
 For example, if we reverse the order of the ``values()`` and ``annotate()``
304 301
 clause from our previous example::
305 302
 
306  
-    >>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name')
  303
+    >>> Author.objects.annotate(average_rating=Avg('book__rating')).values('name', 'average_rating')
307 304
 
308 305
 This will now yield one unique result for each author; however, only
309 306
 the author's name and the ``average_rating`` annotation will be returned
310 307
 in the output data.
311 308
 
  309
+You should also note that ``average_rating`` has been explicitly included
  310
+in the list of values to be returned. This is required because of the
  311
+ordering of the ``values()`` and ``annotate()`` clause.
  312
+
  313
+If the ``values()`` clause precedes the ``annotate()`` clause, any annotations
  314
+will be automatically added to the result set. However, if the ``values()``
  315
+clause is applied after the ``annotate()`` clause, you need to explicitly
  316
+include the aggregate column.
  317
+
312 318
 Aggregating annotations
313 319
 -----------------------
314 320
 
5  tests/modeltests/aggregation/models.py
@@ -207,10 +207,9 @@ class Clues(models.Model):
207 207
 >>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('pk', 'isbn', 'mean_age')
208 208
 [{'pk': 1, 'isbn': u'159059725', 'mean_age': 34.5}]
209 209
 
210  
-# Calling it with paramters reduces the output but does not remove the
211  
-# annotation.
  210
+# Calling values() with parameters reduces the output
212 211
 >>> Book.objects.filter(pk=1).annotate(mean_age=Avg('authors__age')).values('name')
213  
-[{'name': u'The Definitive Guide to Django: Web Development Done Right', 'mean_age': 34.5}]
  212
+[{'name': u'The Definitive Guide to Django: Web Development Done Right'}]
214 213
 
215 214
 # An empty values() call before annotating has the same effect as an
216 215
 # empty values() call after annotating
15  tests/regressiontests/aggregation_regress/models.py
@@ -95,10 +95,18 @@ def __unicode__(self):
95 95
 >>> sorted(Book.objects.all().values().annotate(mean_auth_age=Avg('authors__age')).extra(select={'manufacture_cost' : 'price * .5'}).get(pk=2).items())
96 96
 [('contact_id', 3), ('id', 2), ('isbn', u'067232959'), ('manufacture_cost', ...11.545...), ('mean_auth_age', 45.0), ('name', u'Sams Teach Yourself Django in 24 Hours'), ('pages', 528), ('price', Decimal("23.09")), ('pubdate', datetime.date(2008, 3, 3)), ('publisher_id', 2), ('rating', 3.0)]
97 97
 
98  
-# A values query that selects specific columns reduces the output
  98
+# If the annotation precedes the values clause, it won't be included
  99
+# unless it is explicitly named
99 100
 >>> sorted(Book.objects.all().annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).values('name').get(pk=1).items())
  101
+[('name', u'The Definitive Guide to Django: Web Development Done Right')]
  102
+
  103
+>>> sorted(Book.objects.all().annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).values('name','mean_auth_age').get(pk=1).items())
100 104
 [('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')]
101 105
 
  106
+# If an annotation isn't included in the values, it can still be used in a filter
  107
+>>> Book.objects.annotate(n_authors=Count('authors')).values('name').filter(n_authors__gt=2)
  108
+[{'name': u'Python Web Development with Django'}]
  109
+
102 110
 # The annotations are added to values output if values() precedes annotate()
103 111
 >>> sorted(Book.objects.all().values('name').annotate(mean_auth_age=Avg('authors__age')).extra(select={'price_per_page' : 'price / pages'}).get(pk=1).items())
104 112
 [('mean_auth_age', 34.5), ('name', u'The Definitive Guide to Django: Web Development Done Right')]
@@ -207,6 +215,11 @@ def __unicode__(self):
207 215
 >>> Book.objects.extra(select={'pub':'publisher_id','foo':'pages'}).values('pub').annotate(Count('id')).order_by('pub')
208 216
 [{'pub': 1, 'id__count': 2}, {'pub': 2, 'id__count': 1}, {'pub': 3, 'id__count': 2}, {'pub': 4, 'id__count': 1}]
209 217
 
  218
+# Regression for #10182 - Queries with aggregate calls are correctly realiased when used in a subquery
  219
+>>> ids = Book.objects.filter(pages__gt=100).annotate(n_authors=Count('authors')).filter(n_authors__gt=2).order_by('n_authors')
  220
+>>> Book.objects.filter(id__in=ids)
  221
+[<Book: Python Web Development with Django>]
  222
+
210 223
 # Regression for #10199 - Aggregate calls clone the original query so the original query can still be used
211 224
 >>> books = Book.objects.all()
212 225
 >>> _ = books.aggregate(Avg('authors__age'))

0 notes on commit 542709d

Please sign in to comment.
Something went wrong with that request. Please try again.