Skip to content

Commit

Permalink
Change FieldTracker current for unsaved model
Browse files Browse the repository at this point in the history
Return None values instead of an empty dict
  • Loading branch information
treyhunner committed May 23, 2013
1 parent d190239 commit d28f386
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 22 deletions.
105 changes: 90 additions & 15 deletions model_utils/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,15 +691,6 @@ def test_pre_save_has_changed(self):
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=True)

def test_pre_save_changed(self):
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged()
self.instance.number = 8
self.assertChanged()
self.instance.name = ''
self.assertChanged()

def test_pre_save_previous(self):
self.assertPrevious(name=None, number=None)
self.instance.name = 'new age'
Expand All @@ -718,24 +709,33 @@ def setUp(self):
def test_descriptor(self):
self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker))

def test_pre_save_changed(self):
self.assertChanged(name=None, number=None, id=None)
self.instance.name = 'new age'
self.assertChanged(name=None, number=None, id=None)
self.instance.number = 8
self.assertChanged(name=None, number=None, id=None)
self.instance.name = ''
self.assertChanged(name=None, number=None, id=None)

def test_first_save(self):
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='', number=None, id=None)
self.assertChanged()
self.assertChanged(name=None, number=None, id=None)
self.instance.name = 'retro'
self.instance.number = 4
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='retro', number=4, id=None)
self.assertChanged()
self.assertChanged(name=None, number=None, id=None)
# Django 1.4 doesn't have update_fields
if django.VERSION >= (1, 5, 0):
self.instance.save(update_fields=[])
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='retro', number=4, id=None)
self.assertChanged()
self.assertChanged(name=None, number=None, id=None)
self.assertRaises(ValueError, self.instance.save,
update_fields=['number'])

Expand Down Expand Up @@ -804,6 +804,15 @@ def setUp(self):
self.instance = self.tracked_class()
self.tracker = self.instance.name_tracker

def test_pre_save_changed(self):
self.assertChanged(name=None)
self.instance.name = 'new age'
self.assertChanged(name=None)
self.instance.number = 8
self.assertChanged(name=None)
self.instance.name = ''
self.assertChanged(name=None)

def test_post_save_has_changed(self):
self.update_instance(name='retro', number=4)
self.assertHasChanged(name=False, number=None)
Expand Down Expand Up @@ -855,9 +864,20 @@ def test_pre_save_has_changed(self):
super(FieldTrackedModelMultiTests, self).test_pre_save_has_changed()

def test_pre_save_changed(self):
for tracker in self.trackers:
self.tracker = tracker
super(FieldTrackedModelMultiTests, self).test_pre_save_changed()
self.tracker = self.instance.name_tracker
self.assertChanged(name=None)
self.instance.name = 'new age'
self.assertChanged(name=None)
self.instance.number = 8
self.assertChanged(name=None)
self.instance.name = ''
self.assertChanged(name=None)
self.tracker = self.instance.number_tracker
self.assertChanged(number=None)
self.instance.name = 'new age'
self.assertChanged(number=None)
self.instance.number = 8
self.assertChanged(number=None)

def test_pre_save_previous(self):
for tracker in self.trackers:
Expand Down Expand Up @@ -961,16 +981,71 @@ class ModelTrackerTests(FieldTrackerTests):

tracked_class = ModelTracked

def test_pre_save_changed(self):
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged()
self.instance.number = 8
self.assertChanged()
self.instance.name = ''
self.assertChanged()

def test_first_save(self):
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='', number=None, id=None)
self.assertChanged()
self.instance.name = 'retro'
self.instance.number = 4
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='retro', number=4, id=None)
self.assertChanged()
# Django 1.4 doesn't have update_fields
if django.VERSION >= (1, 5, 0):
self.instance.save(update_fields=[])
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='retro', number=4, id=None)
self.assertChanged()
self.assertRaises(ValueError, self.instance.save,
update_fields=['number'])


class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests):

tracked_class = ModelTrackedNotDefault

def test_pre_save_changed(self):
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged()
self.instance.number = 8
self.assertChanged()
self.instance.name = ''
self.assertChanged()


class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):

tracked_class = ModelTrackedMultiple

def test_pre_save_changed(self):
self.tracker = self.instance.name_tracker
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged()
self.instance.number = 8
self.assertChanged()
self.instance.name = ''
self.assertChanged()
self.tracker = self.instance.number_tracker
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged()
self.instance.number = 8
self.assertChanged()


class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests):

Expand Down
21 changes: 14 additions & 7 deletions model_utils/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def has_changed(self, field):
if not self.instance.pk:
return True
elif field in self.saved_data:
return self.saved_data.get(field) != self.get_field(field)
return self.previous(field) != self.get_field_value(field)
else:
raise FieldError('field "%s" not tracked' % field)

Expand All @@ -40,11 +40,11 @@ def previous(self, field):

def changed(self):
"""Returns dict of fields that changed since save (with old values)"""
if not self.instance.pk:
return {}
saved = self.saved_data.items()
current = self.current()
return dict((k, v) for k, v in saved if v != current[k])
return dict(
(field, self.previous(field))
for field in self.fields
if self.has_changed(field)
)


class FieldTracker(object):
Expand Down Expand Up @@ -97,7 +97,14 @@ def __get__(self, instance, owner):


class ModelInstanceTracker(FieldInstanceTracker):
pass

def changed(self):
"""Returns dict of fields that changed since save (with old values)"""
if not self.instance.pk:
return {}
saved = self.saved_data.items()
current = self.current()
return dict((k, v) for k, v in saved if v != current[k])


class ModelTracker(FieldTracker):
Expand Down

0 comments on commit d28f386

Please sign in to comment.