Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Add an Executor for end-to-end running

  • Loading branch information...
commit e6f7f4533c183800c2a9ac526d8ee8887e96ac5d 1 parent 7f9a0b7
Andrew Godwin authored May 30, 2013
68  django/db/migrations/executor.py
... ...
@@ -0,0 +1,68 @@
  1
+from .loader import MigrationLoader
  2
+from .recorder import MigrationRecorder
  3
+
  4
+
  5
+class MigrationExecutor(object):
  6
+    """
  7
+    End-to-end migration execution - loads migrations, and runs them
  8
+    up or down to a specified set of targets.
  9
+    """
  10
+
  11
+    def __init__(self, connection):
  12
+        self.connection = connection
  13
+        self.loader = MigrationLoader(self.connection)
  14
+        self.recorder = MigrationRecorder(self.connection)
  15
+
  16
+    def migration_plan(self, targets):
  17
+        """
  18
+        Given a set of targets, returns a list of (Migration instance, backwards?).
  19
+        """
  20
+        plan = []
  21
+        applied = self.recorder.applied_migrations()
  22
+        for target in targets:
  23
+            # If the migration is already applied, do backwards mode,
  24
+            # otherwise do forwards mode.
  25
+            if target in applied:
  26
+                for migration in self.loader.graph.backwards_plan(target)[:-1]:
  27
+                    if migration in applied:
  28
+                        plan.append((self.loader.graph.nodes[migration], True))
  29
+                        applied.remove(migration)
  30
+            else:
  31
+                for migration in self.loader.graph.forwards_plan(target):
  32
+                    if migration not in applied:
  33
+                        plan.append((self.loader.graph.nodes[migration], False))
  34
+                        applied.add(migration)
  35
+        return plan
  36
+
  37
+    def migrate(self, targets):
  38
+        """
  39
+        Migrates the database up to the given targets.
  40
+        """
  41
+        plan = self.migration_plan(targets)
  42
+        for migration, backwards in plan:
  43
+            if not backwards:
  44
+                self.apply_migration(migration)
  45
+            else:
  46
+                self.unapply_migration(migration)
  47
+
  48
+    def apply_migration(self, migration):
  49
+        """
  50
+        Runs a migration forwards.
  51
+        """
  52
+        print "Applying %s" % migration
  53
+        with self.connection.schema_editor() as schema_editor:
  54
+            project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
  55
+            migration.apply(project_state, schema_editor)
  56
+        self.recorder.record_applied(migration.app_label, migration.name)
  57
+        print "Finished %s" % migration
  58
+
  59
+    def unapply_migration(self, migration):
  60
+        """
  61
+        Runs a migration backwards.
  62
+        """
  63
+        print "Unapplying %s" % migration
  64
+        with self.connection.schema_editor() as schema_editor:
  65
+            project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
  66
+            migration.unapply(project_state, schema_editor)
  67
+        self.recorder.record_unapplied(migration.app_label, migration.name)
  68
+        print "Finished %s" % migration
48  django/db/migrations/migration.py
@@ -36,6 +36,17 @@ def __init__(self, name, app_label):
36 36
         self.name = name
37 37
         self.app_label = app_label
38 38
 
  39
+    def __eq__(self, other):
  40
+        if not isinstance(other, Migration):
  41
+            return False
  42
+        return (self.name == other.name) and (self.app_label == other.app_label)
  43
+
  44
+    def __ne__(self, other):
  45
+        return not (self == other)
  46
+
  47
+    def __repr__(self):
  48
+        return "<Migration %s.%s>" % (self.app_label, self.name)
  49
+
39 50
     def mutate_state(self, project_state):
40 51
         """
41 52
         Takes a ProjectState and returns a new one with the migration's
@@ -45,3 +56,40 @@ def mutate_state(self, project_state):
45 56
         for operation in self.operations:
46 57
             operation.state_forwards(self.app_label, new_state)
47 58
         return new_state
  59
+
  60
+    def apply(self, project_state, schema_editor):
  61
+        """
  62
+        Takes a project_state representing all migrations prior to this one
  63
+        and a schema_editor for a live database and applies the migration
  64
+        in a forwards order.
  65
+
  66
+        Returns the resulting project state for efficient re-use by following
  67
+        Migrations.
  68
+        """
  69
+        for operation in self.operations:
  70
+            # Get the state after the operation has run
  71
+            new_state = project_state.clone()
  72
+            operation.state_forwards(self.app_label, new_state)
  73
+            # Run the operation
  74
+            operation.database_forwards(self.app_label, schema_editor, project_state, new_state)
  75
+            # Switch states
  76
+            project_state = new_state
  77
+        return project_state
  78
+
  79
+    def unapply(self, project_state, schema_editor):
  80
+        """
  81
+        Takes a project_state representing all migrations prior to this one
  82
+        and a schema_editor for a live database and applies the migration
  83
+        in a reverse order.
  84
+        """
  85
+        # We need to pre-calculate the stack of project states
  86
+        to_run = []
  87
+        for operation in self.operations:
  88
+            new_state = project_state.clone()
  89
+            operation.state_forwards(self.app_label, new_state)
  90
+            to_run.append((operation, project_state, new_state))
  91
+            project_state = new_state
  92
+        # Now run them in reverse
  93
+        to_run.reverse()
  94
+        for operation, to_state, from_state in to_run:
  95
+            operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
16  django/db/migrations/operations/fields.py
@@ -16,13 +16,13 @@ def state_forwards(self, app_label, state):
16 16
 
17 17
     def database_forwards(self, app_label, schema_editor, from_state, to_state):
18 18
         app_cache = to_state.render()
19  
-        model = app_cache.get_model(app_label, self.name)
20  
-        schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
  19
+        model = app_cache.get_model(app_label, self.model_name)
  20
+        schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
21 21
 
22 22
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
23 23
         app_cache = from_state.render()
24  
-        model = app_cache.get_model(app_label, self.name)
25  
-        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
  24
+        model = app_cache.get_model(app_label, self.model_name)
  25
+        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
26 26
 
27 27
 
28 28
 class RemoveField(Operation):
@@ -43,10 +43,10 @@ def state_forwards(self, app_label, state):
43 43
 
44 44
     def database_forwards(self, app_label, schema_editor, from_state, to_state):
45 45
         app_cache = from_state.render()
46  
-        model = app_cache.get_model(app_label, self.name)
47  
-        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name))
  46
+        model = app_cache.get_model(app_label, self.model_name)
  47
+        schema_editor.remove_field(model, model._meta.get_field_by_name(self.name)[0])
48 48
 
49 49
     def database_backwards(self, app_label, schema_editor, from_state, to_state):
50 50
         app_cache = to_state.render()
51  
-        model = app_cache.get_model(app_label, self.name)
52  
-        schema_editor.add_field(model, model._meta.get_field_by_name(self.name))
  51
+        model = app_cache.get_model(app_label, self.model_name)
  52
+        schema_editor.add_field(model, model._meta.get_field_by_name(self.name)[0])
2  tests/migrations/migrations/0002_second.py
@@ -11,7 +11,7 @@ class Migration(migrations.Migration):
11 11
 
12 12
         migrations.RemoveField("Author", "silly_field"),
13 13
 
14  
-        migrations.AddField("Author", "important", models.BooleanField()),
  14
+        migrations.AddField("Author", "rating", models.IntegerField(default=0)),
15 15
 
16 16
         migrations.CreateModel(
17 17
             "Book",
35  tests/migrations/test_executor.py
... ...
@@ -0,0 +1,35 @@
  1
+from django.test import TransactionTestCase
  2
+from django.db import connection
  3
+from django.db.migrations.executor import MigrationExecutor
  4
+
  5
+
  6
+class ExecutorTests(TransactionTestCase):
  7
+    """
  8
+    Tests the migration executor (full end-to-end running).
  9
+
  10
+    Bear in mind that if these are failing you should fix the other
  11
+    test failures first, as they may be propagating into here.
  12
+    """
  13
+
  14
+    def test_run(self):
  15
+        """
  16
+        Tests running a simple set of migrations.
  17
+        """
  18
+        executor = MigrationExecutor(connection)
  19
+        # Let's look at the plan first and make sure it's up to scratch
  20
+        plan = executor.migration_plan([("migrations", "0002_second")])
  21
+        self.assertEqual(
  22
+            plan,
  23
+            [
  24
+                (executor.loader.graph.nodes["migrations", "0001_initial"], False),
  25
+                (executor.loader.graph.nodes["migrations", "0002_second"], False),
  26
+            ],
  27
+        )
  28
+        # Were the tables there before?
  29
+        self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
  30
+        self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
  31
+        # Alright, let's try running it
  32
+        executor.migrate([("migrations", "0002_second")])
  33
+        # Are the tables there now?
  34
+        self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
  35
+        self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
2  tests/migrations/test_loader.py
@@ -54,7 +54,7 @@ def test_load(self):
54 54
         author_state = project_state.models["migrations", "author"]
55 55
         self.assertEqual(
56 56
             [x for x, y in author_state.fields],
57  
-            ["id", "name", "slug", "age", "important"]
  57
+            ["id", "name", "slug", "age", "rating"]
58 58
         )
59 59
 
60 60
         book_state = project_state.models["migrations", "book"]
28  tests/migrations/test_operations.py
... ...
@@ -1,6 +1,6 @@
1 1
 from django.test import TransactionTestCase
2 2
 from django.db import connection, models, migrations
3  
-from django.db.migrations.state import ProjectState, ModelState
  3
+from django.db.migrations.state import ProjectState
4 4
 
5 5
 
6 6
 class OperationTests(TransactionTestCase):
@@ -16,6 +16,12 @@ def assertTableExists(self, table):
16 16
     def assertTableNotExists(self, table):
17 17
         self.assertNotIn(table, connection.introspection.get_table_list(connection.cursor()))
18 18
 
  19
+    def assertColumnExists(self, table, column):
  20
+        self.assertIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
  21
+
  22
+    def assertColumnNotExists(self, table, column):
  23
+        self.assertNotIn(column, [c.name for c in connection.introspection.get_table_description(connection.cursor(), table)])
  24
+
19 25
     def set_up_test_model(self, app_label):
20 26
         """
21 27
         Creates a test model state and database table.
@@ -82,3 +88,23 @@ def test_delete_model(self):
82 88
         with connection.schema_editor() as editor:
83 89
             operation.database_backwards("test_dlmo", editor, new_state, project_state)
84 90
         self.assertTableExists("test_dlmo_pony")
  91
+
  92
+    def test_add_field(self):
  93
+        """
  94
+        Tests the AddField operation.
  95
+        """
  96
+        project_state = self.set_up_test_model("test_adfl")
  97
+        # Test the state alteration
  98
+        operation = migrations.AddField("Pony", "height", models.FloatField(null=True))
  99
+        new_state = project_state.clone()
  100
+        operation.state_forwards("test_adfl", new_state)
  101
+        self.assertEqual(len(new_state.models["test_adfl", "pony"].fields), 3)
  102
+        # Test the database alteration
  103
+        self.assertColumnNotExists("test_adfl_pony", "height")
  104
+        with connection.schema_editor() as editor:
  105
+            operation.database_forwards("test_adfl", editor, project_state, new_state)
  106
+        self.assertColumnExists("test_adfl_pony", "height")
  107
+        # And test reversal
  108
+        with connection.schema_editor() as editor:
  109
+            operation.database_backwards("test_adfl", editor, new_state, project_state)
  110
+        self.assertColumnNotExists("test_adfl_pony", "height")

0 notes on commit e6f7f45

Please sign in to comment.
Something went wrong with that request. Please try again.