Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Added support for parameters in SELECT clauses.

  • Loading branch information...
commit 924a144ef8a80ba4daeeafbe9efaa826566e9d02 1 parent b4351d2
Aymeric Augustin authored February 13, 2013
7  django/contrib/gis/db/backends/mysql/operations.py
@@ -56,12 +56,13 @@ def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
56 56
 
57 57
         lookup_info = self.geometry_functions.get(lookup_type, False)
58 58
         if lookup_info:
59  
-            return "%s(%s, %s)" % (lookup_info, geo_col,
60  
-                                   self.get_geom_placeholder(value, field.srid))
  59
+            sql = "%s(%s, %s)" % (lookup_info, geo_col,
  60
+                                  self.get_geom_placeholder(value, field.srid))
  61
+            return sql, []
61 62
 
62 63
         # TODO: Is this really necessary? MySQL can't handle NULL geometries
63 64
         #  in its spatial indexes anyways.
64 65
         if lookup_type == 'isnull':
65  
-            return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
  66
+            return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
66 67
 
67 68
         raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
4  django/contrib/gis/db/backends/oracle/operations.py
@@ -262,7 +262,7 @@ def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
262 262
                 return lookup_info.as_sql(geo_col, self.get_geom_placeholder(field, value))
263 263
         elif lookup_type == 'isnull':
264 264
             # Handling 'isnull' lookup type
265  
-            return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
  265
+            return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
266 266
 
267 267
         raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
268 268
 
@@ -288,7 +288,7 @@ def geometry_columns(self):
288 288
     def spatial_ref_sys(self):
289 289
         from django.contrib.gis.db.backends.oracle.models import SpatialRefSys
290 290
         return SpatialRefSys
291  
-    
  291
+
292 292
     def modify_insert_params(self, placeholders, params):
293 293
         """Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial
294 294
         backend due to #10888
2  django/contrib/gis/db/backends/postgis/operations.py
@@ -560,7 +560,7 @@ def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
560 560
 
561 561
         elif lookup_type == 'isnull':
562 562
             # Handling 'isnull' lookup type
563  
-            return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
  563
+            return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
564 564
 
565 565
         raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
566 566
 
2  django/contrib/gis/db/backends/spatialite/operations.py
@@ -358,7 +358,7 @@ def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
358 358
             return op.as_sql(geo_col, self.get_geom_placeholder(field, geom))
359 359
         elif lookup_type == 'isnull':
360 360
             # Handling 'isnull' lookup type
361  
-            return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
  361
+            return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
362 362
 
363 363
         raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
364 364
 
2  django/contrib/gis/db/backends/util.py
@@ -16,7 +16,7 @@ def __init__(self, function='', operator='', result='', **kwargs):
16 16
         self.extra = kwargs
17 17
 
18 18
     def as_sql(self, geo_col, geometry='%s'):
19  
-        return self.sql_template % self.params(geo_col, geometry)
  19
+        return self.sql_template % self.params(geo_col, geometry), []
20 20
 
21 21
     def params(self, geo_col, geometry):
22 22
         params = {'function' : self.function,
12  django/contrib/gis/db/models/sql/aggregates.py
@@ -22,13 +22,15 @@ def __init__(self, col, source=None, is_summary=False, tolerance=0.05, **extra):
22 22
             raise ValueError('Geospatial aggregates only allowed on geometry fields.')
23 23
 
24 24
     def as_sql(self, qn, connection):
25  
-        "Return the aggregate, rendered as SQL."
  25
+        "Return the aggregate, rendered as SQL with parameters."
26 26
 
27 27
         if connection.ops.oracle:
28 28
             self.extra['tolerance'] = self.tolerance
29 29
 
  30
+        params = []
  31
+
30 32
         if hasattr(self.col, 'as_sql'):
31  
-            field_name = self.col.as_sql(qn, connection)
  33
+            field_name, params = self.col.as_sql(qn, connection)
32 34
         elif isinstance(self.col, (list, tuple)):
33 35
             field_name = '.'.join([qn(c) for c in self.col])
34 36
         else:
@@ -36,13 +38,13 @@ def as_sql(self, qn, connection):
36 38
 
37 39
         sql_template, sql_function = connection.ops.spatial_aggregate_sql(self)
38 40
 
39  
-        params = {
  41
+        substitutions = {
40 42
             'function': sql_function,
41 43
             'field': field_name
42 44
         }
43  
-        params.update(self.extra)
  45
+        substitutions.update(self.extra)
44 46
 
45  
-        return sql_template % params
  47
+        return sql_template % substitutions, params
46 48
 
47 49
 class Collect(GeoAggregate):
48 50
     pass
23  django/contrib/gis/db/models/sql/compiler.py
@@ -33,6 +33,7 @@ def get_columns(self, with_aliases=False):
33 33
         qn2 = self.connection.ops.quote_name
34 34
         result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias))
35 35
                   for alias, col in six.iteritems(self.query.extra_select)]
  36
+        params = []
36 37
         aliases = set(self.query.extra_select.keys())
37 38
         if with_aliases:
38 39
             col_aliases = aliases.copy()
@@ -63,7 +64,9 @@ def get_columns(self, with_aliases=False):
63 64
                         aliases.add(r)
64 65
                         col_aliases.add(col[1])
65 66
                 else:
66  
-                    result.append(col.as_sql(qn, self.connection))
  67
+                    col_sql, col_params = col.as_sql(qn, self.connection)
  68
+                    result.append(col_sql)
  69
+                    params.extend(col_params)
67 70
 
68 71
                     if hasattr(col, 'alias'):
69 72
                         aliases.add(col.alias)
@@ -76,15 +79,13 @@ def get_columns(self, with_aliases=False):
76 79
             aliases.update(new_aliases)
77 80
 
78 81
         max_name_length = self.connection.ops.max_name_length()
79  
-        result.extend([
80  
-                '%s%s' % (
81  
-                    self.get_extra_select_format(alias) % aggregate.as_sql(qn, self.connection),
82  
-                    alias is not None
83  
-                        and ' AS %s' % qn(truncate_name(alias, max_name_length))
84  
-                        or ''
85  
-                    )
86  
-                for alias, aggregate in self.query.aggregate_select.items()
87  
-        ])
  82
+        for alias, aggregate in self.query.aggregate_select.items():
  83
+            agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
  84
+            if alias is None:
  85
+                result.append(agg_sql)
  86
+            else:
  87
+                result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
  88
+            params.extend(agg_params)
88 89
 
89 90
         # This loop customized for GeoQuery.
90 91
         for (table, col), field in self.query.related_select_cols:
@@ -100,7 +101,7 @@ def get_columns(self, with_aliases=False):
100 101
                 col_aliases.add(col)
101 102
 
102 103
         self._select_aliases = aliases
103  
-        return result
  104
+        return result, params
104 105
 
105 106
     def get_default_columns(self, with_aliases=False, col_aliases=None,
106 107
             start_alias=None, opts=None, as_pairs=False, from_parent=None):
5  django/contrib/gis/db/models/sql/where.py
@@ -44,8 +44,9 @@ def make_atom(self, child, qn, connection):
44 44
         lvalue, lookup_type, value_annot, params_or_value = child
45 45
         if isinstance(lvalue, GeoConstraint):
46 46
             data, params = lvalue.process(lookup_type, params_or_value, connection)
47  
-            spatial_sql = connection.ops.spatial_lookup_sql(data, lookup_type, params_or_value, lvalue.field, qn)
48  
-            return spatial_sql, params
  47
+            spatial_sql, spatial_params = connection.ops.spatial_lookup_sql(
  48
+                    data, lookup_type, params_or_value, lvalue.field, qn)
  49
+            return spatial_sql, spatial_params + params
49 50
         else:
50 51
             return super(GeoWhereNode, self).make_atom(child, qn, connection)
51 52
 
2  django/db/models/query_utils.py
@@ -25,7 +25,7 @@ class QueryWrapper(object):
25 25
     parameters. Can be used to pass opaque data to a where-clause, for example.
26 26
     """
27 27
     def __init__(self, sql, params):
28  
-        self.data = sql, params
  28
+        self.data = sql, list(params)
29 29
 
30 30
     def as_sql(self, qn=None, connection=None):
31 31
         return self.data
11  django/db/models/sql/aggregates.py
@@ -73,22 +73,23 @@ def relabel_aliases(self, change_map):
73 73
             self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
74 74
 
75 75
     def as_sql(self, qn, connection):
76  
-        "Return the aggregate, rendered as SQL."
  76
+        "Return the aggregate, rendered as SQL with parameters."
  77
+        params = []
77 78
 
78 79
         if hasattr(self.col, 'as_sql'):
79  
-            field_name = self.col.as_sql(qn, connection)
  80
+            field_name, params = self.col.as_sql(qn, connection)
80 81
         elif isinstance(self.col, (list, tuple)):
81 82
             field_name = '.'.join([qn(c) for c in self.col])
82 83
         else:
83 84
             field_name = self.col
84 85
 
85  
-        params = {
  86
+        substitutions = {
86 87
             'function': self.sql_function,
87 88
             'field': field_name
88 89
         }
89  
-        params.update(self.extra)
  90
+        substitutions.update(self.extra)
90 91
 
91  
-        return self.sql_template % params
  92
+        return self.sql_template % substitutions, params
92 93
 
93 94
 
94 95
 class Avg(Aggregate):
57  django/db/models/sql/compiler.py
@@ -74,7 +74,7 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
74 74
         # as the pre_sql_setup will modify query state in a way that forbids
75 75
         # another run of it.
76 76
         self.refcounts_before = self.query.alias_refcount.copy()
77  
-        out_cols = self.get_columns(with_col_aliases)
  77
+        out_cols, s_params = self.get_columns(with_col_aliases)
78 78
         ordering, ordering_group_by = self.get_ordering()
79 79
 
80 80
         distinct_fields = self.get_distinct()
@@ -97,6 +97,7 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
97 97
             result.append(self.connection.ops.distinct_sql(distinct_fields))
98 98
 
99 99
         result.append(', '.join(out_cols + self.query.ordering_aliases))
  100
+        params.extend(s_params)
100 101
 
101 102
         result.append('FROM')
102 103
         result.extend(from_)
@@ -164,9 +165,10 @@ def as_nested_sql(self):
164 165
 
165 166
     def get_columns(self, with_aliases=False):
166 167
         """
167  
-        Returns the list of columns to use in the select statement. If no
168  
-        columns have been specified, returns all columns relating to fields in
169  
-        the model.
  168
+        Returns the list of columns to use in the select statement, as well as
  169
+        a list any extra parameters that need to be included. If no columns
  170
+        have been specified, returns all columns relating to fields in the
  171
+        model.
170 172
 
171 173
         If 'with_aliases' is true, any column names that are duplicated
172 174
         (without the table names) are given unique aliases. This is needed in
@@ -175,6 +177,7 @@ def get_columns(self, with_aliases=False):
175 177
         qn = self.quote_name_unless_alias
176 178
         qn2 = self.connection.ops.quote_name
177 179
         result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
  180
+        params = []
178 181
         aliases = set(self.query.extra_select.keys())
179 182
         if with_aliases:
180 183
             col_aliases = aliases.copy()
@@ -204,7 +207,9 @@ def get_columns(self, with_aliases=False):
204 207
                         aliases.add(r)
205 208
                         col_aliases.add(col[1])
206 209
                 else:
207  
-                    result.append(col.as_sql(qn, self.connection))
  210
+                    col_sql, col_params = col.as_sql(qn, self.connection)
  211
+                    result.append(col_sql)
  212
+                    params.extend(col_params)
208 213
 
209 214
                     if hasattr(col, 'alias'):
210 215
                         aliases.add(col.alias)
@@ -217,15 +222,13 @@ def get_columns(self, with_aliases=False):
217 222
             aliases.update(new_aliases)
218 223
 
219 224
         max_name_length = self.connection.ops.max_name_length()
220  
-        result.extend([
221  
-            '%s%s' % (
222  
-                aggregate.as_sql(qn, self.connection),
223  
-                alias is not None
224  
-                    and ' AS %s' % qn(truncate_name(alias, max_name_length))
225  
-                    or ''
226  
-            )
227  
-            for alias, aggregate in self.query.aggregate_select.items()
228  
-        ])
  225
+        for alias, aggregate in self.query.aggregate_select.items():
  226
+            agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
  227
+            if alias is None:
  228
+                result.append(agg_sql)
  229
+            else:
  230
+                result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
  231
+            params.extend(agg_params)
229 232
 
230 233
         for (table, col), _ in self.query.related_select_cols:
231 234
             r = '%s.%s' % (qn(table), qn(col))
@@ -240,7 +243,7 @@ def get_columns(self, with_aliases=False):
240 243
                 col_aliases.add(col)
241 244
 
242 245
         self._select_aliases = aliases
243  
-        return result
  246
+        return result, params
244 247
 
245 248
     def get_default_columns(self, with_aliases=False, col_aliases=None,
246 249
             start_alias=None, opts=None, as_pairs=False, from_parent=None):
@@ -545,14 +548,16 @@ def get_grouping(self, ordering_group_by):
545 548
             seen = set()
546 549
             cols = self.query.group_by + select_cols
547 550
             for col in cols:
  551
+                col_params = ()
548 552
                 if isinstance(col, (list, tuple)):
549 553
                     sql = '%s.%s' % (qn(col[0]), qn(col[1]))
550 554
                 elif hasattr(col, 'as_sql'):
551  
-                    sql = col.as_sql(qn, self.connection)
  555
+                    sql, col_params = col.as_sql(qn, self.connection)
552 556
                 else:
553 557
                     sql = '(%s)' % str(col)
554 558
                 if sql not in seen:
555 559
                     result.append(sql)
  560
+                    params.extend(col_params)
556 561
                     seen.add(sql)
557 562
 
558 563
             # Still, we need to add all stuff in ordering (except if the backend can
@@ -991,15 +996,17 @@ def as_sql(self, qn=None):
991 996
         if qn is None:
992 997
             qn = self.quote_name_unless_alias
993 998
 
994  
-        sql = ('SELECT %s FROM (%s) subquery' % (
995  
-            ', '.join([
996  
-                aggregate.as_sql(qn, self.connection)
997  
-                for aggregate in self.query.aggregate_select.values()
998  
-            ]),
999  
-            self.query.subquery)
1000  
-        )
1001  
-        params = self.query.sub_params
1002  
-        return (sql, params)
  999
+        sql, params = [], []
  1000
+        for aggregate in self.query.aggregate_select.values():
  1001
+            agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
  1002
+            sql.append(agg_sql)
  1003
+            params.extend(agg_params)
  1004
+        sql = ', '.join(sql)
  1005
+        params = tuple(params)
  1006
+
  1007
+        sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery)
  1008
+        params = params + self.query.sub_params
  1009
+        return sql, params
1003 1010
 
1004 1011
 class SQLDateCompiler(SQLCompiler):
1005 1012
     def results_iter(self):
2  django/db/models/sql/datastructures.py
@@ -42,7 +42,7 @@ def as_sql(self, qn, connection):
42 42
             col = '%s.%s' % tuple([qn(c) for c in self.col])
43 43
         else:
44 44
             col = self.col
45  
-        return getattr(connection.ops, self.trunc_func)(self.lookup_type, col)
  45
+        return getattr(connection.ops, self.trunc_func)(self.lookup_type, col), []
46 46
 
47 47
 class DateTime(Date):
48 48
     """
4  django/db/models/sql/expressions.py
@@ -94,9 +94,9 @@ def evaluate_leaf(self, node, qn, connection):
94 94
         if col is None:
95 95
             raise ValueError("Given node not found")
96 96
         if hasattr(col, 'as_sql'):
97  
-            return col.as_sql(qn, connection), ()
  97
+            return col.as_sql(qn, connection)
98 98
         else:
99  
-            return '%s.%s' % (qn(col[0]), qn(col[1])), ()
  99
+            return '%s.%s' % (qn(col[0]), qn(col[1])), []
100 100
 
101 101
     def evaluate_date_modifier_node(self, node, qn, connection):
102 102
         timedelta = node.children.pop()
10  django/db/models/sql/where.py
@@ -172,10 +172,10 @@ def make_atom(self, child, qn, connection):
172 172
 
173 173
         if isinstance(lvalue, tuple):
174 174
             # A direct database column lookup.
175  
-            field_sql = self.sql_for_columns(lvalue, qn, connection)
  175
+            field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), []
176 176
         else:
177 177
             # A smart object with an as_sql() method.
178  
-            field_sql = lvalue.as_sql(qn, connection)
  178
+            field_sql, field_params = lvalue.as_sql(qn, connection)
179 179
 
180 180
         is_datetime_field = value_annotation is datetime.datetime
181 181
         cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
@@ -186,6 +186,8 @@ def make_atom(self, child, qn, connection):
186 186
         else:
187 187
             extra = ''
188 188
 
  189
+        params = field_params + params
  190
+
189 191
         if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
190 192
             and connection.features.interprets_empty_strings_as_nulls):
191 193
             lookup_type = 'isnull'
@@ -245,7 +247,7 @@ def sql_for_columns(self, data, qn, connection):
245 247
         """
246 248
         Returns the SQL fragment used for the left-hand side of a column
247 249
         constraint (for example, the "T1.foo" portion in the clause
248  
-        "WHERE ... T1.foo = 6").
  250
+        "WHERE ... T1.foo = 6") and a list of parameters.
249 251
         """
250 252
         table_alias, name, db_type = data
251 253
         if table_alias:
@@ -338,7 +340,7 @@ def __init__(self, sqls, params):
338 340
 
339 341
     def as_sql(self, qn=None, connection=None):
340 342
         sqls = ["(%s)" % sql for sql in self.sqls]
341  
-        return " AND ".join(sqls), tuple(self.params or ())
  343
+        return " AND ".join(sqls), list(self.params or ())
342 344
 
343 345
     def clone(self):
344 346
         return self

0 notes on commit 924a144

Please sign in to comment.
Something went wrong with that request. Please try again.