Skip to content

Commit

Permalink
Merge d7fe839 into 00e4473
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc committed Aug 16, 2017
2 parents 00e4473 + d7fe839 commit 04c441e
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 42 deletions.
11 changes: 11 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,17 @@ To write a repeat migration, make sure your migration has ``should_run`` defined
The above migration will run **every time** ``migrate`` is called, except if it
is marked as "deferred". ``up`` is run if ``should_run`` returns True.

To write a deferrable migration, add ``@deferrable`` to the up function::

from dbmigrator import deferrable


@deferrable
def up(cursor):
# this migration is automatically deferred

The above migration will not run unless you use ``migrate --run-defers``.

rollback
--------

Expand Down
4 changes: 2 additions & 2 deletions dbmigrator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

from .utils import logger
from .utils import logger, deferrable


__all__ = ('logger',)
__all__ = ('logger', 'deferrable')
14 changes: 8 additions & 6 deletions dbmigrator/commands/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ def cli_command(cursor, migrations_directory='', db_connection_string='',
migrated_versions = dict(list(
utils.get_schema_versions(cursor, versions_only=False,
raise_error=False)))
migrations = utils.get_migrations(migrations_directory)
migrations = utils.get_migrations(migrations_directory,
import_modules=True)

if wide:
migrations = list(migrations)
name_width = max([len(name) for _, name in migrations])
name_width = max([len(name) for _, name, _ in migrations])
else:
name_width = 15

Expand All @@ -35,14 +36,15 @@ def cli_command(cursor, migrations_directory='', db_connection_string='',
print('-' * 70)

name_format = '{: <%s}' % (name_width,)
for version, migration_name in migrations:
for version, migration_name, migration in migrations:
applied_timestamp = migrated_versions.get(version, '')
deferred = applied_timestamp is None
deferred = utils.is_deferred(
version, migration, migrated_versions)
is_applied = deferred and 'deferred' or \
bool(version in migrated_versions)
bool(migrated_versions.get(version))
print('{} {} {!s: <10} {}'.format(
version, name_format.format(migration_name[:name_width]),
is_applied, migrated_versions.get(version, '')))
is_applied, migrated_versions.get(version) or ''))


def cli_loader(parser):
Expand Down
23 changes: 11 additions & 12 deletions dbmigrator/commands/mark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@ def cli_command(cursor, migrations_directory='', migration_timestamp='',
if completed is None:
raise Exception('-t, -f or -d must be supplied.')

migrations = utils.get_migrations(
migrations_directory, import_modules=False)
for version, _ in migrations:
if version == migration_timestamp:
break
else:
migrated_versions = list(utils.get_schema_versions(cursor))
if migration_timestamp not in migrated_versions:
logger.warning(
'Migration {} not found'.format(migration_timestamp))

utils.mark_migration(cursor, migration_timestamp, completed)
migrations = {version: migration
for version, _, migration in utils.get_migrations(
migrations_directory, import_modules=True)}
migration = migrations.get(migration_timestamp)
if migration is None:
raise SystemExit(
'Migration {} not found'.format(migration_timestamp))

utils.mark_migration(
cursor, migration_timestamp, completed,
deferrable=hasattr(migration.up, 'dbmigrator_deferrable'))
if not completed:
message = 'not been run'
elif completed == 'deferred':
Expand Down
11 changes: 8 additions & 3 deletions dbmigrator/commands/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@

@utils.with_cursor
def cli_command(cursor, migrations_directory='', version='',
db_connection_string='', **kwargs):
db_connection_string='', run_defers=False, **kwargs):
pending_migrations = utils.get_pending_migrations(
migrations_directory, cursor, import_modules=True,
up_to_version=version)
up_to_version=version, include_defers=True)

migrated = False
for version, migration_name, migration in pending_migrations:
print('pending_migration: {} {}'.format(version, migration_name))
migrated = True
utils.compare_schema(db_connection_string,
utils.run_migration,
cursor,
version,
migration_name,
migration)
migration,
run_defers)

if not migrated:
print('No pending migrations. Database is up to date.')
Expand All @@ -37,4 +39,7 @@ def cli_command(cursor, migrations_directory='', version='',
def cli_loader(parser):
parser.add_argument('--version',
help='Migrate database up to this version')
parser.add_argument('--run-defers',
action='store_true',
help='Also run the deferred migrations')
return cli_command
3 changes: 3 additions & 0 deletions dbmigrator/tests/data/md/20170810124056_empty.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# -*- coding: utf-8 -*-

from dbmigrator import deferrable


# Uncomment should_run if this is a repeat migration
# def should_run(cursor):
# # TODO return True if migration should run


@deferrable
def up(cursor):
# TODO migration code
pass
Expand Down
117 changes: 106 additions & 11 deletions dbmigrator/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,24 +217,26 @@ def test_mark_as_true_already_true(self):
self.assertIn('20160228202637 add_table True', stdout)
self.assertIn('20160228212456 cool_stuff True', stdout)

@mock.patch('dbmigrator.logger.warning')
def test_migration_not_found(self, warning):
def test_migration_not_found(self):
testing.install_test_packages()
cmd = ['--db-connection-string', testing.db_connection_string]

self.target(cmd + ['--context', 'package-a', 'init', '--version', '0'])

self.target(cmd + ['mark', '-t', '012345'])
warning.assert_called_with('Migration 012345 not found')
with self.assertRaises(SystemExit) as cm:
self.target(cmd + ['mark', '-t', '012345'])
self.assertEqual('Migration 012345 not found', str(cm.exception))

self.target(cmd + ['mark', '-f', '012345'])
warning.assert_called_with('Migration 012345 not found')
with self.assertRaises(SystemExit) as cm:
self.target(cmd + ['mark', '-f', '012345'])
self.assertEqual('Migration 012345 not found', str(cm.exception))

def test_mark_as_false(self):
testing.install_test_packages()
cmd = ['--db-connection-string', testing.db_connection_string]
cmd = ['--db-connection-string', testing.db_connection_string,
'--context', 'package-a']

self.target(cmd + ['--context', 'package-a', 'init'])
self.target(cmd + ['init'])

with testing.captured_output() as (out, err):
self.target(cmd + ['mark', '-f', '20160228202637'])
Expand All @@ -244,7 +246,7 @@ def test_mark_as_false(self):
stdout)

with testing.captured_output() as (out, err):
self.target(cmd + ['--context', 'package-a', 'list'])
self.target(cmd + ['list'])

stdout = out.getvalue()
self.assertIn('20160228202637 add_table False', stdout)
Expand Down Expand Up @@ -325,6 +327,48 @@ def test_mark_as_deferred(self):
self.assertIn('20160228202637 add_table deferred', stdout)
self.assertIn('20160228212456 cool_stuff False', stdout)

def test_deferrable(self):
md = os.path.join(testing.test_data_path, 'md')
cmd = ['--db-connection-string', testing.db_connection_string,
'--migrations-directory', md]

self.target(cmd + ['init', '--version', '0'])

# check list output
with testing.captured_output() as (out, err):
self.target(cmd + ['list'])

stdout = out.getvalue()
self.assertIn('20170810124056 empty deferred', stdout)

# mark a deferrable migration as not deferred
with testing.captured_output() as (out, err):
self.target(cmd + ['mark', '-f', '20170810124056'])

stdout = out.getvalue()
self.assertEqual('Migration 20170810124056 marked as not been run\n',
stdout)

with testing.captured_output() as (out, err):
self.target(cmd + ['list'])

stdout = out.getvalue()
self.assertIn('20170810124056 empty False', stdout)

# mark a deferrable migration as completed
with testing.captured_output() as (out, err):
self.target(cmd + ['mark', '-t', '20170810124056'])

stdout = out.getvalue()
self.assertEqual('Migration 20170810124056 marked as completed\n',
stdout)

with testing.captured_output() as (out, err):
self.target(cmd + ['list'])

stdout = out.getvalue()
self.assertIn('20170810124056 empty True', stdout)


class GenerateTestCase(BaseTestCase):
@mock.patch('dbmigrator.utils.timestamp')
Expand Down Expand Up @@ -477,7 +521,58 @@ def cleanup():
self.target(cmd + ['migrate'])

stdout = out.getvalue()
self.assertIn('No pending migrations', stdout)
self.assertIn('Skipping deferred migration 20170810124056 empty',
stdout)

def test_deferrable(self):
md = os.path.join(testing.test_data_path, 'md')
cmd = ['--db-connection-string', testing.db_connection_string,
'--migrations-directory', md]

def cleanup():
if os.path.exists('insert_data.txt'):
os.remove('insert_data.txt')
with psycopg2.connect(testing.db_connection_string) as db_conn:
with db_conn.cursor() as cursor:
cursor.execute('DROP TABLE IF EXISTS a_table')

self.addCleanup(cleanup)

self.target(cmd + ['init', '--version', '0'])
with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])

stdout = out.getvalue()
self.assertIn('+CREATE TABLE a_table', stdout)
self.assertIn('Skipping deferred migration 20170810124056 empty',
stdout)

# Run the repeat migration by creating this file
with open('insert_data.txt', 'w') as f:
f.write('三好')

with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])

stdout = out.getvalue()
self.assertIn('Running migration 20170810093943', stdout)

# Mark the deferrable as not been run
self.target(cmd + ['mark', '-f', '20170810124056'])

with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])
stdout = out.getvalue()
self.assertIn('Running migration 20170810124056', stdout)

# Mark the deferable as deferred
self.target(cmd + ['mark', '-d', '20170810124056'])

with testing.captured_output() as (out, err):
self.target(cmd + ['migrate', '--run-defers'])

stdout = out.getvalue()
self.assertIn('Running migration 20170810124056', stdout)


class RollbackTestCase(BaseTestCase):
Expand Down Expand Up @@ -568,7 +663,7 @@ def cleanup():

self.target(cmd + ['init', '--version', '0'])
with testing.captured_output() as (out, err):
self.target(cmd + ['migrate'])
self.target(cmd + ['migrate', '--run-defers'])

# Run the repeat migration by creating this file
with open('insert_data.txt', 'w') as f:
Expand Down

0 comments on commit 04c441e

Please sign in to comment.