Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Add an Executor for end-to-end running

  • Loading branch information...
commit e6f7f4533c183800c2a9ac526d8ee8887e96ac5d 1 parent 7f9a0b7
@andrewgodwin andrewgodwin authored
View
68 django/db/migrations/executor.py
@@ -0,0 +1,68 @@
+from .loader import MigrationLoader
+from .recorder import MigrationRecorder
+
+
+class MigrationExecutor(object):
+ """
+ End-to-end migration execution - loads migrations, and runs them
+ up or down to a specified set of targets.
+ """
+
+ def __init__(self, connection):
+ self.connection = connection
+ self.loader = MigrationLoader(self.connection)
+ self.recorder = MigrationRecorder(self.connection)
+
+ def migration_plan(self, targets):
+ """
+ Given a set of targets, returns a list of (Migration instance, backwards?).
+ """
+ plan = []
+ applied = self.recorder.applied_migrations()
+ for target in targets:
+ # If the migration is already applied, do backwards mode,
+ # otherwise do forwards mode.
+ if target in applied:
+ for migration in self.loader.graph.backwards_plan(target)[:-1]:
+ if migration in applied:
+ plan.append((self.loader.graph.nodes[migration], True))
+ applied.remove(migration)
+ else:
+ for migration in self.loader.graph.forwards_plan(target):
+ if migration not in applied:
+ plan.append((self.loader.graph.nodes[migration], False))
+ applied.add(migration)
+ return plan
+
+ def migrate(self, targets):
+ """
+ Migrates the database up to the given targets.
+ """
+ plan = self.migration_plan(targets)
+ for migration, backwards in plan:
+ if not backwards:
+ self.apply_migration(migration)
+ else:
+ self.unapply_migration(migration)
+
+ def apply_migration(self, migration):
+ """
+ Runs a migration forwards.
+ """
+ print "Applying %s" % migration
+ with self.connection.schema_editor() as schema_editor:
+ project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
+ migration.apply(project_state, schema_editor)
+ self.recorder.record_applied(migration.app_label, migration.name)
+ print "Finished %s" % migration
+
+ def unapply_migration(self, migration):
+ """
+ Runs a migration backwards.
+ """
+ print "Unapplying %s" % migration
+ with self.connection.schema_editor() as schema_editor:
+ project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
+ migration.unapply(project_state, schema_editor)
+ self.recorder.record_unapplied(migration.app_label, migration.name)
+ print "Finished %s" % migration
View
48 django/db/migrations/migration.py
@@ -36,6 +36,17 @@ def __init__(self, name, app_label):
self.name = name
self.app_label = app_label
+ def __eq__(self, other):
+ if not isinstance(other, Migration):
+ return False
+ return (self.name == other.name) and (self.app_label == other.app_label)
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ def __repr__(self):
+ return "<Migration %s.%s>" % (self.app_label, self.name)
+
def mutate_state(self, project_state):
"""
Takes a ProjectState and returns a new one with the migration's
@@ -45,3 +56,40 @@ def mutate_state(self, project_state):
for operation in self.operations:
operation.state_forwards(self.app_label, new_state)
return new_state
+
+ def apply(self, project_state, schema_editor):
+ """
+ Takes a project_state representing all migrations prior to this one
+ and a schema_editor for a live database and applies the migration
+ in a forwards order.
+
+ Returns the resulting project state for efficient re-use by following
+ Migrations.
+ """
+ for operation in self.operations:
+ # Get the state after the operation has run
+ new_state = project_state.clone()
+ operation.state_forwards(self.app_label, new_state)
+ # Run the operation
+ operation.database_forwards(self.app_label, schema_editor, project_state, new_state)
+ # Switch states
+ project_state = new_state
+ return project_state
+
+ def unapply(self, project_state, schema_editor):
+ """
+ Takes a project_state representing all migrations prior to this one
+ and a schema_editor for a live database and applies the migration
+ in a reverse order.
+ """
+ # We need to pre-calculate the stack of project states
+ to_run = []
+ for operation in self.operations:
+ new_state = project_state.clone()
+ operation.state_forwards(self.app_label, new_state)
+ to_run.append((operation, project_state, new_state))
+ project_state = new_state
+ # Now run them in reverse
+ to_run.reverse()
+ for operation, to_state, from_state in to_run:
+ operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
View
16 django/db/migrations/operations/fields.py
@@ -16,13 +16,13 @@ def state_forwards(self, app_label, state):
def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render()
- model = app_cache.get_model(app_label, self.name)
- schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
+ model = app_cache.get_model(app_label, self.model_name)
+ schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render()
- model = app_cache.get_model(app_label, self.name)
- schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
+ model = app_cache.get_model(app_label, self.model_name)
+ schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
class RemoveField(Operation):
@@ -43,10 +43,10 @@ def state_forwards(self, app_label, state):
def database_forwards(self, app_label, schema_editor, from_state, to_state):
app_cache = from_state.render()
- model = app_cache.get_model(app_label, self.name)
- schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
+ model = app_cache.get_model(app_label, self.model_name)
+ schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
def database_backwards(self, app_label, schema_editor, from_state, to_state):
app_cache = to_state.render()
- model = app_cache.get_model(app_label, self.name)
- schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
+ model = app_cache.get_model(app_label, self.model_name)
+ schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
View
2  tests/migrations/migrations/0002_second.py
@@ -11,7 +11,7 @@ class Migration(migrations.Migration):
migrations.RemoveField("Author", "silly_field"),
- migrations.AddField("Author", "important", models.BooleanField()),
+ migrations.AddField("Author", "rating", models.IntegerField(default=0)),
migrations.CreateModel(
"Book",
View
35 tests/migrations/test_executor.py
@@ -0,0 +1,35 @@
+from django.test import TransactionTestCase
+from django.db import connection
+from django.db.migrations.executor import MigrationExecutor
+
+
+class ExecutorTests(TransactionTestCase):
+ """
+ Tests the migration executor (full end-to-end running).
+
+ Bear in mind that if these are failing you should fix the other
+ test failures first, as they may be propagating into here.
+ """
+
+ def test_run(self):
+ """
+ Tests running a simple set of migrations.
+ """
+ executor = MigrationExecutor(connection)
+ # Let's look at the plan first and make sure it's up to scratch
+ plan = executor.migration_plan([("migrations", "0002_second")])
+ self.assertEqual(
+ plan,
+ [
+ (executor.loader.graph.nodes["migrations", "0001_initial"], False),
+ (executor.loader.graph.nodes["migrations", "0002_second"], False),
+ ],
+ )
+ # Were the tables there before?
+ self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
+ self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
+ # Alright, let's try running it
+ executor.migrate([("migrations", "0002_second")])
+ # Are the tables there now?
+ self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
+ self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
View
2  tests/migrations/test_loader.py
@@ -54,7 +54,7 @@ def test_load(self):
author_state = project_state.models["migrations", "author"]
self.assertEqual(
[x for x, y in author_state.fields],
- ["id", "name", "slug", "age", "important"]
+ ["id", "name", "slug", "age", "rating"]
)
book_state = project_state.models["migrations", "book"]
View
28 tests/migrations/test_operations.py
@@ -1,6 +1,6 @@
from django.test import TransactionTestCase
from django.db import connection, models, migrations
-from django.db.migrations.state import ProjectState, ModelState
+from django.db.migrations.state import ProjectState
class OperationTests(TransactionTestCase):
@@ -16,6 +16,12 @@ def assertTableExists(self, table):
def assertTableNotExists(self, table):
self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
+ def assertColumnExists(self, table, column):
+ self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
+
+ def assertColumnNotExists(self, table, column):
+ self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
+
def set_up_test_model(self, app_label):
"""
Creates a test model state and database table.
@@ -82,3 +88,23 @@ def test_delete_model(self):
with connection.schema_editor() as editor:
operation.database_backwards("test_dlmo", editor, new_state, project_state)
self.assertTableExists("test_dlmo_pony")
+
+ def test_add_field(self):
+ """
+ Tests the AddField operation.
+ """
+ project_state = self.set_up_test_model("test_adfl")
+ # Test the state alteration
+ operation = migrations.AddField("Pony", "height", models.FloatField(null=True))
+ new_state = project_state.clone()
+ operation.state_forwards("test_adfl", new_state)
+ self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 3)
+ # Test the database alteration
+ self.assertColumnNotExists("test_adfl_pony", "height")
+ with connection.schema_editor() as editor:
+ operation.database_forwards("test_adfl", editor, project_state, new_state)
+ self.assertColumnExists("test_adfl_pony", "height")
+ # And test reversal
+ with connection.schema_editor() as editor:
+ operation.database_backwards("test_adfl", editor, new_state, project_state)
+ self.assertColumnNotExists("test_adfl_pony", "height")
Please sign in to comment.
Something went wrong with that request. Please try again.