diff --git a/django/db/models/query.py b/django/db/models/query.py index 5567b3227df39..2e02f3f497862 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -715,9 +715,6 @@ def __init__(self, *args, **kwargs): def iterator(self): # Purge any extra columns that haven't been explicitly asked for - if self.extra_names is not None: - self.query.trim_extra_select(self.extra_names) - extra_names = self.query.extra_select.keys() field_names = self.field_names aggregate_names = self.query.aggregate_select.keys() @@ -741,13 +738,18 @@ def _setup_query(self): if self._fields: self.extra_names = [] self.aggregate_names = [] - if not self.query.extra_select and not self.query.aggregate_select: + if not self.query.extra and not self.query.aggregates: + # Short cut - if there are no extra or aggregates, then + # the values() clause must be just field names. self.field_names = list(self._fields) else: self.query.default_cols = False self.field_names = [] for f in self._fields: - if self.query.extra_select.has_key(f): + # we inspect the full extra_select list since we might + # be adding back an extra select item that we hadn't + # had selected previously. + if self.query.extra.has_key(f): self.extra_names.append(f) elif self.query.aggregate_select.has_key(f): self.aggregate_names.append(f) @@ -760,6 +762,8 @@ def _setup_query(self): self.aggregate_names = None self.query.select = [] + if self.extra_names is not None: + self.query.set_extra_mask(self.extra_names) self.query.add_fields(self.field_names, False) if self.aggregate_names is not None: self.query.set_aggregate_mask(self.aggregate_names) @@ -816,9 +820,6 @@ def _as_sql(self): class ValuesListQuerySet(ValuesQuerySet): def iterator(self): - if self.extra_names is not None: - self.query.trim_extra_select(self.extra_names) - if self.flat and len(self._fields) == 1: for row in self.query.results_iter(): yield row[0] diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index f4bf8b2b070f2..bafa1e93ea3ef 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -88,7 +88,10 @@ def __init__(self, model, connection, where=WhereNode): # These are for extensions. The contents are more or less appended # verbatim to the appropriate clause. - self.extra_select = SortedDict() # Maps col_alias -> (col_sql, params). + self.extra = SortedDict() # Maps col_alias -> (col_sql, params). + self.extra_select_mask = None + self._extra_select_cache = None + self.extra_tables = () self.extra_where = () self.extra_params = () @@ -214,13 +217,21 @@ def clone(self, klass=None, **kwargs): if self.aggregate_select_mask is None: obj.aggregate_select_mask = None else: - obj.aggregate_select_mask = self.aggregate_select_mask[:] + obj.aggregate_select_mask = self.aggregate_select_mask.copy() if self._aggregate_select_cache is None: obj._aggregate_select_cache = None else: obj._aggregate_select_cache = self._aggregate_select_cache.copy() obj.max_depth = self.max_depth - obj.extra_select = self.extra_select.copy() + obj.extra = self.extra.copy() + if self.extra_select_mask is None: + obj.extra_select_mask = None + else: + obj.extra_select_mask = self.extra_select_mask.copy() + if self._extra_select_cache is None: + obj._extra_select_cache = None + else: + obj._extra_select_cache = self._extra_select_cache.copy() obj.extra_tables = self.extra_tables obj.extra_where = self.extra_where obj.extra_params = self.extra_params @@ -325,7 +336,7 @@ def get_aggregation(self): query = self self.select = [] self.default_cols = False - self.extra_select = {} + self.extra = {} self.remove_inherited_models() query.clear_ordering(True) @@ -540,13 +551,20 @@ def combine(self, rhs, connector): # It would be nice to be able to handle this, but the queries don't # really make sense (or return consistent value sets). Not worth # the extra complexity when you can write a real query instead. - if self.extra_select and rhs.extra_select: + if self.extra and rhs.extra: raise ValueError("When merging querysets using 'or', you " "cannot have extra(select=...) on both sides.") if self.extra_where and rhs.extra_where: raise ValueError("When merging querysets using 'or', you " "cannot have extra(where=...) on both sides.") - self.extra_select.update(rhs.extra_select) + self.extra.update(rhs.extra) + extra_select_mask = set() + if self.extra_select_mask is not None: + extra_select_mask.update(self.extra_select_mask) + if rhs.extra_select_mask is not None: + extra_select_mask.update(rhs.extra_select_mask) + if extra_select_mask: + self.set_extra_mask(extra_select_mask) self.extra_tables += rhs.extra_tables self.extra_where += rhs.extra_where self.extra_params += rhs.extra_params @@ -2011,7 +2029,7 @@ def add_fields(self, field_names, allow_m2m=True): except MultiJoin: raise FieldError("Invalid field name: '%s'" % name) except FieldError: - names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys() + names = opts.get_all_field_names() + self.extra.keys() + self.aggregate_select.keys() names.sort() raise FieldError("Cannot resolve keyword %r into field. " "Choices are: %s" % (name, ", ".join(names))) @@ -2139,7 +2157,7 @@ def add_extra(self, select, select_params, where, params, tables, order_by): pos = entry.find("%s", pos + 2) select_pairs[name] = (entry, entry_params) # This is order preserving, since self.extra_select is a SortedDict. - self.extra_select.update(select_pairs) + self.extra.update(select_pairs) if where: self.extra_where += tuple(where) if params: @@ -2213,22 +2231,26 @@ def get_loaded_field_names_cb(self, target, model, fields): """ target[model] = set([f.name for f in fields]) - def trim_extra_select(self, names): - """ - Removes any aliases in the extra_select dictionary that aren't in - 'names'. - - This is needed if we are selecting certain values that don't incldue - all of the extra_select names. - """ - for key in set(self.extra_select).difference(set(names)): - del self.extra_select[key] - def set_aggregate_mask(self, names): "Set the mask of aggregates that will actually be returned by the SELECT" - self.aggregate_select_mask = names + if names is None: + self.aggregate_select_mask = None + else: + self.aggregate_select_mask = set(names) self._aggregate_select_cache = None + def set_extra_mask(self, names): + """ + Set the mask of extra select items that will be returned by SELECT, + we don't actually remove them from the Query since they might be used + later + """ + if names is None: + self.extra_select_mask = None + else: + self.extra_select_mask = set(names) + self._extra_select_cache = None + def _aggregate_select(self): """The SortedDict of aggregate columns that are not masked, and should be used in the SELECT clause. @@ -2247,6 +2269,19 @@ def _aggregate_select(self): return self.aggregates aggregate_select = property(_aggregate_select) + def _extra_select(self): + if self._extra_select_cache is not None: + return self._extra_select_cache + elif self.extra_select_mask is not None: + self._extra_select_cache = SortedDict([ + (k,v) for k,v in self.extra.items() + if k in self.extra_select_mask + ]) + return self._extra_select_cache + else: + return self.extra + extra_select = property(_extra_select) + def set_start(self, start): """ Sets the table from which to start joining. The start position is diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 4c62457c578e6..0cd393756d1eb 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -178,7 +178,7 @@ def pre_sql_setup(self): # from other tables. query = self.clone(klass=Query) query.bump_prefix() - query.extra_select = {} + query.extra = {} query.select = [] query.add_fields([query.model._meta.pk.name]) must_pre_select = count > 1 and not self.connection.features.update_can_self_select @@ -409,7 +409,7 @@ def add_date_select(self, field, lookup_type, order='ASC'): self.select = [select] self.select_fields = [None] self.select_related = False # See #7097. - self.extra_select = {} + self.extra = {} self.distinct = True self.order_by = order == 'ASC' and [1] or [-1] diff --git a/tests/regressiontests/extra_regress/models.py b/tests/regressiontests/extra_regress/models.py index fd34982c9afc9..5d22d6cc07d2b 100644 --- a/tests/regressiontests/extra_regress/models.py +++ b/tests/regressiontests/extra_regress/models.py @@ -35,6 +35,9 @@ class TestObject(models.Model): second = models.CharField(max_length=20) third = models.CharField(max_length=20) + def __unicode__(self): + return u'TestObject: %s,%s,%s' % (self.first,self.second,self.third) + __test__ = {"API_TESTS": """ # Regression tests for #7314 and #7372 @@ -189,6 +192,19 @@ class TestObject(models.Model): >>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id') [(u'third', u'first', u'second', 1)] -"""} +# Regression for #10847: the list of extra columns can always be accurately evaluated. +# Using an inner query ensures that as_sql() is producing correct output +# without requiring full evaluation and execution of the inner query. +>>> TestObject.objects.extra(select={'extra': 1}).values('pk') +[{'pk': 1}] +>>> TestObject.objects.filter(pk__in=TestObject.objects.extra(select={'extra': 1}).values('pk')) +[] +>>> TestObject.objects.values('pk').extra(select={'extra': 1}) +[{'pk': 1}] + +>>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1})) +[] + +"""}