Skip to content

Commit

Permalink
Add trigger based versioning for PostgreSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
kvesteri committed Aug 14, 2014
1 parent bc984fd commit 838a764
Show file tree
Hide file tree
Showing 18 changed files with 346 additions and 30 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Expand Up @@ -4,10 +4,12 @@ addons:
env:
- DB=mysql
- DB=postgres
- DB=postgres-native
- DB=sqlite

before_script:
- sh -c "if [ '$DB' = 'postgres' ]; then psql -c 'create database sqlalchemy_continuum_test;' -U postgres; fi"
- sh -c "if [ '$DB' = 'postgres-native' ]; then psql -c 'create database sqlalchemy_continuum_test;' -U postgres; fi"
- sh -c "if [ '$DB' = 'mysql' ]; then mysql -e 'create database sqlalchemy_continuum_test;'; fi"

language: python
Expand Down
6 changes: 6 additions & 0 deletions CHANGES.rst
Expand Up @@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Continuum release.


1.1.0 (2014-xx-xx)
^^^^^^^^^^^^^^^^^^

- Add optional native trigger based versioning for PostgreSQL dialect


1.0.3 (2014-07-16)
^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions README.rst
Expand Up @@ -16,6 +16,7 @@ Features
- Transactions can be queried afterwards using SQLAlchemy query syntax
- Query for changed records at given transaction
- Temporal relationship reflection. Version object's relationship show the parent objects relationships as they where in that point in time.
- Supports native versioning for PostgreSQL database (trigger based versioning)


QuickStart
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -49,7 +49,7 @@ def run(self):

setup(
name='SQLAlchemy-Continuum',
version='1.0.3',
version='1.1.0',
url='https://github.com/kvesteri/sqlalchemy-continuum',
license='BSD',
author='Konsta Vesterinen',
Expand Down
3 changes: 2 additions & 1 deletion sqlalchemy_continuum/__init__.py
Expand Up @@ -8,6 +8,7 @@
changeset,
get_versioning_manager,
is_modified,
is_session_modified,
parent_class,
transaction_class,
tx_column_name,
Expand All @@ -16,7 +17,7 @@
)


__version__ = '1.0.3'
__version__ = '1.1.0'


versioning_manager = VersioningManager()
Expand Down
244 changes: 243 additions & 1 deletion sqlalchemy_continuum/builder.py
Expand Up @@ -4,20 +4,262 @@
from sqlalchemy_utils.functions import get_declarative_base

from .model_builder import ModelBuilder
from .plugins import PropertyModTrackerPlugin
from .relationship_builder import RelationshipBuilder
from .table_builder import TableBuilder
from .table_builder import TableBuilder, ColumnReflector


trigger_sql = """
CREATE TRIGGER {trigger_name}
AFTER INSERT OR UPDATE OR DELETE ON {table_name}
FOR EACH ROW EXECUTE PROCEDURE {procedure_name}()
"""

upsert_cte_sql = """
WITH upsert as
(
UPDATE {version_table_name}
SET {update_columns}
WHERE
{transaction_column} = txid_current() AND
{primary_key_condition}
RETURNING *
)
INSERT INTO {version_table_name}
({transaction_column}, {operation_type_column}, {columns})
SELECT * FROM
(VALUES (txid_current(), {operation_type}, {values})) AS columns
WHERE NOT EXISTS (SELECT 1 FROM upsert);
"""


procedure_sql = """
CREATE OR REPLACE FUNCTION {procedure_name}() RETURNS TRIGGER AS $$
BEGIN
IF (TG_OP = 'INSERT') THEN
{after_insert}
{upsert_insert}
ELSIF (TG_OP = 'UPDATE') THEN
{after_update}
{upsert_update}
ELSIF (TG_OP = 'DELETE') THEN
{after_delete}
{upsert_delete}
END IF;
RETURN NEW;
END;
$$
LANGUAGE plpgsql
"""

validity_sql = """
UPDATE {version_table_name} SET {end_transaction_column} = txid_current()
WHERE
{transaction_column} = (
SELECT MIN({transaction_column}) FROM {version_table_name}
WHERE {end_transaction_column} IS NULL AND {primary_key_condition}
) AND
{primary_key_condition};
"""


def uses_property_mod_tracking(manager):
return any(
isinstance(plugin, PropertyModTrackerPlugin)
for plugin in manager.plugins
)


class Builder(object):
def trigger_ddl(self, cls):
table = cls.__table__
if table.schema:
table_name = '%s."%s"' % (table.schema, table.name)
else:
table_name = '"' + table.name + '"'
return sa.schema.DDL(
trigger_sql.format(
trigger_name='%s_trigger' % table.name,
table_name=table_name,
procedure_name='%s_audit' % table.name
)
)

def upsert_sql(self, cls, operation_type):
table = cls.__table__
version_table_name = (
self.manager.option(cls, 'table_name') % table.name
)
if table.schema:
version_table_name = '%s.%s' % (table.schema, version_table_name)

reflector = ColumnReflector(self.manager, table, cls)
columns = list(reflector.reflected_parent_columns)
columns_without_pks = [c for c in columns if not c.primary_key]
operation_type_column = self.manager.option(
cls, 'operation_type_column_name'
)
column_names = [c.name for c in columns]
if uses_property_mod_tracking(self.manager):
column_names += [
'%s_mod' % c.name for c in columns_without_pks
]

if operation_type == 2:
values = ', '.join('OLD.%s' % c.name for c in columns)
primary_key_condition = ' AND '.join(
'{name} = OLD.{name}'.format(name=c.name)
for c in columns if c.primary_key
)
update_columns = ', '.join(
'{name} = OLD.{name}'.format(name=c.name)
for c in columns
)
else:
values = ['NEW.%s' % c.name for c in columns]
if uses_property_mod_tracking(self.manager):
if operation_type == 1:
values += [
'NOT ((OLD.{0} IS NULL AND NEW.{0} IS NULL) '
'OR (OLD.{0} = NEW.{0}))'.format(c.name)
for c in columns_without_pks
]
else:
values += ['True'] * len(columns_without_pks)
values = ', '.join(values)

primary_key_condition = ' AND '.join(
'{name} = NEW.{name}'.format(name=c.name)
for c in columns if c.primary_key
)
parent_columns = tuple(
'{name} = NEW.{name}'.format(name=c.name)
for c in columns
)
mod_columns = tuple()
if uses_property_mod_tracking(self.manager):
mod_columns = tuple(
'{0}_mod = NOT ((OLD.{0} IS NULL AND NEW.{0} IS NULL) '
'OR (OLD.{0} = NEW.{0}))'.format(c.name)
for c in columns_without_pks
)

update_columns = ', '.join(
('%s = 1' % operation_type_column, ) +
parent_columns +
mod_columns
)

return upsert_cte_sql.format(
version_table_name=version_table_name,
transaction_column=self.manager.option(
cls, 'transaction_column_name'
),
operation_type_column=operation_type_column,
columns=', '.join(column_names),
values=values,
update_columns=update_columns,
operation_type=operation_type,
primary_key_condition=primary_key_condition
)

def get_version_table_name(self, cls, table):
version_table_name = (
self.manager.option(cls, 'table_name') % table.name
)
if table.schema:
version_table_name = '%s.%s' % (table.schema, version_table_name)
return version_table_name

def trigger_function_ddl(self, cls):
table = cls.__table__
reflector = ColumnReflector(self.manager, table, cls)
columns = list(reflector.reflected_parent_columns)

update_primary_key_condition = ' AND '.join(
'{name} = NEW.{name}'.format(name=c.name)
for c in columns if c.primary_key
)
delete_primary_key_condition = ' AND '.join(
'{name} = OLD.{name}'.format(name=c.name)
for c in columns if c.primary_key
)
after_delete = ''
after_insert = ''
after_update = ''

if self.manager.option(cls, 'strategy') == 'validity':
for table in sa.inspect(cls).tables:
version_table_name = self.get_version_table_name(cls, table)
sql = validity_sql.format(
version_table_name=version_table_name,
end_transaction_column=self.manager.option(
cls, 'end_transaction_column_name'
),
transaction_column=self.manager.option(
cls, 'transaction_column_name'
),
primary_key_condition=update_primary_key_condition
)
after_insert += sql
after_update += sql
after_delete += validity_sql.format(
version_table_name=version_table_name,
end_transaction_column=self.manager.option(
cls, 'end_transaction_column_name'
),
transaction_column=self.manager.option(
cls, 'transaction_column_name'
),
primary_key_condition=delete_primary_key_condition
)

sql = procedure_sql.format(
procedure_name='%s_audit' % table.name,
after_insert=after_insert,
after_update=after_update,
after_delete=after_delete,
upsert_insert=self.upsert_sql(cls, 0),
upsert_update=self.upsert_sql(cls, 1),
upsert_delete=self.upsert_sql(cls, 2)
)
return sa.schema.DDL(sql)

def add_native_versioning_triggers(self, cls):
sa.event.listen(
cls.__table__,
'after_create',
self.trigger_function_ddl(cls)
)
sa.event.listen(
cls.__table__,
'after_create',
self.trigger_ddl(cls)
)
sa.event.listen(
cls.__table__,
'after_drop',
sa.schema.DDL(
'DROP FUNCTION IF EXISTS %s()' %
'%s_audit' % cls.__table__.name,
)
)

def build_tables(self):
"""
Build tables for version models based on classes that were collected
during class instrumentation process.
"""
processed_tables = set()
for cls in self.manager.pending_classes:
if not self.manager.option(cls, 'versioning'):
continue

if self.manager.options['native_versioning']:
if cls.__table__ not in processed_tables:
self.add_native_versioning_triggers(cls)
processed_tables.add(cls.__table__)

inherited_table = None
for class_ in self.manager.tables:
if (issubclass(cls, class_) and
Expand Down
1 change: 1 addition & 0 deletions sqlalchemy_continuum/manager.py
Expand Up @@ -82,6 +82,7 @@ def __init__(
'table_name': '%s_version',
'exclude': [],
'include': [],
'native_versioning': False,
'transaction_column_name': 'transaction_id',
'end_transaction_column_name': 'end_transaction_id',
'operation_type_column_name': 'operation_type',
Expand Down
7 changes: 6 additions & 1 deletion sqlalchemy_continuum/table_builder.py
Expand Up @@ -80,7 +80,8 @@ def end_transaction_column(self):
index=True
)

def __iter__(self):
@property
def reflected_parent_columns(self):
for column in self.parent_table.c:
if (
self.model and
Expand All @@ -91,6 +92,10 @@ def __iter__(self):
reflected_column = self.reflect_column(column)
yield reflected_column

def __iter__(self):
for column in self.reflected_parent_columns:
yield column

# Only yield internal version columns if parent model is not using
# single table inheritance
if not self.model or not sa.inspect(self.model).single:
Expand Down
14 changes: 13 additions & 1 deletion sqlalchemy_continuum/transaction.py
Expand Up @@ -18,7 +18,6 @@ def compile_big_integer(element, compiler, **kw):


class TransactionBase(object):
id = sa.Column(sa.types.BigInteger, primary_key=True, autoincrement=True)
issued_at = sa.Column(sa.DateTime, default=datetime.utcnow)

@property
Expand Down Expand Up @@ -73,6 +72,19 @@ class Transaction(
__tablename__ = 'transaction'
__versioning_manager__ = manager

if manager.options['native_versioning']:
id = sa.Column(
sa.types.BigInteger,
primary_key=True,
autoincrement=False
)
else:
id = sa.Column(
sa.types.BigInteger,
primary_key=True,
autoincrement=True
)

if self.remote_addr:
remote_addr = sa.Column(sa.String(50))

Expand Down

0 comments on commit 838a764

Please sign in to comment.