Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Fixed #13781 -- Improved select_related in inheritance situations

The select_related code got confused when it needed to travel a
reverse relation to a model which had different parent than the
originally travelled relation.

Thanks to Trac aliases shauncutts for report and ungenio for original
patch (committed patch is somewhat modified version of that).
  • Loading branch information...
commit f51e409a5fb34020e170494320a421503689aea0 1 parent 92d7f54
Anssi Kääriäinen authored November 09, 2012
6  django/db/models/options.py
@@ -75,6 +75,7 @@ def contribute_to_class(self, cls, name):
75 75
         from django.db.backends.util import truncate_name
76 76
 
77 77
         cls._meta = self
  78
+        self.model = cls
78 79
         self.installed = re.sub('\.models$', '', cls.__module__) in settings.INSTALLED_APPS
79 80
         # First, construct the default values for these options.
80 81
         self.object_name = cls.__name__
@@ -464,7 +465,7 @@ def get_base_chain(self, model):
464 465
         a granparent or even more distant relation.
465 466
         """
466 467
         if not self.parents:
467  
-            return
  468
+            return None
468 469
         if model in self.parents:
469 470
             return [model]
470 471
         for parent in self.parents:
@@ -472,8 +473,7 @@ def get_base_chain(self, model):
472 473
             if res:
473 474
                 res.insert(0, parent)
474 475
                 return res
475  
-        raise TypeError('%r is not an ancestor of this model'
476  
-                % model._meta.module_name)
  476
+        return None
477 477
 
478 478
     def get_parent_list(self):
479 479
         """
83  django/db/models/query.py
@@ -1300,7 +1300,7 @@ def values_list(self, *fields, **kwargs):
1300 1300
     value_annotation = False
1301 1301
 
1302 1302
 def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
1303  
-                   only_load=None, local_only=False):
  1303
+                   only_load=None, from_parent=None):
1304 1304
     """
1305 1305
     Helper function that recursively returns an information for a klass, to be
1306 1306
     used in get_cached_row.  It exists just to compute this information only
@@ -1320,8 +1320,10 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
1320 1320
      * only_load - if the query has had only() or defer() applied,
1321 1321
        this is the list of field names that will be returned. If None,
1322 1322
        the full field list for `klass` can be assumed.
1323  
-     * local_only - Only populate local fields. This is used when
1324  
-       following reverse select-related relations
  1323
+     * from_parent - the parent model used to get to this model
  1324
+
  1325
+    Note that when travelling from parent to child, we will only load child
  1326
+    fields which aren't in the parent.
1325 1327
     """
1326 1328
     if max_depth and requested is None and cur_depth > max_depth:
1327 1329
         # We've recursed deeply enough; stop now.
@@ -1347,7 +1349,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
1347 1349
         for field, model in klass._meta.get_fields_with_model():
1348 1350
             if field.name not in load_fields:
1349 1351
                 skip.add(field.attname)
1350  
-            elif local_only and model is not None:
  1352
+            elif from_parent and issubclass(from_parent, model.__class__):
  1353
+                # Avoid loading fields already loaded for parent model for
  1354
+                # child models.
1351 1355
                 continue
1352 1356
             else:
1353 1357
                 init_list.append(field.attname)
@@ -1361,16 +1365,22 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
1361 1365
     else:
1362 1366
         # Load all fields on klass
1363 1367
 
1364  
-        # We trying to not populate field_names variable for perfomance reason.
1365  
-        # If field_names variable is set, it is used to instantiate desired fields,
1366  
-        # by passing **dict(zip(field_names, fields)) as kwargs to Model.__init__ method.
1367  
-        # But kwargs version of Model.__init__ is slower, so we should avoid using
1368  
-        # it when it is not really neccesary.
1369  
-        if local_only and len(klass._meta.local_fields) != len(klass._meta.fields):
1370  
-            field_count = len(klass._meta.local_fields)
1371  
-            field_names = [f.attname for f in klass._meta.local_fields]
1372  
-        else:
1373  
-            field_count = len(klass._meta.fields)
  1368
+        field_count = len(klass._meta.fields)
  1369
+        # Check if we need to skip some parent fields.
  1370
+        if from_parent and len(klass._meta.local_fields) != len(klass._meta.fields):
  1371
+            # Only load those fields which haven't been already loaded into
  1372
+            # 'from_parent'.
  1373
+            non_seen_models = [p for p in klass._meta.get_parent_list()
  1374
+                               if not issubclass(from_parent, p)]
  1375
+            # Load local fields, too...
  1376
+            non_seen_models.append(klass)
  1377
+            field_names = [f.attname for f in klass._meta.fields
  1378
+                           if f.model in non_seen_models]
  1379
+            field_count = len(field_names)
  1380
+        # Try to avoid populating field_names variable for perfomance reasons.
  1381
+        # If field_names variable is set, we use **kwargs based model init
  1382
+        # which is slower than normal init.
  1383
+        if field_count == len(klass._meta.fields):
1374 1384
             field_names = ()
1375 1385
 
1376 1386
     restricted = requested is not None
@@ -1392,8 +1402,9 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
1392 1402
             if o.field.unique and select_related_descend(o.field, restricted, requested,
1393 1403
                                                          only_load.get(o.model), reverse=True):
1394 1404
                 next = requested[o.field.related_query_name()]
  1405
+                parent = klass if issubclass(o.model, klass) else None
1395 1406
                 klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
1396  
-                                            requested=next, only_load=only_load, local_only=True)
  1407
+                                            requested=next, only_load=only_load, from_parent=parent)
1397 1408
                 reverse_related_fields.append((o.field, klass_info))
1398 1409
     if field_names:
1399 1410
         pk_idx = field_names.index(klass._meta.pk.attname)
@@ -1403,7 +1414,8 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
1403 1414
     return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx
1404 1415
 
1405 1416
 
1406  
-def get_cached_row(row, index_start, using,  klass_info, offset=0):
  1417
+def get_cached_row(row, index_start, using,  klass_info, offset=0,
  1418
+                   parent_data=()):
1407 1419
     """
1408 1420
     Helper function that recursively returns an object with the specified
1409 1421
     related attributes already populated.
@@ -1418,13 +1430,16 @@ def get_cached_row(row, index_start, using,  klass_info, offset=0):
1418 1430
          * offset - the number of additional fields that are known to
1419 1431
            exist in row for `klass`. This usually means the number of
1420 1432
            annotated results on `klass`.
1421  
-        * using - the database alias on which the query is being executed.
  1433
+         * using - the database alias on which the query is being executed.
1422 1434
          * klass_info - result of the get_klass_info function
  1435
+         * parent_data - parent model data in format (field, value). Used
  1436
+           to populate the non-local fields of child models.
1423 1437
     """
1424 1438
     if klass_info is None:
1425 1439
         return None
1426 1440
     klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx = klass_info
1427 1441
 
  1442
+
1428 1443
     fields = row[index_start : index_start + field_count]
1429 1444
     # If the pk column is None (or the Oracle equivalent ''), then the related
1430 1445
     # object must be non-existent - set the relation to None.
@@ -1434,7 +1449,6 @@ def get_cached_row(row, index_start, using,  klass_info, offset=0):
1434 1449
         obj = klass(**dict(zip(field_names, fields)))
1435 1450
     else:
1436 1451
         obj = klass(*fields)
1437  
-
1438 1452
     # If an object was retrieved, set the database state.
1439 1453
     if obj:
1440 1454
         obj._state.db = using
@@ -1464,34 +1478,35 @@ def get_cached_row(row, index_start, using,  klass_info, offset=0):
1464 1478
     # Only handle the restricted case - i.e., don't do a depth
1465 1479
     # descent into reverse relations unless explicitly requested
1466 1480
     for f, klass_info in reverse_related_fields:
  1481
+        # Transfer data from this object to childs.
  1482
+        parent_data = []
  1483
+        for rel_field, rel_model in klass_info[0]._meta.get_fields_with_model():
  1484
+            if rel_model is not None and isinstance(obj, rel_model):
  1485
+                parent_data.append((rel_field, getattr(obj, rel_field.attname)))
1467 1486
         # Recursively retrieve the data for the related object
1468  
-        cached_row = get_cached_row(row, index_end, using, klass_info)
  1487
+        cached_row = get_cached_row(row, index_end, using, klass_info,
  1488
+                                   parent_data=parent_data)
1469 1489
         # If the recursive descent found an object, populate the
1470 1490
         # descriptor caches relevant to the object
1471 1491
         if cached_row:
1472 1492
             rel_obj, index_end = cached_row
1473 1493
             if obj is not None:
1474  
-                # If the field is unique, populate the
1475  
-                # reverse descriptor cache
  1494
+                # populate the reverse descriptor cache
1476 1495
                 setattr(obj, f.related.get_cache_name(), rel_obj)
1477 1496
             if rel_obj is not None:
1478 1497
                 # If the related object exists, populate
1479 1498
                 # the descriptor cache.
1480 1499
                 setattr(rel_obj, f.get_cache_name(), obj)
1481  
-                # Now populate all the non-local field values
1482  
-                # on the related object
1483  
-                for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
1484  
-                    if rel_model is not None:
  1500
+                # Populate related object caches using parent data.
  1501
+                for rel_field, _ in parent_data:
  1502
+                    if rel_field.rel:
1485 1503
                         setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
1486  
-                        # populate the field cache for any related object
1487  
-                        # that has already been retrieved
1488  
-                        if rel_field.rel:
1489  
-                            try:
1490  
-                                cached_obj = getattr(obj, rel_field.get_cache_name())
1491  
-                                setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
1492  
-                            except AttributeError:
1493  
-                                # Related object hasn't been cached yet
1494  
-                                pass
  1504
+                        try:
  1505
+                            cached_obj = getattr(obj, rel_field.get_cache_name())
  1506
+                            setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
  1507
+                        except AttributeError:
  1508
+                            # Related object hasn't been cached yet
  1509
+                            pass
1495 1510
     return obj, index_end
1496 1511
 
1497 1512
 
13  django/db/models/sql/compiler.py
@@ -240,7 +240,7 @@ def get_columns(self, with_aliases=False):
240 240
         return result
241 241
 
242 242
     def get_default_columns(self, with_aliases=False, col_aliases=None,
243  
-            start_alias=None, opts=None, as_pairs=False, local_only=False):
  243
+            start_alias=None, opts=None, as_pairs=False, from_parent=None):
244 244
         """
245 245
         Computes the default columns for selecting every field in the base
246 246
         model. Will sometimes be called to pull in related models (e.g. via
@@ -265,7 +265,8 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
265 265
         if start_alias:
266 266
             seen = {None: start_alias}
267 267
         for field, model in opts.get_fields_with_model():
268  
-            if local_only and model is not None:
  268
+            if from_parent and model is not None and issubclass(from_parent, model):
  269
+                # Avoid loading data for already loaded parents.
269 270
                 continue
270 271
             if start_alias:
271 272
                 try:
@@ -686,11 +687,13 @@ def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
686 687
                     (alias, table, f.rel.get_related_field().column, f.column),
687 688
                     promote=True
688 689
                 )
  690
+                from_parent = (opts.model if issubclass(model, opts.model)
  691
+                               else None)
689 692
                 columns, aliases = self.get_default_columns(start_alias=alias,
690  
-                    opts=model._meta, as_pairs=True, local_only=True)
  693
+                    opts=model._meta, as_pairs=True, from_parent=from_parent)
691 694
                 self.query.related_select_cols.extend(
692  
-                    SelectInfo(col, field) for col, field in zip(columns, model._meta.fields))
693  
-
  695
+                    SelectInfo(col, field) for col, field
  696
+                    in zip(columns, model._meta.fields))
694 697
                 next = requested.get(f.related_query_name(), {})
695 698
                 # Use True here because we are looking at the _reverse_ side of
696 699
                 # the relation, which is always nullable.
42  tests/regressiontests/select_related_onetoone/models.py
@@ -51,6 +51,7 @@ def __str__(self):
@@ -58,3 +59,44 @@ class Image(models.Model):
102  tests/regressiontests/select_related_onetoone/tests.py
... ...
@@ -1,9 +1,11 @@
@@ -21,6 +23,14 @@ def setUp(self):
@@ -108,3 +118,93 @@ def test_nullable_missing_reverse(self):

0 notes on commit f51e409

Please sign in to comment.
Something went wrong with that request. Please try again.