Skip to content

Commit

Permalink
Trimming doesn't happen for multicolumn joins
Browse files Browse the repository at this point in the history
get_path_info holds direction of join off of join_field
removing unneeded values sent from get_path_info
  • Loading branch information
Jeremy Tillman committed Jan 17, 2013
1 parent b04bec9 commit bfa871a
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 31 deletions.
5 changes: 2 additions & 3 deletions django/contrib/contenttypes/generic.py
Expand Up @@ -168,8 +168,7 @@ def get_path_info(self):
# Note that we are using different field for the join_field
# than from_field or to_field. This is a hack, but we need the
# GenericRelation to generate the extra SQL.
return ([PathInfo(from_field, target, self.model._meta, opts, self, True, False)],
opts, target, self)
return [PathInfo(from_field, target, self.model._meta, opts, self, True, False)]

def get_choices_default(self):
return Field.get_choices(self, include_blank=False)
Expand All @@ -182,7 +181,7 @@ def get_joining_columns(self, reverse_join=False):
# Our second join will happen in the extra sql
join_cols = ((self.m2m_target_field_name(), self.m2m_column_name()),)
if not reverse_join:
raise ValueError('Reverse join is not supported on generic relations')
raise ValueError('GenericRelation only supports reverse joins.')

return join_cols

Expand Down
14 changes: 14 additions & 0 deletions django/contrib/multicolumn/tests.py
Expand Up @@ -139,6 +139,20 @@ def test_reverse_query_filters_correctly(self):
attrgetter('name')
)

def test_forard_in_lookup_filters_correctly(self):
Membership.objects.create(membership_country_id=self.usa.id, person_id=self.bob.id, group_id=self.cia.id)
Membership.objects.create(membership_country_id=self.usa.id, person_id=self.jim.id, group_id=self.cia.id)

# Creating an invalid membership
Membership.objects.create(membership_country_id=self.soviet_union.id, person_id=self.george.id, group_id=self.cia.id)

self.assertQuerysetEqual(
Membership.objects.filter(person__in=[self.george, self.jim]),[
self.jim.id,
],
attrgetter('person_id')
)

def test_select_related_foreignkey_forward_works(self):
Membership.objects.create(membership_country=self.usa, person=self.bob, group=self.cia)
Membership.objects.create(membership_country=self.usa, person=self.jim, group=self.democrat)
Expand Down
16 changes: 8 additions & 8 deletions django/db/models/fields/related.py
Expand Up @@ -1077,7 +1077,7 @@ def get_path_info(self):
opts = self.rel.to._meta
target = self.rel.get_related_field()
from_opts = self.model._meta
return [PathInfo(self, target, from_opts, opts, self, False, True)], opts, target, self
return [PathInfo(self, target, from_opts, opts, self, False, True)]

def get_reverse_path_info(self):
"""
Expand All @@ -1086,14 +1086,14 @@ def get_reverse_path_info(self):
opts = self.model._meta
from_field = self.rel.get_related_field()
from_opts = from_field.model._meta
pathinfos = [PathInfo(from_field, self, from_opts, opts, self, not self.unique, False)]
if from_field.model is self.model:
# Recursive foreign key to self.
target = opts.get_field_by_name(
self.rel.field_name)[0]
else:
target = opts.pk
return pathinfos, opts, target, self
pathinfos = [PathInfo(from_field, target, from_opts, opts, self, not self.unique, False)]
return pathinfos

def validate(self, value, model_instance):
if self.rel.parent_link:
Expand Down Expand Up @@ -1298,14 +1298,14 @@ def _get_path_info(self, direct=False):
linkfield1 = int_model._meta.get_field_by_name(self.m2m_field_name())[0]
linkfield2 = int_model._meta.get_field_by_name(self.m2m_reverse_field_name())[0]
if direct:
join1infos, _, _, _ = linkfield1.get_reverse_path_info()
join2infos, opts, target, final_field = linkfield2.get_path_info()
join1infos = linkfield1.get_reverse_path_info()
join2infos = linkfield2.get_path_info()
else:
join1infos, _, _, _ = linkfield2.get_reverse_path_info()
join2infos, opts, target, final_field = linkfield1.get_path_info()
join1infos = linkfield2.get_reverse_path_info()
join2infos = linkfield1.get_path_info()
pathinfos.extend(join1infos)
pathinfos.extend(join2infos)
return pathinfos, opts, target, final_field
return pathinfos

def get_path_info(self):
return self._get_path_info(direct=True)
Expand Down
5 changes: 3 additions & 2 deletions django/db/models/sql/compiler.py
Expand Up @@ -465,11 +465,12 @@ def _final_join_removal(self, col, alias):
if alias:
while 1:
join = self.query.alias_map[alias]
if col != join.join_cols[0][1]:
if len(join.join_cols) != 1 or join.join_cols[0][1] != col:
break

col = join.join_cols[0][0]
self.query.unref_alias(alias)
alias = join.lhs_alias
col = join.join_cols[0][0]
return col, alias

def get_from_clause(self):
Expand Down
36 changes: 18 additions & 18 deletions django/db/models/sql/query.py
Expand Up @@ -949,8 +949,6 @@ def join(self, connection, reuse=None, promote=False,
"""
lhs, table, join_cols = connection
assert lhs is None or join_field is not None
if join_cols is None:
join_cols = ((None, None),)
existing = self.join_map.get(connection, ())
if reuse is None:
reuse = existing
Expand Down Expand Up @@ -981,7 +979,7 @@ def join(self, connection, reuse=None, promote=False,
join_type = self.LOUTER
else:
join_type = self.INNER
join = JoinInfo(table, alias, join_type, lhs, join_cols, nullable,
join = JoinInfo(table, alias, join_type, lhs, join_cols or ((None, None),), nullable,
join_field)
self.alias_map[alias] = join
if connection in self.join_map:
Expand Down Expand Up @@ -1360,11 +1358,14 @@ def names_to_path(self, names, opts, allow_many=False,
final_field = opts.parents[int_model]
target = final_field.rel.get_related_field()
opts = int_model._meta
path.append(PathInfo(final_field, target, final_field.model._meta,
opts, final_field, False, True))
path.append(PathInfo(final_field, target, final_field.model._meta, opts, final_field, False, True))
if hasattr(field, 'get_path_info'):
pathinfos, opts, target, final_field = field.get_path_info()
pathinfos = field.get_path_info()
path.extend(pathinfos)
last = pathinfos[-1]
final_field = last.join_field
opts = last.to_opts
target = last.to_field
else:
# Local non-relational field.
final_field = target = field
Expand Down Expand Up @@ -1450,7 +1451,10 @@ def trim_joins(self, target, joins, path):
the join.
"""
for info in reversed(path):
if info.to_field == target and info.direct:
if len(joins) > 1 and \
info.to_field == target and \
info.direct and \
len(self.alias_map[joins[-1]].join_cols) == 1:
target = info.from_field
self.unref_alias(joins.pop())
else:
Expand Down Expand Up @@ -1593,18 +1597,14 @@ def add_fields(self, field_names, allow_m2m=True):

try:
for name in field_names:
field, target, u2, joins, u3 = self.setup_joins(
field, target, u2, joins, path = self.setup_joins(
name.split(LOOKUP_SEP), opts, alias, None, allow_m2m,
True)
final_alias = joins[-1]
col = target.column
if len(joins) > 1:
join = self.alias_map[final_alias]
if col == join.join_cols[0][1]:
self.unref_alias(final_alias)
final_alias = join.lhs_alias
col = join.join_cols[0][0]
joins = joins[:-1]

# Trim last join if possible
col, final_alias, remaining_joins = self.trim_joins(target, joins[-2:], path)
joins = joins[:-2] + remaining_joins

self.promote_joins(joins[1:])
self.select.append(SelectInfo((final_alias, col), field))
except MultiJoin:
Expand Down Expand Up @@ -1885,7 +1885,7 @@ def set_start(self, start):
"""
opts = self.model._meta
alias = self.get_initial_alias()
field, col, opts, joins, extra = self.setup_joins(
field, target, opts, joins, extra = self.setup_joins(
start.split(LOOKUP_SEP), opts, alias)
select_col = self.alias_map[joins[1]].join_cols[0][0]
select_alias = alias
Expand Down

0 comments on commit bfa871a

Please sign in to comment.