Skip to content

Commit

Permalink
Fixed #18823 -- Ensured m2m.clear() works when using through+to_field
Browse files Browse the repository at this point in the history
There was a potential data-loss issue involved -- when clearing
instance's m2m assignments it was possible some other instance's
m2m data was deleted instead.

This commit also improved None handling for to_field cases.
  • Loading branch information
akaariai committed Oct 28, 2012
1 parent 98032f6 commit 611c4d6
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 15 deletions.
45 changes: 36 additions & 9 deletions django/db/models/fields/related.py
Expand Up @@ -573,9 +573,31 @@ def __init__(self, model=None, query_field_name=None, instance=None, symmetrical
self.reverse = reverse self.reverse = reverse
self.through = through self.through = through
self.prefetch_cache_name = prefetch_cache_name self.prefetch_cache_name = prefetch_cache_name
self._pk_val = self.instance.pk self._fk_val = self._get_fk_val(instance, source_field_name)
if self._pk_val is None: if self._fk_val is None:
raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__) raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' %
(instance, source_field_name))
# Even if this relation is not to pk, we require still pk value.
# The wish is that the instance has been already saved to DB,
# although having a pk value isn't a guarantee of that.
if instance.pk is None:
raise ValueError("%r instance needs to have a primary key value before "
"a many-to-many relationship can be used." %
instance.__class__.__name__)


def _get_fk_val(self, obj, field_name):
"""
Returns the correct value for this relationship's foreign key. This
might be something else than pk value when to_field is used.
"""
fk = self.through._meta.get_field(field_name)
if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname:
attname = fk.rel.get_related_field().get_attname()
return fk.get_prep_lookup('exact', getattr(obj, attname))
else:
return obj.pk


def get_query_set(self): def get_query_set(self):
try: try:
Expand Down Expand Up @@ -677,15 +699,19 @@ def _add_items(self, source_field_name, target_field_name, *objs):
if not router.allow_relation(obj, self.instance): if not router.allow_relation(obj, self.instance):
raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' % raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' %
(obj, self.instance._state.db, obj._state.db)) (obj, self.instance._state.db, obj._state.db))
new_ids.add(obj.pk) fk_val = self._get_fk_val(obj, target_field_name)
if fk_val is None:
raise ValueError('Cannot add "%r": the value for field "%s" is None' %
(obj, target_field_name))
new_ids.add(self._get_fk_val(obj, target_field_name))
elif isinstance(obj, Model): elif isinstance(obj, Model):
raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj))
else: else:
new_ids.add(obj) new_ids.add(obj)
db = router.db_for_write(self.through, instance=self.instance) db = router.db_for_write(self.through, instance=self.instance)
vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True) vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
vals = vals.filter(**{ vals = vals.filter(**{
source_field_name: self._pk_val, source_field_name: self._fk_val,
'%s__in' % target_field_name: new_ids, '%s__in' % target_field_name: new_ids,
}) })
new_ids = new_ids - set(vals) new_ids = new_ids - set(vals)
Expand All @@ -699,11 +725,12 @@ def _add_items(self, source_field_name, target_field_name, *objs):
# Add the ones that aren't there already # Add the ones that aren't there already
self.through._default_manager.using(db).bulk_create([ self.through._default_manager.using(db).bulk_create([
self.through(**{ self.through(**{
'%s_id' % source_field_name: self._pk_val, '%s_id' % source_field_name: self._fk_val,
'%s_id' % target_field_name: obj_id, '%s_id' % target_field_name: obj_id,
}) })
for obj_id in new_ids for obj_id in new_ids
]) ])

if self.reverse or source_field_name == self.source_field_name: if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are inserting the # Don't send the signal when we are inserting the
# duplicate data row for symmetrical reverse entries. # duplicate data row for symmetrical reverse entries.
Expand All @@ -722,7 +749,7 @@ def _remove_items(self, source_field_name, target_field_name, *objs):
old_ids = set() old_ids = set()
for obj in objs: for obj in objs:
if isinstance(obj, self.model): if isinstance(obj, self.model):
old_ids.add(obj.pk) old_ids.add(self._get_fk_val(obj, target_field_name))
else: else:
old_ids.add(obj) old_ids.add(obj)
# Work out what DB we're operating on # Work out what DB we're operating on
Expand All @@ -736,7 +763,7 @@ def _remove_items(self, source_field_name, target_field_name, *objs):
model=self.model, pk_set=old_ids, using=db) model=self.model, pk_set=old_ids, using=db)
# Remove the specified objects from the join table # Remove the specified objects from the join table
self.through._default_manager.using(db).filter(**{ self.through._default_manager.using(db).filter(**{
source_field_name: self._pk_val, source_field_name: self._fk_val,
'%s__in' % target_field_name: old_ids '%s__in' % target_field_name: old_ids
}).delete() }).delete()
if self.reverse or source_field_name == self.source_field_name: if self.reverse or source_field_name == self.source_field_name:
Expand All @@ -756,7 +783,7 @@ def _clear_items(self, source_field_name):
instance=self.instance, reverse=self.reverse, instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db) model=self.model, pk_set=None, using=db)
self.through._default_manager.using(db).filter(**{ self.through._default_manager.using(db).filter(**{
source_field_name: self._pk_val source_field_name: self._fk_val
}).delete() }).delete()
if self.reverse or source_field_name == self.source_field_name: if self.reverse or source_field_name == self.source_field_name:
# Don't send the signal when we are clearing the # Don't send the signal when we are clearing the
Expand Down
8 changes: 4 additions & 4 deletions tests/regressiontests/m2m_through_regress/models.py
Expand Up @@ -62,18 +62,18 @@ class B(models.Model):
# Using to_field on the through model # Using to_field on the through model
@python_2_unicode_compatible @python_2_unicode_compatible
class Car(models.Model): class Car(models.Model):
make = models.CharField(max_length=20, unique=True) make = models.CharField(max_length=20, unique=True, null=True)
drivers = models.ManyToManyField('Driver', through='CarDriver') drivers = models.ManyToManyField('Driver', through='CarDriver')


def __str__(self): def __str__(self):
return self.make return "%s" % self.make


@python_2_unicode_compatible @python_2_unicode_compatible
class Driver(models.Model): class Driver(models.Model):
name = models.CharField(max_length=20, unique=True) name = models.CharField(max_length=20, unique=True, null=True)


def __str__(self): def __str__(self):
return self.name return "%s" % self.name


@python_2_unicode_compatible @python_2_unicode_compatible
class CarDriver(models.Model): class CarDriver(models.Model):
Expand Down
90 changes: 88 additions & 2 deletions tests/regressiontests/m2m_through_regress/tests.py
Expand Up @@ -123,18 +123,104 @@ def setUp(self):
self.car = Car.objects.create(make="Toyota") self.car = Car.objects.create(make="Toyota")
self.driver = Driver.objects.create(name="Ryan Briscoe") self.driver = Driver.objects.create(name="Ryan Briscoe")
CarDriver.objects.create(car=self.car, driver=self.driver) CarDriver.objects.create(car=self.car, driver=self.driver)
# We are testing if wrong objects get deleted due to using wrong
# field value in m2m queries. So, it is essential that the pk
# numberings do not match.
# Create one intentionally unused driver to mix up the autonumbering
self.unused_driver = Driver.objects.create(name="Barney Gumble")
# And two intentionally unused cars.
self.unused_car1 = Car.objects.create(make="Trabant")
self.unused_car2 = Car.objects.create(make="Wartburg")


def test_to_field(self): def test_to_field(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
self.car.drivers.all(), self.car.drivers.all(),
["<Driver: Ryan Briscoe>"] ["<Driver: Ryan Briscoe>"]
) )


def test_to_field_reverse(self): def test_to_field_reverse(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
self.driver.car_set.all(), self.driver.car_set.all(),
["<Car: Toyota>"] ["<Car: Toyota>"]
) )

def test_to_field_clear_reverse(self):
self.driver.car_set.clear()
self.assertQuerysetEqual(
self.driver.car_set.all(),[])

def test_to_field_clear(self):
self.car.drivers.clear()
self.assertQuerysetEqual(
self.car.drivers.all(),[])

# Low level tests for _add_items and _remove_items. We test these methods
# because .add/.remove aren't available for m2m fields with through, but
# through is the only way to set to_field currently. We do want to make
# sure these methods are ready if the ability to use .add or .remove with
# to_field relations is added some day.
def test_add(self):
self.assertQuerysetEqual(
self.car.drivers.all(),
["<Driver: Ryan Briscoe>"]
)
# Yikes - barney is going to drive...
self.car.drivers._add_items('car', 'driver', self.unused_driver)
self.assertQuerysetEqual(
self.car.drivers.all(),
["<Driver: Ryan Briscoe>", "<Driver: Barney Gumble>"]
)

def test_add_null(self):
nullcar = Car.objects.create(make=None)
with self.assertRaises(ValueError):
nullcar.drivers._add_items('car', 'driver', self.unused_driver)

def test_add_related_null(self):
nulldriver = Driver.objects.create(name=None)
with self.assertRaises(ValueError):
self.car.drivers._add_items('car', 'driver', nulldriver)

def test_add_reverse(self):
car2 = Car.objects.create(make="Honda")
self.assertQuerysetEqual(
self.driver.car_set.all(),
["<Car: Toyota>"]
)
self.driver.car_set._add_items('driver', 'car', car2)
self.assertQuerysetEqual(
self.driver.car_set.all(),
["<Car: Toyota>", "<Car: Honda>"]
)

def test_add_null_reverse(self):
nullcar = Car.objects.create(make=None)
with self.assertRaises(ValueError):
self.driver.car_set._add_items('driver', 'car', nullcar)

def test_add_null_reverse_related(self):
nulldriver = Driver.objects.create(name=None)
with self.assertRaises(ValueError):
nulldriver.car_set._add_items('driver', 'car', self.car)

def test_remove(self):
self.assertQuerysetEqual(
self.car.drivers.all(),
["<Driver: Ryan Briscoe>"]
)
self.car.drivers._remove_items('car', 'driver', self.driver)
self.assertQuerysetEqual(
self.car.drivers.all(),[])

def test_remove_reverse(self):
self.assertQuerysetEqual(
self.driver.car_set.all(),
["<Car: Toyota>"]
)
self.driver.car_set._remove_items('driver', 'car', self.car)
self.assertQuerysetEqual(
self.driver.car_set.all(),[])



class ThroughLoadDataTestCase(TestCase): class ThroughLoadDataTestCase(TestCase):
fixtures = ["m2m_through"] fixtures = ["m2m_through"]
Expand Down

0 comments on commit 611c4d6

Please sign in to comment.