Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Reconciling where- and having-clause behaviour.

Extricated the code that works directly with SQL columns (standard
"where" stuff) from the the code that takes SQL fragments and combines
it with lookup types and values. The latter portion is now more
generally reusable. Any existing code that was poking at Query.having
will now break in very visible ways (no subtle miscalculations, which is
a good thing).

This patch, en passant, removes the existing "having" test, since the
new implementation requires more setting up than previously. The
aggregates support (currently in a separate codebase) has tests for this
functionality that work as a replacement for the removed test.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9700 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 062a94ef45e71f525de1c9095aeb3f376feb8232 1 parent ff4b844
Malcolm Tredinnick authored January 05, 2009
40  django/db/models/sql/query.py
@@ -16,7 +16,7 @@
16 16
 from django.db.models import signals
17 17
 from django.db.models.fields import FieldDoesNotExist
18 18
 from django.db.models.query_utils import select_related_descend
19  
-from django.db.models.sql.where import WhereNode, EverythingNode, AND, OR
  19
+from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR
20 20
 from django.db.models.sql.datastructures import Count
21 21
 from django.core.exceptions import FieldError
22 22
 from datastructures import EmptyResultSet, Empty, MultiJoin
@@ -66,7 +66,7 @@ def __init__(self, model, connection, where=WhereNode):
66 66
         self.where = where()
67 67
         self.where_class = where
68 68
         self.group_by = []
69  
-        self.having = []
  69
+        self.having = where()
70 70
         self.order_by = []
71 71
         self.low_mark, self.high_mark = 0, None  # Used for offset/limit
72 72
         self.distinct = False
@@ -172,7 +172,7 @@ def clone(self, klass=None, **kwargs):
172 172
         obj.where = deepcopy(self.where)
173 173
         obj.where_class = self.where_class
174 174
         obj.group_by = self.group_by[:]
175  
-        obj.having = self.having[:]
  175
+        obj.having = deepcopy(self.having)
176 176
         obj.order_by = self.order_by[:]
177 177
         obj.low_mark, obj.high_mark = self.low_mark, self.high_mark
178 178
         obj.distinct = self.distinct
@@ -261,7 +261,9 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
261 261
         # get_from_clause() for details.
262 262
         from_, f_params = self.get_from_clause()
263 263
 
264  
-        where, w_params = self.where.as_sql(qn=self.quote_name_unless_alias)
  264
+        qn = self.quote_name_unless_alias
  265
+        where, w_params = self.where.as_sql(qn=qn)
  266
+        having, h_params = self.having.as_sql(qn=qn)
265 267
         params = []
266 268
         for val in self.extra_select.itervalues():
267 269
             params.extend(val[1])
@@ -291,9 +293,8 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
291 293
             if not ordering:
292 294
                 ordering = self.connection.ops.force_no_ordering()
293 295
 
294  
-        if self.having:
295  
-            having, h_params = self.get_having()
296  
-            result.append('HAVING %s' % ', '.join(having))
  296
+        if having:
  297
+            result.append('HAVING %s' % having)
297 298
             params.extend(h_params)
298 299
 
299 300
         if ordering:
@@ -577,24 +578,6 @@ def get_grouping(self):
577 578
                 result.append(str(col))
578 579
         return result
579 580
 
580  
-    def get_having(self):
581  
-        """
582  
-        Returns a tuple representing the SQL elements in the "having" clause.
583  
-        By default, the elements of self.having have their as_sql() method
584  
-        called or are returned unchanged (if they don't have an as_sql()
585  
-        method).
586  
-        """
587  
-        result = []
588  
-        params = []
589  
-        for elt in self.having:
590  
-            if hasattr(elt, 'as_sql'):
591  
-                sql, params = elt.as_sql()
592  
-                result.append(sql)
593  
-                params.extend(params)
594  
-            else:
595  
-                result.append(elt)
596  
-        return result, params
597  
-
598 581
     def get_ordering(self):
599 582
         """
600 583
         Returns list representing the SQL elements in the "order by" clause.
@@ -1197,7 +1180,8 @@ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
1197 1180
             self.promote_alias_chain(join_it, join_promote)
1198 1181
             self.promote_alias_chain(table_it, table_promote)
1199 1182
 
1200  
-        self.where.add((alias, col, field, lookup_type, value), connector)
  1183
+        self.where.add((Constraint(alias, col, field), lookup_type, value),
  1184
+            connector)
1201 1185
 
1202 1186
         if negate:
1203 1187
             self.promote_alias_chain(join_list)
@@ -1207,7 +1191,7 @@ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
1207 1191
                         if self.alias_map[alias][JOIN_TYPE] == self.LOUTER:
1208 1192
                             j_col = self.alias_map[alias][RHS_JOIN_COL]
1209 1193
                             entry = self.where_class()
1210  
-                            entry.add((alias, j_col, None, 'isnull', True), AND)
  1194
+                            entry.add((Constraint(alias, j_col, None), 'isnull', True), AND)
1211 1195
                             entry.negate()
1212 1196
                             self.where.add(entry, AND)
1213 1197
                             break
@@ -1216,7 +1200,7 @@ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
1216 1200
                     # exclude the "foo__in=[]" case from this handling, because
1217 1201
                     # it's short-circuited in the Where class.
1218 1202
                     entry = self.where_class()
1219  
-                    entry.add((alias, col, None, 'isnull', True), AND)
  1203
+                    entry.add((Constraint(alias, col, None), 'isnull', True), AND)
1220 1204
                     entry.negate()
1221 1205
                     self.where.add(entry, AND)
1222 1206
 
15  django/db/models/sql/subqueries.py
@@ -6,7 +6,7 @@
6 6
 from django.db.models.sql.constants import *
7 7
 from django.db.models.sql.datastructures import Date
8 8
 from django.db.models.sql.query import Query
9  
-from django.db.models.sql.where import AND
  9
+from django.db.models.sql.where import AND, Constraint
10 10
 
11 11
 __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
12 12
         'CountQuery']
@@ -48,8 +48,9 @@ def delete_batch_related(self, pk_list):
48 48
             if not isinstance(related.field, generic.GenericRelation):
49 49
                 for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
50 50
                     where = self.where_class()
51  
-                    where.add((None, related.field.m2m_reverse_name(),
52  
-                            related.field, 'in',
  51
+                    where.add((Constraint(None,
  52
+                            related.field.m2m_reverse_name(), related.field),
  53
+                            'in',
53 54
                             pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]),
54 55
                             AND)
55 56
                     self.do_query(related.field.m2m_db_table(), where)
@@ -59,11 +60,11 @@ def delete_batch_related(self, pk_list):
59 60
             if isinstance(f, generic.GenericRelation):
60 61
                 from django.contrib.contenttypes.models import ContentType
61 62
                 field = f.rel.to._meta.get_field(f.content_type_field_name)
62  
-                w1.add((None, field.column, field, 'exact',
  63
+                w1.add((Constraint(None, field.column, field), 'exact',
63 64
                         ContentType.objects.get_for_model(cls).id), AND)
64 65
             for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
65 66
                 where = self.where_class()
66  
-                where.add((None, f.m2m_column_name(), f, 'in',
  67
+                where.add((Constraint(None, f.m2m_column_name(), f), 'in',
67 68
                         pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
68 69
                         AND)
69 70
                 if w1:
@@ -81,7 +82,7 @@ def delete_batch(self, pk_list):
81 82
         for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
82 83
             where = self.where_class()
83 84
             field = self.model._meta.pk
84  
-            where.add((None, field.column, field, 'in',
  85
+            where.add((Constraint(None, field.column, field), 'in',
85 86
                     pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
86 87
             self.do_query(self.model._meta.db_table, where)
87 88
 
@@ -212,7 +213,7 @@ def clear_related(self, related_field, pk_list):
212 213
         for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
213 214
             self.where = self.where_class()
214 215
             f = self.model._meta.pk
215  
-            self.where.add((None, f.column, f, 'in',
  216
+            self.where.add((Constraint(None, f.column, f), 'in',
216 217
                     pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
217 218
                     AND)
218 219
             self.values = [(related_field.column, None, '%s')]
110  django/db/models/sql/where.py
@@ -13,6 +13,13 @@
13 13
 AND = 'AND'
14 14
 OR = 'OR'
15 15
 
  16
+class EmptyShortCircuit(Exception):
  17
+    """
  18
+    Internal exception used to indicate that a "matches nothing" node should be
  19
+    added to the where-clause.
  20
+    """
  21
+    pass
  22
+
16 23
 class WhereNode(tree.Node):
17 24
     """
18 25
     Used to represent the SQL where-clause.
@@ -35,36 +42,35 @@ def add(self, data, connector):
35 42
         storing any reference to field objects). Otherwise, the 'data' is
36 43
         stored unchanged and can be anything with an 'as_sql()' method.
37 44
         """
38  
-        # Because of circular imports, we need to import this here.
39  
-        from django.db.models.base import ObjectDoesNotExist
40  
-
41 45
         if not isinstance(data, (list, tuple)):
42 46
             super(WhereNode, self).add(data, connector)
43 47
             return
44 48
 
45  
-        alias, col, field, lookup_type, value = data
46  
-        try:
47  
-            if field:
48  
-                params = field.get_db_prep_lookup(lookup_type, value)
49  
-                db_type = field.db_type()
50  
-            else:
51  
-                # This is possible when we add a comparison to NULL sometimes
52  
-                # (we don't really need to waste time looking up the associated
53  
-                # field object).
54  
-                params = Field().get_db_prep_lookup(lookup_type, value)
55  
-                db_type = None
56  
-        except ObjectDoesNotExist:
57  
-            # This can happen when trying to insert a reference to a null pk.
58  
-            # We break out of the normal path and indicate there's nothing to
59  
-            # match.
60  
-            super(WhereNode, self).add(NothingNode(), connector)
61  
-            return
  49
+        obj, lookup_type, value = data
  50
+        if hasattr(obj, "process"):
  51
+            try:
  52
+                obj, params = obj.process(lookup_type, value)
  53
+            except EmptyShortCircuit:
  54
+                # There are situations where we want to short-circuit any
  55
+                # comparisons and make sure that nothing is returned. One
  56
+                # example is when checking for a NULL pk value, or the
  57
+                # equivalent.
  58
+                super(WhereNode, self).add(NothingNode(), connector)
  59
+                return
  60
+        else:
  61
+            params = Field().get_db_prep_lookup(lookup_type, value)
  62
+
  63
+        # The "annotation" parameter is used to pass auxilliary information
  64
+        # about the value(s) to the query construction. Specifically, datetime
  65
+        # and empty values need special handling. Other types could be used
  66
+        # here in the future (using Python types is suggested for consistency).
62 67
         if isinstance(value, datetime.datetime):
63 68
             annotation = datetime.datetime
64 69
         else:
65 70
             annotation = bool(value)
66  
-        super(WhereNode, self).add((alias, col, db_type, lookup_type,
67  
-                annotation, params), connector)
  71
+
  72
+        super(WhereNode, self).add((obj, lookup_type, annotation, params),
  73
+                connector)
68 74
 
69 75
     def as_sql(self, qn=None):
70 76
         """
@@ -130,12 +136,13 @@ def make_atom(self, child, qn):
130 136
         Returns the string for the SQL fragment and the parameters to use for
131 137
         it.
132 138
         """
133  
-        table_alias, name, db_type, lookup_type, value_annot, params = child
134  
-        if table_alias:
135  
-            lhs = '%s.%s' % (qn(table_alias), qn(name))
  139
+        lvalue, lookup_type, value_annot, params = child
  140
+        if isinstance(lvalue, tuple):
  141
+            # A direct database column lookup.
  142
+            field_sql = self.sql_for_columns(lvalue, qn)
136 143
         else:
137  
-            lhs = qn(name)
138  
-        field_sql = connection.ops.field_cast_sql(db_type) % lhs
  144
+            # A smart object with an as_sql() method.
  145
+            field_sql = lvalue.as_sql(quote_func=qn)
139 146
 
140 147
         if value_annot is datetime.datetime:
141 148
             cast_sql = connection.ops.datetime_cast_sql()
@@ -175,6 +182,19 @@ def make_atom(self, child, qn):
175 182
 
176 183
         raise TypeError('Invalid lookup_type: %r' % lookup_type)
177 184
 
  185
+    def sql_for_columns(self, data, qn):
  186
+        """
  187
+        Returns the SQL fragment used for the left-hand side of a column
  188
+        constraint (for example, the "T1.foo" portion in the clause
  189
+        "WHERE ... T1.foo = 6").
  190
+        """
  191
+        table_alias, name, db_type = data
  192
+        if table_alias:
  193
+            lhs = '%s.%s' % (qn(table_alias), qn(name))
  194
+        else:
  195
+            lhs = qn(name)
  196
+        return connection.ops.field_cast_sql(db_type) % lhs
  197
+
178 198
     def relabel_aliases(self, change_map, node=None):
179 199
         """
180 200
         Relabels the alias values of any children. 'change_map' is a dictionary
@@ -188,8 +208,10 @@ def relabel_aliases(self, change_map, node=None):
188 208
             elif isinstance(child, tree.Node):
189 209
                 self.relabel_aliases(change_map, child)
190 210
             else:
191  
-                if child[0] in change_map:
192  
-                    node.children[pos] = (change_map[child[0]],) + child[1:]
  211
+                elt = list(child[0])
  212
+                if elt[0] in change_map:
  213
+                    elt[0] = change_map[elt[0]]
  214
+                    node.children[pos] = (tuple(elt),) + child[1:]
193 215
 
194 216
 class EverythingNode(object):
195 217
     """
@@ -211,3 +233,33 @@ def as_sql(self, qn=None):
211 233
     def relabel_aliases(self, change_map, node=None):
212 234
         return
213 235
 
  236
+class Constraint(object):
  237
+    """
  238
+    An object that can be passed to WhereNode.add() and knows how to
  239
+    pre-process itself prior to including in the WhereNode.
  240
+    """
  241
+    def __init__(self, alias, col, field):
  242
+        self.alias, self.col, self.field = alias, col, field
  243
+
  244
+    def process(self, lookup_type, value):
  245
+        """
  246
+        Returns a tuple of data suitable for inclusion in a WhereNode
  247
+        instance.
  248
+        """
  249
+        # Because of circular imports, we need to import this here.
  250
+        from django.db.models.base import ObjectDoesNotExist
  251
+        try:
  252
+            if self.field:
  253
+                params = self.field.get_db_prep_lookup(lookup_type, value)
  254
+                db_type = self.field.db_type()
  255
+            else:
  256
+                # This branch is used at times when we add a comparison to NULL
  257
+                # (we don't really want to waste time looking up the associated
  258
+                # field object at the calling location).
  259
+                params = Field().get_db_prep_lookup(lookup_type, value)
  260
+                db_type = None
  261
+        except ObjectDoesNotExist:
  262
+            raise EmptyShortCircuit
  263
+
  264
+        return (self.alias, self.col, db_type), params
  265
+
13  tests/regressiontests/queries/models.py
@@ -973,19 +973,6 @@ class PointerB(models.Model):
973 973
 >>> len([x[2] for x in q.alias_map.values() if x[2] == q.LOUTER and q.alias_refcount[x[1]]])
974 974
 1
975 975
 
976  
-A check to ensure we don't break the internal query construction of GROUP BY
977  
-and HAVING. These aren't supported in the public API, but the Query class knows
978  
-about them and shouldn't do bad things.
979  
->>> qs = Tag.objects.values_list('parent_id', flat=True).order_by()
980  
->>> qs.query.group_by = ['parent_id']
981  
->>> qs.query.having = ['count(parent_id) > 1']
982  
->>> expected = [t3.parent_id, t4.parent_id]
983  
->>> expected.sort()
984  
->>> result = list(qs)
985  
->>> result.sort()
986  
->>> expected == result
987  
-True
988  
-
989 976
 Make sure bump_prefix() (an internal Query method) doesn't (re-)break. It's
990 977
 sufficient that this query runs without error.
991 978
 >>> qs = Tag.objects.values_list('id', flat=True).order_by('id')

0 notes on commit 062a94e

Please sign in to comment.
Something went wrong with that request. Please try again.