Skip to content

Commit

Permalink
Fixed #10847 -- Modified handling of extra() to use a masking strateg…
Browse files Browse the repository at this point in the history
…y, rather than last-minute trimming. Thanks to Tai Lee for the report, and Alex Gaynor for his work on the patch.

This enables querysets with an extra clause to be used in an __in filter; as a side effect, it also means that as_sql() now returns the correct result for any query with an extra clause.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10648 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information
freakboy3742 committed Apr 30, 2009
1 parent 17958fa commit 5e2d384
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 31 deletions.
17 changes: 9 additions & 8 deletions django/db/models/query.py
Expand Up @@ -715,9 +715,6 @@ def __init__(self, *args, **kwargs):

def iterator(self):
# Purge any extra columns that haven't been explicitly asked for
if self.extra_names is not None:
self.query.trim_extra_select(self.extra_names)

extra_names = self.query.extra_select.keys()
field_names = self.field_names
aggregate_names = self.query.aggregate_select.keys()
Expand All @@ -741,13 +738,18 @@ def _setup_query(self):
if self._fields:
self.extra_names = []
self.aggregate_names = []
if not self.query.extra_select and not self.query.aggregate_select:
if not self.query.extra and not self.query.aggregates:
# Short cut - if there are no extra or aggregates, then
# the values() clause must be just field names.
self.field_names = list(self._fields)
else:
self.query.default_cols = False
self.field_names = []
for f in self._fields:
if self.query.extra_select.has_key(f):
# we inspect the full extra_select list since we might
# be adding back an extra select item that we hadn't
# had selected previously.
if self.query.extra.has_key(f):
self.extra_names.append(f)
elif self.query.aggregate_select.has_key(f):
self.aggregate_names.append(f)
Expand All @@ -760,6 +762,8 @@ def _setup_query(self):
self.aggregate_names = None

self.query.select = []
if self.extra_names is not None:
self.query.set_extra_mask(self.extra_names)
self.query.add_fields(self.field_names, False)
if self.aggregate_names is not None:
self.query.set_aggregate_mask(self.aggregate_names)
Expand Down Expand Up @@ -816,9 +820,6 @@ def _as_sql(self):

class ValuesListQuerySet(ValuesQuerySet):
def iterator(self):
if self.extra_names is not None:
self.query.trim_extra_select(self.extra_names)

if self.flat and len(self._fields) == 1:
for row in self.query.results_iter():
yield row[0]
Expand Down
75 changes: 55 additions & 20 deletions django/db/models/sql/query.py
Expand Up @@ -88,7 +88,10 @@ def __init__(self, model, connection, where=WhereNode):

# These are for extensions. The contents are more or less appended
# verbatim to the appropriate clause.
self.extra_select = SortedDict() # Maps col_alias -> (col_sql, params).
self.extra = SortedDict() # Maps col_alias -> (col_sql, params).
self.extra_select_mask = None
self._extra_select_cache = None

self.extra_tables = ()
self.extra_where = ()
self.extra_params = ()
Expand Down Expand Up @@ -214,13 +217,21 @@ def clone(self, klass=None, **kwargs):
if self.aggregate_select_mask is None:
obj.aggregate_select_mask = None
else:
obj.aggregate_select_mask = self.aggregate_select_mask[:]
obj.aggregate_select_mask = self.aggregate_select_mask.copy()
if self._aggregate_select_cache is None:
obj._aggregate_select_cache = None
else:
obj._aggregate_select_cache = self._aggregate_select_cache.copy()
obj.max_depth = self.max_depth
obj.extra_select = self.extra_select.copy()
obj.extra = self.extra.copy()
if self.extra_select_mask is None:
obj.extra_select_mask = None
else:
obj.extra_select_mask = self.extra_select_mask.copy()
if self._extra_select_cache is None:
obj._extra_select_cache = None
else:
obj._extra_select_cache = self._extra_select_cache.copy()
obj.extra_tables = self.extra_tables
obj.extra_where = self.extra_where
obj.extra_params = self.extra_params
Expand Down Expand Up @@ -325,7 +336,7 @@ def get_aggregation(self):
query = self
self.select = []
self.default_cols = False
self.extra_select = {}
self.extra = {}
self.remove_inherited_models()

query.clear_ordering(True)
Expand Down Expand Up @@ -540,13 +551,20 @@ def combine(self, rhs, connector):
# It would be nice to be able to handle this, but the queries don't
# really make sense (or return consistent value sets). Not worth
# the extra complexity when you can write a real query instead.
if self.extra_select and rhs.extra_select:
if self.extra and rhs.extra:
raise ValueError("When merging querysets using 'or', you "
"cannot have extra(select=...) on both sides.")
if self.extra_where and rhs.extra_where:
raise ValueError("When merging querysets using 'or', you "
"cannot have extra(where=...) on both sides.")
self.extra_select.update(rhs.extra_select)
self.extra.update(rhs.extra)
extra_select_mask = set()
if self.extra_select_mask is not None:
extra_select_mask.update(self.extra_select_mask)
if rhs.extra_select_mask is not None:
extra_select_mask.update(rhs.extra_select_mask)
if extra_select_mask:
self.set_extra_mask(extra_select_mask)
self.extra_tables += rhs.extra_tables
self.extra_where += rhs.extra_where
self.extra_params += rhs.extra_params
Expand Down Expand Up @@ -2011,7 +2029,7 @@ def add_fields(self, field_names, allow_m2m=True):
except MultiJoin:
raise FieldError("Invalid field name: '%s'" % name)
except FieldError:
names = opts.get_all_field_names() + self.extra_select.keys() + self.aggregate_select.keys()
names = opts.get_all_field_names() + self.extra.keys() + self.aggregate_select.keys()
names.sort()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
Expand Down Expand Up @@ -2139,7 +2157,7 @@ def add_extra(self, select, select_params, where, params, tables, order_by):
pos = entry.find("%s", pos + 2)
select_pairs[name] = (entry, entry_params)
# This is order preserving, since self.extra_select is a SortedDict.
self.extra_select.update(select_pairs)
self.extra.update(select_pairs)
if where:
self.extra_where += tuple(where)
if params:
Expand Down Expand Up @@ -2213,22 +2231,26 @@ def get_loaded_field_names_cb(self, target, model, fields):
"""
target[model] = set([f.name for f in fields])

def trim_extra_select(self, names):
"""
Removes any aliases in the extra_select dictionary that aren't in
'names'.
This is needed if we are selecting certain values that don't incldue
all of the extra_select names.
"""
for key in set(self.extra_select).difference(set(names)):
del self.extra_select[key]

def set_aggregate_mask(self, names):
"Set the mask of aggregates that will actually be returned by the SELECT"
self.aggregate_select_mask = names
if names is None:
self.aggregate_select_mask = None
else:
self.aggregate_select_mask = set(names)
self._aggregate_select_cache = None

def set_extra_mask(self, names):
"""
Set the mask of extra select items that will be returned by SELECT,
we don't actually remove them from the Query since they might be used
later
"""
if names is None:
self.extra_select_mask = None
else:
self.extra_select_mask = set(names)
self._extra_select_cache = None

def _aggregate_select(self):
"""The SortedDict of aggregate columns that are not masked, and should
be used in the SELECT clause.
Expand All @@ -2247,6 +2269,19 @@ def _aggregate_select(self):
return self.aggregates
aggregate_select = property(_aggregate_select)

def _extra_select(self):
if self._extra_select_cache is not None:
return self._extra_select_cache
elif self.extra_select_mask is not None:
self._extra_select_cache = SortedDict([
(k,v) for k,v in self.extra.items()
if k in self.extra_select_mask
])
return self._extra_select_cache
else:
return self.extra
extra_select = property(_extra_select)

def set_start(self, start):
"""
Sets the table from which to start joining. The start position is
Expand Down
4 changes: 2 additions & 2 deletions django/db/models/sql/subqueries.py
Expand Up @@ -178,7 +178,7 @@ def pre_sql_setup(self):
# from other tables.
query = self.clone(klass=Query)
query.bump_prefix()
query.extra_select = {}
query.extra = {}
query.select = []
query.add_fields([query.model._meta.pk.name])
must_pre_select = count > 1 and not self.connection.features.update_can_self_select
Expand Down Expand Up @@ -409,7 +409,7 @@ def add_date_select(self, field, lookup_type, order='ASC'):
self.select = [select]
self.select_fields = [None]
self.select_related = False # See #7097.
self.extra_select = {}
self.extra = {}
self.distinct = True
self.order_by = order == 'ASC' and [1] or [-1]

Expand Down
18 changes: 17 additions & 1 deletion tests/regressiontests/extra_regress/models.py
Expand Up @@ -35,6 +35,9 @@ class TestObject(models.Model):
second = models.CharField(max_length=20)
third = models.CharField(max_length=20)

def __unicode__(self):
return u'TestObject: %s,%s,%s' % (self.first,self.second,self.third)

__test__ = {"API_TESTS": """
# Regression tests for #7314 and #7372
Expand Down Expand Up @@ -189,6 +192,19 @@ class TestObject(models.Model):
>>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz', 'first', 'bar', 'id')
[(u'third', u'first', u'second', 1)]
"""}
# Regression for #10847: the list of extra columns can always be accurately evaluated.
# Using an inner query ensures that as_sql() is producing correct output
# without requiring full evaluation and execution of the inner query.
>>> TestObject.objects.extra(select={'extra': 1}).values('pk')
[{'pk': 1}]
>>> TestObject.objects.filter(pk__in=TestObject.objects.extra(select={'extra': 1}).values('pk'))
[<TestObject: TestObject: first,second,third>]
>>> TestObject.objects.values('pk').extra(select={'extra': 1})
[{'pk': 1}]
>>> TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra': 1}))
[<TestObject: TestObject: first,second,third>]
"""}

0 comments on commit 5e2d384

Please sign in to comment.