Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Fixed #13003 -- Ensured that ._state.db is set correctly for select_r…

…elated() queries. Thanks to Alex Gaynor for the report.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@12701 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 18983f0ee73c9b3708b10b90e0a37bd17a8a1729 1 parent 3508a86
Russell Keith-Magee authored March 07, 2010
24  django/db/models/query.py
@@ -267,7 +267,7 @@ def iterator(self):
267 267
         for row in compiler.results_iter():
268 268
             if fill_cache:
269 269
                 obj, _ = get_cached_row(self.model, row,
270  
-                            index_start, max_depth,
  270
+                            index_start, using=self.db, max_depth=max_depth,
271 271
                             requested=requested, offset=len(aggregate_select),
272 272
                             only_load=only_load)
273 273
             else:
@@ -279,6 +279,9 @@ def iterator(self):
279 279
                     # Omit aggregates in object creation.
280 280
                     obj = self.model(*row[index_start:aggregate_start])
281 281
 
  282
+                # Store the source database of the object
  283
+                obj._state.db = self.db
  284
+
282 285
             for i, k in enumerate(extra_select):
283 286
                 setattr(obj, k, row[i])
284 287
 
@@ -286,9 +289,6 @@ def iterator(self):
286 289
             for i, aggregate in enumerate(aggregate_select):
287 290
                 setattr(obj, aggregate, row[i+aggregate_start])
288 291
 
289  
-            # Store the source database of the object
290  
-            obj._state.db = self.db
291  
-
292 292
             yield obj
293 293
 
294 294
     def aggregate(self, *args, **kwargs):
@@ -1112,7 +1112,7 @@ def update(self, **kwargs):
1112 1112
     value_annotation = False
1113 1113
 
1114 1114
 
1115  
-def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
  1115
+def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
1116 1116
                    requested=None, offset=0, only_load=None):
1117 1117
     """
1118 1118
     Helper function that recursively returns an object with the specified
@@ -1126,6 +1126,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
1126 1126
      * row - the row of data returned by the database cursor
1127 1127
      * index_start - the index of the row at which data for this
1128 1128
        object is known to start
  1129
+     * using - the database alias on which the query is being executed.
1129 1130
      * max_depth - the maximum depth to which a select_related()
1130 1131
        relationship should be explored.
1131 1132
      * cur_depth - the current depth in the select_related() tree.
@@ -1170,6 +1171,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
1170 1171
             obj = klass(**dict(zip(init_list, fields)))
1171 1172
         else:
1172 1173
             obj = klass(*fields)
  1174
+
1173 1175
     else:
1174 1176
         # Load all fields on klass
1175 1177
         field_count = len(klass._meta.fields)
@@ -1182,6 +1184,10 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
1182 1184
         else:
1183 1185
             obj = klass(*fields)
1184 1186
 
  1187
+    # If an object was retrieved, set the database state.
  1188
+    if obj:
  1189
+        obj._state.db = using
  1190
+
1185 1191
     index_end = index_start + field_count + offset
1186 1192
     # Iterate over each related object, populating any
1187 1193
     # select_related() fields
@@ -1193,8 +1199,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
1193 1199
         else:
1194 1200
             next = None
1195 1201
         # Recursively retrieve the data for the related object
1196  
-        cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
1197  
-                cur_depth+1, next)
  1202
+        cached_row = get_cached_row(f.rel.to, row, index_end, using,
  1203
+                max_depth, cur_depth+1, next)
1198 1204
         # If the recursive descent found an object, populate the
1199 1205
         # descriptor caches relevant to the object
1200 1206
         if cached_row:
@@ -1222,8 +1228,8 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
1222 1228
                 continue
1223 1229
             next = requested[f.related_query_name()]
1224 1230
             # Recursively retrieve the data for the related object
1225  
-            cached_row = get_cached_row(model, row, index_end, max_depth,
1226  
-                cur_depth+1, next)
  1231
+            cached_row = get_cached_row(model, row, index_end, using,
  1232
+                max_depth, cur_depth+1, next)
1227 1233
             # If the recursive descent found an object, populate the
1228 1234
             # descriptor caches relevant to the object
1229 1235
             if cached_row:
14  tests/regressiontests/multiple_database/tests.py
@@ -641,6 +641,20 @@ def test_raw(self):
641 641
         val = Book.objects.raw('SELECT id FROM "multiple_database_book"').using('other')
642 642
         self.assertEqual(map(lambda o: o.pk, val), [dive.pk])
643 643
 
  644
+    def test_select_related(self):
  645
+        "Database assignment is retained if an object is retrieved with select_related()"
  646
+        # Create a book and author on the other database
  647
+        mark = Person.objects.using('other').create(name="Mark Pilgrim")
  648
+        dive = Book.objects.using('other').create(title="Dive into Python",
  649
+                                                  published=datetime.date(2009, 5, 4),
  650
+                                                  editor=mark)
  651
+
  652
+        # Retrieve the Person using select_related()
  653
+        book = Book.objects.using('other').select_related('editor').get(title="Dive into Python")
  654
+
  655
+        # The editor instance should have a db state
  656
+        self.assertEqual(book.editor._state.db, 'other')
  657
+
644 658
 class TestRouter(object):
645 659
     # A test router. The behaviour is vaguely master/slave, but the
646 660
     # databases aren't assumed to propagate changes.

0 notes on commit 18983f0

Please sign in to comment.
Something went wrong with that request. Please try again.