Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Added support for parameters in SELECT clauses.

  • Loading branch information...
commit 924a144ef8a80ba4daeeafbe9efaa826566e9d02 1 parent b4351d2
@aaugustin aaugustin authored
View
7 django/contrib/gis/db/backends/mysql/operations.py
@@ -56,12 +56,13 @@ def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
lookup_info = self.geometry_functions.get(lookup_type, False)
if lookup_info:
- return "%s(%s, %s)" % (lookup_info, geo_col,
- self.get_geom_placeholder(value, field.srid))
+ sql = "%s(%s, %s)" % (lookup_info, geo_col,
+ self.get_geom_placeholder(value, field.srid))
+ return sql, []
# TODO: Is this really necessary? MySQL can't handle NULL geometries
# in its spatial indexes anyways.
if lookup_type == 'isnull':
- return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
+ return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
View
4 django/contrib/gis/db/backends/oracle/operations.py
@@ -262,7 +262,7 @@ def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
return lookup_info.as_sql(geo_col, self.get_geom_placeholder(field, value))
elif lookup_type == 'isnull':
# Handling 'isnull' lookup type
- return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
+ return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
@@ -288,7 +288,7 @@ def geometry_columns(self):
def spatial_ref_sys(self):
from django.contrib.gis.db.backends.oracle.models import SpatialRefSys
return SpatialRefSys
-
+
def modify_insert_params(self, placeholders, params):
"""Drop out insert parameters for NULL placeholder. Needed for Oracle Spatial
backend due to #10888
View
2  django/contrib/gis/db/backends/postgis/operations.py
@@ -560,7 +560,7 @@ def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
elif lookup_type == 'isnull':
# Handling 'isnull' lookup type
- return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
+ return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
View
2  django/contrib/gis/db/backends/spatialite/operations.py
@@ -358,7 +358,7 @@ def spatial_lookup_sql(self, lvalue, lookup_type, value, field, qn):
return op.as_sql(geo_col, self.get_geom_placeholder(field, geom))
elif lookup_type == 'isnull':
# Handling 'isnull' lookup type
- return "%s IS %sNULL" % (geo_col, (not value and 'NOT ' or ''))
+ return "%s IS %sNULL" % (geo_col, ('' if value else 'NOT ')), []
raise TypeError("Got invalid lookup_type: %s" % repr(lookup_type))
View
2  django/contrib/gis/db/backends/util.py
@@ -16,7 +16,7 @@ def __init__(self, function='', operator='', result='', **kwargs):
self.extra = kwargs
def as_sql(self, geo_col, geometry='%s'):
- return self.sql_template % self.params(geo_col, geometry)
+ return self.sql_template % self.params(geo_col, geometry), []
def params(self, geo_col, geometry):
params = {'function' : self.function,
View
12 django/contrib/gis/db/models/sql/aggregates.py
@@ -22,13 +22,15 @@ def __init__(self, col, source=None, is_summary=False, tolerance=0.05, **extra):
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
def as_sql(self, qn, connection):
- "Return the aggregate, rendered as SQL."
+ "Return the aggregate, rendered as SQL with parameters."
if connection.ops.oracle:
self.extra['tolerance'] = self.tolerance
+ params = []
+
if hasattr(self.col, 'as_sql'):
- field_name = self.col.as_sql(qn, connection)
+ field_name, params = self.col.as_sql(qn, connection)
elif isinstance(self.col, (list, tuple)):
field_name = '.'.join([qn(c) for c in self.col])
else:
@@ -36,13 +38,13 @@ def as_sql(self, qn, connection):
sql_template, sql_function = connection.ops.spatial_aggregate_sql(self)
- params = {
+ substitutions = {
'function': sql_function,
'field': field_name
}
- params.update(self.extra)
+ substitutions.update(self.extra)
- return sql_template % params
+ return sql_template % substitutions, params
class Collect(GeoAggregate):
pass
View
23 django/contrib/gis/db/models/sql/compiler.py
@@ -33,6 +33,7 @@ def get_columns(self, with_aliases=False):
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias))
for alias, col in six.iteritems(self.query.extra_select)]
+ params = []
aliases = set(self.query.extra_select.keys())
if with_aliases:
col_aliases = aliases.copy()
@@ -63,7 +64,9 @@ def get_columns(self, with_aliases=False):
aliases.add(r)
col_aliases.add(col[1])
else:
- result.append(col.as_sql(qn, self.connection))
+ col_sql, col_params = col.as_sql(qn, self.connection)
+ result.append(col_sql)
+ params.extend(col_params)
if hasattr(col, 'alias'):
aliases.add(col.alias)
@@ -76,15 +79,13 @@ def get_columns(self, with_aliases=False):
aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length()
- result.extend([
- '%s%s' % (
- self.get_extra_select_format(alias) % aggregate.as_sql(qn, self.connection),
- alias is not None
- and ' AS %s' % qn(truncate_name(alias, max_name_length))
- or ''
- )
- for alias, aggregate in self.query.aggregate_select.items()
- ])
+ for alias, aggregate in self.query.aggregate_select.items():
+ agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
+ if alias is None:
+ result.append(agg_sql)
+ else:
+ result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
+ params.extend(agg_params)
# This loop customized for GeoQuery.
for (table, col), field in self.query.related_select_cols:
@@ -100,7 +101,7 @@ def get_columns(self, with_aliases=False):
col_aliases.add(col)
self._select_aliases = aliases
- return result
+ return result, params
def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False, from_parent=None):
View
5 django/contrib/gis/db/models/sql/where.py
@@ -44,8 +44,9 @@ def make_atom(self, child, qn, connection):
lvalue, lookup_type, value_annot, params_or_value = child
if isinstance(lvalue, GeoConstraint):
data, params = lvalue.process(lookup_type, params_or_value, connection)
- spatial_sql = connection.ops.spatial_lookup_sql(data, lookup_type, params_or_value, lvalue.field, qn)
- return spatial_sql, params
+ spatial_sql, spatial_params = connection.ops.spatial_lookup_sql(
+ data, lookup_type, params_or_value, lvalue.field, qn)
+ return spatial_sql, spatial_params + params
else:
return super(GeoWhereNode, self).make_atom(child, qn, connection)
View
2  django/db/models/query_utils.py
@@ -25,7 +25,7 @@ class QueryWrapper(object):
parameters. Can be used to pass opaque data to a where-clause, for example.
"""
def __init__(self, sql, params):
- self.data = sql, params
+ self.data = sql, list(params)
def as_sql(self, qn=None, connection=None):
return self.data
View
11 django/db/models/sql/aggregates.py
@@ -73,22 +73,23 @@ def relabel_aliases(self, change_map):
self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
def as_sql(self, qn, connection):
- "Return the aggregate, rendered as SQL."
+ "Return the aggregate, rendered as SQL with parameters."
+ params = []
if hasattr(self.col, 'as_sql'):
- field_name = self.col.as_sql(qn, connection)
+ field_name, params = self.col.as_sql(qn, connection)
elif isinstance(self.col, (list, tuple)):
field_name = '.'.join([qn(c) for c in self.col])
else:
field_name = self.col
- params = {
+ substitutions = {
'function': self.sql_function,
'field': field_name
}
- params.update(self.extra)
+ substitutions.update(self.extra)
- return self.sql_template % params
+ return self.sql_template % substitutions, params
class Avg(Aggregate):
View
57 django/db/models/sql/compiler.py
@@ -74,7 +74,7 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
# as the pre_sql_setup will modify query state in a way that forbids
# another run of it.
self.refcounts_before = self.query.alias_refcount.copy()
- out_cols = self.get_columns(with_col_aliases)
+ out_cols, s_params = self.get_columns(with_col_aliases)
ordering, ordering_group_by = self.get_ordering()
distinct_fields = self.get_distinct()
@@ -97,6 +97,7 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
result.append(self.connection.ops.distinct_sql(distinct_fields))
result.append(', '.join(out_cols + self.query.ordering_aliases))
+ params.extend(s_params)
result.append('FROM')
result.extend(from_)
@@ -164,9 +165,10 @@ def as_nested_sql(self):
def get_columns(self, with_aliases=False):
"""
- Returns the list of columns to use in the select statement. If no
- columns have been specified, returns all columns relating to fields in
- the model.
+ Returns the list of columns to use in the select statement, as well as
+ a list any extra parameters that need to be included. If no columns
+ have been specified, returns all columns relating to fields in the
+ model.
If 'with_aliases' is true, any column names that are duplicated
(without the table names) are given unique aliases. This is needed in
@@ -175,6 +177,7 @@ def get_columns(self, with_aliases=False):
qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in six.iteritems(self.query.extra_select)]
+ params = []
aliases = set(self.query.extra_select.keys())
if with_aliases:
col_aliases = aliases.copy()
@@ -204,7 +207,9 @@ def get_columns(self, with_aliases=False):
aliases.add(r)
col_aliases.add(col[1])
else:
- result.append(col.as_sql(qn, self.connection))
+ col_sql, col_params = col.as_sql(qn, self.connection)
+ result.append(col_sql)
+ params.extend(col_params)
if hasattr(col, 'alias'):
aliases.add(col.alias)
@@ -217,15 +222,13 @@ def get_columns(self, with_aliases=False):
aliases.update(new_aliases)
max_name_length = self.connection.ops.max_name_length()
- result.extend([
- '%s%s' % (
- aggregate.as_sql(qn, self.connection),
- alias is not None
- and ' AS %s' % qn(truncate_name(alias, max_name_length))
- or ''
- )
- for alias, aggregate in self.query.aggregate_select.items()
- ])
+ for alias, aggregate in self.query.aggregate_select.items():
+ agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
+ if alias is None:
+ result.append(agg_sql)
+ else:
+ result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
+ params.extend(agg_params)
for (table, col), _ in self.query.related_select_cols:
r = '%s.%s' % (qn(table), qn(col))
@@ -240,7 +243,7 @@ def get_columns(self, with_aliases=False):
col_aliases.add(col)
self._select_aliases = aliases
- return result
+ return result, params
def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False, from_parent=None):
@@ -545,14 +548,16 @@ def get_grouping(self, ordering_group_by):
seen = set()
cols = self.query.group_by + select_cols
for col in cols:
+ col_params = ()
if isinstance(col, (list, tuple)):
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
elif hasattr(col, 'as_sql'):
- sql = col.as_sql(qn, self.connection)
+ sql, col_params = col.as_sql(qn, self.connection)
else:
sql = '(%s)' % str(col)
if sql not in seen:
result.append(sql)
+ params.extend(col_params)
seen.add(sql)
# Still, we need to add all stuff in ordering (except if the backend can
@@ -991,15 +996,17 @@ def as_sql(self, qn=None):
if qn is None:
qn = self.quote_name_unless_alias
- sql = ('SELECT %s FROM (%s) subquery' % (
- ', '.join([
- aggregate.as_sql(qn, self.connection)
- for aggregate in self.query.aggregate_select.values()
- ]),
- self.query.subquery)
- )
- params = self.query.sub_params
- return (sql, params)
+ sql, params = [], []
+ for aggregate in self.query.aggregate_select.values():
+ agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
+ sql.append(agg_sql)
+ params.extend(agg_params)
+ sql = ', '.join(sql)
+ params = tuple(params)
+
+ sql = 'SELECT %s FROM (%s) subquery' % (sql, self.query.subquery)
+ params = params + self.query.sub_params
+ return sql, params
class SQLDateCompiler(SQLCompiler):
def results_iter(self):
View
2  django/db/models/sql/datastructures.py
@@ -42,7 +42,7 @@ def as_sql(self, qn, connection):
col = '%s.%s' % tuple([qn(c) for c in self.col])
else:
col = self.col
- return getattr(connection.ops, self.trunc_func)(self.lookup_type, col)
+ return getattr(connection.ops, self.trunc_func)(self.lookup_type, col), []
class DateTime(Date):
"""
View
4 django/db/models/sql/expressions.py
@@ -94,9 +94,9 @@ def evaluate_leaf(self, node, qn, connection):
if col is None:
raise ValueError("Given node not found")
if hasattr(col, 'as_sql'):
- return col.as_sql(qn, connection), ()
+ return col.as_sql(qn, connection)
else:
- return '%s.%s' % (qn(col[0]), qn(col[1])), ()
+ return '%s.%s' % (qn(col[0]), qn(col[1])), []
def evaluate_date_modifier_node(self, node, qn, connection):
timedelta = node.children.pop()
View
10 django/db/models/sql/where.py
@@ -172,10 +172,10 @@ def make_atom(self, child, qn, connection):
if isinstance(lvalue, tuple):
# A direct database column lookup.
- field_sql = self.sql_for_columns(lvalue, qn, connection)
+ field_sql, field_params = self.sql_for_columns(lvalue, qn, connection), []
else:
# A smart object with an as_sql() method.
- field_sql = lvalue.as_sql(qn, connection)
+ field_sql, field_params = lvalue.as_sql(qn, connection)
is_datetime_field = value_annotation is datetime.datetime
cast_sql = connection.ops.datetime_cast_sql() if is_datetime_field else '%s'
@@ -186,6 +186,8 @@ def make_atom(self, child, qn, connection):
else:
extra = ''
+ params = field_params + params
+
if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
and connection.features.interprets_empty_strings_as_nulls):
lookup_type = 'isnull'
@@ -245,7 +247,7 @@ def sql_for_columns(self, data, qn, connection):
"""
Returns the SQL fragment used for the left-hand side of a column
constraint (for example, the "T1.foo" portion in the clause
- "WHERE ... T1.foo = 6").
+ "WHERE ... T1.foo = 6") and a list of parameters.
"""
table_alias, name, db_type = data
if table_alias:
@@ -338,7 +340,7 @@ def __init__(self, sqls, params):
def as_sql(self, qn=None, connection=None):
sqls = ["(%s)" % sql for sql in self.sqls]
- return " AND ".join(sqls), tuple(self.params or ())
+ return " AND ".join(sqls), list(self.params or ())
def clone(self):
return self
Please sign in to comment.
Something went wrong with that request. Please try again.