From 224311a0a05247d9856828917dfe425f6c1d9972 Mon Sep 17 00:00:00 2001 From: Tero Vuotila Date: Tue, 8 Dec 2015 14:23:15 +0200 Subject: [PATCH] Fix some relationship changes not counted as modifications --- sqlalchemy_continuum/utils.py | 7 ++++--- tests/relationships/test_non_versioned_classes.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sqlalchemy_continuum/utils.py b/sqlalchemy_continuum/utils.py index c027a868..372cc317 100644 --- a/sqlalchemy_continuum/utils.py +++ b/sqlalchemy_continuum/utils.py @@ -203,7 +203,7 @@ def versioned_column_properties(obj_or_class): yield getattr(mapper.attrs, key) -def versioned_relationships(obj): +def versioned_relationships(obj, versioned_column_keys): """ Return all versioned relationships for given versioned SQLAlchemy declarative model object. @@ -211,7 +211,7 @@ def versioned_relationships(obj): :param obj: SQLAlchemy declarative model object """ for prop in sa.inspect(obj.__class__).relationships: - if is_versioned(prop.mapper.class_): + if any(c.key in versioned_column_keys for c in prop.local_columns): yield prop @@ -310,7 +310,8 @@ def is_modified(obj): prop.key for prop in versioned_column_properties(obj) ] versioned_relationship_keys = [ - prop.key for prop in versioned_relationships(obj) + prop.key + for prop in versioned_relationships(obj, versioned_column_keys) ] for key, attr in sa.inspect(obj).attrs.items(): if key in column_names: diff --git a/tests/relationships/test_non_versioned_classes.py b/tests/relationships/test_non_versioned_classes.py index 3e4d915b..cf2dad53 100644 --- a/tests/relationships/test_non_versioned_classes.py +++ b/tests/relationships/test_non_versioned_classes.py @@ -36,6 +36,20 @@ def test_single_insert(self): assert isinstance(article.versions[0].author, self.User) + def test_change_relationship(self): + article = self.Article() + article.name = u'Some article' + article.content = u'Some content' + user = self.User(name=u'Some user') + self.session.add(article) + self.session.add(user) + self.session.commit() + + assert article.versions.count() == 1 + article.author = user + self.session.commit() + assert article.versions.count() == 2 + class TestManyToManyRelationshipToNonVersionedClass(TestCase): def create_models(self):