Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Add check constraint support - needed a few Field changes

  • Loading branch information...
commit ca9c3cd39fade827cced1b5198dd37bb80c208b0 1 parent 375178f
@andrewgodwin andrewgodwin authored
View
3  django/db/backends/__init__.py
@@ -435,6 +435,9 @@ class BaseDatabaseFeatures(object):
# Does it support foreign keys?
supports_foreign_keys = True
+ # Does it support CHECK constraints?
+ supports_check_constraints = True
+
def __init__(self, connection):
self.connection = connection
View
1  django/db/backends/creation.py
@@ -18,6 +18,7 @@ class BaseDatabaseCreation(object):
destruction of test databases.
"""
data_types = {}
+ data_type_check_constraints = {}
def __init__(self, connection):
self.connection = connection
View
1  django/db/backends/mysql/base.py
@@ -170,6 +170,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
requires_explicit_null_ordering_when_grouping = True
allows_primary_key_0 = False
uses_savepoints = True
+ supports_check_constraints = False
def __init__(self, connection):
super(DatabaseFeatures, self).__init__(connection)
View
9 django/db/backends/postgresql_psycopg2/creation.py
@@ -26,14 +26,19 @@ class DatabaseCreation(BaseDatabaseCreation):
'GenericIPAddressField': 'inet',
'NullBooleanField': 'boolean',
'OneToOneField': 'integer',
- 'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)',
- 'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)',
+ 'PositiveIntegerField': 'integer',
+ 'PositiveSmallIntegerField': 'smallint',
'SlugField': 'varchar(%(max_length)s)',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
}
+ data_type_check_constraints = {
+ 'PositiveIntegerField': '"%(column)s" >= 0',
+ 'PositiveSmallIntegerField': '"%(column)s" >= 0',
+ }
+
def sql_table_creation_suffix(self):
assert self.connection.settings_dict['TEST_COLLATION'] is None, "PostgreSQL does not support collation setting at database creation time."
if self.connection.settings_dict['TEST_CHARSET']:
View
2  django/db/backends/postgresql_psycopg2/introspection.py
@@ -137,7 +137,7 @@ def get_constraints(self, cursor, table_name):
kc.table_schema = %s AND
kc.table_name = %s
""", ["public", table_name])
- for constraint, column, kind in cursor.fetchall():
+ for constraint, column in cursor.fetchall():
# If we're the first column, make the record
if constraint not in constraints:
constraints[constraint] = {
View
61 django/db/backends/schema.py
@@ -19,9 +19,6 @@ class BaseDatabaseSchemaEditor(object):
then the relevant actions, and then commit(). This is necessary to allow
things like circular foreign key references - FKs will only be created once
commit() is called.
-
- TODO:
- - Check constraints (PosIntField)
"""
# Overrideable SQL templates
@@ -41,7 +38,7 @@ class BaseDatabaseSchemaEditor(object):
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"
- sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)"
+ sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
@@ -105,7 +102,8 @@ def column_sql(self, model, field, include_default=False):
The field must already have had set_attributes_from_name called.
"""
# Get the column's type and use that as the basis of the SQL
- sql = field.db_type(connection=self.connection)
+ db_params = field.db_parameters(connection=self.connection)
+ sql = db_params['type']
params = []
# Check for fields that aren't actually columns (e.g. M2M)
if sql is None:
@@ -169,6 +167,11 @@ def create_model(self, model, force=False):
definition, extra_params = self.column_sql(model, field)
if definition is None:
continue
+ # Check constraints can go on the column SQL here
+ db_params = field.db_parameters(connection=self.connection)
+ if db_params['check']:
+ definition += " CHECK (%s)" % db_params['check']
+ # Add the SQL to our big list
column_sqls.append("%s %s" % (
self.quote_name(field.column),
definition,
@@ -295,6 +298,10 @@ def create_field(self, model, field, keep_default=False):
# It might not actually have a column behind it
if definition is None:
return
+ # Check constraints can go on the column SQL here
+ db_params = field.db_parameters(connection=self.connection)
+ if db_params['check']:
+ definition += " CHECK (%s)" % db_params['check']
# Build the SQL and run it
sql = self.sql_create_column % {
"table": self.quote_name(model._meta.db_table),
@@ -358,8 +365,10 @@ def alter_field(self, model, old_field, new_field, strict=False):
If strict is true, raises errors if the old column does not match old_field precisely.
"""
# Ensure this field is even column-based
- old_type = old_field.db_type(connection=self.connection)
- new_type = self._type_for_alter(new_field)
+ old_db_params = old_field.db_parameters(connection=self.connection)
+ old_type = old_db_params['type']
+ new_db_params = new_field.db_parameters(connection=self.connection)
+ new_type = new_db_params['type']
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
return self._alter_many_to_many(model, old_field, new_field, strict)
elif old_type is None or new_type is None:
@@ -417,6 +426,22 @@ def alter_field(self, model, old_field, new_field, strict=False):
"name": fk_name,
}
)
+ # Change check constraints?
+ if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
+ constraint_names = self._constraint_names(model, [old_field.column], check=True)
+ if strict and len(constraint_names) != 1:
+ raise ValueError("Found wrong number (%s) of check constraints for %s.%s" % (
+ len(constraint_names),
+ model._meta.db_table,
+ old_field.column,
+ ))
+ for constraint_name in constraint_names:
+ self.execute(
+ self.sql_delete_check % {
+ "table": self.quote_name(model._meta.db_table),
+ "name": constraint_name,
+ }
+ )
# Have they renamed the column?
if old_field.column != new_field.column:
self.execute(self.sql_rename_column % {
@@ -543,6 +568,16 @@ def alter_field(self, model, old_field, new_field, strict=False):
"to_column": self.quote_name(new_field.rel.get_related_field().column),
}
)
+ # Does it have check constraints we need to add?
+ if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
+ self.execute(
+ self.sql_create_check % {
+ "table": self.quote_name(model._meta.db_table),
+ "name": self._create_index_name(model, [new_field.column], suffix="_check"),
+ "column": self.quote_name(new_field.column),
+ "check": new_db_params['check'],
+ }
+ )
def _alter_many_to_many(self, model, old_field, new_field, strict):
"Alters M2Ms to repoint their to= endpoints."
@@ -555,14 +590,6 @@ def _alter_many_to_many(self, model, old_field, new_field, strict):
new_field.rel.through._meta.get_field_by_name(new_field.m2m_reverse_field_name())[0],
)
- def _type_for_alter(self, field):
- """
- Returns a field's type suitable for ALTER COLUMN.
- By default it just returns field.db_type().
- To be overriden by backend specific subclasses
- """
- return field.db_type(connection=self.connection)
-
def _create_index_name(self, model, column_names, suffix=""):
"Generates a unique name for an index/unique constraint."
# If there is just one column in the index, use a default algorithm from Django
@@ -581,7 +608,7 @@ def _create_index_name(self, model, column_names, suffix=""):
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
return index_name
- def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None):
+ def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None, check=None):
"Returns all constraint names matching the columns and conditions"
column_names = set(column_names) if column_names else None
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
@@ -594,6 +621,8 @@ def _constraint_names(self, model, column_names=None, unique=None, primary_key=N
continue
if index is not None and infodict['index'] != index:
continue
+ if check is not None and infodict['check'] != check:
+ continue
if foreign_key is not None and not infodict['foreign_key']:
continue
result.append(name)
View
1  django/db/backends/sqlite3/base.py
@@ -97,6 +97,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_bulk_insert = True
can_combine_inserts_with_and_without_auto_increment_pk = False
supports_foreign_keys = False
+ supports_check_constraints = False
@cached_property
def supports_stddev(self):
View
6 django/db/backends/sqlite3/schema.py
@@ -99,8 +99,10 @@ def delete_field(self, model, field):
def alter_field(self, model, old_field, new_field, strict=False):
# Ensure this field is even column-based
- old_type = old_field.db_type(connection=self.connection)
- new_type = self._type_for_alter(new_field)
+ old_db_params = old_field.db_parameters(connection=self.connection)
+ old_type = old_db_params['type']
+ new_db_params = new_field.db_parameters(connection=self.connection)
+ new_type = new_db_params['type']
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
return self._alter_many_to_many(model, old_field, new_field, strict)
elif old_type is None or new_type is None:
View
26 django/db/models/fields/__init__.py
@@ -232,12 +232,32 @@ def db_type(self, connection):
# mapped to one of the built-in Django field types. In this case, you
# can implement db_type() instead of get_internal_type() to specify
# exactly which wacky database column type you want to use.
+ params = self.db_parameters(connection)
+ if params['type']:
+ if params['check']:
+ return "%s CHECK (%s)" % (params['type'], params['check'])
+ else:
+ return params['type']
+ return None
+
+ def db_parameters(self, connection):
+ """
+ Replacement for db_type, providing a range of different return
+ values (type, checks)
+ """
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
try:
- return (connection.creation.data_types[self.get_internal_type()]
- % data)
+ type_string = connection.creation.data_types[self.get_internal_type()] % data
except KeyError:
- return None
+ type_string = None
+ try:
+ check_string = connection.creation.data_type_check_constraints[self.get_internal_type()] % data
+ except KeyError:
+ check_string = None
+ return {
+ "type": type_string,
+ "check": check_string,
+ }
@property
def unique(self):
View
6 django/db/models/fields/related.py
@@ -1050,6 +1050,9 @@ def db_type(self, connection):
return IntegerField().db_type(connection=connection)
return rel_field.db_type(connection=connection)
+ def db_parameters(self, connection):
+ return {"type": self.db_type(connection), "check": []}
+
class OneToOneField(ForeignKey):
"""
A OneToOneField is essentially the same as a ForeignKey, with the exception
@@ -1292,3 +1295,6 @@ def db_type(self, connection):
# A ManyToManyField is not represented by a single column,
# so return None.
return None
+
+ def db_parameters(self, connection):
+ return {"type": None, "check": None}
View
1  tests/modeltests/schema/models.py
@@ -7,6 +7,7 @@
class Author(models.Model):
name = models.CharField(max_length=255)
+ height = models.PositiveIntegerField(null=True, blank=True)
class Meta:
managed = False
View
50 tests/modeltests/schema/tests.py
@@ -347,6 +347,56 @@ def test_m2m_repoint(self):
else:
self.fail("No FK constraint for tag_id found")
+ @skipUnless(connection.features.supports_check_constraints, "No check constraints")
+ def test_check_constraints(self):
+ """
+ Tests creating/deleting CHECK constraints
+ """
+ # Create the tables
+ editor = connection.schema_editor()
+ editor.start()
+ editor.create_model(Author)
+ editor.commit()
+ # Ensure the constraint exists
+ constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
+ for name, details in constraints.items():
+ if details['columns'] == set(["height"]) and details['check']:
+ break
+ else:
+ self.fail("No check constraint for height found")
+ # Alter the column to remove it
+ new_field = IntegerField(null=True, blank=True)
+ new_field.set_attributes_from_name("height")
+ editor = connection.schema_editor()
+ editor.start()
+ editor.alter_field(
+ Author,
+ Author._meta.get_field_by_name("height")[0],
+ new_field,
+ strict = True,
+ )
+ editor.commit()
+ constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
+ for name, details in constraints.items():
+ if details['columns'] == set(["height"]) and details['check']:
+ self.fail("Check constraint for height found")
+ # Alter the column to re-add it
+ editor = connection.schema_editor()
+ editor.start()
+ editor.alter_field(
+ Author,
+ new_field,
+ Author._meta.get_field_by_name("height")[0],
+ strict = True,
+ )
+ editor.commit()
+ constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
+ for name, details in constraints.items():
+ if details['columns'] == set(["height"]) and details['check']:
+ break
+ else:
+ self.fail("No check constraint for height found")
+
def test_unique(self):
"""
Tests removing and adding unique constraints to a single column.
Please sign in to comment.
Something went wrong with that request. Please try again.