Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fixed #20874 -- bump_prefix() in nested subqueries

Also made some cleanup to build_filter() code by introducing submethods
solve_lookup_type() and prepare_lookup_value().
  • Loading branch information...
commit dcdc579d162b750ee3449e34efd772703592faca 1 parent 6c12cd1
@akaariai akaariai authored
View
5 django/db/models/sql/compiler.py
@@ -167,7 +167,6 @@ def as_nested_sql(self):
if obj.low_mark == 0 and obj.high_mark is None:
# If there is no slicing in use, then we can safely drop all ordering
obj.clear_ordering(True)
- obj.bump_prefix()
return obj.get_compiler(connection=self.connection).as_sql()
def get_columns(self, with_aliases=False):
@@ -808,13 +807,14 @@ def execute_sql(self, result_type=MULTI):
return result
def as_subquery_condition(self, alias, columns, qn):
+ inner_qn = self.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
if len(columns) == 1:
sql, params = self.as_sql()
return '%s.%s IN (%s)' % (qn(alias), qn2(columns[0]), sql), params
for index, select_col in enumerate(self.query.select):
- lhs = '%s.%s' % (qn(select_col.col[0]), qn2(select_col.col[1]))
+ lhs = '%s.%s' % (inner_qn(select_col.col[0]), qn2(select_col.col[1]))
rhs = '%s.%s' % (qn(alias), qn2(columns[index]))
self.query.where.add(
QueryWrapper('%s = %s' % (lhs, rhs), []), 'AND')
@@ -1010,7 +1010,6 @@ def pre_sql_setup(self):
# We need to use a sub-select in the where clause to filter on things
# from other tables.
query = self.query.clone(klass=Query)
- query.bump_prefix()
query.extra = {}
query.select = []
query.add_fields([query.get_meta().pk.name])
View
144 django/db/models/sql/query.py
@@ -97,6 +97,7 @@ class Query(object):
LOUTER = 'LEFT OUTER JOIN'
alias_prefix = 'T'
+ subq_aliases = frozenset([alias_prefix])
query_terms = QUERY_TERMS
aggregates_module = base_aggregates_module
@@ -273,6 +274,10 @@ def clone(self, klass=None, memo=None, **kwargs):
else:
obj.used_aliases = set()
obj.filter_is_sticky = False
+ if 'alias_prefix' in self.__dict__:
+ obj.alias_prefix = self.alias_prefix
+ if 'subq_aliases' in self.__dict__:
+ obj.subq_aliases = self.subq_aliases.copy()
obj.__dict__.update(kwargs)
if hasattr(obj, '_setup_query'):
@@ -780,28 +785,22 @@ def relabel_column(col):
data = data._replace(lhs_alias=change_map[lhs])
self.alias_map[alias] = data
- def bump_prefix(self, exceptions=()):
+ def bump_prefix(self, outer_query):
"""
- Changes the alias prefix to the next letter in the alphabet and
- relabels all the aliases. Even tables that previously had no alias will
- get an alias after this call (it's mostly used for nested queries and
- the outer query will already be using the non-aliased table name).
-
- Subclasses who create their own prefix should override this method to
- produce a similar result (a new prefix and relabelled aliases).
-
- The 'exceptions' parameter is a container that holds alias names which
- should not be changed.
+ Changes the alias prefix to the next letter in the alphabet in a way
+ that the outer query's aliases and this query's aliases will not
+ conflict. Even tables that previously had no alias will get an alias
+ after this call.
"""
- current = ord(self.alias_prefix)
- assert current < ord('Z')
- prefix = chr(current + 1)
- self.alias_prefix = prefix
+ self.alias_prefix = chr(ord(self.alias_prefix) + 1)
+ while self.alias_prefix in self.subq_aliases:
+ self.alias_prefix = chr(ord(self.alias_prefix) + 1)
+ assert self.alias_prefix < 'Z'
+ self.subq_aliases = self.subq_aliases.union([self.alias_prefix])
+ outer_query.subq_aliases = outer_query.subq_aliases.union(self.subq_aliases)
change_map = OrderedDict()
for pos, alias in enumerate(self.tables):
- if alias in exceptions:
- continue
- new_alias = '%s%d' % (prefix, pos)
+ new_alias = '%s%d' % (self.alias_prefix, pos)
change_map[alias] = new_alias
self.tables[pos] = new_alias
self.change_aliases(change_map)
@@ -1005,6 +1004,65 @@ def add_aggregate(self, aggregate, model, alias, is_summary):
# Add the aggregate to the query
aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
+ def prepare_lookup_value(self, value, lookup_type, can_reuse):
+ # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
+ # uses of None as a query value.
+ if value is None:
+ if lookup_type != 'exact':
+ raise ValueError("Cannot use None as a query value")
+ lookup_type = 'isnull'
+ value = True
+ elif callable(value):
+ value = value()
+ elif isinstance(value, ExpressionNode):
+ # If value is a query expression, evaluate it
+ value = SQLEvaluator(value, self, reuse=can_reuse)
+ if hasattr(value, 'query') and hasattr(value.query, 'bump_prefix'):
+ value = value._clone()
+ value.query.bump_prefix(self)
+ if hasattr(value, 'bump_prefix'):
+ value = value.clone()
+ value.bump_prefix(self)
+ # For Oracle '' is equivalent to null. The check needs to be done
+ # at this stage because join promotion can't be done at compiler
+ # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we
+ # can do here. Similar thing is done in is_nullable(), too.
+ if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and
+ lookup_type == 'exact' and value == ''):
+ value = True
+ lookup_type = 'isnull'
+ return value, lookup_type
+
+ def solve_lookup_type(self, lookup):
+ """
+ Solve the lookup type from the lookup (eg: 'foobar__id__icontains')
+ """
+ lookup_type = 'exact' # Default lookup type
+ lookup_parts = lookup.split(LOOKUP_SEP)
+ num_parts = len(lookup_parts)
+ if (len(lookup_parts) > 1 and lookup_parts[-1] in self.query_terms
+ and lookup not in self.aggregates):
+ # Traverse the lookup query to distinguish related fields from
+ # lookup types.
+ lookup_model = self.model
+ for counter, field_name in enumerate(lookup_parts):
+ try:
+ lookup_field = lookup_model._meta.get_field(field_name)
+ except FieldDoesNotExist:
+ # Not a field. Bail out.
+ lookup_type = lookup_parts.pop()
+ break
+ # Unless we're at the end of the list of lookups, let's attempt
+ # to continue traversing relations.
+ if (counter + 1) < num_parts:
+ try:
+ lookup_model = lookup_field.rel.to
+ except AttributeError:
+ # Not a related field. Bail out.
+ lookup_type = lookup_parts.pop()
+ break
+ return lookup_type, lookup_parts
+
def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
can_reuse=None):
"""
@@ -1033,58 +1091,15 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
is responsible for unreffing the joins used.
"""
arg, value = filter_expr
- parts = arg.split(LOOKUP_SEP)
+ lookup_type, parts = self.solve_lookup_type(arg)
if not parts:
raise FieldError("Cannot parse keyword query %r" % arg)
# Work out the lookup type and remove it from the end of 'parts',
# if necessary.
- lookup_type = 'exact' # Default lookup type
- num_parts = len(parts)
- if (len(parts) > 1 and parts[-1] in self.query_terms
- and arg not in self.aggregates):
- # Traverse the lookup query to distinguish related fields from
- # lookup types.
- lookup_model = self.model
- for counter, field_name in enumerate(parts):
- try:
- lookup_field = lookup_model._meta.get_field(field_name)
- except FieldDoesNotExist:
- # Not a field. Bail out.
- lookup_type = parts.pop()
- break
- # Unless we're at the end of the list of lookups, let's attempt
- # to continue traversing relations.
- if (counter + 1) < num_parts:
- try:
- lookup_model = lookup_field.rel.to
- except AttributeError:
- # Not a related field. Bail out.
- lookup_type = parts.pop()
- break
+ value, lookup_type = self.prepare_lookup_value(value, lookup_type, can_reuse)
clause = self.where_class()
- # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
- # uses of None as a query value.
- if value is None:
- if lookup_type != 'exact':
- raise ValueError("Cannot use None as a query value")
- lookup_type = 'isnull'
- value = True
- elif callable(value):
- value = value()
- elif isinstance(value, ExpressionNode):
- # If value is a query expression, evaluate it
- value = SQLEvaluator(value, self, reuse=can_reuse)
- # For Oracle '' is equivalent to null. The check needs to be done
- # at this stage because join promotion can't be done at compiler
- # stage. Using DEFAULT_DB_ALIAS isn't nice, but it is the best we
- # can do here. Similar thing is done in is_nullable(), too.
- if (connections[DEFAULT_DB_ALIAS].features.interprets_empty_strings_as_nulls and
- lookup_type == 'exact' and value == ''):
- value = True
- lookup_type = 'isnull'
-
for alias, aggregate in self.aggregates.items():
if alias in (parts[0], LOOKUP_SEP.join(parts)):
clause.add((aggregate, lookup_type, value), AND)
@@ -1096,7 +1111,7 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False,
try:
field, sources, opts, join_list, path = self.setup_joins(
- parts, opts, alias, can_reuse, allow_many,)
+ parts, opts, alias, can_reuse, allow_many,)
if can_reuse is not None:
can_reuse.update(join_list)
except MultiJoin as e:
@@ -1404,7 +1419,6 @@ def split_exclude(self, filter_expr, prefix, can_reuse, names_with_path):
# Generate the inner query.
query = Query(self.model)
query.where.add(query.build_filter(filter_expr), AND)
- query.bump_prefix()
query.clear_ordering(True)
# Try to have as simple as possible subquery -> trim leading joins from
# the subquery.
View
19 tests/foreign_object/tests.py
@@ -132,7 +132,6 @@ def test_forward_in_lookup_filters_correctly(self):
],
attrgetter('person_id')
)
-
self.assertQuerysetEqual(
Membership.objects.filter(person__in=Person.objects.filter(name='Jim')), [
self.jim.id,
@@ -140,6 +139,24 @@ def test_forward_in_lookup_filters_correctly(self):
attrgetter('person_id')
)
+ def test_double_nested_query(self):
+ m1 = Membership.objects.create(membership_country_id=self.usa.id, person_id=self.bob.id,
+ group_id=self.cia.id)
+ m2 = Membership.objects.create(membership_country_id=self.usa.id, person_id=self.jim.id,
+ group_id=self.cia.id)
+ Friendship.objects.create(from_friend_country_id=self.usa.id, from_friend_id=self.bob.id,
+ to_friend_country_id=self.usa.id, to_friend_id=self.jim.id)
+ self.assertQuerysetEqual(Membership.objects.filter(
+ person__in=Person.objects.filter(
+ from_friend__in=Friendship.objects.filter(
+ to_friend__in=Person.objects.all()))),
+ [m1], lambda x: x)
+ self.assertQuerysetEqual(Membership.objects.exclude(
+ person__in=Person.objects.filter(
+ from_friend__in=Friendship.objects.filter(
+ to_friend__in=Person.objects.all()))),
+ [m2], lambda x: x)
+
def test_select_related_foreignkey_forward_works(self):
Membership.objects.create(membership_country=self.usa, person=self.bob, group=self.cia)
Membership.objects.create(membership_country=self.usa, person=self.jim, group=self.democrat)
View
16 tests/queries/tests.py
@@ -27,7 +27,6 @@
BaseA, FK1, Identifier, Program, Channel, Page, Paragraph, Chapter, Book,
MyObject, Order, OrderItem)
-
class BaseQuerysetTest(TestCase):
def assertValueQuerysetEqual(self, qs, values):
return self.assertQuerysetEqual(qs, values, transform=lambda x: x)
@@ -84,6 +83,19 @@ def setUp(self):
Cover.objects.create(title="first", item=i4)
Cover.objects.create(title="second", item=self.i2)
+ def test_subquery_condition(self):
+ qs1 = Tag.objects.filter(pk__lte=0)
+ qs2 = Tag.objects.filter(parent__in=qs1)
+ qs3 = Tag.objects.filter(parent__in=qs2)
+ self.assertEqual(qs3.query.subq_aliases, set(['T', 'U', 'V']))
+ self.assertIn('V0', str(qs3.query))
+ qs4 = qs3.filter(parent__in=qs1)
+ self.assertEqual(qs4.query.subq_aliases, set(['T', 'U', 'V']))
+ # It is possible to reuse U for the second subquery, no need to use W.
+ self.assertNotIn('W0', str(qs4.query))
+ # So, 'U0."id"' is referenced twice.
+ self.assertTrue(str(qs4.query).count('U0."id"'), 2)
+
def test_ticket1050(self):
self.assertQuerysetEqual(
Item.objects.filter(tags__isnull=True),
@@ -810,7 +822,7 @@ def test_ticket9411(self):
# Make sure bump_prefix() (an internal Query method) doesn't (re-)break. It's
# sufficient that this query runs without error.
qs = Tag.objects.values_list('id', flat=True).order_by('id')
- qs.query.bump_prefix()
+ qs.query.bump_prefix(qs.query)
first = qs[0]
self.assertEqual(list(qs), list(range(first, first+5)))
Please sign in to comment.
Something went wrong with that request. Please try again.