Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Fixed handling of multiple fields in a model pointing to the same rel…

…ated model.

Thanks to ElliotM, mk and oyvind for some excellent test cases for this. Fixed #7110, #7125.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@7778 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit bb2182453b49157fb6fba4de6d3c53a09f73d74b 1 parent d800c0b
Malcolm Tredinnick authored June 29, 2008
11  django/db/models/fields/related.py
@@ -692,6 +692,11 @@ def flatten_data(self, follow, obj=None):
@@ -826,6 +831,12 @@ def contribute_to_class(self, cls, name):
19  django/db/models/options.py
@@ -44,6 +44,7 @@ def __init__(self, meta, app_label=None):
44 44
         self.one_to_one_field = None
45 45
         self.abstract = False
46 46
         self.parents = SortedDict()
  47
+        self.duplicate_targets = {}
47 48
 
48 49
     def contribute_to_class(self, cls, name):
49 50
         from django.db import connection
@@ -115,6 +116,24 @@ def _prepare(self, model):
115 116
                         auto_created=True)
116 117
                 model.add_to_class('id', auto)
117 118
 
  119
+        # Determine any sets of fields that are pointing to the same targets
  120
+        # (e.g. two ForeignKeys to the same remote model). The query
  121
+        # construction code needs to know this. At the end of this,
  122
+        # self.duplicate_targets will map each duplicate field column to the
  123
+        # columns it duplicates.
  124
+        collections = {}
  125
+        for column, target in self.duplicate_targets.iteritems():
  126
+            try:
  127
+                collections[target].add(column)
  128
+            except KeyError:
  129
+                collections[target] = set([column])
  130
+        self.duplicate_targets = {}
  131
+        for elt in collections.itervalues():
  132
+            if len(elt) == 1:
  133
+                continue
  134
+            for column in elt:
  135
+                self.duplicate_targets[column] = elt.difference(set([column]))
  136
+
118 137
     def add_field(self, field):
119 138
         # Insert the given field in the order in which it was created, using
120 139
         # the "creation_counter" attribute of the field.
99  django/db/models/sql/query.py
@@ -57,6 +57,7 @@ def __init__(self, model, connection, where=WhereNode):
57 57
         self.start_meta = None
58 58
         self.select_fields = []
59 59
         self.related_select_fields = []
  60
+        self.dupe_avoidance = {}
60 61
 
61 62
         # SQL-related attributes
62 63
         self.select = []
@@ -165,6 +166,7 @@ def clone(self, klass=None, **kwargs):
165 166
         obj.start_meta = self.start_meta
166 167
         obj.select_fields = self.select_fields[:]
167 168
         obj.related_select_fields = self.related_select_fields[:]
  169
+        obj.dupe_avoidance = self.dupe_avoidance.copy()
168 170
         obj.select = self.select[:]
169 171
         obj.tables = self.tables[:]
170 172
         obj.where = deepcopy(self.where)
@@ -830,8 +832,8 @@ def join(self, connection, always_create=False, exclusions=(),
830 832
 
831 833
         if reuse and always_create and table in self.table_map:
832 834
             # Convert the 'reuse' to case to be "exclude everything but the
833  
-            # reusable set for this table".
834  
-            exclusions = set(self.table_map[table]).difference(reuse)
  835
+            # reusable set, minus exclusions, for this table".
  836
+            exclusions = set(self.table_map[table]).difference(reuse).union(set(exclusions))
835 837
             always_create = False
836 838
         t_ident = (lhs_table, table, lhs_col, col)
837 839
         if not always_create:
@@ -866,7 +868,8 @@ def join(self, connection, always_create=False, exclusions=(),
866 868
         return alias
867 869
 
868 870
     def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
869  
-            used=None, requested=None, restricted=None, nullable=None):
  871
+            used=None, requested=None, restricted=None, nullable=None,
  872
+            dupe_set=None):
870 873
         """
871 874
         Fill in the information needed for a select_related query. The current
872 875
         depth is measured as the number of connections away from the root model
@@ -876,6 +879,7 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
876 879
         if not restricted and self.max_depth and cur_depth > self.max_depth:
877 880
             # We've recursed far enough; bail out.
878 881
             return
  882
+
879 883
         if not opts:
880 884
             opts = self.get_meta()
881 885
             root_alias = self.get_initial_alias()
@@ -883,6 +887,10 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
883 887
             self.related_select_fields = []
884 888
         if not used:
885 889
             used = set()
  890
+        if dupe_set is None:
  891
+            dupe_set = set()
  892
+        orig_dupe_set = dupe_set
  893
+        orig_used = used
886 894
 
887 895
         # Setup for the case when only particular related fields should be
888 896
         # included in the related selection.
@@ -897,6 +905,8 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
897 905
             if (not f.rel or (restricted and f.name not in requested) or
898 906
                     (not restricted and f.null) or f.rel.parent_link):
899 907
                 continue
  908
+            dupe_set = orig_dupe_set.copy()
  909
+            used = orig_used.copy()
900 910
             table = f.rel.to._meta.db_table
901 911
             if nullable or f.null:
902 912
                 promote = True
@@ -907,12 +917,26 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
907 917
                 alias = root_alias
908 918
                 for int_model in opts.get_base_chain(model):
909 919
                     lhs_col = int_opts.parents[int_model].column
  920
+                    dedupe = lhs_col in opts.duplicate_targets
  921
+                    if dedupe:
  922
+                        used.update(self.dupe_avoidance.get(id(opts), lhs_col),
  923
+                                ())
  924
+                        dupe_set.add((opts, lhs_col))
910 925
                     int_opts = int_model._meta
911 926
                     alias = self.join((alias, int_opts.db_table, lhs_col,
912 927
                             int_opts.pk.column), exclusions=used,
913 928
                             promote=promote)
  929
+                    for (dupe_opts, dupe_col) in dupe_set:
  930
+                        self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
914 931
             else:
915 932
                 alias = root_alias
  933
+
  934
+            dedupe = f.column in opts.duplicate_targets
  935
+            if dupe_set or dedupe:
  936
+                used.update(self.dupe_avoidance.get((id(opts), f.column), ()))
  937
+                if dedupe:
  938
+                    dupe_set.add((opts, f.column))
  939
+
916 940
             alias = self.join((alias, table, f.column,
917 941
                     f.rel.get_related_field().column), exclusions=used,
918 942
                     promote=promote)
@@ -928,8 +952,10 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
928 952
                 new_nullable = f.null
929 953
             else:
930 954
                 new_nullable = None
  955
+            for dupe_opts, dupe_col in dupe_set:
  956
+                self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
931 957
             self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1,
932  
-                    used, next, restricted, new_nullable)
  958
+                    used, next, restricted, new_nullable, dupe_set)
933 959
 
934 960
     def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
935 961
             can_reuse=None):
@@ -1128,7 +1154,9 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
1128 1154
         (which gives the table we are joining to), 'alias' is the alias for the
1129 1155
         table we are joining to. If dupe_multis is True, any many-to-many or
1130 1156
         many-to-one joins will always create a new alias (necessary for
1131  
-        disjunctive filters).
  1157
+        disjunctive filters). If can_reuse is not None, it's a list of aliases
  1158
+        that can be reused in these joins (nothing else can be reused in this
  1159
+        case).
1132 1160
 
1133 1161
         Returns the final field involved in the join, the target database
1134 1162
         column (used for any 'where' constraint), the final 'opts' value and the
@@ -1136,7 +1164,14 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
1136 1164
         """
1137 1165
         joins = [alias]
1138 1166
         last = [0]
  1167
+        dupe_set = set()
  1168
+        exclusions = set()
1139 1169
         for pos, name in enumerate(names):
  1170
+            try:
  1171
+                exclusions.add(int_alias)
  1172
+            except NameError:
  1173
+                pass
  1174
+            exclusions.add(alias)
1140 1175
             last.append(len(joins))
1141 1176
             if name == 'pk':
1142 1177
                 name = opts.pk.name
@@ -1155,6 +1190,7 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
1155 1190
                     names = opts.get_all_field_names()
1156 1191
                     raise FieldError("Cannot resolve keyword %r into field. "
1157 1192
                             "Choices are: %s" % (name, ", ".join(names)))
  1193
+
1158 1194
             if not allow_many and (m2m or not direct):
1159 1195
                 for alias in joins:
1160 1196
                     self.unref_alias(alias)
@@ -1164,12 +1200,27 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
1164 1200
                 alias_list = []
1165 1201
                 for int_model in opts.get_base_chain(model):
1166 1202
                     lhs_col = opts.parents[int_model].column
  1203
+                    dedupe = lhs_col in opts.duplicate_targets
  1204
+                    if dedupe:
  1205
+                        exclusions.update(self.dupe_avoidance.get(
  1206
+                                (id(opts), lhs_col), ()))
  1207
+                        dupe_set.add((opts, lhs_col))
1167 1208
                     opts = int_model._meta
1168 1209
                     alias = self.join((alias, opts.db_table, lhs_col,
1169  
-                            opts.pk.column), exclusions=joins)
  1210
+                            opts.pk.column), exclusions=exclusions)
1170 1211
                     joins.append(alias)
  1212
+                    exclusions.add(alias)
  1213
+                    for (dupe_opts, dupe_col) in dupe_set:
  1214
+                        self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
1171 1215
             cached_data = opts._join_cache.get(name)
1172 1216
             orig_opts = opts
  1217
+            dupe_col = direct and field.column or field.field.column
  1218
+            dedupe = dupe_col in opts.duplicate_targets
  1219
+            if dupe_set or dedupe:
  1220
+                if dedupe:
  1221
+                    dupe_set.add((opts, dupe_col))
  1222
+                exclusions.update(self.dupe_avoidance.get((id(opts), dupe_col),
  1223
+                        ()))
1173 1224
 
1174 1225
             if direct:
1175 1226
                 if m2m:
@@ -1191,9 +1242,11 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
1191 1242
                                 target)
1192 1243
 
1193 1244
                     int_alias = self.join((alias, table1, from_col1, to_col1),
1194  
-                            dupe_multis, joins, nullable=True, reuse=can_reuse)
  1245
+                            dupe_multis, exclusions, nullable=True,
  1246
+                            reuse=can_reuse)
1195 1247
                     alias = self.join((int_alias, table2, from_col2, to_col2),
1196  
-                            dupe_multis, joins, nullable=True, reuse=can_reuse)
  1248
+                            dupe_multis, exclusions, nullable=True,
  1249
+                            reuse=can_reuse)
1197 1250
                     joins.extend([int_alias, alias])
1198 1251
                 elif field.rel:
1199 1252
                     # One-to-one or many-to-one field
@@ -1209,7 +1262,7 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
1209 1262
                                 opts, target)
1210 1263
 
1211 1264
                     alias = self.join((alias, table, from_col, to_col),
1212  
-                            exclusions=joins, nullable=field.null)
  1265
+                            exclusions=exclusions, nullable=field.null)
1213 1266
                     joins.append(alias)
1214 1267
                 else:
1215 1268
                     # Non-relation fields.
@@ -1237,9 +1290,11 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
1237 1290
                                 target)
1238 1291
 
1239 1292
                     int_alias = self.join((alias, table1, from_col1, to_col1),
1240  
-                            dupe_multis, joins, nullable=True, reuse=can_reuse)
  1293
+                            dupe_multis, exclusions, nullable=True,
  1294
+                            reuse=can_reuse)
1241 1295
                     alias = self.join((int_alias, table2, from_col2, to_col2),
1242  
-                            dupe_multis, joins, nullable=True, reuse=can_reuse)
  1296
+                            dupe_multis, exclusions, nullable=True,
  1297
+                            reuse=can_reuse)
1243 1298
                     joins.extend([int_alias, alias])
1244 1299
                 else:
1245 1300
                     # One-to-many field (ForeignKey defined on the target model)
@@ -1257,14 +1312,34 @@ def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
1257 1312
                                 opts, target)
1258 1313
 
1259 1314
                     alias = self.join((alias, table, from_col, to_col),
1260  
-                            dupe_multis, joins, nullable=True, reuse=can_reuse)
  1315
+                            dupe_multis, exclusions, nullable=True,
  1316
+                            reuse=can_reuse)
1261 1317
                     joins.append(alias)
1262 1318
 
  1319
+            for (dupe_opts, dupe_col) in dupe_set:
  1320
+                try:
  1321
+                    self.update_dupe_avoidance(dupe_opts, dupe_col, int_alias)
  1322
+                except NameError:
  1323
+                    self.update_dupe_avoidance(dupe_opts, dupe_col, alias)
  1324
+
1263 1325
         if pos != len(names) - 1:
1264 1326
             raise FieldError("Join on field %r not permitted." % name)
1265 1327
 
1266 1328
         return field, target, opts, joins, last
1267 1329
 
  1330
+    def update_dupe_avoidance(self, opts, col, alias):
  1331
+        """
  1332
+        For a column that is one of multiple pointing to the same table, update
  1333
+        the internal data structures to note that this alias shouldn't be used
  1334
+        for those other columns.
  1335
+        """
  1336
+        ident = id(opts)
  1337
+        for name in opts.duplicate_targets[col]:
  1338
+            try:
  1339
+                self.dupe_avoidance[ident, name].add(alias)
  1340
+            except KeyError:
  1341
+                self.dupe_avoidance[ident, name] = set([alias])
  1342
+
1268 1343
     def split_exclude(self, filter_expr, prefix):
1269 1344
         """
1270 1345
         When doing an exclude against any kind of N-to-many relation, we need
40  tests/regressiontests/many_to_one_regress/models.py
@@ -28,6 +28,24 @@ class Child(models.Model):
28 28
     parent = models.ForeignKey(Parent)
29 29
 
30 30
 
  31
+# Multiple paths to the same model (#7110, #7125)
  32
+class Category(models.Model):
  33
+    name = models.CharField(max_length=20)
  34
+
  35
+    def __unicode__(self):
  36
+        return self.name
  37
+
  38
+class Record(models.Model):
  39
+    category = models.ForeignKey(Category)
  40
+
  41
+class Relation(models.Model):
  42
+    left = models.ForeignKey(Record, related_name='left_set')
  43
+    right = models.ForeignKey(Record, related_name='right_set')
  44
+
  45
+    def __unicode__(self):
  46
+        return u"%s - %s" % (self.left.category.name, self.right.category.name)
  47
+
  48
+
31 49
 __test__ = {'API_TESTS':"""
32 50
 >>> Third.objects.create(id='3', name='An example')
33 51
 <Third: Third object>
@@ -73,4 +91,26 @@ class Child(models.Model):
73 91
     ...
74 92
 ValueError: Cannot assign "<First: First object>": "Child.parent" must be a "Parent" instance.
75 93
 
  94
+# Test of multiple ForeignKeys to the same model (bug #7125)
  95
+
  96
+>>> c1 = Category.objects.create(name='First')
  97
+>>> c2 = Category.objects.create(name='Second')
  98
+>>> c3 = Category.objects.create(name='Third')
  99
+>>> r1 = Record.objects.create(category=c1)
  100
+>>> r2 = Record.objects.create(category=c1)
  101
+>>> r3 = Record.objects.create(category=c2)
  102
+>>> r4 = Record.objects.create(category=c2)
  103
+>>> r5 = Record.objects.create(category=c3)
  104
+>>> r = Relation.objects.create(left=r1, right=r2)
  105
+>>> r = Relation.objects.create(left=r3, right=r4)
  106
+>>> r = Relation.objects.create(left=r1, right=r3)
  107
+>>> r = Relation.objects.create(left=r5, right=r2)
  108
+>>> r = Relation.objects.create(left=r3, right=r2)
  109
+
  110
+>>> Relation.objects.filter(left__category__name__in=['First'], right__category__name__in=['Second'])
  111
+[<Relation: First - Second>]
  112
+
  113
+>>> Category.objects.filter(record__left_set__right__category__name='Second').order_by('name')
  114
+[<Category: First>, <Category: Second>]
  115
+
76 116
 """}
0  tests/regressiontests/select_related_regress/__init__.py
No changes.
60  tests/regressiontests/select_related_regress/models.py
... ...
@@ -0,0 +1,60 @@

0 notes on commit bb21824

Please sign in to comment.
Something went wrong with that request. Please try again.