Skip to content

Commit

Permalink
Fixed #20946 -- model inheritance + m2m failure
Browse files Browse the repository at this point in the history
Cleaned up the internal implementation of m2m fields by removing
related.py _get_fk_val(). The _get_fk_val() was doing the wrong thing
if asked for the foreign key value on foreign key to parent model's
primary key when child model had different primary key field.
  • Loading branch information
akaariai authored and andrewgodwin committed Aug 21, 2013
1 parent 7775ced commit 244e2b7
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
38 changes: 15 additions & 23 deletions django/db/models/fields/related.py
Expand Up @@ -501,8 +501,6 @@ def __init__(self, model=None, query_field_name=None, instance=None, symmetrical
self.through = through self.through = through
self.prefetch_cache_name = prefetch_cache_name self.prefetch_cache_name = prefetch_cache_name
self.related_val = source_field.get_foreign_related_value(instance) self.related_val = source_field.get_foreign_related_value(instance)
# Used for single column related auto created models
self._fk_val = self.related_val[0]
if None in self.related_val: if None in self.related_val:
raise ValueError('"%r" needs to have a value for field "%s" before ' raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' % 'this many-to-many relationship can be used.' %
Expand All @@ -515,18 +513,6 @@ def __init__(self, model=None, query_field_name=None, instance=None, symmetrical
"a many-to-many relationship can be used." % "a many-to-many relationship can be used." %
instance.__class__.__name__) instance.__class__.__name__)


def _get_fk_val(self, obj, field_name):
"""
Returns the correct value for this relationship's foreign key. This
might be something else than pk value when to_field is used.
"""
fk = self.through._meta.get_field(field_name)
if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
attname = fk.rel.get_related_field().get_attname()
return fk.get_prep_lookup('exact', getattr(obj, attname))
else:
return obj.pk

def get_queryset(self): def get_queryset(self):
try: try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name] return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
Expand Down Expand Up @@ -624,19 +610,20 @@ def _add_items(self, source_field_name, target_field_name, *objs):
if not router.allow_relation(obj, self.instance): if not router.allow_relation(obj, self.instance):
raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' % raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
(obj, self.instance._state.db, obj._state.db)) (obj, self.instance._state.db, obj._state.db))
fk_val = self._get_fk_val(obj, target_field_name) fk_val = self.through._meta.get_field(
target_field_name).get_foreign_related_value(obj)[0]
if fk_val is None: if fk_val is None:
raise ValueError('Cannot add "%r": the value for field "%s" is None' % raise ValueError('Cannot add "%r": the value for field "%s" is None' %
(obj, target_field_name)) (obj, target_field_name))
new_ids.add(self._get_fk_val(obj, target_field_name)) new_ids.add(fk_val)
elif isinstance(obj, Model): elif isinstance(obj, Model):
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
else: else:
new_ids.add(obj) new_ids.add(obj)
db = router.db_for_write(self.through, instance=self.instance) db = router.db_for_write(self.through, instance=self.instance)
vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True) vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
vals = vals.filter(**{ vals = vals.filter(**{
source_field_name: self._fk_val, source_field_name: self.related_val[0],
'%s__in' % target_field_name: new_ids, '%s__in' % target_field_name: new_ids,
}) })
new_ids = new_ids - set(vals) new_ids = new_ids - set(vals)
Expand All @@ -650,7 +637,7 @@ def _add_items(self, source_field_name, target_field_name, *objs):
# Add the ones that aren't there already # Add the ones that aren't there already
self.through._default_manager.using(db).bulk_create([ self.through._default_manager.using(db).bulk_create([
self.through(**{ self.through(**{
'%s_id' % source_field_name: self._fk_val, '%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: obj_id, '%s_id' % target_field_name: obj_id,
}) })
for obj_id in new_ids for obj_id in new_ids
Expand All @@ -674,7 +661,9 @@ def _remove_items(self, source_field_name, target_field_name, *objs):
old_ids = set() old_ids = set()
for obj in objs: for obj in objs:
if isinstance(obj, self.model): if isinstance(obj, self.model):
old_ids.add(self._get_fk_val(obj, target_field_name)) fk_val = self.through._meta.get_field(
target_field_name).get_foreign_related_value(obj)[0]
old_ids.add(fk_val)
else: else:
old_ids.add(obj) old_ids.add(obj)
# Work out what DB we're operating on # Work out what DB we're operating on
Expand All @@ -688,7 +677,7 @@ def _remove_items(self, source_field_name, target_field_name, *objs):
model=self.model, pk_set=old_ids, using=db) model=self.model, pk_set=old_ids, using=db)
# Remove the specified objects from the join table # Remove the specified objects from the join table
self.through._default_manager.using(db).filter(**{ self.through._default_manager.using(db).filter(**{
source_field_name: self._fk_val, source_field_name: self.related_val[0],
'%s__in' % target_field_name: old_ids '%s__in' % target_field_name: old_ids
}).delete() }).delete()
if self.reverse or source_field_name == self.source_field_name: if self.reverse or source_field_name == self.source_field_name:
Expand Down Expand Up @@ -994,10 +983,13 @@ def get_instance_value_for_fields(instance, fields):
# Gotcha: in some cases (like fixture loading) a model can have # Gotcha: in some cases (like fixture loading) a model can have
# different values in parent_ptr_id and parent's id. So, use # different values in parent_ptr_id and parent's id. So, use
# instance.pk (that is, parent_ptr_id) when asked for instance.id. # instance.pk (that is, parent_ptr_id) when asked for instance.id.
opts = instance._meta
if field.primary_key: if field.primary_key:
ret.append(instance.pk) possible_parent_link = opts.get_ancestor_link(field.model)
else: if not possible_parent_link or possible_parent_link.primary_key:
ret.append(getattr(instance, field.attname)) ret.append(instance.pk)
continue
ret.append(getattr(instance, field.attname))
return tuple(ret) return tuple(ret)


def get_attname_column(self): def get_attname_column(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/model_inheritance/models.py
Expand Up @@ -162,3 +162,9 @@ def __init__(self):


class MixinModel(models.Model, Mixin): class MixinModel(models.Model, Mixin):
pass pass

class Base(models.Model):
titles = models.ManyToManyField(Title)

class SubBase(Base):
sub_id = models.IntegerField(primary_key=True)
16 changes: 15 additions & 1 deletion tests/model_inheritance/tests.py
Expand Up @@ -10,7 +10,8 @@


from .models import ( from .models import (
Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post, Chef, CommonInfo, ItalianRestaurant, ParkingLot, Place, Post,
Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel) Restaurant, Student, StudentWorker, Supplier, Worker, MixinModel,
Title, Base, SubBase)




class ModelInheritanceTests(TestCase): class ModelInheritanceTests(TestCase):
Expand Down Expand Up @@ -357,3 +358,16 @@ def test_ticket_12567(self):
[Place.objects.get(pk=s.pk)], [Place.objects.get(pk=s.pk)],
lambda x: x lambda x: x
) )

def test_custompk_m2m(self):
b = Base.objects.create()
b.titles.add(Title.objects.create(title="foof"))
s = SubBase.objects.create(sub_id=b.id)
b = Base.objects.get(pk=s.id)
self.assertNotEqual(b.pk, s.pk)
# Low-level test for related_val
self.assertEqual(s.titles.related_val, (s.id,))
# Higher level test for correct query values (title foof not
# accidentally found).
self.assertQuerysetEqual(
s.titles.all(), [])

0 comments on commit 244e2b7

Please sign in to comment.