diff --git a/dbmigrator/tests/test_cli.py b/dbmigrator/tests/test_cli.py index 7624952..086ba73 100644 --- a/dbmigrator/tests/test_cli.py +++ b/dbmigrator/tests/test_cli.py @@ -10,16 +10,22 @@ import unittest import pkg_resources +import psycopg2 -from .testing import captured_output +from . import testing class MainTestCase(unittest.TestCase): + def tearDown(self): + with psycopg2.connect(testing.db_connection_string) as db_conn: + with db_conn.cursor() as cursor: + cursor.execute('DROP TABLE IF EXISTS schema_migrations') + def test_version(self): from ..cli import main version = pkg_resources.get_distribution('db-migrator').version - with captured_output() as (out, err): + with testing.captured_output() as (out, err): self.assertRaises(SystemExit, main, ['-V']) stdout = out.getvalue() @@ -30,3 +36,21 @@ def test_version(self): else: self.assertEqual(stdout, '') self.assertEqual(stderr.strip(), version) + + def test_list_no_migrations_directory(self): + from ..cli import main + + cmd = ['--db-connection-string', testing.db_connection_string] + main(cmd + ['init']) + with testing.captured_output() as (out, err): + main(cmd + ['list']) + + stdout = out.getvalue() + stderr = err.getvalue() + + self.assertEqual(stdout, """\ +name | is applied | date applied +----------------------------------------------------------------------\n""") + self.assertEqual(stderr, """\ +context undefined, using current directory name "['db-migrator']" +migrations directory undefined\n""") diff --git a/dbmigrator/tests/test_utils.py b/dbmigrator/tests/test_utils.py index 8b50dc0..cee6b35 100644 --- a/dbmigrator/tests/test_utils.py +++ b/dbmigrator/tests/test_utils.py @@ -147,6 +147,13 @@ def test_get_migrations(self): ('20160228210326', 'initial_data'), ('20160228212456', 'cool_stuff')]) + def test_get_migrations_no_migrations_directories(self): + from ..utils import get_migrations + + migrations = get_migrations([]) + + self.assertEqual(list(migrations), []) + def test_get_pending_migrations(self): from ..utils import get_pending_migrations diff --git a/dbmigrator/utils.py b/dbmigrator/utils.py index 803314a..49eada3 100644 --- a/dbmigrator/utils.py +++ b/dbmigrator/utils.py @@ -94,7 +94,7 @@ def get_migrations(migration_directories, import_modules=False, reverse=False): paths = [os.path.join(md, '*.py') for md in migration_directories] python_files = functools.reduce( lambda a, b: a + b, - [glob.glob(path) for path in paths]) + [glob.glob(path) for path in paths], []) for path in sorted(python_files, key=lambda path: os.path.basename(path), reverse=reverse):