Skip to content

Commit

Permalink
Fixed some more join and lookup tests.
Browse files Browse the repository at this point in the history
git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@6121 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information
malcolmt committed Sep 13, 2007
1 parent 184a643 commit ff3f6df
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 25 deletions.
4 changes: 2 additions & 2 deletions django/db/models/base.py
Expand Up @@ -334,8 +334,8 @@ def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs):
qn(self._meta.db_table), qn(self._meta.pk.column), op) qn(self._meta.db_table), qn(self._meta.pk.column), op)
param = smart_str(getattr(self, field.attname)) param = smart_str(getattr(self, field.attname))
q = self.__class__._default_manager.filter(**kwargs).order_by((not is_next and '-' or '') + field.name, (not is_next and '-' or '') + self._meta.pk.name) q = self.__class__._default_manager.filter(**kwargs).order_by((not is_next and '-' or '') + field.name, (not is_next and '-' or '') + self._meta.pk.name)
q._where.append(where) q.extra(where=where, params=[param, param,
q._params.extend([param, param, getattr(self, self._meta.pk.attname)]) getattr(self, self._meta.pk.attname)])
try: try:
return q[0] return q[0]
except IndexError: except IndexError:
Expand Down
56 changes: 45 additions & 11 deletions django/db/models/sql/query.py
Expand Up @@ -10,7 +10,7 @@
import copy import copy


from django.utils import tree from django.utils import tree
from django.db.models.sql.where import WhereNode, AND from django.db.models.sql.where import WhereNode, AND, OR
from django.db.models.sql.datastructures import Count from django.db.models.sql.datastructures import Count
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist
from django.contrib.contenttypes import generic from django.contrib.contenttypes import generic
Expand Down Expand Up @@ -233,14 +233,27 @@ def combine(self, rhs, connection):
# Work out how to relabel the rhs aliases, if necessary. # Work out how to relabel the rhs aliases, if necessary.
change_map = {} change_map = {}
used = {} used = {}
first_new_join = True
for alias in rhs.tables: for alias in rhs.tables:
if not rhs.alias_map[alias][ALIAS_REFCOUNT]:
# An unused alias.
continue
promote = (rhs.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] == promote = (rhs.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] ==
self.LOUTER) self.LOUTER)
new_alias = self.join(rhs.rev_join_map[alias], exclusions=used, new_alias = self.join(rhs.rev_join_map[alias], exclusions=used,
promote=promote) promote=promote, outer_if_first=True)
if self.alias_map[alias][ALIAS_REFCOUNT] == 1:
first_new_join = False
used[new_alias] = None used[new_alias] = None
change_map[alias] = new_alias change_map[alias] = new_alias


# So that we don't exclude valid results, the first join that is
# exclusive to the lhs (self) must be converted to an outer join.
for alias in self.tables[1:]:
if self.alias_map[alias][ALIAS_REFCOUNT] == 1:
self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER
break

# Now relabel a copy of the rhs where-clause and add it to the current # Now relabel a copy of the rhs where-clause and add it to the current
# one. # one.
if rhs.where: if rhs.where:
Expand Down Expand Up @@ -380,8 +393,12 @@ def unref_alias(self, alias):
""" Decreases the reference count for this alias. """ """ Decreases the reference count for this alias. """
self.alias_map[alias][ALIAS_REFCOUNT] -= 1 self.alias_map[alias][ALIAS_REFCOUNT] -= 1


def promote_alias(self, alias):
""" Promotes the join type of an alias to an outer join. """
self.alias_map[alias][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER

def join(self, (lhs, table, lhs_col, col), always_create=False, def join(self, (lhs, table, lhs_col, col), always_create=False,
exclusions=(), promote=False): exclusions=(), promote=False, outer_if_first=False):
""" """
Returns an alias for a join between 'table' and 'lhs' on the given Returns an alias for a join between 'table' and 'lhs' on the given
columns, either reusing an existing alias for that join or creating a columns, either reusing an existing alias for that join or creating a
Expand All @@ -398,6 +415,10 @@ def join(self, (lhs, table, lhs_col, col), always_create=False,
If 'promote' is True, the join type for the alias will be LOUTER (if If 'promote' is True, the join type for the alias will be LOUTER (if
the alias previously existed, the join type will be promoted from INNER the alias previously existed, the join type will be promoted from INNER
to LOUTER, if necessary). to LOUTER, if necessary).
If 'outer_if_first' is True and a new join is created, it will have the
LOUTER join type. This is used when joining certain types of querysets
and Q-objects together.
""" """
if lhs not in self.alias_map: if lhs not in self.alias_map:
lhs_table = lhs lhs_table = lhs
Expand All @@ -422,7 +443,7 @@ def join(self, (lhs, table, lhs_col, col), always_create=False,
assert not is_table, \ assert not is_table, \
"Must pass in lhs alias when creating a new join." "Must pass in lhs alias when creating a new join."
alias, _ = self.table_alias(table, True) alias, _ = self.table_alias(table, True)
join_type = promote and self.LOUTER or self.INNER join_type = (promote or outer_if_first) and self.LOUTER or self.INNER
join = [table, alias, join_type, lhs, lhs_col, col] join = [table, alias, join_type, lhs, lhs_col, col]
if not lhs: if not lhs:
# Not all tables need to be joined to anything. No join type # Not all tables need to be joined to anything. No join type
Expand Down Expand Up @@ -487,7 +508,8 @@ def add_filter(self, filter_expr, connection=AND, negate=False):
opts = self.model._meta opts = self.model._meta
alias = self.join((None, opts.db_table, None, None)) alias = self.join((None, opts.db_table, None, None))
dupe_multis = (connection == AND) dupe_multis = (connection == AND)
last = None seen_aliases = []
done_split = not self.where


# FIXME: Using enumerate() here is expensive. We only need 'i' to # FIXME: Using enumerate() here is expensive. We only need 'i' to
# check we aren't joining against a non-joinable field. Find a # check we aren't joining against a non-joinable field. Find a
Expand All @@ -498,23 +520,35 @@ def add_filter(self, filter_expr, connection=AND, negate=False):
if name == 'pk': if name == 'pk':
name = target_field.name name = target_field.name
if joins is not None: if joins is not None:
seen_aliases.extend(joins)
last = joins last = joins
alias = joins[-1] alias = joins[-1]
if connection == OR and not done_split:
if self.alias_map[joins[0]][ALIAS_REFCOUNT] == 1:
done_split = True
self.promote_alias(joins[0])
for t in self.tables[1:]:
if t in seen_aliases:
continue
self.promote_alias(t)
break
else:
seen_aliases.extend(joins)
else: else:
# Normal field lookup must be the last field in the filter. # Normal field lookup must be the last field in the filter.
if i != len(parts) - 1: if i != len(parts) - 1:
raise TypeError("Joins on field %r not permitted." raise TypeError("Join on field %r not permitted."
% name) % name)


col = target_col or target_field.column col = target_col or target_field.column


if target_field is opts.pk and last: if target_field is opts.pk and seen_aliases:
# An optimization: if the final join is against a primary key, # An optimization: if the final join is against a primary key,
# we can go back one step in the join chain and compare against # we can go back one step in the join chain and compare against
# the lhs of the join instead. The result (potentially) involves # the lhs of the join instead. The result (potentially) involves
# one less table join. # one less table join.
self.unref_alias(alias) self.unref_alias(alias)
join = self.alias_map[last[-1]][ALIAS_JOIN] join = self.alias_map[seen_aliases[-1]][ALIAS_JOIN]
alias = join[LHS_ALIAS] alias = join[LHS_ALIAS]
col = join[LHS_JOIN_COL] col = join[LHS_JOIN_COL]


Expand All @@ -523,13 +557,13 @@ def add_filter(self, filter_expr, connection=AND, negate=False):
# join when connecting to the previous model. We make that # join when connecting to the previous model. We make that
# adjustment here. We don't do this unless needed because it's less # adjustment here. We don't do this unless needed because it's less
# efficient at the database level. # efficient at the database level.
self.alias_map[joins[0]][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER self.promote_alias(joins[0])


self.where.add([alias, col, orig_field, lookup_type, value], self.where.add([alias, col, orig_field, lookup_type, value],
connection) connection)
if negate: if negate:
if last: if seen_aliases:
self.alias_map[last[0]][ALIAS_JOIN][JOIN_TYPE] = self.LOUTER self.promote_alias(last[0])
self.where.negate() self.where.negate()


def add_q(self, q_object): def add_q(self, q_object):
Expand Down
18 changes: 13 additions & 5 deletions django/db/models/sql/where.py
Expand Up @@ -64,11 +64,19 @@ def as_sql(self, node=None):
else: else:
format = '(%s)' format = '(%s)'
else: else:
sql = self.make_atom(child) try:
params = child[2].get_db_prep_lookup(child[3], child[4]) sql = self.make_atom(child)
format = '%s' params = child[2].get_db_prep_lookup(child[3], child[4])
result.append(format % sql) format = '%s'
result_params.extend(params) except EmptyResultSet:
if node.negated:
# If this is a "not" atom, being empty means it has no
# effect on the result, so we can ignore it.
continue
raise
if sql:
result.append(format % sql)
result_params.extend(params)
conn = ' %s ' % node.connection conn = ' %s ' % node.connection
return conn.join(result), result_params return conn.join(result), result_params


Expand Down
7 changes: 6 additions & 1 deletion django/utils/tree.py
Expand Up @@ -21,6 +21,10 @@ def __init__(self, children=None, connection=None):
self.subtree_parents = [] self.subtree_parents = []
self.negated = False self.negated = False


def __str__(self):
return '(%s: %s)' % (self.connection, ', '.join([str(c) for c in
self.children]))

def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
""" """
Utility method used by copy.deepcopy(). Utility method used by copy.deepcopy().
Expand Down Expand Up @@ -59,7 +63,8 @@ def add(self, node, conn_type):
if len(self.children) < 2: if len(self.children) < 2:
self.connection = conn_type self.connection = conn_type
if self.connection == conn_type: if self.connection == conn_type:
if isinstance(node, Node) and node.connection == conn_type: if isinstance(node, Node) and (node.connection == conn_type
or len(node) == 1):
self.children.extend(node.children) self.children.extend(node.children)
else: else:
self.children.append(node) self.children.append(node)
Expand Down
2 changes: 1 addition & 1 deletion tests/modeltests/lookup/models.py
Expand Up @@ -258,7 +258,7 @@ def __unicode__(self):
>>> Article.objects.filter(headline__starts='Article') >>> Article.objects.filter(headline__starts='Article')
Traceback (most recent call last): Traceback (most recent call last):
... ...
TypeError: Cannot resolve keyword 'headline__starts' into field. Choices are: id, headline, pub_date TypeError: Join on field 'headline' not permitted.
# Create some articles with a bit more interesting headlines for testing field lookups: # Create some articles with a bit more interesting headlines for testing field lookups:
>>> now = datetime.now() >>> now = datetime.now()
Expand Down
8 changes: 3 additions & 5 deletions tests/regressiontests/queries/models.py
Expand Up @@ -111,11 +111,9 @@ def __unicode__(self):
>>> Item.objects.filter(tags__in=[t1, t2]).filter(tags=t3) >>> Item.objects.filter(tags__in=[t1, t2]).filter(tags=t3)
[<Item: two>] [<Item: two>]
Bug #2080 Bug #2080, #3592
# FIXME: Still problematic: the join needs to be "left outer" on the reverse >>> Author.objects.filter(Q(name='a3') | Q(item__name='one'))
# fk, but the individual joins only need to be inner. [<Author: a1>, <Author: a3>]
# >>> Author.objects.filter(Q(name='a3') | Q(item__name='one'))
# [<Author: a3>]
Bug #2939 Bug #2939
# FIXME: ValueQuerySets don't work yet. # FIXME: ValueQuerySets don't work yet.
Expand Down

0 comments on commit ff3f6df

Please sign in to comment.