Skip to content

Commit

Permalink
Merge pull request #116 from tvuotila/hotfix/detect-changed-relations…
Browse files Browse the repository at this point in the history
…hips-correctly

Fix some relationship changes not counted as modifications
  • Loading branch information
kvesteri committed Dec 8, 2015
2 parents c2cd53a + 224311a commit 596da3e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
7 changes: 4 additions & 3 deletions sqlalchemy_continuum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,15 @@ def versioned_column_properties(obj_or_class):
yield getattr(mapper.attrs, key)


def versioned_relationships(obj):
def versioned_relationships(obj, versioned_column_keys):
"""
Return all versioned relationships for given versioned SQLAlchemy
declarative model object.
:param obj: SQLAlchemy declarative model object
"""
for prop in sa.inspect(obj.__class__).relationships:
if is_versioned(prop.mapper.class_):
if any(c.key in versioned_column_keys for c in prop.local_columns):
yield prop


Expand Down Expand Up @@ -310,7 +310,8 @@ def is_modified(obj):
prop.key for prop in versioned_column_properties(obj)
]
versioned_relationship_keys = [
prop.key for prop in versioned_relationships(obj)
prop.key
for prop in versioned_relationships(obj, versioned_column_keys)
]
for key, attr in sa.inspect(obj).attrs.items():
if key in column_names:
Expand Down
14 changes: 14 additions & 0 deletions tests/relationships/test_non_versioned_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def test_single_insert(self):

assert isinstance(article.versions[0].author, self.User)

def test_change_relationship(self):
article = self.Article()
article.name = u'Some article'
article.content = u'Some content'
user = self.User(name=u'Some user')
self.session.add(article)
self.session.add(user)
self.session.commit()

assert article.versions.count() == 1
article.author = user
self.session.commit()
assert article.versions.count() == 2


class TestManyToManyRelationshipToNonVersionedClass(TestCase):
def create_models(self):
Expand Down

0 comments on commit 596da3e

Please sign in to comment.