Skip to content

Commit

Permalink
Refs #29898 -- Moved state_forwards()'s logic from migration operatio…
Browse files Browse the repository at this point in the history
…ns to ProjectState.

Thanks Simon Charette and Markus Holtermann for reviews.
  • Loading branch information
manav014 authored and felixxm committed Jun 29, 2021
1 parent 594d6e9 commit 503ee41
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 138 deletions.
95 changes: 16 additions & 79 deletions django/db/migrations/operations/fields.py
@@ -1,7 +1,4 @@
from django.core.exceptions import FieldDoesNotExist
from django.db.migrations.utils import (
field_is_referenced, field_references, get_references,
)
from django.db.migrations.utils import field_references
from django.db.models import NOT_PROVIDED
from django.utils.functional import cached_property

Expand Down Expand Up @@ -85,16 +82,13 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
# If preserve default is off, don't use the default for future state
if not self.preserve_default:
field = self.field.clone()
field.default = NOT_PROVIDED
else:
field = self.field
state.models[app_label, self.model_name_lower].fields[self.name] = field
# Delay rendering of relationships if it's not a relational field
delay = not field.is_relation
state.reload_model(app_label, self.model_name_lower, delay=delay)
state.add_field(
app_label,
self.model_name_lower,
self.name,
self.field,
self.preserve_default,
)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
Expand Down Expand Up @@ -160,11 +154,7 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
old_field = model_state.fields.pop(self.name)
# Delay rendering of relationships if it's not a relational field
delay = not old_field.is_relation
state.reload_model(app_label, self.model_name_lower, delay=delay)
state.remove_field(app_label, self.model_name_lower, self.name)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
Expand Down Expand Up @@ -216,24 +206,13 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
if not self.preserve_default:
field = self.field.clone()
field.default = NOT_PROVIDED
else:
field = self.field
model_state = state.models[app_label, self.model_name_lower]
model_state.fields[self.name] = field
# TODO: investigate if old relational fields must be reloaded or if it's
# sufficient if the new field is (#27737).
# Delay rendering of relationships if it's not a relational field and
# not referenced by a foreign key.
delay = (
not field.is_relation and
not field_is_referenced(
state, (app_label, self.model_name_lower), (self.name, field),
)
state.alter_field(
app_label,
self.model_name_lower,
self.name,
self.field,
self.preserve_default,
)
state.reload_model(app_label, self.model_name_lower, delay=delay)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
Expand Down Expand Up @@ -301,49 +280,7 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
# Rename the field
fields = model_state.fields
try:
found = fields.pop(self.old_name)
except KeyError:
raise FieldDoesNotExist(
"%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
)
fields[self.new_name] = found
for field in fields.values():
# Fix from_fields to refer to the new field.
from_fields = getattr(field, 'from_fields', None)
if from_fields:
field.from_fields = tuple([
self.new_name if from_field_name == self.old_name else from_field_name
for from_field_name in from_fields
])
# Fix index/unique_together to refer to the new field
options = model_state.options
for option in ('index_together', 'unique_together'):
if option in options:
options[option] = [
[self.new_name if n == self.old_name else n for n in together]
for together in options[option]
]
# Fix to_fields to refer to the new field.
delay = True
references = get_references(
state, (app_label, self.model_name_lower), (self.old_name, found),
)
for *_, field, reference in references:
delay = False
if reference.to:
remote_field, to_fields = reference.to
if getattr(remote_field, 'field_name', None) == self.old_name:
remote_field.field_name = self.new_name
if to_fields:
field.to_fields = tuple([
self.new_name if to_field_name == self.old_name else to_field_name
for to_field_name in to_fields
])
state.reload_model(app_label, self.model_name_lower, delay=delay)
state.rename_field(app_label, self.model_name_lower, self.old_name, self.new_name)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
Expand Down
83 changes: 24 additions & 59 deletions django/db/migrations/operations/models.py
@@ -1,9 +1,7 @@
from django.db import models
from django.db.migrations.operations.base import Operation
from django.db.migrations.state import ModelState
from django.db.migrations.utils import (
field_references, get_references, resolve_relation,
)
from django.db.migrations.utils import field_references, resolve_relation
from django.db.models.options import normalize_together
from django.utils.functional import cached_property

Expand Down Expand Up @@ -316,31 +314,7 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
# Add a new model.
renamed_model = state.models[app_label, self.old_name_lower].clone()
renamed_model.name = self.new_name
state.models[app_label, self.new_name_lower] = renamed_model
# Repoint all fields pointing to the old model to the new one.
old_model_tuple = (app_label, self.old_name_lower)
new_remote_model = '%s.%s' % (app_label, self.new_name)
to_reload = set()
for model_state, name, field, reference in get_references(state, old_model_tuple):
changed_field = None
if reference.to:
changed_field = field.clone()
changed_field.remote_field.model = new_remote_model
if reference.through:
if changed_field is None:
changed_field = field.clone()
changed_field.remote_field.through = new_remote_model
if changed_field:
model_state.fields[name] = changed_field
to_reload.add((model_state.app_label, model_state.name_lower))
# Reload models related to old model before removing the old model.
state.reload_models(to_reload, delay=True)
# Remove the old model.
state.remove_model(app_label, self.old_name_lower)
state.reload_model(app_label, self.new_name_lower, delay=True)
state.rename_model(app_label, self.old_name, self.new_name)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.new_name)
Expand Down Expand Up @@ -458,8 +432,7 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
state.models[app_label, self.name_lower].options["db_table"] = self.table
state.reload_model(app_label, self.name_lower, delay=True)
state.alter_model_options(app_label, self.name_lower, {'db_table': self.table})

def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name)
Expand Down Expand Up @@ -518,9 +491,11 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name_lower]
model_state.options[self.option_name] = self.option_value
state.reload_model(app_label, self.name_lower, delay=True)
state.alter_model_options(
app_label,
self.name_lower,
{self.option_name: self.option_value},
)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name)
Expand Down Expand Up @@ -596,9 +571,11 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name_lower]
model_state.options['order_with_respect_to'] = self.order_with_respect_to
state.reload_model(app_label, self.name_lower, delay=True)
state.alter_model_options(
app_label,
self.name_lower,
{self.option_name: self.order_with_respect_to},
)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.name)
Expand Down Expand Up @@ -676,12 +653,12 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name_lower]
model_state.options = {**model_state.options, **self.options}
for key in self.ALTER_OPTION_KEYS:
if key not in self.options:
model_state.options.pop(key, False)
state.reload_model(app_label, self.name_lower, delay=True)
state.alter_model_options(
app_label,
self.name_lower,
self.options,
self.ALTER_OPTION_KEYS,
)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass
Expand Down Expand Up @@ -714,9 +691,7 @@ def deconstruct(self):
)

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.name_lower]
model_state.managers = list(self.managers)
state.reload_model(app_label, self.name_lower, delay=True)
state.alter_model_managers(app_label, self.name_lower, self.managers)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass
Expand Down Expand Up @@ -753,9 +728,7 @@ def __init__(self, model_name, index):
self.index = index

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.index.clone()]
state.reload_model(app_label, self.model_name_lower, delay=True)
state.add_index(app_label, self.model_name_lower, self.index)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
Expand Down Expand Up @@ -804,10 +777,7 @@ def __init__(self, model_name, name):
self.name = name

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
indexes = model_state.options[self.option_name]
model_state.options[self.option_name] = [idx for idx in indexes if idx.name != self.name]
state.reload_model(app_label, self.model_name_lower, delay=True)
state.remove_index(app_label, self.model_name_lower, self.name)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.model_name)
Expand Down Expand Up @@ -850,9 +820,7 @@ def __init__(self, model_name, constraint):
self.constraint = constraint

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.constraint]
state.reload_model(app_label, self.model_name_lower, delay=True)
state.add_constraint(app_label, self.model_name_lower, self.constraint)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
Expand Down Expand Up @@ -886,10 +854,7 @@ def __init__(self, model_name, name):
self.name = name

def state_forwards(self, app_label, state):
model_state = state.models[app_label, self.model_name_lower]
constraints = model_state.options[self.option_name]
model_state.options[self.option_name] = [c for c in constraints if c.name != self.name]
state.reload_model(app_label, self.model_name_lower, delay=True)
state.remove_constraint(app_label, self.model_name_lower, self.name)

def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
Expand Down

0 comments on commit 503ee41

Please sign in to comment.