Skip to content

Commit

Permalink
Add check constraint support - needed a few Field changes
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgodwin committed Sep 7, 2012
1 parent 375178f commit ca9c3cd
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 24 deletions.
3 changes: 3 additions & 0 deletions django/db/backends/__init__.py
Expand Up @@ -435,6 +435,9 @@ class BaseDatabaseFeatures(object):
# Does it support foreign keys? # Does it support foreign keys?
supports_foreign_keys = True supports_foreign_keys = True


# Does it support CHECK constraints?
supports_check_constraints = True

def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection


Expand Down
1 change: 1 addition & 0 deletions django/db/backends/creation.py
Expand Up @@ -18,6 +18,7 @@ class BaseDatabaseCreation(object):
destruction of test databases. destruction of test databases.
""" """
data_types = {} data_types = {}
data_type_check_constraints = {}


def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection
Expand Down
1 change: 1 addition & 0 deletions django/db/backends/mysql/base.py
Expand Up @@ -170,6 +170,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
requires_explicit_null_ordering_when_grouping = True requires_explicit_null_ordering_when_grouping = True
allows_primary_key_0 = False allows_primary_key_0 = False
uses_savepoints = True uses_savepoints = True
supports_check_constraints = False


def __init__(self, connection): def __init__(self, connection):
super(DatabaseFeatures, self).__init__(connection) super(DatabaseFeatures, self).__init__(connection)
Expand Down
9 changes: 7 additions & 2 deletions django/db/backends/postgresql_psycopg2/creation.py
Expand Up @@ -26,14 +26,19 @@ class DatabaseCreation(BaseDatabaseCreation):
'GenericIPAddressField': 'inet', 'GenericIPAddressField': 'inet',
'NullBooleanField': 'boolean', 'NullBooleanField': 'boolean',
'OneToOneField': 'integer', 'OneToOneField': 'integer',
'PositiveIntegerField': 'integer CHECK ("%(column)s" >= 0)', 'PositiveIntegerField': 'integer',
'PositiveSmallIntegerField': 'smallint CHECK ("%(column)s" >= 0)', 'PositiveSmallIntegerField': 'smallint',
'SlugField': 'varchar(%(max_length)s)', 'SlugField': 'varchar(%(max_length)s)',
'SmallIntegerField': 'smallint', 'SmallIntegerField': 'smallint',
'TextField': 'text', 'TextField': 'text',
'TimeField': 'time', 'TimeField': 'time',
} }


data_type_check_constraints = {
'PositiveIntegerField': '"%(column)s" >= 0',
'PositiveSmallIntegerField': '"%(column)s" >= 0',
}

def sql_table_creation_suffix(self): def sql_table_creation_suffix(self):
assert self.connection.settings_dict['TEST_COLLATION'] is None, "PostgreSQL does not support collation setting at database creation time." assert self.connection.settings_dict['TEST_COLLATION'] is None, "PostgreSQL does not support collation setting at database creation time."
if self.connection.settings_dict['TEST_CHARSET']: if self.connection.settings_dict['TEST_CHARSET']:
Expand Down
2 changes: 1 addition & 1 deletion django/db/backends/postgresql_psycopg2/introspection.py
Expand Up @@ -137,7 +137,7 @@ def get_constraints(self, cursor, table_name):
kc.table_schema = %s AND kc.table_schema = %s AND
kc.table_name = %s kc.table_name = %s
""", ["public", table_name]) """, ["public", table_name])
for constraint, column, kind in cursor.fetchall(): for constraint, column in cursor.fetchall():
# If we're the first column, make the record # If we're the first column, make the record
if constraint not in constraints: if constraint not in constraints:
constraints[constraint] = { constraints[constraint] = {
Expand Down
61 changes: 45 additions & 16 deletions django/db/backends/schema.py
Expand Up @@ -19,9 +19,6 @@ class BaseDatabaseSchemaEditor(object):
then the relevant actions, and then commit(). This is necessary to allow then the relevant actions, and then commit(). This is necessary to allow
things like circular foreign key references - FKs will only be created once things like circular foreign key references - FKs will only be created once
commit() is called. commit() is called.
TODO:
- Check constraints (PosIntField)
""" """


# Overrideable SQL templates # Overrideable SQL templates
Expand All @@ -41,7 +38,7 @@ class BaseDatabaseSchemaEditor(object):
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE" sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s CASCADE"
sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s" sql_rename_column = "ALTER TABLE %(table)s RENAME COLUMN %(old_column)s TO %(new_column)s"


sql_create_check = "ADD CONSTRAINT %(name)s CHECK (%(check)s)" sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s" sql_delete_check = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"


sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)" sql_create_unique = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s UNIQUE (%(columns)s)"
Expand Down Expand Up @@ -105,7 +102,8 @@ def column_sql(self, model, field, include_default=False):
The field must already have had set_attributes_from_name called. The field must already have had set_attributes_from_name called.
""" """
# Get the column's type and use that as the basis of the SQL # Get the column's type and use that as the basis of the SQL
sql = field.db_type(connection=self.connection) db_params = field.db_parameters(connection=self.connection)
sql = db_params['type']
params = [] params = []
# Check for fields that aren't actually columns (e.g. M2M) # Check for fields that aren't actually columns (e.g. M2M)
if sql is None: if sql is None:
Expand Down Expand Up @@ -169,6 +167,11 @@ def create_model(self, model, force=False):
definition, extra_params = self.column_sql(model, field) definition, extra_params = self.column_sql(model, field)
if definition is None: if definition is None:
continue continue
# Check constraints can go on the column SQL here
db_params = field.db_parameters(connection=self.connection)
if db_params['check']:
definition += " CHECK (%s)" % db_params['check']
# Add the SQL to our big list
column_sqls.append("%s %s" % ( column_sqls.append("%s %s" % (
self.quote_name(field.column), self.quote_name(field.column),
definition, definition,
Expand Down Expand Up @@ -295,6 +298,10 @@ def create_field(self, model, field, keep_default=False):
# It might not actually have a column behind it # It might not actually have a column behind it
if definition is None: if definition is None:
return return
# Check constraints can go on the column SQL here
db_params = field.db_parameters(connection=self.connection)
if db_params['check']:
definition += " CHECK (%s)" % db_params['check']
# Build the SQL and run it # Build the SQL and run it
sql = self.sql_create_column % { sql = self.sql_create_column % {
"table": self.quote_name(model._meta.db_table), "table": self.quote_name(model._meta.db_table),
Expand Down Expand Up @@ -358,8 +365,10 @@ def alter_field(self, model, old_field, new_field, strict=False):
If strict is true, raises errors if the old column does not match old_field precisely. If strict is true, raises errors if the old column does not match old_field precisely.
""" """
# Ensure this field is even column-based # Ensure this field is even column-based
old_type = old_field.db_type(connection=self.connection) old_db_params = old_field.db_parameters(connection=self.connection)
new_type = self._type_for_alter(new_field) old_type = old_db_params['type']
new_db_params = new_field.db_parameters(connection=self.connection)
new_type = new_db_params['type']
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created): if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
return self._alter_many_to_many(model, old_field, new_field, strict) return self._alter_many_to_many(model, old_field, new_field, strict)
elif old_type is None or new_type is None: elif old_type is None or new_type is None:
Expand Down Expand Up @@ -417,6 +426,22 @@ def alter_field(self, model, old_field, new_field, strict=False):
"name": fk_name, "name": fk_name,
} }
) )
# Change check constraints?
if old_db_params['check'] != new_db_params['check'] and old_db_params['check']:
constraint_names = self._constraint_names(model, [old_field.column], check=True)
if strict and len(constraint_names) != 1:
raise ValueError("Found wrong number (%s) of check constraints for %s.%s" % (
len(constraint_names),
model._meta.db_table,
old_field.column,
))
for constraint_name in constraint_names:
self.execute(
self.sql_delete_check % {
"table": self.quote_name(model._meta.db_table),
"name": constraint_name,
}
)
# Have they renamed the column? # Have they renamed the column?
if old_field.column != new_field.column: if old_field.column != new_field.column:
self.execute(self.sql_rename_column % { self.execute(self.sql_rename_column % {
Expand Down Expand Up @@ -543,6 +568,16 @@ def alter_field(self, model, old_field, new_field, strict=False):
"to_column": self.quote_name(new_field.rel.get_related_field().column), "to_column": self.quote_name(new_field.rel.get_related_field().column),
} }
) )
# Does it have check constraints we need to add?
if old_db_params['check'] != new_db_params['check'] and new_db_params['check']:
self.execute(
self.sql_create_check % {
"table": self.quote_name(model._meta.db_table),
"name": self._create_index_name(model, [new_field.column], suffix="_check"),
"column": self.quote_name(new_field.column),
"check": new_db_params['check'],
}
)


def _alter_many_to_many(self, model, old_field, new_field, strict): def _alter_many_to_many(self, model, old_field, new_field, strict):
"Alters M2Ms to repoint their to= endpoints." "Alters M2Ms to repoint their to= endpoints."
Expand All @@ -555,14 +590,6 @@ def _alter_many_to_many(self, model, old_field, new_field, strict):
new_field.rel.through._meta.get_field_by_name(new_field.m2m_reverse_field_name())[0], new_field.rel.through._meta.get_field_by_name(new_field.m2m_reverse_field_name())[0],
) )


def _type_for_alter(self, field):
"""
Returns a field's type suitable for ALTER COLUMN.
By default it just returns field.db_type().
To be overriden by backend specific subclasses
"""
return field.db_type(connection=self.connection)

def _create_index_name(self, model, column_names, suffix=""): def _create_index_name(self, model, column_names, suffix=""):
"Generates a unique name for an index/unique constraint." "Generates a unique name for an index/unique constraint."
# If there is just one column in the index, use a default algorithm from Django # If there is just one column in the index, use a default algorithm from Django
Expand All @@ -581,7 +608,7 @@ def _create_index_name(self, model, column_names, suffix=""):
index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part) index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
return index_name return index_name


def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None): def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None, check=None):
"Returns all constraint names matching the columns and conditions" "Returns all constraint names matching the columns and conditions"
column_names = set(column_names) if column_names else None column_names = set(column_names) if column_names else None
constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table) constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
Expand All @@ -594,6 +621,8 @@ def _constraint_names(self, model, column_names=None, unique=None, primary_key=N
continue continue
if index is not None and infodict['index'] != index: if index is not None and infodict['index'] != index:
continue continue
if check is not None and infodict['check'] != check:
continue
if foreign_key is not None and not infodict['foreign_key']: if foreign_key is not None and not infodict['foreign_key']:
continue continue
result.append(name) result.append(name)
Expand Down
1 change: 1 addition & 0 deletions django/db/backends/sqlite3/base.py
Expand Up @@ -97,6 +97,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_bulk_insert = True has_bulk_insert = True
can_combine_inserts_with_and_without_auto_increment_pk = False can_combine_inserts_with_and_without_auto_increment_pk = False
supports_foreign_keys = False supports_foreign_keys = False
supports_check_constraints = False


@cached_property @cached_property
def supports_stddev(self): def supports_stddev(self):
Expand Down
6 changes: 4 additions & 2 deletions django/db/backends/sqlite3/schema.py
Expand Up @@ -99,8 +99,10 @@ def delete_field(self, model, field):


def alter_field(self, model, old_field, new_field, strict=False): def alter_field(self, model, old_field, new_field, strict=False):
# Ensure this field is even column-based # Ensure this field is even column-based
old_type = old_field.db_type(connection=self.connection) old_db_params = old_field.db_parameters(connection=self.connection)
new_type = self._type_for_alter(new_field) old_type = old_db_params['type']
new_db_params = new_field.db_parameters(connection=self.connection)
new_type = new_db_params['type']
if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created): if old_type is None and new_type is None and (old_field.rel.through and new_field.rel.through and old_field.rel.through._meta.auto_created and new_field.rel.through._meta.auto_created):
return self._alter_many_to_many(model, old_field, new_field, strict) return self._alter_many_to_many(model, old_field, new_field, strict)
elif old_type is None or new_type is None: elif old_type is None or new_type is None:
Expand Down
26 changes: 23 additions & 3 deletions django/db/models/fields/__init__.py
Expand Up @@ -232,12 +232,32 @@ def db_type(self, connection):
# mapped to one of the built-in Django field types. In this case, you # mapped to one of the built-in Django field types. In this case, you
# can implement db_type() instead of get_internal_type() to specify # can implement db_type() instead of get_internal_type() to specify
# exactly which wacky database column type you want to use. # exactly which wacky database column type you want to use.
params = self.db_parameters(connection)
if params['type']:
if params['check']:
return "%s CHECK (%s)" % (params['type'], params['check'])
else:
return params['type']
return None

def db_parameters(self, connection):
"""
Replacement for db_type, providing a range of different return
values (type, checks)
"""
data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_") data = DictWrapper(self.__dict__, connection.ops.quote_name, "qn_")
try: try:
return (connection.creation.data_types[self.get_internal_type()] type_string = connection.creation.data_types[self.get_internal_type()] % data
% data)
except KeyError: except KeyError:
return None type_string = None
try:
check_string = connection.creation.data_type_check_constraints[self.get_internal_type()] % data
except KeyError:
check_string = None
return {
"type": type_string,
"check": check_string,
}


@property @property
def unique(self): def unique(self):
Expand Down
6 changes: 6 additions & 0 deletions django/db/models/fields/related.py
Expand Up @@ -1050,6 +1050,9 @@ def db_type(self, connection):
return IntegerField().db_type(connection=connection) return IntegerField().db_type(connection=connection)
return rel_field.db_type(connection=connection) return rel_field.db_type(connection=connection)


def db_parameters(self, connection):
return {"type": self.db_type(connection), "check": []}

class OneToOneField(ForeignKey): class OneToOneField(ForeignKey):
""" """
A OneToOneField is essentially the same as a ForeignKey, with the exception A OneToOneField is essentially the same as a ForeignKey, with the exception
Expand Down Expand Up @@ -1292,3 +1295,6 @@ def db_type(self, connection):
# A ManyToManyField is not represented by a single column, # A ManyToManyField is not represented by a single column,
# so return None. # so return None.
return None return None

def db_parameters(self, connection):
return {"type": None, "check": None}
1 change: 1 addition & 0 deletions tests/modeltests/schema/models.py
Expand Up @@ -7,6 +7,7 @@


class Author(models.Model): class Author(models.Model):
name = models.CharField(max_length=255) name = models.CharField(max_length=255)
height = models.PositiveIntegerField(null=True, blank=True)


class Meta: class Meta:
managed = False managed = False
Expand Down
50 changes: 50 additions & 0 deletions tests/modeltests/schema/tests.py
Expand Up @@ -347,6 +347,56 @@ def test_m2m_repoint(self):
else: else:
self.fail("No FK constraint for tag_id found") self.fail("No FK constraint for tag_id found")


@skipUnless(connection.features.supports_check_constraints, "No check constraints")
def test_check_constraints(self):
"""
Tests creating/deleting CHECK constraints
"""
# Create the tables
editor = connection.schema_editor()
editor.start()
editor.create_model(Author)
editor.commit()
# Ensure the constraint exists
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == set(["height"]) and details['check']:
break
else:
self.fail("No check constraint for height found")
# Alter the column to remove it
new_field = IntegerField(null=True, blank=True)
new_field.set_attributes_from_name("height")
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Author,
Author._meta.get_field_by_name("height")[0],
new_field,
strict = True,
)
editor.commit()
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == set(["height"]) and details['check']:
self.fail("Check constraint for height found")
# Alter the column to re-add it
editor = connection.schema_editor()
editor.start()
editor.alter_field(
Author,
new_field,
Author._meta.get_field_by_name("height")[0],
strict = True,
)
editor.commit()
constraints = connection.introspection.get_constraints(connection.cursor(), Author._meta.db_table)
for name, details in constraints.items():
if details['columns'] == set(["height"]) and details['check']:
break
else:
self.fail("No check constraint for height found")

def test_unique(self): def test_unique(self):
""" """
Tests removing and adding unique constraints to a single column. Tests removing and adding unique constraints to a single column.
Expand Down

0 comments on commit ca9c3cd

Please sign in to comment.