Skip to content

Commit

Permalink
Fix alterfield, indexes issue
Browse files Browse the repository at this point in the history
Issue #77
  • Loading branch information
absci committed Dec 7, 2021
1 parent 9758247 commit d1cc2d8
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions mssql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from django import VERSION as django_version
from django.db.models import Index, UniqueConstraint
from django.db.models.fields import AutoField, BigAutoField, TextField
from django.db.models.fields import AutoField, BigAutoField
from django.db.models.sql.where import AND
from django.db.transaction import TransactionManagementError
from django.utils.encoding import force_str
Expand All @@ -27,6 +27,7 @@
from django.db.models.sql import Query
from django.db.backends.ddl_references import Expressions


class Statement(DjStatement):
def __hash__(self):
return hash((self.template, str(self.parts['name'])))
Expand All @@ -42,6 +43,7 @@ def rename_column_references(self, table, old_column, new_column):
if condition:
self.parts['condition'] = condition.replace(f'[{old_column}]', f'[{new_column}]')


class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):

_sql_check_constraint = " CONSTRAINT %(name)s CHECK (%(check)s)"
Expand Down Expand Up @@ -389,7 +391,8 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
columns_to_recreate_index = ', '.join(['%s' % self.quote_name(column[0]) for column in result])
filter_definition = result[0][1]
sql_restore_index += f'CREATE UNIQUE INDEX {index_name} ON {model._meta.db_table} ({columns_to_recreate_index}) WHERE {filter_definition};'
self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, model._meta.db_table, index_name))
self.execute(self._db_table_delete_constraint_sql(
self.sql_delete_index, model._meta.db_table, index_name))
self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
# Restore indexes for altered table
if(sql_restore_index):
Expand Down Expand Up @@ -440,7 +443,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
self._delete_unique_constraints(model, old_field, new_field, strict)
# Drop indexes, SQL Server requires explicit deletion
self._delete_indexes(model, old_field, new_field)
if not isinstance(new_field, TextField):
if not new_field.get_internal_type() in ("JSONField", "TextField") and not (old_field.db_index and new_field.db_index):
post_actions.append((self._create_index_sql(model, [new_field]), ()))
# Only if we have a default and there is a change from NULL to NOT NULL
four_way_default_alteration = (
Expand Down Expand Up @@ -562,7 +565,10 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
index_columns.append(columns)
if index_columns:
for columns in index_columns:
self.execute(self._create_index_sql(model, columns, suffix='_idx'))
create_index_sql_statement = self._create_index_sql(model, columns)
if create_index_sql_statement.__str__() not in [sql.__str__() for sql in self.deferred_sql]:
self.execute(create_index_sql_statement)

# Type alteration on primary key? Then we need to alter the column
# referring to us.
rels_to_update = []
Expand Down Expand Up @@ -592,7 +598,8 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
# Drop related_model indexes, so it can be altered
index_names = self._db_table_constraint_names(old_rel.related_model._meta.db_table, index=True)
for index_name in index_names:
self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, old_rel.related_model._meta.db_table, index_name))
self.execute(self._db_table_delete_constraint_sql(
self.sql_delete_index, old_rel.related_model._meta.db_table, index_name))
self.execute(
self.sql_alter_column % {
"table": self.quote_name(new_rel.related_model._meta.db_table),
Expand Down Expand Up @@ -646,10 +653,10 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,

def _delete_indexes(self, model, old_field, new_field):
index_columns = []
if old_field.null != new_field.null:
index_columns.append([old_field.column])
if old_field.db_index and new_field.db_index:
index_columns.append([old_field.column])
elif old_field.null != new_field.null:
index_columns.append([old_field.column])
for fields in model._meta.index_together:
columns = [model._meta.get_field(field).column for field in fields]
if old_field.column in columns:
Expand Down Expand Up @@ -754,12 +761,13 @@ def add_field(self, model, field):
self.connection.close()

if django_version >= (4, 0):
def _create_unique_sql(self, model, fields , name=None, condition=None, deferrable=None, include=None, opclasses=None, expressions=None):
if (deferrable and not getattr(self.connection.features, 'supports_deferrable_unique_constraints', False)
or
def _create_unique_sql(self, model, fields,
name=None, condition=None, deferrable=None,
include=None, opclasses=None, expressions=None):
if (deferrable and not getattr(self.connection.features, 'supports_deferrable_unique_constraints', False) or
(condition and not self.connection.features.supports_partial_indexes) or
(include and not self.connection.features.supports_covering_indexes) or
(expressions and not self.connection.features.supports_expression_indexes)):
(expressions and not self.connection.features.supports_expression_indexes)):
return None

def create_unique_name(*args, **kwargs):
Expand Down Expand Up @@ -803,12 +811,13 @@ def create_unique_name(*args, **kwargs):
include=include,
)
else:
def _create_unique_sql(self, model, columns , name=None, condition=None, deferrable=None, include=None, opclasses=None, expressions=None):
if (deferrable and not getattr(self.connection.features, 'supports_deferrable_unique_constraints', False)
or
def _create_unique_sql(self, model, columns,
name=None, condition=None, deferrable=None,
include=None, opclasses=None, expressions=None):
if (deferrable and not getattr(self.connection.features, 'supports_deferrable_unique_constraints', False) or
(condition and not self.connection.features.supports_partial_indexes) or
(include and not self.connection.features.supports_covering_indexes) or
(expressions and not self.connection.features.supports_expression_indexes)):
(expressions and not self.connection.features.supports_expression_indexes)):
return None

def create_unique_name(*args, **kwargs):
Expand All @@ -823,7 +832,7 @@ def create_unique_name(*args, **kwargs):
statement_args = {
"deferrable": self._deferrable_constraint_sql(deferrable)
} if django_version >= (3, 1) else {}
include = self._index_include_sql(model, include) if django_version >=(3, 2) else ''
include = self._index_include_sql(model, include) if django_version >= (3, 2) else ''

if condition:
return Statement(
Expand Down Expand Up @@ -973,7 +982,8 @@ def _delete_unique_sql(
if condition or include or opclasses:
sql = self.sql_delete_index
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE WHERE CONSTRAINT_NAME = '%s'" % name)
cursor.execute(
"SELECT 1 FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE WHERE CONSTRAINT_NAME = '%s'" % name)
row = cursor.fetchone()
if row:
sql = self.sql_delete_unique
Expand Down Expand Up @@ -1113,6 +1123,6 @@ def _create_index_name(self, table_name, column_names, suffix=""):
index_name = super()._create_index_name(table_name, column_names, suffix)
# Check if the db_table specified a user-defined schema
if('].[' in index_name):
new_index_name = index_name.replace('[','').replace(']','').replace('.', '_')
new_index_name = index_name.replace('[', '').replace(']', '').replace('.', '_')
return new_index_name
return index_name

0 comments on commit d1cc2d8

Please sign in to comment.