Skip to content

Commit

Permalink
Merge 3cb162d into d6e4255
Browse files Browse the repository at this point in the history
  • Loading branch information
ykiu committed Nov 14, 2017
2 parents d6e4255 + 3cb162d commit e5dafc1
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
13 changes: 12 additions & 1 deletion closuretree/models.py
Expand Up @@ -26,7 +26,7 @@
# Public methods are useful!
# pylint: disable=R0904

from django.db import models
from django.db import models, transaction
from django.db.models.base import ModelBase
from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
Expand Down Expand Up @@ -116,6 +116,17 @@ def __setattr__(self, name, value):
self._closure_change_init()
super(ClosureModel, self).__setattr__(name, value)

def save_base(self, *args, **kwargs):
"""Wrap the saving operation of this model and that of the closure
table in one transaction.
"""
# the superclass save_base() sends post_save() signal.
# post_save() is then received by closure_model_save() function,
# which saves closure model.
# we're going to wrap this series of operations in a transaction.
with transaction.atomic():
super(ClosureModel, self).save_base(*args, **kwargs)

@classmethod
def _toplevel(cls):
"""Find the top level of the chain we're in.
Expand Down
51 changes: 50 additions & 1 deletion closuretree/tests.py
Expand Up @@ -24,8 +24,9 @@
# pylint: disable=

from django import VERSION
from django.test import TestCase
from django.test import TestCase, TransactionTestCase
from django.db import models
from django.db.utils import IntegrityError
from closuretree.models import ClosureModel
import uuid

Expand Down Expand Up @@ -118,6 +119,54 @@ def test_deletion(self):
self.b.delete()
self.failUnlessEqual(self.closure_model.objects.count(), 2)

class DirtyParentTestCase(TransactionTestCase):

normal_model = TC
closure_model = TCClosure

def setUp(self):
self.a = self.normal_model.objects.create(name="a")
self.b = self.normal_model.objects.create(name="b")
self.c = self.normal_model.objects.create(name="c")
self.d = self.normal_model.objects.create(name="d")

def test_selfreferencing_parent(self):
"""Tests that instances with self-referencing parents are not saved"""
self.b.parent2 = self.a
self.b.save()
self.c.parent2 = self.b
self.c.save()

# this is the prerequisite of this test:
# closure model contains 7 rows before we save dirty data
self.assertEqual(self.closure_model.objects.count(), 7)

# dirty data
self.b.parent2 = self.b
# save dirty data
self.assertRaises(IntegrityError, self.b.save)

# closure model contains the same number of data as before
self.assertEqual(self.closure_model.objects.count(), 7)

def test_parent_referencing_child(self):
"""Tests instances with circular-referencing parents are not saved"""
self.b.parent2 = self.a
self.b.save()
self.c.parent2 = self.b
self.c.save()

# this is the prerequisite of this test:
# closure model contains 7 rows before we save dirty data
self.assertEqual(self.closure_model.objects.count(), 7)

# dirty data
self.b.parent2 = self.c
# save dirty data
self.assertRaises(IntegrityError, self.b.save)

# closure model contains the same number of data as before
self.assertEqual(self.closure_model.objects.count(), 7)

if VERSION >= (1, 8):
class UUIDTC(ClosureModel):
Expand Down

0 comments on commit e5dafc1

Please sign in to comment.