From f2ee5b63b4c2b5ba0db699384eff485a57044ea6 Mon Sep 17 00:00:00 2001 From: Nikolay Shebanov Date: Wed, 6 May 2020 21:28:18 +0200 Subject: [PATCH] Fix #245: create column aliases in the version model --- sqlalchemy_continuum/builder.py | 42 ++++++++++++++++--- sqlalchemy_continuum/model_builder.py | 1 + .../test_single_table_inheritance.py | 8 ++++ 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/sqlalchemy_continuum/builder.py b/sqlalchemy_continuum/builder.py index 47bbef35..8a2f219c 100644 --- a/sqlalchemy_continuum/builder.py +++ b/sqlalchemy_continuum/builder.py @@ -144,14 +144,16 @@ def build_transaction_class(self): def configure_versioned_classes(self): """ Configures all versioned classes that were collected during - instrumentation process. The configuration has 4 steps: + instrumentation process. The configuration has 6 steps: 1. Build tables for version models. 2. Build the actual version model declarative classes. 3. Build relationships between these models. 4. Empty pending_classes list so that consecutive mapper configuration does not create multiple version classes - 5. Assign all versioned attributes to use active history. + 5. Build aliases for columns. + 6. Assign all versioned attributes to use active history. + """ if not self.manager.options['versioning']: return @@ -168,11 +170,39 @@ def configure_versioned_classes(self): # Create copy of all pending versioned classes so that we can inspect # them later when creating relationships. - pending_copy = copy(self.manager.pending_classes) + pending_classes_copies = copy(self.manager.pending_classes) self.manager.pending_classes = [] - self.build_relationships(pending_copy) + self.build_relationships(pending_classes_copies) + self.enable_active_history(pending_classes_copies) + self.create_column_aliases(pending_classes_copies) - for cls in pending_copy: - # set the "active_history" flag + def enable_active_history(self, version_classes): + """ + Assign all versioned attributes to use active history. + """ + for cls in version_classes: for prop in sa.inspect(cls).iterate_properties: getattr(cls, prop.key).impl.active_history = True + + def create_column_aliases(self, version_classes): + """ + Create aliases for the columns from the original model. + + This, for example, imitates the behavior of @declared_attr columns. + """ + for cls in version_classes: + model_mapper = sa.inspect(cls) + version_class = self.manager.version_class_map.get(cls) + if not version_class: + continue + + version_class_mapper = sa.inspect(version_class) + + for key, column in model_mapper.columns.items(): + if key != column.key: + version_class_column = version_class.__table__.c.get(column.key) + + if version_class_column is None: + continue + + version_class_mapper.add_property(key, sa.orm.column_property(version_class_column)) diff --git a/sqlalchemy_continuum/model_builder.py b/sqlalchemy_continuum/model_builder.py index 2be6e63b..1d29d68c 100644 --- a/sqlalchemy_continuum/model_builder.py +++ b/sqlalchemy_continuum/model_builder.py @@ -261,6 +261,7 @@ def mapper_args(cls): name = '%sVersion' % (self.model.__name__,) return type(name, self.base_classes(), args) + def __call__(self, table, tx_class): """ Build history model and relationships to parent model, transaction diff --git a/tests/inheritance/test_single_table_inheritance.py b/tests/inheritance/test_single_table_inheritance.py index 9b723c15..9a259a54 100644 --- a/tests/inheritance/test_single_table_inheritance.py +++ b/tests/inheritance/test_single_table_inheritance.py @@ -1,4 +1,5 @@ import sqlalchemy as sa +from sqlalchemy.ext.declarative import declared_attr from sqlalchemy_continuum import versioning_manager, version_class from tests import TestCase, create_test_cases @@ -25,6 +26,10 @@ class Article(TextItem): __mapper_args__ = {'polymorphic_identity': u'article'} name = sa.Column(sa.Unicode(255)) + @sa.ext.declarative.declared_attr + def status(cls): + return sa.Column("_status", sa.Unicode(255)) + class BlogPost(TextItem): __mapper_args__ = {'polymorphic_identity': u'blog_post'} title = sa.Column(sa.Unicode(255)) @@ -79,5 +84,8 @@ def test_transaction_changed_entities(self): assert transaction.entity_names == [u'Article'] assert transaction.changed_entities + def test_declared_attr_inheritance(self): + assert self.ArticleVersion.status + create_test_cases(SingleTableInheritanceTestCase)