Skip to content

Commit

Permalink
Fixed #14876 -- Ensure that join promotion works correctly when there…
Browse files Browse the repository at this point in the history
… are nullable related fields. Thanks to simonpercivall for the report, oinopion and Aleksandra Sendecka for the original patch, and to Malcolm for helping me wrestle the edge cases to the ground.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16648 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information
freakboy3742 committed Aug 22, 2011
1 parent 5edf1aa commit 3afe409
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 19 deletions.
50 changes: 36 additions & 14 deletions django/db/models/sql/query.py
Expand Up @@ -445,8 +445,6 @@ def combine(self, rhs, connector):
"Cannot combine a unique query with a non-unique query." "Cannot combine a unique query with a non-unique query."


self.remove_inherited_models() self.remove_inherited_models()
l_tables = set([a for a in self.tables if self.alias_refcount[a]])
r_tables = set([a for a in rhs.tables if rhs.alias_refcount[a]])
# Work out how to relabel the rhs aliases, if necessary. # Work out how to relabel the rhs aliases, if necessary.
change_map = {} change_map = {}
used = set() used = set()
Expand All @@ -471,16 +469,27 @@ def combine(self, rhs, connector):
# all joins exclusive to either the lhs or the rhs must be converted # all joins exclusive to either the lhs or the rhs must be converted
# to an outer join. # to an outer join.
if not conjunction: if not conjunction:
l_tables = set(self.tables)
r_tables = set(rhs.tables)
# Update r_tables aliases. # Update r_tables aliases.
for alias in change_map: for alias in change_map:
if alias in r_tables: if alias in r_tables:
r_tables.remove(alias) # r_tables may contain entries that have a refcount of 0
r_tables.add(change_map[alias]) # if the query has references to a table that can be
# trimmed because only the foreign key is used.
# We only need to fix the aliases for the tables that
# actually have aliases.
if rhs.alias_refcount[alias]:
r_tables.remove(alias)
r_tables.add(change_map[alias])
# Find aliases that are exclusive to rhs or lhs. # Find aliases that are exclusive to rhs or lhs.
# These are promoted to outer joins. # These are promoted to outer joins.
outer_aliases = (l_tables | r_tables) - (l_tables & r_tables) outer_tables = (l_tables | r_tables) - (l_tables & r_tables)
for alias in outer_aliases: for alias in outer_tables:
self.promote_alias(alias, True) # Again, some of the tables won't have aliases due to
# the trimming of unnecessary tables.
if self.alias_refcount.get(alias) or rhs.alias_refcount.get(alias):
self.promote_alias(alias, True)


# Now relabel a copy of the rhs where-clause and add it to the current # Now relabel a copy of the rhs where-clause and add it to the current
# one. # one.
Expand Down Expand Up @@ -668,7 +677,7 @@ def promote_alias(self, alias, unconditional=False):
False, the join is only promoted if it is nullable, otherwise it is False, the join is only promoted if it is nullable, otherwise it is
always promoted. always promoted.
Returns True if the join was promoted. Returns True if the join was promoted by this call.
""" """
if ((unconditional or self.alias_map[alias][NULLABLE]) and if ((unconditional or self.alias_map[alias][NULLABLE]) and
self.alias_map[alias][JOIN_TYPE] != self.LOUTER): self.alias_map[alias][JOIN_TYPE] != self.LOUTER):
Expand Down Expand Up @@ -1076,17 +1085,20 @@ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
can_reuse) can_reuse)
return return


table_promote = False
join_promote = False

if (lookup_type == 'isnull' and value is True and not negate and if (lookup_type == 'isnull' and value is True and not negate and
len(join_list) > 1): len(join_list) > 1):
# If the comparison is against NULL, we may need to use some left # If the comparison is against NULL, we may need to use some left
# outer joins when creating the join chain. This is only done when # outer joins when creating the join chain. This is only done when
# needed, as it's less efficient at the database level. # needed, as it's less efficient at the database level.
self.promote_alias_chain(join_list) self.promote_alias_chain(join_list)
join_promote = True


# Process the join list to see if we can remove any inner joins from # Process the join list to see if we can remove any inner joins from
# the far end (fewer tables in a query is better). # the far end (fewer tables in a query is better).
col, alias, join_list = self.trim_joins(target, join_list, last, trim) col, alias, join_list = self.trim_joins(target, join_list, last, trim)

if connector == OR: if connector == OR:
# Some joins may need to be promoted when adding a new filter to a # Some joins may need to be promoted when adding a new filter to a
# disjunction. We walk the list of new joins and where it diverges # disjunction. We walk the list of new joins and where it diverges
Expand All @@ -1096,19 +1108,29 @@ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
join_it = iter(join_list) join_it = iter(join_list)
table_it = iter(self.tables) table_it = iter(self.tables)
join_it.next(), table_it.next() join_it.next(), table_it.next()
table_promote = False unconditional = False
join_promote = False
for join in join_it: for join in join_it:
table = table_it.next() table = table_it.next()
# Once we hit an outer join, all subsequent joins must
# also be promoted, regardless of whether they have been
# promoted as a result of this pass through the tables.
unconditional = (unconditional or
self.alias_map[join][JOIN_TYPE] == self.LOUTER)
if join == table and self.alias_refcount[join] > 1: if join == table and self.alias_refcount[join] > 1:
# We have more than one reference to this join table.
# This means that we are dealing with two different query
# subtrees, so we don't need to do any join promotion.
continue continue
join_promote = self.promote_alias(join) join_promote = join_promote or self.promote_alias(join, unconditional)
if table != join: if table != join:
table_promote = self.promote_alias(table) table_promote = self.promote_alias(table)
# We only get here if we have found a table that exists
# in the join list, but isn't on the original tables list.
# This means we've reached the point where we only have
# new tables, so we can break out of this promotion loop.
break break
self.promote_alias_chain(join_it, join_promote) self.promote_alias_chain(join_it, join_promote)
self.promote_alias_chain(table_it, table_promote) self.promote_alias_chain(table_it, table_promote or join_promote)



if having_clause or force_having: if having_clause or force_having:
if (alias, col) not in self.group_by: if (alias, col) not in self.group_by:
Expand Down
34 changes: 29 additions & 5 deletions tests/regressiontests/queries/tests.py
Expand Up @@ -959,12 +959,36 @@ def setUp(self):
e1 = ExtraInfo.objects.create(info='e1', note=n1) e1 = ExtraInfo.objects.create(info='e1', note=n1)
e2 = ExtraInfo.objects.create(info='e2', note=n2) e2 = ExtraInfo.objects.create(info='e2', note=n2)


a1 = Author.objects.create(name='a1', num=1001, extra=e1) self.a1 = Author.objects.create(name='a1', num=1001, extra=e1)
a3 = Author.objects.create(name='a3', num=3003, extra=e2) self.a3 = Author.objects.create(name='a3', num=3003, extra=e2)


Report.objects.create(name='r1', creator=a1) self.r1 = Report.objects.create(name='r1', creator=self.a1)
Report.objects.create(name='r2', creator=a3) self.r2 = Report.objects.create(name='r2', creator=self.a3)
Report.objects.create(name='r3') self.r3 = Report.objects.create(name='r3')

Item.objects.create(name='i1', created=datetime.datetime.now(), note=n1, creator=self.a1)
Item.objects.create(name='i2', created=datetime.datetime.now(), note=n1, creator=self.a3)

def test_ticket14876(self):
q1 = Report.objects.filter(Q(creator__isnull=True) | Q(creator__extra__info='e1'))
q2 = Report.objects.filter(Q(creator__isnull=True)) | Report.objects.filter(Q(creator__extra__info='e1'))
self.assertQuerysetEqual(q1, ["<Report: r1>", "<Report: r3>"])
self.assertEqual(str(q1.query), str(q2.query))

q1 = Report.objects.filter(Q(creator__extra__info='e1') | Q(creator__isnull=True))
q2 = Report.objects.filter(Q(creator__extra__info='e1')) | Report.objects.filter(Q(creator__isnull=True))
self.assertQuerysetEqual(q1, ["<Report: r1>", "<Report: r3>"])
self.assertEqual(str(q1.query), str(q2.query))

q1 = Item.objects.filter(Q(creator=self.a1) | Q(creator__report__name='r1')).order_by()
q2 = Item.objects.filter(Q(creator=self.a1)).order_by() | Item.objects.filter(Q(creator__report__name='r1')).order_by()
self.assertQuerysetEqual(q1, ["<Item: i1>"])
self.assertEqual(str(q1.query), str(q2.query))

q1 = Item.objects.filter(Q(creator__report__name='e1') | Q(creator=self.a1)).order_by()
q2 = Item.objects.filter(Q(creator__report__name='e1')).order_by() | Item.objects.filter(Q(creator=self.a1)).order_by()
self.assertQuerysetEqual(q1, ["<Item: i1>"])
self.assertEqual(str(q1.query), str(q2.query))


def test_ticket7095(self): def test_ticket7095(self):
# Updates that are filtered on the model being updated are somewhat # Updates that are filtered on the model being updated are somewhat
Expand Down

0 comments on commit 3afe409

Please sign in to comment.