Skip to content

Commit

Permalink
Fix migration planner to fully understand squashed migrations. And test.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgodwin committed Oct 23, 2013
1 parent 4cfbde7 commit 5ab8b5d
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 37 deletions.
19 changes: 15 additions & 4 deletions django/db/migrations/executor.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class MigrationExecutor(object):
def __init__(self, connection, progress_callback=None): def __init__(self, connection, progress_callback=None):
self.connection = connection self.connection = connection
self.loader = MigrationLoader(self.connection) self.loader = MigrationLoader(self.connection)
self.loader.load_disk()
self.recorder = MigrationRecorder(self.connection) self.recorder = MigrationRecorder(self.connection)
self.progress_callback = progress_callback self.progress_callback = progress_callback


Expand All @@ -20,7 +19,7 @@ def migration_plan(self, targets):
Given a set of targets, returns a list of (Migration instance, backwards?). Given a set of targets, returns a list of (Migration instance, backwards?).
""" """
plan = [] plan = []
applied = self.recorder.applied_migrations() applied = set(self.loader.applied_migrations)
for target in targets: for target in targets:
# If the target is (appname, None), that means unmigrate everything # If the target is (appname, None), that means unmigrate everything
if target[1] is None: if target[1] is None:
Expand Down Expand Up @@ -87,7 +86,13 @@ def apply_migration(self, migration, fake=False):
with self.connection.schema_editor() as schema_editor: with self.connection.schema_editor() as schema_editor:
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
migration.apply(project_state, schema_editor) migration.apply(project_state, schema_editor)
self.recorder.record_applied(migration.app_label, migration.name) # For replacement migrations, record individual statuses
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_applied(app_label, name)
else:
self.recorder.record_applied(migration.app_label, migration.name)
# Report prgress
if self.progress_callback: if self.progress_callback:
self.progress_callback("apply_success", migration) self.progress_callback("apply_success", migration)


Expand All @@ -101,6 +106,12 @@ def unapply_migration(self, migration, fake=False):
with self.connection.schema_editor() as schema_editor: with self.connection.schema_editor() as schema_editor:
project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False) project_state = self.loader.graph.project_state((migration.app_label, migration.name), at_end=False)
migration.unapply(project_state, schema_editor) migration.unapply(project_state, schema_editor)
self.recorder.record_unapplied(migration.app_label, migration.name) # For replacement migrations, record individual statuses
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_unapplied(app_label, name)
else:
self.recorder.record_unapplied(migration.app_label, migration.name)
# Report progress
if self.progress_callback: if self.progress_callback:
self.progress_callback("unapply_success", migration) self.progress_callback("unapply_success", migration)
53 changes: 29 additions & 24 deletions django/db/migrations/loader.py
Original file line number Original file line Diff line number Diff line change
@@ -1,9 +1,10 @@
import os import os
import sys
from importlib import import_module from importlib import import_module
from django.utils.functional import cached_property
from django.db.models.loading import cache from django.db.models.loading import cache
from django.db.migrations.recorder import MigrationRecorder from django.db.migrations.recorder import MigrationRecorder
from django.db.migrations.graph import MigrationGraph from django.db.migrations.graph import MigrationGraph
from django.utils import six
from django.conf import settings from django.conf import settings




Expand Down Expand Up @@ -32,10 +33,12 @@ class MigrationLoader(object):
in memory. in memory.
""" """


def __init__(self, connection): def __init__(self, connection, load=True):
self.connection = connection self.connection = connection
self.disk_migrations = None self.disk_migrations = None
self.applied_migrations = None self.applied_migrations = None
if load:
self.build_graph()


@classmethod @classmethod
def migrations_module(cls, app_label): def migrations_module(cls, app_label):
Expand All @@ -55,6 +58,7 @@ def load_disk(self):
# Get the migrations module directory # Get the migrations module directory
app_label = app.__name__.split(".")[-2] app_label = app.__name__.split(".")[-2]
module_name = self.migrations_module(app_label) module_name = self.migrations_module(app_label)
was_loaded = module_name in sys.modules
try: try:
module = import_module(module_name) module = import_module(module_name)
except ImportError as e: except ImportError as e:
Expand All @@ -71,6 +75,9 @@ def load_disk(self):
# Module is not a package (e.g. migrations.py). # Module is not a package (e.g. migrations.py).
if not hasattr(module, '__path__'): if not hasattr(module, '__path__'):
continue continue
# Force a reload if it's already loaded (tests need this)
if was_loaded:
six.moves.reload_module(module)
self.migrated_apps.add(app_label) self.migrated_apps.add(app_label)
directory = os.path.dirname(module.__file__) directory = os.path.dirname(module.__file__)
# Scan for .py[c|o] files # Scan for .py[c|o] files
Expand Down Expand Up @@ -107,9 +114,6 @@ def get_migration(self, app_label, name_prefix):


def get_migration_by_prefix(self, app_label, name_prefix): def get_migration_by_prefix(self, app_label, name_prefix):
"Returns the migration(s) which match the given app label and name _prefix_" "Returns the migration(s) which match the given app label and name _prefix_"
# Make sure we have the disk data
if self.disk_migrations is None:
self.load_disk()
# Do the search # Do the search
results = [] results = []
for l, n in self.disk_migrations: for l, n in self.disk_migrations:
Expand All @@ -122,18 +126,17 @@ def get_migration_by_prefix(self, app_label, name_prefix):
else: else:
return self.disk_migrations[results[0]] return self.disk_migrations[results[0]]


@cached_property def build_graph(self):
def graph(self):
""" """
Builds a migration dependency graph using both the disk and database. Builds a migration dependency graph using both the disk and database.
You'll need to rebuild the graph if you apply migrations. This isn't
usually a problem as generally migration stuff runs in a one-shot process.
""" """
# Make sure we have the disk data # Load disk data
if self.disk_migrations is None: self.load_disk()
self.load_disk() # Load database data
# And the database data recorder = MigrationRecorder(self.connection)
if self.applied_migrations is None: self.applied_migrations = recorder.applied_migrations()
recorder = MigrationRecorder(self.connection)
self.applied_migrations = recorder.applied_migrations()
# Do a first pass to separate out replacing and non-replacing migrations # Do a first pass to separate out replacing and non-replacing migrations
normal = {} normal = {}
replacing = {} replacing = {}
Expand All @@ -152,12 +155,12 @@ def graph(self):
# Carry out replacements if we can - that is, if all replaced migrations # Carry out replacements if we can - that is, if all replaced migrations
# are either unapplied or missing. # are either unapplied or missing.
for key, migration in replacing.items(): for key, migration in replacing.items():
# Do the check # Ensure this replacement migration is not in applied_migrations
can_replace = True self.applied_migrations.discard(key)
for target in migration.replaces: # Do the check. We can replace if all our replace targets are
if target in self.applied_migrations: # applied, or if all of them are unapplied.
can_replace = False applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
break can_replace = all(applied_statuses) or (not any(applied_statuses))
if not can_replace: if not can_replace:
continue continue
# Alright, time to replace. Step through the replaced migrations # Alright, time to replace. Step through the replaced migrations
Expand All @@ -171,14 +174,16 @@ def graph(self):
normal[child_key].dependencies.remove(replaced) normal[child_key].dependencies.remove(replaced)
normal[child_key].dependencies.append(key) normal[child_key].dependencies.append(key)
normal[key] = migration normal[key] = migration
# Mark the replacement as applied if all its replaced ones are
if all(applied_statuses):
self.applied_migrations.add(key)
# Finally, make a graph and load everything into it # Finally, make a graph and load everything into it
graph = MigrationGraph() self.graph = MigrationGraph()
for key, migration in normal.items(): for key, migration in normal.items():
graph.add_node(key, migration) self.graph.add_node(key, migration)
for key, migration in normal.items(): for key, migration in normal.items():
for parent in migration.dependencies: for parent in migration.dependencies:
graph.add_dependency(key, parent) self.graph.add_dependency(key, parent)
return graph




class BadMigrationError(Exception): class BadMigrationError(Exception):
Expand Down
5 changes: 5 additions & 0 deletions django/db/migrations/migration.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class IrreversibleError(RuntimeError):
def __init__(self, name, app_label): def __init__(self, name, app_label):
self.name = name self.name = name
self.app_label = app_label self.app_label = app_label
# Copy dependencies & other attrs as we might mutate them at runtime
self.operations = list(self.__class__.operations)
self.dependencies = list(self.__class__.dependencies)
self.run_before = list(self.__class__.run_before)
self.replaces = list(self.__class__.replaces)


def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Migration): if not isinstance(other, Migration):
Expand Down
53 changes: 53 additions & 0 deletions tests/migrations/test_executor.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -38,7 +38,58 @@ def test_run(self):
# Are the tables there now? # Are the tables there now?
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor())) self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Rebuild the graph to reflect the new DB state
executor.loader.build_graph()
# Alright, let's undo what we did # Alright, let's undo what we did
plan = executor.migration_plan([("migrations", None)])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0002_second"], True),
(executor.loader.graph.nodes["migrations", "0001_initial"], True),
],
)
executor.migrate([("migrations", None)])
# Are the tables gone?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))

@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
def test_run_with_squashed(self):
"""
Tests running a squashed migration from zero (should ignore what it replaces)
"""
executor = MigrationExecutor(connection)
executor.recorder.flush()
# Check our leaf node is the squashed one
leaves = [key for key in executor.loader.graph.leaf_nodes() if key[0] == "migrations"]
self.assertEqual(leaves, [("migrations", "0001_squashed_0002")])
# Check the plan
plan = executor.migration_plan([("migrations", "0001_squashed_0002")])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0001_squashed_0002"], False),
],
)
# Were the tables there before?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertNotIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Alright, let's try running it
executor.migrate([("migrations", "0001_squashed_0002")])
# Are the tables there now?
self.assertIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
self.assertIn("migrations_book", connection.introspection.get_table_list(connection.cursor()))
# Rebuild the graph to reflect the new DB state
executor.loader.build_graph()
# Alright, let's undo what we did. Should also just use squashed.
plan = executor.migration_plan([("migrations", None)])
self.assertEqual(
plan,
[
(executor.loader.graph.nodes["migrations", "0001_squashed_0002"], True),
],
)
executor.migrate([("migrations", None)]) executor.migrate([("migrations", None)])
# Are the tables gone? # Are the tables gone?
self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor())) self.assertNotIn("migrations_author", connection.introspection.get_table_list(connection.cursor()))
Expand Down Expand Up @@ -70,6 +121,8 @@ def test_empty_plan(self):
) )
# Fake-apply all migrations # Fake-apply all migrations
executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True) executor.migrate([("migrations", "0002_second"), ("sessions", "0001_initial")], fake=True)
# Rebuild the graph to reflect the new DB state
executor.loader.build_graph()
# Now plan a second time and make sure it's empty # Now plan a second time and make sure it's empty
plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")]) plan = executor.migration_plan([("migrations", "0002_second"), ("sessions", "0001_initial")])
self.assertEqual(plan, []) self.assertEqual(plan, [])
Expand Down
31 changes: 22 additions & 9 deletions tests/migrations/test_loader.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -82,21 +82,34 @@ def test_name_match(self):
migration_loader.get_migration_by_prefix("migrations", "blarg") migration_loader.get_migration_by_prefix("migrations", "blarg")


def test_load_import_error(self): def test_load_import_error(self):
migration_loader = MigrationLoader(connection)

with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}): with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.import_error"}):
with self.assertRaises(ImportError): with self.assertRaises(ImportError):
migration_loader.load_disk() MigrationLoader(connection)


def test_load_module_file(self): def test_load_module_file(self):
migration_loader = MigrationLoader(connection)

with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}): with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.file"}):
migration_loader.load_disk() MigrationLoader(connection)


@skipIf(six.PY2, "PY2 doesn't load empty dirs.") @skipIf(six.PY2, "PY2 doesn't load empty dirs.")
def test_load_empty_dir(self): def test_load_empty_dir(self):
migration_loader = MigrationLoader(connection)

with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}): with override_settings(MIGRATION_MODULES={"migrations": "migrations.faulty_migrations.namespace"}):
migration_loader.load_disk() MigrationLoader(connection)

@override_settings(MIGRATION_MODULES={"migrations": "migrations.test_migrations_squashed"})
def test_loading_squashed(self):
"Tests loading a squashed migration"
migration_loader = MigrationLoader(connection)
recorder = MigrationRecorder(connection)
# Loading with nothing applied should just give us the one node
self.assertEqual(
len(migration_loader.graph.nodes),
1,
)
# However, fake-apply one migration and it should now use the old two
recorder.record_applied("migrations", "0001_initial")
migration_loader.build_graph()
self.assertEqual(
len(migration_loader.graph.nodes),
2,
)
recorder.flush()
27 changes: 27 additions & 0 deletions tests/migrations/test_migrations_squashed/0001_initial.py
Original file line number Original file line Diff line number Diff line change
@@ -0,0 +1,27 @@
from django.db import migrations, models


class Migration(migrations.Migration):

operations = [

migrations.CreateModel(
"Author",
[
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=255)),
("slug", models.SlugField(null=True)),
("age", models.IntegerField(default=0)),
("silly_field", models.BooleanField(default=False)),
],
),

migrations.CreateModel(
"Tribble",
[
("id", models.AutoField(primary_key=True)),
("fluffy", models.BooleanField(default=True)),
],
)

]
32 changes: 32 additions & 0 deletions tests/migrations/test_migrations_squashed/0001_squashed_0002.py
Original file line number Original file line Diff line number Diff line change
@@ -0,0 +1,32 @@
from django.db import migrations, models


class Migration(migrations.Migration):

replaces = [
("migrations", "0001_initial"),
("migrations", "0002_second"),
]

operations = [

migrations.CreateModel(
"Author",
[
("id", models.AutoField(primary_key=True)),
("name", models.CharField(max_length=255)),
("slug", models.SlugField(null=True)),
("age", models.IntegerField(default=0)),
("rating", models.IntegerField(default=0)),
],
),

migrations.CreateModel(
"Book",
[
("id", models.AutoField(primary_key=True)),
("author", models.ForeignKey("migrations.Author", null=True)),
],
),

]
24 changes: 24 additions & 0 deletions tests/migrations/test_migrations_squashed/0002_second.py
Original file line number Original file line Diff line number Diff line change
@@ -0,0 +1,24 @@
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [("migrations", "0001_initial")]

operations = [

migrations.DeleteModel("Tribble"),

migrations.RemoveField("Author", "silly_field"),

migrations.AddField("Author", "rating", models.IntegerField(default=0)),

migrations.CreateModel(
"Book",
[
("id", models.AutoField(primary_key=True)),
("author", models.ForeignKey("migrations.Author", null=True)),
],
)

]
Empty file.

0 comments on commit 5ab8b5d

Please sign in to comment.