Skip to content

Commit

Permalink
Add repeat conditional migrations
Browse files Browse the repository at this point in the history
If the migration file has `should_run` defined, it is considered a
repeat migration.  `should_run` is called every time `migrate` is run,
and if it returns true, `up` is run.
  • Loading branch information
karenc committed Aug 16, 2017
1 parent 96665ef commit 00e4473
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 24 deletions.
24 changes: 24 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,16 @@ Example usage::
generates a file called ``migrations/20151217170514_add_id_to_users.py``
with content::

# Uncomment should_run if this is a repeat migration
# def should_run(cursor):
# # TODO return True if migration should run


def up(cursor):
# TODO migration code
pass


def down(cursor):
# TODO rollback code
pass
Expand Down Expand Up @@ -188,6 +194,24 @@ if all migrations have already been run::
$ dbmigrator migrate
No pending migrations. Database is up to date.

To write a repeat migration, make sure your migration has ``should_run`` defined::

def should_run(cursor):
return os.path.exists('data.txt')


def up(cursor):
with open('data.txt') as f:
data = f.read()
cursor.execute('INSERT INTO table VALUES (%s)', (data,))


def down(cursor):
pass

The above migration will run **every time** ``migrate`` is called, except if it
is marked as "deferred". ``up`` is run if ``should_run`` returns True.

rollback
--------

Expand Down
5 changes: 5 additions & 0 deletions dbmigrator/commands/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def cli_command(migration_name='', **kwargs):
# -*- coding: utf-8 -*-
# Uncomment should_run if this is a repeat migration
# def should_run(cursor):
# # TODO return True if migration should run
def up(cursor):
# TODO migration code
pass
Expand Down
31 changes: 16 additions & 15 deletions dbmigrator/commands/rollback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,29 @@
def cli_command(cursor, migrations_directory='', steps=1,
db_connection_string='', **kwargs):
migrated_versions = list(utils.get_schema_versions(
cursor, include_deferred=False))
cursor, include_deferred=False, order_by='applied'))
logger.debug('migrated_versions: {}'.format(migrated_versions))
if not migrated_versions:
print('No migrations to roll back.')
return
migrations = utils.get_migrations(
migrations_directory, import_modules=True, reverse=True)
migrations = {version: (name, migration)
for version, name, migration in utils.get_migrations(
migrations_directory, import_modules=True,
reverse=True)}

rolled_back = 0
for version, migration_name, migration in migrations:
if not migrated_versions:
for version in reversed(migrated_versions):
if version not in migrations:
print('Migration {} not found.'.format(version))
break
last_version = migrated_versions[-1]
if version == last_version:
utils.compare_schema(db_connection_string,
utils.rollback_migration,
cursor,
version,
migration_name,
migration)
rolled_back += 1
migrated_versions.pop()
migration_name, migration = migrations[version]
utils.compare_schema(db_connection_string,
utils.rollback_migration,
cursor,
version,
migration_name,
migration)
rolled_back += 1
if rolled_back >= steps:
break

Expand Down
14 changes: 14 additions & 0 deletions dbmigrator/tests/data/md/20170810093842_create_a_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-


# Uncomment should_run if this is a repeat migration
# def should_run(cursor):
# # TODO return True if migration should run


def up(cursor):
cursor.execute('CREATE TABLE a_table (name TEXT)')


def down(cursor):
cursor.execute('DROP TABLE a_table')
19 changes: 19 additions & 0 deletions dbmigrator/tests/data/md/20170810093943_repeat_insert_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-

import os


def should_run(cursor):
return os.path.exists('insert_data.txt')


def up(cursor):
with open('insert_data.txt') as f:
data = f.read()
cursor.execute('INSERT INTO a_table VALUES (%s)', (data,))


def down(cursor):
with open('insert_data.txt') as f:
data = f.read()
cursor.execute('DELETE FROM a_table WHERE name = %s', (data,))
16 changes: 16 additions & 0 deletions dbmigrator/tests/data/md/20170810124056_empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-


# Uncomment should_run if this is a repeat migration
# def should_run(cursor):
# # TODO return True if migration should run


def up(cursor):
# TODO migration code
pass


def down(cursor):
# TODO rollback code
pass
109 changes: 109 additions & 0 deletions dbmigrator/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,61 @@ def cleanup():
WHERE table_name = 'a_table'""")
self.assertEqual([('a_table',)], cursor.fetchall())

def test_repeat(self):
md = os.path.join(testing.test_data_path, 'md')
cmd = ['--db-connection-string', testing.db_connection_string,
'--migrations-directory', md]

def cleanup():
if os.path.exists('insert_data.txt'):
os.remove('insert_data.txt')
with psycopg2.connect(testing.db_connection_string) as db_conn:
with db_conn.cursor() as cursor:
cursor.execute('DROP TABLE IF EXISTS a_table')

self.addCleanup(cleanup)

self.target(cmd + ['init', '--version', '0'])
with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])

stdout = out.getvalue()
self.assertIn('+CREATE TABLE a_table', stdout)
self.assertIn('Skipping migration 20170810093943', stdout)

# Run the repeat migration by creating this file
with open('insert_data.txt', 'w') as f:
f.write('三好')

with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])

stdout = out.getvalue()
self.assertIn('Running migration 20170810093943', stdout)

with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])

stdout = out.getvalue()
self.assertIn('Running migration 20170810093943', stdout)

# Make sure data has been inserted twice
with psycopg2.connect(testing.db_connection_string) as db_conn:
with db_conn.cursor() as cursor:
cursor.execute('SELECT name FROM a_table')
self.assertEqual([('三好',), ('三好',)],
cursor.fetchall())

# Mark the repeat migration as deferred
self.target(cmd + ['mark', '-d', '20170810093943'])

# The deferred repeat migration should not run
with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])

stdout = out.getvalue()
self.assertIn('No pending migrations', stdout)


class RollbackTestCase(BaseTestCase):
@mock.patch('dbmigrator.utils.timestamp')
Expand Down Expand Up @@ -486,9 +541,63 @@ def cleanup():
# Rollback three migrations
with testing.captured_output() as (out, err):
self.target(cmd + ['rollback', '--steps', '3'])
stdout = out.getvalue()
self.assertIn('Rolling back migration 20160228212456', stdout)
self.assertIn('Rolling back migration 20160228202637', stdout)

with psycopg2.connect(testing.db_connection_string) as db_conn:
with db_conn.cursor() as cursor:
cursor.execute("""\
SELECT table_name FROM information_schema.tables
WHERE table_name = 'a_table'""")
self.assertEqual(None, cursor.fetchone())

def test_repeat(self):
md = os.path.join(testing.test_data_path, 'md')
cmd = ['--db-connection-string', testing.db_connection_string,
'--migrations-directory', md]

def cleanup():
if os.path.exists('insert_data.txt'):
os.remove('insert_data.txt')
with psycopg2.connect(testing.db_connection_string) as db_conn:
with db_conn.cursor() as cursor:
cursor.execute('DROP TABLE IF EXISTS a_table')

self.addCleanup(cleanup)

self.target(cmd + ['init', '--version', '0'])
with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])

# Run the repeat migration by creating this file
with open('insert_data.txt', 'w') as f:
f.write('三好')

with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])

with open('insert_data.txt', 'w') as f:
f.write('ジョーカーゲーム')

with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])

# Version wise, the empty migration is more recent, but the repeat
# migration is the last migration applied, so rollback should rollback
# the repeat migration.
with testing.captured_output() as (out, err):
self.target(cmd + ['rollback'])
stdout = out.getvalue()
self.assertIn('Rolling back migration 20170810093943', stdout)

with psycopg2.connect(testing.db_connection_string) as db_conn:
with db_conn.cursor() as cursor:
cursor.execute('SELECT name FROM a_table')
self.assertEqual([('三好',)], cursor.fetchall())

# Next, rollback the empty migration
with testing.captured_output() as (out, err):
self.target(cmd + ['rollback'])
stdout = out.getvalue()
self.assertIn('Rolling back migration 20170810124056', stdout)
36 changes: 27 additions & 9 deletions dbmigrator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def get_migrations(migration_directories, import_modules=False, reverse=False):


def get_schema_versions(cursor, versions_only=True, raise_error=True,
include_deferred=True):
include_deferred=True, order_by='version'):
try:
cursor.execute('SELECT * FROM schema_migrations ORDER BY version')
cursor.execute('SELECT * FROM schema_migrations ORDER BY {}'
.format(order_by))
for i in cursor.fetchall():
if not include_deferred and i[1] is None:
continue
Expand All @@ -164,20 +165,29 @@ def get_schema_versions(cursor, versions_only=True, raise_error=True,

def get_pending_migrations(migration_directories, cursor, import_modules=False,
up_to_version=None):
migrated_versions = list(get_schema_versions(cursor))
if up_to_version and up_to_version in migrated_versions:
raise StopIteration
migrations = list(get_migrations(migration_directories, import_modules))
migrated_versions = {i[0]: i[1] or 'deferred'
for i in get_schema_versions(
cursor, versions_only=False)}

migrations = list(get_migrations(migration_directories,
import_modules=True))
versions = [m[0] for m in migrations]
if up_to_version:
try:
migrations = migrations[:versions.index(up_to_version) + 1]
except ValueError:
raise Exception('Version "{}" not found'.format(up_to_version))

for migration in migrations:
version = migration[0]
if version not in migrated_versions:
yield migration
version, migration_name, mod = migration
if not import_modules:
migration = migration[:-1]
if migrated_versions.get(version) != 'deferred':
if hasattr(mod, 'should_run'):
# repeat migrations are always included
yield migration
elif version not in migrated_versions:
yield migration


def compare_schema(db_connection_string, callback, *args, **kwargs):
Expand All @@ -193,6 +203,14 @@ def compare_schema(db_connection_string, callback, *args, **kwargs):


def run_migration(cursor, version, migration_name, migration):
try:
if not migration.should_run(cursor):
print('Skipping migration {} {}: should_run is false'
.format(version, migration_name))
return
except AttributeError:
# not a repeat migration
pass
print('Running migration {} {}'.format(version, migration_name))
migration.up(cursor)
mark_migration(cursor, version, True)
Expand Down

0 comments on commit 00e4473

Please sign in to comment.