Skip to content

Commit

Permalink
Merge pull request #2396 from loic/ticket21893
Browse files Browse the repository at this point in the history
Fixed #21893 -- ModelState didn't account for MTI parents inherited from abstract models.
  • Loading branch information
andrewgodwin committed Mar 4, 2014
2 parents 6fe22b3 + 6436f1f commit 8fcc014
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 6 deletions.
22 changes: 19 additions & 3 deletions django/db/migrations/state.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -151,19 +151,35 @@ def from_model(cls, model):
options[name] = set(normalize_together(it)) options[name] = set(normalize_together(it))
else: else:
options[name] = model._meta.original_attrs[name] options[name] = model._meta.original_attrs[name]

def flatten_bases(model):
bases = []
for base in model.__bases__:
if hasattr(base, "_meta") and base._meta.abstract:
bases.extend(flatten_bases(base))
else:
bases.append(base)
return bases

# We can't rely on __mro__ directly because we only want to flatten
# abstract models and not the whole tree. However by recursing on
# __bases__ we may end up with duplicates and ordering issues, we
# therefore discard any duplicates and reorder the bases according
# to their index in the MRO.
flattened_bases = sorted(set(flatten_bases(model)), key=lambda x:model.__mro__.index(x))

# Make our record # Make our record
bases = tuple( bases = tuple(
( (
"%s.%s" % (base._meta.app_label, base._meta.model_name) "%s.%s" % (base._meta.app_label, base._meta.model_name)
if hasattr(base, "_meta") else if hasattr(base, "_meta") else
base base
) )
for base in model.__bases__ for base in flattened_bases
if (not hasattr(base, "_meta") or not base._meta.abstract)
) )
# Ensure at least one base inherits from models.Model # Ensure at least one base inherits from models.Model
if not any((isinstance(base, six.string_types) or issubclass(base, models.Model)) for base in bases): if not any((isinstance(base, six.string_types) or issubclass(base, models.Model)) for base in bases):
bases = (models.Model, ) bases = (models.Model,)
return cls( return cls(
model._meta.app_label, model._meta.app_label,
model._meta.object_name, model._meta.object_name,
Expand Down
53 changes: 50 additions & 3 deletions tests/migrations/test_operations.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class OperationTests(MigrationTestBase):
both forwards and backwards. both forwards and backwards.
""" """


def set_up_test_model(self, app_label, second_model=False, related_model=False): def set_up_test_model(self, app_label, second_model=False, related_model=False, mti_model=False):
""" """
Creates a test model state and database table. Creates a test model state and database table.
""" """
Expand All @@ -38,7 +38,12 @@ def set_up_test_model(self, app_label, second_model=False, related_model=False):
], ],
)] )]
if second_model: if second_model:
operations.append(migrations.CreateModel("Stable", [("id", models.AutoField(primary_key=True))])) operations.append(migrations.CreateModel(
"Stable",
[
("id", models.AutoField(primary_key=True)),
]
))
if related_model: if related_model:
operations.append(migrations.CreateModel( operations.append(migrations.CreateModel(
"Rider", "Rider",
Expand All @@ -47,6 +52,21 @@ def set_up_test_model(self, app_label, second_model=False, related_model=False):
("pony", models.ForeignKey("Pony")), ("pony", models.ForeignKey("Pony")),
], ],
)) ))
if mti_model:
operations.append(migrations.CreateModel(
"ShetlandPony",
fields=[
('pony_ptr', models.OneToOneField(
auto_created=True,
primary_key=True,
to_field='id',
serialize=False,
to='Pony',
)),
("cuteness", models.IntegerField(default=1)),
],
bases=['%s.Pony' % app_label],
))
project_state = ProjectState() project_state = ProjectState()
for operation in operations: for operation in operations:
operation.state_forwards(app_label, project_state) operation.state_forwards(app_label, project_state)
Expand Down Expand Up @@ -495,7 +515,7 @@ def test_run_python(self):
Tests the RunPython operation Tests the RunPython operation
""" """


project_state = self.set_up_test_model("test_runpython") project_state = self.set_up_test_model("test_runpython", mti_model=True)


# Create the operation # Create the operation
def inner_method(models, schema_editor): def inner_method(models, schema_editor):
Expand Down Expand Up @@ -533,7 +553,34 @@ def inner_method_reverse(models, schema_editor):
no_reverse_operation.database_forwards("test_runpython", editor, project_state, new_state) no_reverse_operation.database_forwards("test_runpython", editor, project_state, new_state)
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
no_reverse_operation.database_backwards("test_runpython", editor, new_state, project_state) no_reverse_operation.database_backwards("test_runpython", editor, new_state, project_state)
self.assertEqual(project_state.render().get_model("test_runpython", "Pony").objects.count(), 2)


def create_ponies(models, schema_editor):
Pony = models.get_model("test_runpython", "Pony")
pony1 = Pony.objects.create(pink=1, weight=3.55)
self.assertIsNot(pony1.pk, None)
pony2 = Pony.objects.create(weight=5)
self.assertIsNot(pony2.pk, None)
self.assertNotEqual(pony1.pk, pony2.pk)

operation = migrations.RunPython(create_ponies)
with connection.schema_editor() as editor:
operation.database_forwards("test_runpython", editor, project_state, new_state)
self.assertEqual(project_state.render().get_model("test_runpython", "Pony").objects.count(), 4)

def create_shetlandponies(models, schema_editor):
ShetlandPony = models.get_model("test_runpython", "ShetlandPony")
pony1 = ShetlandPony.objects.create(weight=4.0)
self.assertIsNot(pony1.pk, None)
pony2 = ShetlandPony.objects.create(weight=5.0)
self.assertIsNot(pony2.pk, None)
self.assertNotEqual(pony1.pk, pony2.pk)

operation = migrations.RunPython(create_shetlandponies)
with connection.schema_editor() as editor:
operation.database_forwards("test_runpython", editor, project_state, new_state)
self.assertEqual(project_state.render().get_model("test_runpython", "Pony").objects.count(), 6)
self.assertEqual(project_state.render().get_model("test_runpython", "ShetlandPony").objects.count(), 2)


class MigrateNothingRouter(object): class MigrateNothingRouter(object):
""" """
Expand Down
15 changes: 15 additions & 0 deletions tests/migrations/test_state.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ class Meta:
app_label = "migrations" app_label = "migrations"
apps = Apps() apps = Apps()


class AbstractSubFooBar(FooBar):
class Meta:
abstract = True
apps = Apps()

class SubFooBar(AbstractSubFooBar):
class Meta:
app_label = "migrations"
apps = Apps()

apps = Apps(["migrations"]) apps = Apps(["migrations"])


# We shouldn't be able to render yet # We shouldn't be able to render yet
Expand All @@ -175,8 +185,13 @@ class Meta:


# Once the parent models are in the app registry, it should be fine # Once the parent models are in the app registry, it should be fine
ModelState.from_model(Foo).render(apps) ModelState.from_model(Foo).render(apps)
self.assertSequenceEqual(ModelState.from_model(Foo).bases, [models.Model])
ModelState.from_model(Bar).render(apps) ModelState.from_model(Bar).render(apps)
self.assertSequenceEqual(ModelState.from_model(Bar).bases, [models.Model])
ModelState.from_model(FooBar).render(apps) ModelState.from_model(FooBar).render(apps)
self.assertSequenceEqual(ModelState.from_model(FooBar).bases, ['migrations.foo', 'migrations.bar'])
ModelState.from_model(SubFooBar).render(apps)
self.assertSequenceEqual(ModelState.from_model(SubFooBar).bases, ['migrations.foobar'])


def test_render_project_dependencies(self): def test_render_project_dependencies(self):
""" """
Expand Down

0 comments on commit 8fcc014

Please sign in to comment.