Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fixed handling of multiple fields in a model pointing to the same rel…

…ated model.

Thanks to ElliotM, mk and oyvind for some excellent test cases for this. Fixed #7110, #7125.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@7778 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit bb2182453b49157fb6fba4de6d3c53a09f73d74b 1 parent d800c0b
@malcolmt malcolmt authored
View
11 django/db/models/fields/related.py
@@ -692,6 +692,11 @@ def flatten_data(self, follow, obj=None):
def contribute_to_class(self, cls, name):
super(ForeignKey, self).contribute_to_class(cls, name)
setattr(cls, self.name, ReverseSingleRelatedObjectDescriptor(self))
+ if isinstance(self.rel.to, basestring):
+ target = self.rel.to
+ else:
+ target = self.rel.to._meta.db_table
+ cls._meta.duplicate_targets[self.column] = (target, "o2m")
def contribute_to_related_class(self, cls, related):
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
@@ -826,6 +831,12 @@ def contribute_to_class(self, cls, name):
# Set up the accessor for the m2m table name for the relation
self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)
+ if isinstance(self.rel.to, basestring):
+ target = self.rel.to
+ else:
+ target = self.rel.to._meta.db_table
+ cls._meta.duplicate_targets[self.column] = (target, "m2m")
+
def contribute_to_related_class(self, cls, related):
# m2m relations to self do not have a ManyRelatedObjectsDescriptor,
# as it would be redundant - unless the field is non-symmetrical.
View
19 django/db/models/options.py
@@ -44,6 +44,7 @@ def __init__(self, meta, app_label=None):
self.one_to_one_field = None
self.abstract = False
self.parents = SortedDict()
+ self.duplicate_targets = {}
def contribute_to_class(self, cls, name):
from django.db import connection
@@ -115,6 +116,24 @@ def _prepare(self, model):
auto_created=True)
model.add_to_class('id', auto)
+ # Determine any sets of fields that are pointing to the same targets
+ # (e.g. two ForeignKeys to the same remote model). The query
+ # construction code needs to know this. At the end of this,
+ # self.duplicate_targets will map each duplicate field column to the
+ # columns it duplicates.
+ collections = {}
+ for column, target in self.duplicate_targets.iteritems():
+ try:
+ collections[target].add(column)
+ except KeyError:
+ collections[target] = set([column])
+ self.duplicate_targets = {}
+ for elt in collections.itervalues():
+ if len(elt) == 1:
+ continue
+ for column in elt:
+ self.duplicate_targets[column] = elt.difference(set([column]))
+
def add_field(self, field):
# Insert the given field in the order in which it was created, using
# the "creation_counter" attribute of the field.
View
99 django/db/models/sql/query.py
@@ -57,6 +57,7 @@ def __init__(self, model, connection, where=WhereNode):
self.start_meta = None
self.select_fields = []
self.related_select_fields = []
+ self.dupe_avoidance = {}
# SQL-related attributes
self.select = []
@@ -165,6 +166,7 @@ def clone(self, klass=None, **kwargs):
obj.start_meta = self.start_meta
obj.select_fields = self.select_fields[:]
obj.related_select_fields = self.related_select_fields[:]
+ obj.dupe_avoidance = self.dupe_avoidance.copy()
obj.select = self.select[:]
obj.tables = self.tables[:]
obj.where = deepcopy(self.where)
@@ -830,8 +832,8 @@ def join(self, connection, always_create=False, exclusions=(),
if reuse and always_create and table in self.table_map:
# Convert the 'reuse' to case to be "exclude everything but the
- # reusable set for this table".
- exclusions = set(self.table_map[table]).difference(reuse)
+ # reusable set, minus exclusions, for this table".
+ exclusions = set(self.table_map[table]).difference(reuse).union(set(exclusions))
always_create = False
t_ident = (lhs_table, table, lhs_col, col)
if not always_create:
@@ -866,7 +868,8 @@ def join(self, connection, always_create=False, exclusions=(),
return alias
def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
- used=None, requested=None, restricted=None, nullable=None):
+ used=None, requested=None, restricted=None, nullable=None,
+ dupe_set=None):
"""
Fill in the information needed for a select_related query. The current
depth is measured as the number of connections away from the root model
@@ -876,6 +879,7 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
if not restricted and self.max_depth and cur_depth > self.max_depth:
# We've recursed far enough; bail out.
return
+
if not opts:
opts = self.get_meta()
root_alias = self.get_initial_alias()
@@ -883,6 +887,10 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
self.related_select_fields = []
if not used:
used = set()
+ if dupe_set is None:
+ dupe_set = set()
+ orig_dupe_set = dupe_set
+ orig_used = used
# Setup for the case when only particular related fields should be
# included in the related selection.
@@ -897,6 +905,8 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
if (not f.rel or (restricted and f.name not in requested) or
(not restricted and f.null) or f.rel.parent_link):
continue
+ dupe_set = orig_dupe_set.copy()
+ used = orig_used.copy()
table = f.rel.to._meta.db_table
if nullable or f.null:
promote = True
@@ -907,12 +917,26 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
alias = root_alias
for int_model in opts.get_base_chain(model):
lhs_col = int_opts.parents[int_model].column
+ dedupe = lhs_col in opts.duplicate_targets
+ if dedupe:
+ used.update(self.dupe_avoidance.get(id(opts), lhs_col),
+ ())
+ dupe_set.add((opts, lhs_col))
int_opts = int_model._meta
alias = self.join((alias, int_opts.db_table, lhs_col,
int_opts.pk.column), exclusions=used,
promote=promote)
+ for (dupe_opts, dupe_col) in dupe_set:
+ self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
else:
alias = root_alias
+
+ dedupe = f.column in opts.duplicate_targets
+ if dupe_set or dedupe:
+ used.update(self.dupe_avoidance.get((id(opts), f.column), ()))
+ if dedupe:
+ dupe_set.add((opts, f.column))
+
alias = self.join((alias, table, f.column,
f.rel.get_related_field().column), exclusions=used,
promote=promote)
@@ -928,8 +952,10 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
new_nullable = f.null
else:
new_nullable = None
+ for dupe_opts, dupe_col in dupe_set:
+ self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
- used, next, restricted, new_nullable)
+ used, next, restricted, new_nullable, dupe_set)
def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
can_reuse=None):
@@ -1128,7 +1154,9 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
(which gives the table we are joining to), 'alias' is the alias for the
table we are joining to. If dupe_multis is True, any many-to-many or
many-to-one joins will always create a new alias (necessary for
- disjunctive filters).
+ disjunctive filters). If can_reuse is not None, it's a list of aliases
+ that can be reused in these joins (nothing else can be reused in this
+ case).
Returns the final field involved in the join, the target database
column (used for any 'where' constraint), the final 'opts' value and the
@@ -1136,7 +1164,14 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
"""
joins = [alias]
last = [0]
+ dupe_set = set()
+ exclusions = set()
for pos, name in enumerate(names):
+ try:
+ exclusions.add(int_alias)
+ except NameError:
+ pass
+ exclusions.add(alias)
last.append(len(joins))
if name == 'pk':
name = opts.pk.name
@@ -1155,6 +1190,7 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
names = opts.get_all_field_names()
raise FieldError("Cannot resolve keyword %r into field. "
"Choices are: %s" % (name, ", ".join(names)))
+
if not allow_many and (m2m or not direct):
for alias in joins:
self.unref_alias(alias)
@@ -1164,12 +1200,27 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
alias_list = []
for int_model in opts.get_base_chain(model):
lhs_col = opts.parents[int_model].column
+ dedupe = lhs_col in opts.duplicate_targets
+ if dedupe:
+ exclusions.update(self.dupe_avoidance.get(
+ (id(opts), lhs_col), ()))
+ dupe_set.add((opts, lhs_col))
opts = int_model._meta
alias = self.join((alias, opts.db_table, lhs_col,
- opts.pk.column), exclusions=joins)
+ opts.pk.column), exclusions=exclusions)
joins.append(alias)
+ exclusions.add(alias)
+ for (dupe_opts, dupe_col) in dupe_set:
+ self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
cached_data = opts._join_cache.get(name)
orig_opts = opts
+ dupe_col = direct and field.column or field.field.column
+ dedupe = dupe_col in opts.duplicate_targets
+ if dupe_set or dedupe:
+ if dedupe:
+ dupe_set.add((opts, dupe_col))
+ exclusions.update(self.dupe_avoidance.get((id(opts), dupe_col),
+ ()))
if direct:
if m2m:
@@ -1191,9 +1242,11 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
target)
int_alias = self.join((alias, table1, from_col1, to_col1),
- dupe_multis, joins, nullable=True, reuse=can_reuse)
+ dupe_multis, exclusions, nullable=True,
+ reuse=can_reuse)
alias = self.join((int_alias, table2, from_col2, to_col2),
- dupe_multis, joins, nullable=True, reuse=can_reuse)
+ dupe_multis, exclusions, nullable=True,
+ reuse=can_reuse)
joins.extend([int_alias, alias])
elif field.rel:
# One-to-one or many-to-one field
@@ -1209,7 +1262,7 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
opts, target)
alias = self.join((alias, table, from_col, to_col),
- exclusions=joins, nullable=field.null)
+ exclusions=exclusions, nullable=field.null)
joins.append(alias)
else:
# Non-relation fields.
@@ -1237,9 +1290,11 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
target)
int_alias = self.join((alias, table1, from_col1, to_col1),
- dupe_multis, joins, nullable=True, reuse=can_reuse)
+ dupe_multis, exclusions, nullable=True,
+ reuse=can_reuse)
alias = self.join((int_alias, table2, from_col2, to_col2),
- dupe_multis, joins, nullable=True, reuse=can_reuse)
+ dupe_multis, exclusions, nullable=True,
+ reuse=can_reuse)
joins.extend([int_alias, alias])
else:
# One-to-many field (ForeignKey defined on the target model)
@@ -1257,14 +1312,34 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
opts, target)
alias = self.join((alias, table, from_col, to_col),
- dupe_multis, joins, nullable=True, reuse=can_reuse)
+ dupe_multis, exclusions, nullable=True,
+ reuse=can_reuse)
joins.append(alias)
+ for (dupe_opts, dupe_col) in dupe_set:
+ try:
+ self.update_dupe_avoidance(dupe_opts, dupe_col, int_alias)
+ except NameError:
+ self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
+
if pos != len(names) - 1:
raise FieldError("Join on field %r not permitted." % name)
return field, target, opts, joins, last
+ def update_dupe_avoidance(self, opts, col, alias):
+ """
+ For a column that is one of multiple pointing to the same table, update
+ the internal data structures to note that this alias shouldn't be used
+ for those other columns.
+ """
+ ident = id(opts)
+ for name in opts.duplicate_targets[col]:
+ try:
+ self.dupe_avoidance[ident, name].add(alias)
+ except KeyError:
+ self.dupe_avoidance[ident, name] = set([alias])
+
def split_exclude(self, filter_expr, prefix):
"""
When doing an exclude against any kind of N-to-many relation, we need
View
40 tests/regressiontests/many_to_one_regress/models.py
@@ -28,6 +28,24 @@ class Child(models.Model):
parent = models.ForeignKey(Parent)
+# Multiple paths to the same model (#7110, #7125)
+class Category(models.Model):
+ name = models.CharField(max_length=20)
+
+ def __unicode__(self):
+ return self.name
+
+class Record(models.Model):
+ category = models.ForeignKey(Category)
+
+class Relation(models.Model):
+ left = models.ForeignKey(Record, related_name='left_set')
+ right = models.ForeignKey(Record, related_name='right_set')
+
+ def __unicode__(self):
+ return u"%s - %s" % (self.left.category.name, self.right.category.name)
+
+
__test__ = {'API_TESTS':"""
>>> Third.objects.create(id='3', name='An example')
<Third: Third object>
@@ -73,4 +91,26 @@ class Child(models.Model):
...
ValueError: Cannot assign "<First: First object>": "Child.parent" must be a "Parent" instance.
+# Test of multiple ForeignKeys to the same model (bug #7125)
+
+>>> c1 = Category.objects.create(name='First')
+>>> c2 = Category.objects.create(name='Second')
+>>> c3 = Category.objects.create(name='Third')
+>>> r1 = Record.objects.create(category=c1)
+>>> r2 = Record.objects.create(category=c1)
+>>> r3 = Record.objects.create(category=c2)
+>>> r4 = Record.objects.create(category=c2)
+>>> r5 = Record.objects.create(category=c3)
+>>> r = Relation.objects.create(left=r1, right=r2)
+>>> r = Relation.objects.create(left=r3, right=r4)
+>>> r = Relation.objects.create(left=r1, right=r3)
+>>> r = Relation.objects.create(left=r5, right=r2)
+>>> r = Relation.objects.create(left=r3, right=r2)
+
+>>> Relation.objects.filter(left__category__name__in=['First'], right__category__name__in=['Second'])
+[<Relation: First - Second>]
+
+>>> Category.objects.filter(record__left_set__right__category__name='Second').order_by('name')
+[<Category: First>, <Category: Second>]
+
"""}
View
0  tests/regressiontests/select_related_regress/__init__.py
No changes.
View
60 tests/regressiontests/select_related_regress/models.py
@@ -0,0 +1,60 @@
+from django.db import models
+
+class Building(models.Model):
+ name = models.CharField(max_length=10)
+
+ def __unicode__(self):
+ return u"Building: %s" % self.name
+
+class Device(models.Model):
+ building = models.ForeignKey('Building')
+ name = models.CharField(max_length=10)
+
+ def __unicode__(self):
+ return u"device '%s' in building %s" % (self.name, self.building)
+
+class Port(models.Model):
+ device = models.ForeignKey('Device')
+ number = models.CharField(max_length=10)
+
+ def __unicode__(self):
+ return u"%s/%s" % (self.device.name, self.number)
+
+class Connection(models.Model):
+ start = models.ForeignKey(Port, related_name='connection_start',
+ unique=True)
+ end = models.ForeignKey(Port, related_name='connection_end', unique=True)
+
+ def __unicode__(self):
+ return u"%s to %s" % (self.start, self.end)
+
+__test__ = {'API_TESTS': """
+Regression test for bug #7110. When using select_related(), we must query the
+Device and Building tables using two different aliases (each) in order to
+differentiate the start and end Connection fields. The net result is that both
+the "connections = ..." queries here should give the same results.
+
+>>> b=Building.objects.create(name='101')
+>>> dev1=Device.objects.create(name="router", building=b)
+>>> dev2=Device.objects.create(name="switch", building=b)
+>>> dev3=Device.objects.create(name="server", building=b)
+>>> port1=Port.objects.create(number='4',device=dev1)
+>>> port2=Port.objects.create(number='7',device=dev2)
+>>> port3=Port.objects.create(number='1',device=dev3)
+>>> c1=Connection.objects.create(start=port1, end=port2)
+>>> c2=Connection.objects.create(start=port2, end=port3)
+
+>>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).order_by('id')
+>>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections]
+[(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')]
+
+>>> connections=Connection.objects.filter(start__device__building=b, end__device__building=b).select_related().order_by('id')
+>>> [(c.id, unicode(c.start), unicode(c.end)) for c in connections]
+[(1, u'router/4', u'switch/7'), (2, u'switch/7', u'server/1')]
+
+# This final query should only join seven tables (port, device and building
+# twice each, plus connection once).
+>>> connections.query.count_active_tables()
+7
+
+"""}
Please sign in to comment.
Something went wrong with that request. Please try again.