Skip to content

Commit

Permalink
Add support for nullable fields in unique_together (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
OskarPersson committed Dec 11, 2019
1 parent 8bf0154 commit 2e60754
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 25 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
@@ -1,5 +1,5 @@
[flake8]
exclude = .git,__pycache__,
exclude = .git,__pycache__,migrations
# W504 is mutually exclusive with W503
ignore = W504
max-line-length = 119
8 changes: 2 additions & 6 deletions sql_server/pyodbc/base.py
Expand Up @@ -480,17 +480,13 @@ def check_constraints(self, table_names=None):
table_names)

def disable_constraint_checking(self):
# Azure SQL Database doesn't support sp_msforeachtable
# cursor.execute('EXEC sp_msforeachtable "ALTER TABLE ? NOCHECK CONSTRAINT ALL"')
if not self.needs_rollback:
self._execute_foreach('ALTER TABLE %s NOCHECK CONSTRAINT ALL')
self.cursor().execute('EXEC sp_msforeachtable "ALTER TABLE ? NOCHECK CONSTRAINT ALL"')
return not self.needs_rollback

def enable_constraint_checking(self):
# Azure SQL Database doesn't support sp_msforeachtable
# cursor.execute('EXEC sp_msforeachtable "ALTER TABLE ? WITH CHECK CHECK CONSTRAINT ALL"')
if not self.needs_rollback:
self.check_constraints()
self.cursor().execute('EXEC sp_msforeachtable "ALTER TABLE ? WITH CHECK CHECK CONSTRAINT ALL"')


class CursorWrapper(object):
Expand Down
5 changes: 4 additions & 1 deletion sql_server/pyodbc/features.py
Expand Up @@ -23,7 +23,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_ignore_conflicts = False
supports_index_on_text_field = False
supports_paramstyle_pyformat = False
supports_partially_nullable_unique_constraints = False
supports_regex_backreferencing = False
supports_sequence_reset = False
supports_subqueries_in_group_by = False
Expand All @@ -41,6 +40,10 @@ def has_bulk_insert(self):
def supports_nullable_unique_constraints(self):
return self.connection.sql_server_version > 2005

@cached_property
def supports_partially_nullable_unique_constraints(self):
return self.connection.sql_server_version > 2005

@cached_property
def supports_partial_indexes(self):
return self.connection.sql_server_version > 2005
Expand Down
196 changes: 183 additions & 13 deletions sql_server/pyodbc/schema.py
Expand Up @@ -5,14 +5,22 @@
BaseDatabaseSchemaEditor, logger, _is_relevant_relation, _related_non_m2m_objects,
)
from django.db.backends.ddl_references import (
Statement,
Columns, IndexName, Statement as DjStatement, Table,
)
from django.db.models import Index
from django.db.models.fields import AutoField, BigAutoField
from django.db.transaction import TransactionManagementError
from django.utils.encoding import force_text


class Statement(DjStatement):
def __hash__(self):
return hash((self.template, str(self.parts['name'])))

def __eq__(self, other):
return self.template == other.template and str(self.parts['name']) == str(other.parts['name'])


class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):

_sql_check_constraint = " CONSTRAINT %(name)s CHECK (%(check)s)"
Expand Down Expand Up @@ -123,6 +131,105 @@ def _alter_column_type_sql(self, model, old_field, new_field, new_type):
new_type = self._set_field_new_type_null_status(old_field, new_type)
return super()._alter_column_type_sql(model, old_field, new_field, new_type)

def alter_unique_together(self, model, old_unique_together, new_unique_together):
"""
Deal with a model changing its unique_together. The input
unique_togethers must be doubly-nested, not the single-nested
["foo", "bar"] format.
"""
olds = {tuple(fields) for fields in old_unique_together}
news = {tuple(fields) for fields in new_unique_together}
# Deleted uniques
for fields in olds.difference(news):
self._delete_composed_index(model, fields, {'unique': True}, self.sql_delete_index)
# Created uniques
for fields in news.difference(olds):
columns = [model._meta.get_field(field).column for field in fields]
condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns])
sql = self._create_unique_sql(model, columns, condition=condition)
self.execute(sql)

def _model_indexes_sql(self, model):
"""
Return a list of all index SQL statements (field indexes,
index_together, Meta.indexes) for the specified model.
"""
if not model._meta.managed or model._meta.proxy or model._meta.swapped:
return []
output = []
for field in model._meta.local_fields:
output.extend(self._field_indexes_sql(model, field))

for field_names in model._meta.index_together:
fields = [model._meta.get_field(field) for field in field_names]
output.append(self._create_index_sql(model, fields, suffix="_idx"))

for field_names in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in field_names]
condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns])
sql = self._create_unique_sql(model, columns, condition=condition)
output.append(sql)

for index in model._meta.indexes:
output.append(index.create_sql(model, self))
return output

def _alter_many_to_many(self, model, old_field, new_field, strict):
"""Alter M2Ms to repoint their to= endpoints."""

for idx in self._constraint_names(old_field.remote_field.through, index=True, unique=True):
self.execute(self.sql_delete_index % {'name': idx, 'table': old_field.remote_field.through._meta.db_table})

return super()._alter_many_to_many(model, old_field, new_field, strict)

def _db_table_constraint_names(self, db_table, column_names=None, unique=None,
primary_key=None, index=None, foreign_key=None,
check=None, type_=None, exclude=None):
"""Return all constraint names matching the columns and conditions."""
if column_names is not None:
column_names = [
self.connection.introspection.identifier_converter(name)
for name in column_names
]
with self.connection.cursor() as cursor:
constraints = self.connection.introspection.get_constraints(cursor, db_table)
result = []
for name, infodict in constraints.items():
if column_names is None or column_names == infodict['columns']:
if unique is not None and infodict['unique'] != unique:
continue
if primary_key is not None and infodict['primary_key'] != primary_key:
continue
if index is not None and infodict['index'] != index:
continue
if check is not None and infodict['check'] != check:
continue
if foreign_key is not None and not infodict['foreign_key']:
continue
if type_ is not None and infodict['type'] != type_:
continue
if not exclude or name not in exclude:
result.append(name)
return result

def _db_table_delete_constraint_sql(self, template, db_table, name):
return Statement(
template,
table=Table(db_table, self.quote_name),
name=self.quote_name(name),
)

def alter_db_table(self, model, old_db_table, new_db_table):
index_names = self._db_table_constraint_names(old_db_table, index=True)
for index_name in index_names:
self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, old_db_table, index_name))

index_names = self._db_table_constraint_names(new_db_table, index=True)
for index_name in index_names:
self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, new_db_table, index_name))

return super().alter_db_table(model, old_db_table, new_db_table)

def _alter_field(self, model, old_field, new_field, old_type, new_type,
old_db_params, new_db_params, strict=False):
"""Actually perform a "physical" (non-ManyToMany) field update."""
Expand Down Expand Up @@ -224,11 +331,15 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
self.execute(self._delete_constraint_sql(self.sql_delete_check, model, constraint_name))
# Have they renamed the column?
if old_field.column != new_field.column:
# remove old indices
self._delete_indexes(model, old_field, new_field)

self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
# Rename all references to the renamed column.
for sql in self.deferred_sql:
if isinstance(sql, Statement):
if isinstance(sql, DjStatement):
sql.rename_column_references(model._meta.db_table, old_field.column, new_field.column)

# Next, start accumulating actions to do
actions = []
null_actions = []
Expand Down Expand Up @@ -286,6 +397,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
actions = [(", ".join(sql), sum(params, []))]
# Apply those actions
for sql, params in actions:
self._delete_indexes(model, old_field, new_field)
self.execute(
self.sql_alter_column % {
"table": self.quote_name(model._meta.db_table),
Expand Down Expand Up @@ -438,6 +550,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
"changes": changes_sql,
}
self.execute(sql, params)

# Reset connection if required
if self.connection.features.connection_persists_old_columns:
self.connection.close()
Expand All @@ -446,11 +559,15 @@ def _delete_indexes(self, model, old_field, new_field):
index_columns = []
if old_field.db_index and new_field.db_index:
index_columns.append([old_field.column])
else:
for fields in model._meta.index_together:
columns = [model._meta.get_field(field).column for field in fields]
if old_field.column in columns:
index_columns.append(columns)
for fields in model._meta.index_together:
columns = [model._meta.get_field(field).column for field in fields]
if old_field.column in columns:
index_columns.append(columns)

for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields]
if old_field.column in columns:
index_columns.append(columns)
if index_columns:
for columns in index_columns:
index_names = self._constraint_names(model, columns, index=True)
Expand All @@ -461,11 +578,6 @@ def _delete_unique_constraints(self, model, old_field, new_field, strict=False):
unique_columns = []
if old_field.unique and new_field.unique:
unique_columns.append([old_field.column])
else:
for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields]
if old_field.column in columns:
unique_columns.append(columns)
if unique_columns:
for columns in unique_columns:
constraint_names = self._constraint_names(model, columns, unique=True)
Expand Down Expand Up @@ -544,6 +656,61 @@ def add_field(self, model, field):
if self.connection.features.connection_persists_old_columns:
self.connection.close()

def _create_unique_sql(self, model, columns, name=None, condition=None):
def create_unique_name(*args, **kwargs):
return self.quote_name(self._create_index_name(*args, **kwargs))

table = Table(model._meta.db_table, self.quote_name)
if name is None:
name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name)
else:
name = self.quote_name(name)
columns = Columns(table, columns, self.quote_name)
if condition:
return Statement(
self.sql_create_unique_index,
table=table,
name=name,
columns=columns,
condition=' WHERE ' + condition,
) if self.connection.features.supports_partial_indexes else None
else:
return Statement(
self.sql_create_unique,
table=table,
name=name,
columns=columns,
)

def _create_index_sql(self, model, fields, *, name=None, suffix='', using='',
db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
condition=None):
"""
Return the SQL statement to create the index for one or several fields.
`sql` can be specified if the syntax differs from the standard (GIS
indexes, ...).
"""
tablespace_sql = self._get_index_tablespace_sql(model, fields, db_tablespace=db_tablespace)
columns = [field.column for field in fields]
sql_create_index = sql or self.sql_create_index
table = model._meta.db_table

def create_index_name(*args, **kwargs):
nonlocal name
if name is None:
name = self._create_index_name(*args, **kwargs)
return self.quote_name(name)

return Statement(
sql_create_index,
table=Table(table, self.quote_name),
name=IndexName(table, columns, suffix, create_index_name),
using=using,
columns=self._index_columns(table, columns, col_suffixes, opclasses),
extra=tablespace_sql,
condition=(' WHERE ' + condition) if condition else '',
)

def create_model(self, model):
"""
Takes a model and creates a table for it in the database.
Expand Down Expand Up @@ -605,7 +772,9 @@ def create_model(self, model):
# created afterwards, like geometry fields with some backends)
for fields in model._meta.unique_together:
columns = [model._meta.get_field(field).column for field in fields]
self.deferred_sql.append(self._create_unique_sql(model, columns))
condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns])
self.deferred_sql.append(self._create_unique_sql(model, columns, condition=condition))

# Make the table
sql = self.sql_create_table % {
"table": self.quote_name(model._meta.db_table),
Expand All @@ -620,6 +789,7 @@ def create_model(self, model):

# Add any field index and index_together's (deferred as SQLite3 _remake_table needs it)
self.deferred_sql.extend(self._model_indexes_sql(model))
self.deferred_sql = list(set(self.deferred_sql))

# Make M2M tables
for field in model._meta.local_many_to_many:
Expand Down
28 changes: 28 additions & 0 deletions testapp/migrations/0001_initial.py
Expand Up @@ -14,13 +14,41 @@ class Migration(migrations.Migration):
]

operations = [
migrations.CreateModel(
name='Author',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
],
),
migrations.CreateModel(
name='Editor',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(max_length=100)),
],
),
migrations.CreateModel(
name='Post',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.CharField(max_length=255, verbose_name='title')),
],
),
migrations.AddField(
model_name='post',
name='alt_editor',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='testapp.Editor'),
),
migrations.AddField(
model_name='post',
name='author',
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='testapp.Author'),
),
migrations.AlterUniqueTogether(
name='post',
unique_together={('author', 'title', 'alt_editor')},
),
migrations.CreateModel(
name='Comment',
fields=[
Expand Down
16 changes: 16 additions & 0 deletions testapp/models.py
Expand Up @@ -4,8 +4,24 @@
from django.utils import timezone


class Author(models.Model):
name = models.CharField(max_length=100)


class Editor(models.Model):
name = models.CharField(max_length=100)


class Post(models.Model):
title = models.CharField('title', max_length=255)
author = models.ForeignKey(Author, models.CASCADE)
# Optional secondary author
alt_editor = models.ForeignKey(Editor, models.SET_NULL, blank=True, null=True)

class Meta:
unique_together = (
('author', 'title', 'alt_editor'),
)

def __str__(self):
return self.title
Expand Down

0 comments on commit 2e60754

Please sign in to comment.