Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fixed #19190 -- Refactored Query select clause attributes

The Query.select and Query.select_fields were collapsed into one list
because the attributes had to be always in sync. Now that they are in
one attribute it is impossible to edit them out of sync.

Similar collapse was done for Query.related_select_cols and
Query.related_select_fields.
  • Loading branch information...
commit 11699ac4b5f98ec11dba02b356a8fd4ab6b4b889 1 parent 789ea33
@akaariai akaariai authored
View
4 django/contrib/gis/db/models/sql/compiler.py
@@ -39,7 +39,7 @@ def get_columns(self, with_aliases=False):
if self.query.select:
only_load = self.deferred_to_columns()
# This loop customized for GeoQuery.
- for col, field in zip(self.query.select, self.query.select_fields):
+ for col, field in self.query.select:
if isinstance(col, (list, tuple)):
alias, column = col
table = self.query.alias_map[alias].table_name
@@ -85,7 +85,7 @@ def get_columns(self, with_aliases=False):
])
# This loop customized for GeoQuery.
- for (table, col), field in zip(self.query.related_select_cols, self.query.related_select_fields):
+ for (table, col), field in self.query.related_select_cols:
r = self.get_field_select(field, table, col)
if with_aliases and col in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
View
28 django/db/models/sql/compiler.py
@@ -6,7 +6,7 @@
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import select_related_descend
from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
- GET_ITERATOR_CHUNK_SIZE)
+ GET_ITERATOR_CHUNK_SIZE, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import get_order_dir, Query
@@ -188,7 +188,7 @@ def get_columns(self, with_aliases=False):
col_aliases = set()
if self.query.select:
only_load = self.deferred_to_columns()
- for col in self.query.select:
+ for col, _ in self.query.select:
if isinstance(col, (list, tuple)):
alias, column = col
table = self.query.alias_map[alias].table_name
@@ -233,7 +233,7 @@ def get_columns(self, with_aliases=False):
for alias, aggregate in self.query.aggregate_select.items()
])
- for table, col in self.query.related_select_cols:
+ for (table, col), _ in self.query.related_select_cols:
r = '%s.%s' % (qn(table), qn(col))
if with_aliases and col in col_aliases:
c_alias = 'Col%d' % len(col_aliases)
@@ -557,8 +557,9 @@ def get_grouping(self):
for extra_select, extra_params in six.itervalues(self.query.extra_select):
extra_selects.append(extra_select)
params.extend(extra_params)
- cols = (group_by + self.query.select +
- self.query.related_select_cols + extra_selects)
+ select_cols = [s.col for s in self.query.select]
+ related_select_cols = [s.col for s in self.query.related_select_cols]
+ cols = (group_by + select_cols + related_select_cols + extra_selects)
seen = set()
for col in cols:
if col in seen:
@@ -589,7 +590,6 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
opts = self.query.get_meta()
root_alias = self.query.get_initial_alias()
self.query.related_select_cols = []
- self.query.related_select_fields = []
if not used:
used = set()
if dupe_set is None:
@@ -664,8 +664,8 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
used.add(alias)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=f.rel.to._meta, as_pairs=True)
- self.query.related_select_cols.extend(columns)
- self.query.related_select_fields.extend(f.rel.to._meta.fields)
+ self.query.related_select_cols.extend(
+ SelectInfo(col, field) for col, field in zip(columns, f.rel.to._meta.fields))
if restricted:
next = requested.get(f.name, {})
else:
@@ -734,8 +734,8 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
used.add(alias)
columns, aliases = self.get_default_columns(start_alias=alias,
opts=model._meta, as_pairs=True, local_only=True)
- self.query.related_select_cols.extend(columns)
- self.query.related_select_fields.extend(model._meta.fields)
+ self.query.related_select_cols.extend(
+ SelectInfo(col, field) for col, field in zip(columns, model._meta.fields))
next = requested.get(f.related_query_name(), {})
# Use True here because we are looking at the _reverse_ side of
@@ -772,7 +772,7 @@ def results_iter(self):
if resolve_columns:
if fields is None:
# We only set this up here because
- # related_select_fields isn't populated until
+ # related_select_cols isn't populated until
# execute_sql() has been called.
# We also include types of fields of related models that
@@ -782,11 +782,11 @@ def results_iter(self):
# This code duplicates the logic for the order of fields
# found in get_columns(). It would be nice to clean this up.
- if self.query.select_fields:
- fields = self.query.select_fields
+ if self.query.select:
+ fields = [f.field for f in self.query.select]
else:
fields = self.query.model._meta.fields
- fields = fields + self.query.related_select_fields
+ fields = fields + [f.field for f in self.query.related_select_cols]
# If the field was deferred, exclude it from being passed
# into `resolve_columns` because it wasn't selected.
View
3  django/db/models/sql/constants.py
@@ -25,6 +25,9 @@
'table_name rhs_alias join_type lhs_alias '
'lhs_join_col rhs_join_col nullable')
+# Pairs of column clauses to select, and (possibly None) field for the clause.
+SelectInfo = namedtuple('SelectInfo', 'col field')
+
# How many results to expect from a cursor.execute call
MULTI = 'multi'
SINGLE = 'single'
View
87 django/db/models/sql/query.py
@@ -20,7 +20,7 @@
from django.db.models.fields import FieldDoesNotExist
from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
- ORDER_PATTERN, JoinInfo)
+ ORDER_PATTERN, JoinInfo, SelectInfo)
from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
@@ -115,17 +115,20 @@ def __init__(self, model, where=WhereNode):
self.default_ordering = True
self.standard_ordering = True
self.ordering_aliases = []
- self.related_select_fields = []
self.dupe_avoidance = {}
self.used_aliases = set()
self.filter_is_sticky = False
self.included_inherited_models = {}
- # SQL-related attributes
+ # SQL-related attributes
+ # Select and related select clauses as SelectInfo instances.
+ # The select is used for cases where we want to set up the select
+ # clause to contain other than default fields (values(), annotate(),
+ # subqueries...)
self.select = []
- # For each to-be-selected field in self.select there must be a
- # corresponding entry in self.select - git seems to need this.
- self.select_fields = []
+ # The related_select_cols is used for columns needed for
+ # select_related - this is populated in compile stage.
+ self.related_select_cols = []
self.tables = [] # Aliases in the order they are created.
self.where = where()
self.where_class = where
@@ -138,7 +141,6 @@ def __init__(self, model, where=WhereNode):
self.select_for_update = False
self.select_for_update_nowait = False
self.select_related = False
- self.related_select_cols = []
# SQL aggregate-related attributes
self.aggregates = SortedDict() # Maps alias -> SQL aggregate function
@@ -191,15 +193,14 @@ def __getstate__(self):
Pickling support.
"""
obj_dict = self.__dict__.copy()
- obj_dict['related_select_fields'] = []
obj_dict['related_select_cols'] = []
# Fields can't be pickled, so if a field list has been
# specified, we pickle the list of field names instead.
# None is also a possible value; that can pass as-is
- obj_dict['select_fields'] = [
- f is not None and f.name or None
- for f in obj_dict['select_fields']
+ obj_dict['select'] = [
+ (s.col, s.field is not None and s.field.name or None)
+ for s in obj_dict['select']
]
return obj_dict
@@ -209,9 +210,9 @@ def __setstate__(self, obj_dict):
"""
# Rebuild list of field instances
opts = obj_dict['model']._meta
- obj_dict['select_fields'] = [
- name is not None and opts.get_field(name) or None
- for name in obj_dict['select_fields']
+ obj_dict['select'] = [
+ SelectInfo(tpl[0], tpl[1] is not None and opts.get_field(tpl[1]) or None)
+ for tpl in obj_dict['select']
]
self.__dict__.update(obj_dict)
@@ -256,10 +257,9 @@ def clone(self, klass=None, memo=None, **kwargs):
obj.standard_ordering = self.standard_ordering
obj.included_inherited_models = self.included_inherited_models.copy()
obj.ordering_aliases = []
- obj.select_fields = self.select_fields[:]
- obj.related_select_fields = self.related_select_fields[:]
obj.dupe_avoidance = self.dupe_avoidance.copy()
obj.select = self.select[:]
+ obj.related_select_cols = []
obj.tables = self.tables[:]
obj.where = copy.deepcopy(self.where, memo=memo)
obj.where_class = self.where_class
@@ -275,7 +275,6 @@ def clone(self, klass=None, memo=None, **kwargs):
obj.select_for_update = self.select_for_update
obj.select_for_update_nowait = self.select_for_update_nowait
obj.select_related = self.select_related
- obj.related_select_cols = []
obj.aggregates = copy.deepcopy(self.aggregates, memo=memo)
if self.aggregate_select_mask is None:
obj.aggregate_select_mask = None
@@ -384,7 +383,6 @@ def get_aggregation(self, using):
query.select_for_update = False
query.select_related = False
query.related_select_cols = []
- query.related_select_fields = []
result = query.get_compiler(using).execute_sql(SINGLE)
if result is None:
@@ -527,14 +525,14 @@ def combine(self, rhs, connector):
# Selection columns and extra extensions are those provided by 'rhs'.
self.select = []
- for col in rhs.select:
+ for col, field in rhs.select:
if isinstance(col, (list, tuple)):
- self.select.append((change_map.get(col[0], col[0]), col[1]))
+ new_col = change_map.get(col[0], col[0]), col[1]
+ self.select.append(SelectInfo(new_col, field))
else:
item = copy.deepcopy(col)
item.relabel_aliases(change_map)
- self.select.append(item)
- self.select_fields = rhs.select_fields[:]
+ self.select.append(SelectInfo(item, field))
if connector == OR:
# It would be nice to be able to handle this, but the queries don't
@@ -750,24 +748,23 @@ def change_aliases(self, change_map):
"""
assert set(change_map.keys()).intersection(set(change_map.values())) == set()
+ def relabel_column(col):
+ if isinstance(col, (list, tuple)):
+ old_alias = col[0]
+ return (change_map.get(old_alias, old_alias), col[1])
+ else:
+ col.relabel_aliases(change_map)
+ return col
# 1. Update references in "select" (normal columns plus aliases),
# "group by", "where" and "having".
self.where.relabel_aliases(change_map)
self.having.relabel_aliases(change_map)
- for columns in [self.select, self.group_by or []]:
- for pos, col in enumerate(columns):
- if isinstance(col, (list, tuple)):
- old_alias = col[0]
- columns[pos] = (change_map.get(old_alias, old_alias), col[1])
- else:
- col.relabel_aliases(change_map)
- for mapping in [self.aggregates]:
- for key, col in mapping.items():
- if isinstance(col, (list, tuple)):
- old_alias = col[0]
- mapping[key] = (change_map.get(old_alias, old_alias), col[1])
- else:
- col.relabel_aliases(change_map)
+ if self.group_by:
+ self.group_by = [relabel_column(col) for col in self.group_by]
+ self.select = [SelectInfo(relabel_column(s.col), s.field)
+ for s in self.select]
+ self.aggregates = SortedDict(
+ (key, relabel_column(col)) for key, col in self.aggregates.items())
# 2. Rename the alias in the internal table/alias datastructures.
for k, aliases in self.join_map.items():
@@ -1570,7 +1567,7 @@ def split_exclude(self, filter_expr, prefix, can_reuse):
# since we are adding a IN <subquery> clause. This prevents the
# database from tripping over IN (...,NULL,...) selects and returning
# nothing
- alias, col = query.select[0]
+ alias, col = query.select[0].col
query.where.add((Constraint(alias, col, None), 'isnull', False), AND)
self.add_filter(('%s__in' % prefix, query), negate=True, trim=True,
@@ -1629,7 +1626,6 @@ def clear_select_clause(self):
Removes all fields from SELECT clause.
"""
self.select = []
- self.select_fields = []
self.default_cols = False
self.select_related = False
self.set_extra_mask(())
@@ -1642,7 +1638,6 @@ def clear_select_fields(self):
columns.
"""
self.select = []
- self.select_fields = []
def add_distinct_fields(self, *field_names):
"""
@@ -1674,8 +1669,7 @@ def add_fields(self, field_names, allow_m2m=True):
col = join.lhs_join_col
joins = joins[:-1]
self.promote_joins(joins[1:])
- self.select.append((final_alias, col))
- self.select_fields.append(field)
+ self.select.append(SelectInfo((final_alias, col), field))
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
@@ -1731,8 +1725,8 @@ def set_group_by(self):
"""
self.group_by = []
- for sel in self.select:
- self.group_by.append(sel)
+ for col, _ in self.select:
+ self.group_by.append(col)
def add_count_column(self):
"""
@@ -1745,7 +1739,7 @@ def add_count_column(self):
else:
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select': %r" % self.select
- count = self.aggregates_module.Count(self.select[0])
+ count = self.aggregates_module.Count(self.select[0].col)
else:
opts = self.model._meta
if not self.select:
@@ -1757,7 +1751,7 @@ def add_count_column(self):
assert len(self.select) == 1, \
"Cannot add count col with multiple cols in 'select'."
- count = self.aggregates_module.Count(self.select[0], distinct=True)
+ count = self.aggregates_module.Count(self.select[0].col, distinct=True)
# Distinct handling is done in Count(), so don't do it at this
# level.
self.distinct = False
@@ -1781,7 +1775,6 @@ def add_select_related(self, fields):
d = d.setdefault(part, {})
self.select_related = field_dict
self.related_select_cols = []
- self.related_select_fields = []
def add_extra(self, select, select_params, where, params, tables, order_by):
"""
@@ -1975,7 +1968,7 @@ def set_start(self, start):
self.unref_alias(select_alias)
select_alias = join_info.rhs_alias
select_col = join_info.rhs_join_col
- self.select = [(select_alias, select_col)]
+ self.select = [SelectInfo((select_alias, select_col), None)]
self.remove_inherited_models()
def is_nullable(self, field):
View
4 django/db/models/sql/subqueries.py
@@ -76,7 +76,7 @@ def delete_qs(self, query, using):
return
else:
innerq.clear_select_clause()
- innerq.select, innerq.select_fields = [(self.get_initial_alias(), pk.column)], [None]
+ innerq.select = [SelectInfo((self.get_initial_alias(), pk.column), None)]
values = innerq
where = self.where_class()
where.add((Constraint(None, pk.column, pk), 'in', values), AND)
@@ -244,7 +244,7 @@ def add_date_select(self, field_name, lookup_type, order='ASC'):
alias = result[3][-1]
select = Date((alias, field.column), lookup_type)
self.clear_select_clause()
- self.select, self.select_fields = [select], [None]
+ self.select = [SelectInfo(select, None)]
self.distinct = True
self.order_by = order == 'ASC' and [1] or [-1]
Please sign in to comment.
Something went wrong with that request. Please try again.