Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fixed #20946 -- model inheritance + m2m failure

Cleaned up the internal implementation of m2m fields by removing
related.py _get_fk_val(). The _get_fk_val() was doing the wrong thing
if asked for the foreign key value on foreign key to parent model's
primary key when child model had different primary key field.
  • Loading branch information...
commit b065aeb17f9daf395e22d4d5f9f49c0e2c7f4522 1 parent 83e434a
@akaariai akaariai authored
View
38 django/db/models/fields/related.py
@@ -503,8 +503,6 @@ def __init__(self, model=None, query_field_name=None, instance=None, symmetrical
self.through = through
self.prefetch_cache_name = prefetch_cache_name
self.related_val = source_field.get_foreign_related_value(instance)
- # Used for single column related auto created models
- self._fk_val = self.related_val[0]
if None in self.related_val:
raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' %
@@ -517,18 +515,6 @@ def __init__(self, model=None, query_field_name=None, instance=None, symmetrical
"a many-to-many relationship can be used." %
instance.__class__.__name__)
- def _get_fk_val(self, obj, field_name):
- """
- Returns the correct value for this relationship's foreign key. This
- might be something else than pk value when to_field is used.
- """
- fk = self.through._meta.get_field(field_name)
- if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
- attname = fk.rel.get_related_field().get_attname()
- return fk.get_prep_lookup('exact', getattr(obj, attname))
- else:
- return obj.pk
-
def get_queryset(self):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
@@ -626,11 +612,12 @@ def _add_items(self, source_field_name, target_field_name, *objs):
if not router.allow_relation(obj, self.instance):
raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
(obj, self.instance._state.db, obj._state.db))
- fk_val = self._get_fk_val(obj, target_field_name)
+ fk_val = self.through._meta.get_field(
+ target_field_name).get_foreign_related_value(obj)[0]
if fk_val is None:
raise ValueError('Cannot add "%r": the value for field "%s" is None' %
(obj, target_field_name))
- new_ids.add(self._get_fk_val(obj, target_field_name))
+ new_ids.add(fk_val)
elif isinstance(obj, Model):
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
else:
@@ -638,7 +625,7 @@ def _add_items(self, source_field_name, target_field_name, *objs):
db = router.db_for_write(self.through, instance=self.instance)
vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
vals = vals.filter(**{
- source_field_name: self._fk_val,
+ source_field_name: self.related_val[0],
'%s__in' % target_field_name: new_ids,
})
new_ids = new_ids - set(vals)
@@ -652,7 +639,7 @@ def _add_items(self, source_field_name, target_field_name, *objs):
# Add the ones that aren't there already
self.through._default_manager.using(db).bulk_create([
self.through(**{
- '%s_id' % source_field_name: self._fk_val,
+ '%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: obj_id,
})
for obj_id in new_ids
@@ -676,7 +663,9 @@ def _remove_items(self, source_field_name, target_field_name, *objs):
old_ids = set()
for obj in objs:
if isinstance(obj, self.model):
- old_ids.add(self._get_fk_val(obj, target_field_name))
+ fk_val = self.through._meta.get_field(
+ target_field_name).get_foreign_related_value(obj)[0]
+ old_ids.add(fk_val)
else:
old_ids.add(obj)
# Work out what DB we're operating on
@@ -690,7 +679,7 @@ def _remove_items(self, source_field_name, target_field_name, *objs):
model=self.model, pk_set=old_ids, using=db)
# Remove the specified objects from the join table
self.through._default_manager.using(db).filter(**{
- source_field_name: self._fk_val,
+ source_field_name: self.related_val[0],
'%s__in' % target_field_name: old_ids
}).delete()
if self.reverse or source_field_name == self.source_field_name:
@@ -994,10 +983,13 @@ def get_instance_value_for_fields(instance, fields):
# Gotcha: in some cases (like fixture loading) a model can have
# different values in parent_ptr_id and parent's id. So, use
# instance.pk (that is, parent_ptr_id) when asked for instance.id.
+ opts = instance._meta
if field.primary_key:
- ret.append(instance.pk)
- else:
- ret.append(getattr(instance, field.attname))
+ possible_parent_link = opts.get_ancestor_link(field.model)
+ if not possible_parent_link or possible_parent_link.primary_key:
+ ret.append(instance.pk)
+ continue
+ ret.append(getattr(instance, field.attname))
return tuple(ret)
def get_attname_column(self):
View
6 tests/model_inheritance/models.py
@@ -162,3 +162,9 @@ def __init__(self):
class MixinModel(models.Model, Mixin):
pass
+
+class Base(models.Model):
+ titles = models.ManyToManyField(Title)
+
+class SubBase(Base):
+ sub_id = models.IntegerField(primary_key=True)
View
16 tests/model_inheritance/tests.py
@@ -10,7 +10,8 @@
from .models import (
Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post,
- Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel)
+ Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel,
+ Title, Base, SubBase)
class ModelInheritanceTests(TestCase):
@@ -357,3 +358,16 @@ def test_ticket_12567(self):
[Place.objects.get(pk=s.pk)],
lambda x: x
)
+
+ def test_custompk_m2m(self):
+ b = Base.objects.create()
+ b.titles.add(Title.objects.create(title="foof"))
+ s = SubBase.objects.create(sub_id=b.id)
+ b = Base.objects.get(pk=s.id)
+ self.assertNotEqual(b.pk, s.pk)
+ # Low-level test for related_val
+ self.assertEqual(s.titles.related_val, (s.id,))
+ # Higher level test for correct query values (title foof not
+ # accidentally found).
+ self.assertQuerysetEqual(
+ s.titles.all(), [])
Please sign in to comment.
Something went wrong with that request. Please try again.