From 3cb162d14cc04156865ccc11ffce1a85d3e60507 Mon Sep 17 00:00:00 2001 From: ykiu <32252655+ykiu@users.noreply.github.com> Date: Tue, 14 Nov 2017 19:57:46 +0900 Subject: [PATCH] Wrap saving operations in a transaction Wrap saving operations in a transaction to prevent inconsistency between main model and closure model --- closuretree/models.py | 13 ++++++++++- closuretree/tests.py | 51 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/closuretree/models.py b/closuretree/models.py index 46bcc5e..2d86ed2 100644 --- a/closuretree/models.py +++ b/closuretree/models.py @@ -26,7 +26,7 @@ # Public methods are useful! # pylint: disable=R0904 -from django.db import models +from django.db import models, transaction from django.db.models.base import ModelBase from django.db.models.signals import post_save, pre_delete from django.dispatch import receiver @@ -116,6 +116,17 @@ def __setattr__(self, name, value): self._closure_change_init() super(ClosureModel, self).__setattr__(name, value) + def save_base(self, *args, **kwargs): + """Wrap the saving operation of this model and that of the closure + table in one transaction. + """ + # the superclass save_base() sends post_save() signal. + # post_save() is then received by closure_model_save() function, + # which saves closure model. + # we're going to wrap this series of operations in a transaction. + with transaction.atomic(): + super(ClosureModel, self).save_base(*args, **kwargs) + @classmethod def _toplevel(cls): """Find the top level of the chain we're in. diff --git a/closuretree/tests.py b/closuretree/tests.py index fd2b1a4..157d46f 100644 --- a/closuretree/tests.py +++ b/closuretree/tests.py @@ -24,8 +24,9 @@ # pylint: disable= from django import VERSION -from django.test import TestCase +from django.test import TestCase, TransactionTestCase from django.db import models +from django.db.utils import IntegrityError from closuretree.models import ClosureModel import uuid @@ -118,6 +119,54 @@ def test_deletion(self): self.b.delete() self.failUnlessEqual(self.closure_model.objects.count(), 2) +class DirtyParentTestCase(TransactionTestCase): + + normal_model = TC + closure_model = TCClosure + + def setUp(self): + self.a = self.normal_model.objects.create(name="a") + self.b = self.normal_model.objects.create(name="b") + self.c = self.normal_model.objects.create(name="c") + self.d = self.normal_model.objects.create(name="d") + + def test_selfreferencing_parent(self): + """Tests that instances with self-referencing parents are not saved""" + self.b.parent2 = self.a + self.b.save() + self.c.parent2 = self.b + self.c.save() + + # this is the prerequisite of this test: + # closure model contains 7 rows before we save dirty data + self.assertEqual(self.closure_model.objects.count(), 7) + + # dirty data + self.b.parent2 = self.b + # save dirty data + self.assertRaises(IntegrityError, self.b.save) + + # closure model contains the same number of data as before + self.assertEqual(self.closure_model.objects.count(), 7) + + def test_parent_referencing_child(self): + """Tests instances with circular-referencing parents are not saved""" + self.b.parent2 = self.a + self.b.save() + self.c.parent2 = self.b + self.c.save() + + # this is the prerequisite of this test: + # closure model contains 7 rows before we save dirty data + self.assertEqual(self.closure_model.objects.count(), 7) + + # dirty data + self.b.parent2 = self.c + # save dirty data + self.assertRaises(IntegrityError, self.b.save) + + # closure model contains the same number of data as before + self.assertEqual(self.closure_model.objects.count(), 7) if VERSION >= (1, 8): class UUIDTC(ClosureModel):