From 2e6075456a147713ccd38e3b2df9a23772dc55b5 Mon Sep 17 00:00:00 2001 From: Oskar Persson Date: Wed, 11 Dec 2019 15:59:36 +0100 Subject: [PATCH] Add support for nullable fields in unique_together (#24) --- setup.cfg | 2 +- sql_server/pyodbc/base.py | 8 +- sql_server/pyodbc/features.py | 5 +- sql_server/pyodbc/schema.py | 196 +++++++++++++++++++++++++++-- testapp/migrations/0001_initial.py | 28 +++++ testapp/models.py | 16 +++ testapp/tests/test_expressions.py | 24 +++- 7 files changed, 254 insertions(+), 25 deletions(-) diff --git a/setup.cfg b/setup.cfg index 1064d458..918b972d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] -exclude = .git,__pycache__, +exclude = .git,__pycache__,migrations # W504 is mutually exclusive with W503 ignore = W504 max-line-length = 119 diff --git a/sql_server/pyodbc/base.py b/sql_server/pyodbc/base.py index d239b2ec..5cc90ea9 100644 --- a/sql_server/pyodbc/base.py +++ b/sql_server/pyodbc/base.py @@ -480,17 +480,13 @@ def check_constraints(self, table_names=None): table_names) def disable_constraint_checking(self): - # Azure SQL Database doesn't support sp_msforeachtable - # cursor.execute('EXEC sp_msforeachtable "ALTER TABLE ? NOCHECK CONSTRAINT ALL"') if not self.needs_rollback: - self._execute_foreach('ALTER TABLE %s NOCHECK CONSTRAINT ALL') + self.cursor().execute('EXEC sp_msforeachtable "ALTER TABLE ? NOCHECK CONSTRAINT ALL"') return not self.needs_rollback def enable_constraint_checking(self): - # Azure SQL Database doesn't support sp_msforeachtable - # cursor.execute('EXEC sp_msforeachtable "ALTER TABLE ? WITH CHECK CHECK CONSTRAINT ALL"') if not self.needs_rollback: - self.check_constraints() + self.cursor().execute('EXEC sp_msforeachtable "ALTER TABLE ? WITH CHECK CHECK CONSTRAINT ALL"') class CursorWrapper(object): diff --git a/sql_server/pyodbc/features.py b/sql_server/pyodbc/features.py index 21a71207..de166d37 100644 --- a/sql_server/pyodbc/features.py +++ b/sql_server/pyodbc/features.py @@ -23,7 +23,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_ignore_conflicts = False supports_index_on_text_field = False supports_paramstyle_pyformat = False - supports_partially_nullable_unique_constraints = False supports_regex_backreferencing = False supports_sequence_reset = False supports_subqueries_in_group_by = False @@ -41,6 +40,10 @@ def has_bulk_insert(self): def supports_nullable_unique_constraints(self): return self.connection.sql_server_version > 2005 + @cached_property + def supports_partially_nullable_unique_constraints(self): + return self.connection.sql_server_version > 2005 + @cached_property def supports_partial_indexes(self): return self.connection.sql_server_version > 2005 diff --git a/sql_server/pyodbc/schema.py b/sql_server/pyodbc/schema.py index 03a57f2e..08861ab2 100644 --- a/sql_server/pyodbc/schema.py +++ b/sql_server/pyodbc/schema.py @@ -5,7 +5,7 @@ BaseDatabaseSchemaEditor, logger, _is_relevant_relation, _related_non_m2m_objects, ) from django.db.backends.ddl_references import ( - Statement, + Columns, IndexName, Statement as DjStatement, Table, ) from django.db.models import Index from django.db.models.fields import AutoField, BigAutoField @@ -13,6 +13,14 @@ from django.utils.encoding import force_text +class Statement(DjStatement): + def __hash__(self): + return hash((self.template, str(self.parts['name']))) + + def __eq__(self, other): + return self.template == other.template and str(self.parts['name']) == str(other.parts['name']) + + class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): _sql_check_constraint = " CONSTRAINT %(name)s CHECK (%(check)s)" @@ -123,6 +131,105 @@ def _alter_column_type_sql(self, model, old_field, new_field, new_type): new_type = self._set_field_new_type_null_status(old_field, new_type) return super()._alter_column_type_sql(model, old_field, new_field, new_type) + def alter_unique_together(self, model, old_unique_together, new_unique_together): + """ + Deal with a model changing its unique_together. The input + unique_togethers must be doubly-nested, not the single-nested + ["foo", "bar"] format. + """ + olds = {tuple(fields) for fields in old_unique_together} + news = {tuple(fields) for fields in new_unique_together} + # Deleted uniques + for fields in olds.difference(news): + self._delete_composed_index(model, fields, {'unique': True}, self.sql_delete_index) + # Created uniques + for fields in news.difference(olds): + columns = [model._meta.get_field(field).column for field in fields] + condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns]) + sql = self._create_unique_sql(model, columns, condition=condition) + self.execute(sql) + + def _model_indexes_sql(self, model): + """ + Return a list of all index SQL statements (field indexes, + index_together, Meta.indexes) for the specified model. + """ + if not model._meta.managed or model._meta.proxy or model._meta.swapped: + return [] + output = [] + for field in model._meta.local_fields: + output.extend(self._field_indexes_sql(model, field)) + + for field_names in model._meta.index_together: + fields = [model._meta.get_field(field) for field in field_names] + output.append(self._create_index_sql(model, fields, suffix="_idx")) + + for field_names in model._meta.unique_together: + columns = [model._meta.get_field(field).column for field in field_names] + condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns]) + sql = self._create_unique_sql(model, columns, condition=condition) + output.append(sql) + + for index in model._meta.indexes: + output.append(index.create_sql(model, self)) + return output + + def _alter_many_to_many(self, model, old_field, new_field, strict): + """Alter M2Ms to repoint their to= endpoints.""" + + for idx in self._constraint_names(old_field.remote_field.through, index=True, unique=True): + self.execute(self.sql_delete_index % {'name': idx, 'table': old_field.remote_field.through._meta.db_table}) + + return super()._alter_many_to_many(model, old_field, new_field, strict) + + def _db_table_constraint_names(self, db_table, column_names=None, unique=None, + primary_key=None, index=None, foreign_key=None, + check=None, type_=None, exclude=None): + """Return all constraint names matching the columns and conditions.""" + if column_names is not None: + column_names = [ + self.connection.introspection.identifier_converter(name) + for name in column_names + ] + with self.connection.cursor() as cursor: + constraints = self.connection.introspection.get_constraints(cursor, db_table) + result = [] + for name, infodict in constraints.items(): + if column_names is None or column_names == infodict['columns']: + if unique is not None and infodict['unique'] != unique: + continue + if primary_key is not None and infodict['primary_key'] != primary_key: + continue + if index is not None and infodict['index'] != index: + continue + if check is not None and infodict['check'] != check: + continue + if foreign_key is not None and not infodict['foreign_key']: + continue + if type_ is not None and infodict['type'] != type_: + continue + if not exclude or name not in exclude: + result.append(name) + return result + + def _db_table_delete_constraint_sql(self, template, db_table, name): + return Statement( + template, + table=Table(db_table, self.quote_name), + name=self.quote_name(name), + ) + + def alter_db_table(self, model, old_db_table, new_db_table): + index_names = self._db_table_constraint_names(old_db_table, index=True) + for index_name in index_names: + self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, old_db_table, index_name)) + + index_names = self._db_table_constraint_names(new_db_table, index=True) + for index_name in index_names: + self.execute(self._db_table_delete_constraint_sql(self.sql_delete_index, new_db_table, index_name)) + + return super().alter_db_table(model, old_db_table, new_db_table) + def _alter_field(self, model, old_field, new_field, old_type, new_type, old_db_params, new_db_params, strict=False): """Actually perform a "physical" (non-ManyToMany) field update.""" @@ -224,11 +331,15 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, self.execute(self._delete_constraint_sql(self.sql_delete_check, model, constraint_name)) # Have they renamed the column? if old_field.column != new_field.column: + # remove old indices + self._delete_indexes(model, old_field, new_field) + self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type)) # Rename all references to the renamed column. for sql in self.deferred_sql: - if isinstance(sql, Statement): + if isinstance(sql, DjStatement): sql.rename_column_references(model._meta.db_table, old_field.column, new_field.column) + # Next, start accumulating actions to do actions = [] null_actions = [] @@ -286,6 +397,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, actions = [(", ".join(sql), sum(params, []))] # Apply those actions for sql, params in actions: + self._delete_indexes(model, old_field, new_field) self.execute( self.sql_alter_column % { "table": self.quote_name(model._meta.db_table), @@ -438,6 +550,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, "changes": changes_sql, } self.execute(sql, params) + # Reset connection if required if self.connection.features.connection_persists_old_columns: self.connection.close() @@ -446,11 +559,15 @@ def _delete_indexes(self, model, old_field, new_field): index_columns = [] if old_field.db_index and new_field.db_index: index_columns.append([old_field.column]) - else: - for fields in model._meta.index_together: - columns = [model._meta.get_field(field).column for field in fields] - if old_field.column in columns: - index_columns.append(columns) + for fields in model._meta.index_together: + columns = [model._meta.get_field(field).column for field in fields] + if old_field.column in columns: + index_columns.append(columns) + + for fields in model._meta.unique_together: + columns = [model._meta.get_field(field).column for field in fields] + if old_field.column in columns: + index_columns.append(columns) if index_columns: for columns in index_columns: index_names = self._constraint_names(model, columns, index=True) @@ -461,11 +578,6 @@ def _delete_unique_constraints(self, model, old_field, new_field, strict=False): unique_columns = [] if old_field.unique and new_field.unique: unique_columns.append([old_field.column]) - else: - for fields in model._meta.unique_together: - columns = [model._meta.get_field(field).column for field in fields] - if old_field.column in columns: - unique_columns.append(columns) if unique_columns: for columns in unique_columns: constraint_names = self._constraint_names(model, columns, unique=True) @@ -544,6 +656,61 @@ def add_field(self, model, field): if self.connection.features.connection_persists_old_columns: self.connection.close() + def _create_unique_sql(self, model, columns, name=None, condition=None): + def create_unique_name(*args, **kwargs): + return self.quote_name(self._create_index_name(*args, **kwargs)) + + table = Table(model._meta.db_table, self.quote_name) + if name is None: + name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name) + else: + name = self.quote_name(name) + columns = Columns(table, columns, self.quote_name) + if condition: + return Statement( + self.sql_create_unique_index, + table=table, + name=name, + columns=columns, + condition=' WHERE ' + condition, + ) if self.connection.features.supports_partial_indexes else None + else: + return Statement( + self.sql_create_unique, + table=table, + name=name, + columns=columns, + ) + + def _create_index_sql(self, model, fields, *, name=None, suffix='', using='', + db_tablespace=None, col_suffixes=(), sql=None, opclasses=(), + condition=None): + """ + Return the SQL statement to create the index for one or several fields. + `sql` can be specified if the syntax differs from the standard (GIS + indexes, ...). + """ + tablespace_sql = self._get_index_tablespace_sql(model, fields, db_tablespace=db_tablespace) + columns = [field.column for field in fields] + sql_create_index = sql or self.sql_create_index + table = model._meta.db_table + + def create_index_name(*args, **kwargs): + nonlocal name + if name is None: + name = self._create_index_name(*args, **kwargs) + return self.quote_name(name) + + return Statement( + sql_create_index, + table=Table(table, self.quote_name), + name=IndexName(table, columns, suffix, create_index_name), + using=using, + columns=self._index_columns(table, columns, col_suffixes, opclasses), + extra=tablespace_sql, + condition=(' WHERE ' + condition) if condition else '', + ) + def create_model(self, model): """ Takes a model and creates a table for it in the database. @@ -605,7 +772,9 @@ def create_model(self, model): # created afterwards, like geometry fields with some backends) for fields in model._meta.unique_together: columns = [model._meta.get_field(field).column for field in fields] - self.deferred_sql.append(self._create_unique_sql(model, columns)) + condition = ' AND '.join(["[%s] IS NOT NULL" % col for col in columns]) + self.deferred_sql.append(self._create_unique_sql(model, columns, condition=condition)) + # Make the table sql = self.sql_create_table % { "table": self.quote_name(model._meta.db_table), @@ -620,6 +789,7 @@ def create_model(self, model): # Add any field index and index_together's (deferred as SQLite3 _remake_table needs it) self.deferred_sql.extend(self._model_indexes_sql(model)) + self.deferred_sql = list(set(self.deferred_sql)) # Make M2M tables for field in model._meta.local_many_to_many: diff --git a/testapp/migrations/0001_initial.py b/testapp/migrations/0001_initial.py index f0afdcb6..6d898a4d 100644 --- a/testapp/migrations/0001_initial.py +++ b/testapp/migrations/0001_initial.py @@ -14,6 +14,20 @@ class Migration(migrations.Migration): ] operations = [ + migrations.CreateModel( + name='Author', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ], + ), + migrations.CreateModel( + name='Editor', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ], + ), migrations.CreateModel( name='Post', fields=[ @@ -21,6 +35,20 @@ class Migration(migrations.Migration): ('title', models.CharField(max_length=255, verbose_name='title')), ], ), + migrations.AddField( + model_name='post', + name='alt_editor', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='testapp.Editor'), + ), + migrations.AddField( + model_name='post', + name='author', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='testapp.Author'), + ), + migrations.AlterUniqueTogether( + name='post', + unique_together={('author', 'title', 'alt_editor')}, + ), migrations.CreateModel( name='Comment', fields=[ diff --git a/testapp/models.py b/testapp/models.py index 7c31bc85..d177d706 100644 --- a/testapp/models.py +++ b/testapp/models.py @@ -4,8 +4,24 @@ from django.utils import timezone +class Author(models.Model): + name = models.CharField(max_length=100) + + +class Editor(models.Model): + name = models.CharField(max_length=100) + + class Post(models.Model): title = models.CharField('title', max_length=255) + author = models.ForeignKey(Author, models.CASCADE) + # Optional secondary author + alt_editor = models.ForeignKey(Editor, models.SET_NULL, blank=True, null=True) + + class Meta: + unique_together = ( + ('author', 'title', 'alt_editor'), + ) def __str__(self): return self.title diff --git a/testapp/tests/test_expressions.py b/testapp/tests/test_expressions.py index 8ec8aed6..18da063a 100644 --- a/testapp/tests/test_expressions.py +++ b/testapp/tests/test_expressions.py @@ -1,12 +1,14 @@ from django.db.models.expressions import Exists, OuterRef, Subquery -from django.test import TestCase +from django.db.utils import IntegrityError +from django.test import TestCase, skipUnlessDBFeature -from ..models import Comment, Post +from ..models import Author, Comment, Editor, Post class TestSubquery(TestCase): def setUp(self): - self.post = Post.objects.create(title="foo") + self.author = Author.objects.create(name="author") + self.post = Post.objects.create(title="foo", author=self.author) def test_with_count(self): newest = Comment.objects.filter(post=OuterRef('pk')).order_by('-created_at') @@ -17,9 +19,23 @@ def test_with_count(self): class TestExists(TestCase): def setUp(self): - self.post = Post.objects.create(title="foo") + self.author = Author.objects.create(name="author") + self.post = Post.objects.create(title="foo", author=self.author) def test_with_count(self): Post.objects.annotate( post_exists=Exists(Post.objects.all()) ).filter(post_exists=True).count() + + +@skipUnlessDBFeature('supports_partially_nullable_unique_constraints') +class TestPartiallyNullableUniqueTogether(TestCase): + def test_partially_nullable(self): + author = Author.objects.create(name="author") + Post.objects.create(title="foo", author=author) + Post.objects.create(title="foo", author=author) + + editor = Editor.objects.create(name="editor") + Post.objects.create(title="foo", author=author, alt_editor=editor) + with self.assertRaises(IntegrityError): + Post.objects.create(title="foo", author=author, alt_editor=editor)