Skip to content

Commit

Permalink
Wrap saving operations in a transaction
Browse files Browse the repository at this point in the history
Wrap saving operations in a transaction to prevent inconsistency between main model and closure model
  • Loading branch information
ykiu committed Nov 14, 2017
1 parent d6e4255 commit 3cb162d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
13 changes: 12 additions & 1 deletion closuretree/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# Public methods are useful!
# pylint: disable=R0904

from django.db import models
from django.db import models, transaction
from django.db.models.base import ModelBase
from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
Expand Down Expand Up @@ -116,6 +116,17 @@ def __setattr__(self, name, value):
self._closure_change_init()
super(ClosureModel, self).__setattr__(name, value)

def save_base(self, *args, **kwargs):
"""Wrap the saving operation of this model and that of the closure
table in one transaction.
"""
# the superclass save_base() sends post_save() signal.
# post_save() is then received by closure_model_save() function,
# which saves closure model.
# we're going to wrap this series of operations in a transaction.
with transaction.atomic():
super(ClosureModel, self).save_base(*args, **kwargs)

@classmethod
def _toplevel(cls):
"""Find the top level of the chain we're in.
Expand Down
51 changes: 50 additions & 1 deletion closuretree/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
# pylint: disable=

from django import VERSION
from django.test import TestCase
from django.test import TestCase, TransactionTestCase
from django.db import models
from django.db.utils import IntegrityError
from closuretree.models import ClosureModel
import uuid

Expand Down Expand Up @@ -118,6 +119,54 @@ def test_deletion(self):
self.b.delete()
self.failUnlessEqual(self.closure_model.objects.count(), 2)

class DirtyParentTestCase(TransactionTestCase):

normal_model = TC
closure_model = TCClosure

def setUp(self):
self.a = self.normal_model.objects.create(name="a")
self.b = self.normal_model.objects.create(name="b")
self.c = self.normal_model.objects.create(name="c")
self.d = self.normal_model.objects.create(name="d")

def test_selfreferencing_parent(self):
"""Tests that instances with self-referencing parents are not saved"""
self.b.parent2 = self.a
self.b.save()
self.c.parent2 = self.b
self.c.save()

# this is the prerequisite of this test:
# closure model contains 7 rows before we save dirty data
self.assertEqual(self.closure_model.objects.count(), 7)

# dirty data
self.b.parent2 = self.b
# save dirty data
self.assertRaises(IntegrityError, self.b.save)

# closure model contains the same number of data as before
self.assertEqual(self.closure_model.objects.count(), 7)

def test_parent_referencing_child(self):
"""Tests instances with circular-referencing parents are not saved"""
self.b.parent2 = self.a
self.b.save()
self.c.parent2 = self.b
self.c.save()

# this is the prerequisite of this test:
# closure model contains 7 rows before we save dirty data
self.assertEqual(self.closure_model.objects.count(), 7)

# dirty data
self.b.parent2 = self.c
# save dirty data
self.assertRaises(IntegrityError, self.b.save)

# closure model contains the same number of data as before
self.assertEqual(self.closure_model.objects.count(), 7)

if VERSION >= (1, 8):
class UUIDTC(ClosureModel):
Expand Down

0 comments on commit 3cb162d

Please sign in to comment.