Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Add M2M tests and some unique support

  • Loading branch information...
commit b139315f1c5e3eb05c76237c2824bdf03bd689b6 1 parent 4a2e80f
@andrewgodwin andrewgodwin authored
View
12 django/db/backends/__init__.py
@@ -427,6 +427,9 @@ class BaseDatabaseFeatures(object):
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
supports_combined_alters = False
+ # What's the maximum length for index names?
+ max_index_name_length = 63
+
def __init__(self, connection):
self.connection = connection
@@ -1056,6 +1059,15 @@ def get_indexes(self, cursor, table_name):
"""
raise NotImplementedError
+ def get_constraints(self, cursor, table_name):
+ """
+ Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
+
+ Both single- and multi-column constraints are introspected.
+ """
+ raise NotImplementedError
+
+
class BaseDatabaseClient(object):
"""
This class encapsulates all backend-specific methods for opening a
View
3  django/db/backends/creation.py
@@ -21,7 +21,8 @@ class BaseDatabaseCreation(object):
def __init__(self, connection):
self.connection = connection
- def _digest(self, *args):
+ @classmethod
+ def _digest(cls, *args):
"""
Generates a 32-bit digest of a set of arguments that can be used to
shorten identifying names.
View
32 django/db/backends/postgresql_psycopg2/introspection.py
@@ -88,3 +88,35 @@ def get_indexes(self, cursor, table_name):
continue
indexes[row[0]] = {'primary_key': row[3], 'unique': row[2]}
return indexes
+
+ def get_constraints(self, cursor, table_name):
+ """
+ Retrieves any constraints (unique, pk, check) across one or more columns.
+ Returns {'cnname': {'columns': set(columns), 'primary_key': bool, 'unique': bool}}
+ """
+ constraints = {}
+ # Loop over the constraint tables, collecting things as constraints
+ ifsc_tables = ["constraint_column_usage", "key_column_usage"]
+ for ifsc_table in ifsc_tables:
+ cursor.execute("""
+ SELECT kc.constraint_name, kc.column_name, c.constraint_type
+ FROM information_schema.%s AS kc
+ JOIN information_schema.table_constraints AS c ON
+ kc.table_schema = c.table_schema AND
+ kc.table_name = c.table_name AND
+ kc.constraint_name = c.constraint_name
+ WHERE
+ kc.table_schema = %%s AND
+ kc.table_name = %%s
+ """ % ifsc_table, ["public", table_name])
+ for constraint, column, kind in cursor.fetchall():
+ # If we're the first column, make the record
+ if constraint not in constraints:
+ constraints[constraint] = {
+ "columns": set(),
+ "primary_key": kind.lower() == "primary key",
+ "unique": kind.lower() in ["primary key", "unique"],
+ }
+ # Record the details
+ constraints[constraint]['columns'].add(column)
+ return constraints
View
82 django/db/backends/schema.py
@@ -4,6 +4,8 @@
from django.conf import settings
from django.db import transaction
from django.db.utils import load_backend
+from django.db.backends.creation import BaseDatabaseCreation
+from django.db.backends.util import truncate_name
from django.utils.log import getLogger
from django.db.models.fields.related import ManyToManyField
@@ -294,7 +296,23 @@ def alter_field(self, model, old_field, new_field):
old_field,
new_field,
))
- # First, have they renamed the column?
+ # Has unique been removed?
+ if old_field.unique and not new_field.unique:
+ # Find the unique constraint for this field
+ constraint_names = self._constraint_names(model, [old_field.column], unique=True)
+ if len(constraint_names) != 1:
+ raise ValueError("Found wrong number (%s) of constraints for %s.%s" % (
+ len(constraint_names),
+ model._meta.db_table,
+ old_field.column,
+ ))
+ self.execute(
+ self.sql_delete_unique % {
+ "table": self.quote_name(model._meta.db_table),
+ "name": constraint_names[0],
+ },
+ )
+ # Have they renamed the column?
if old_field.column != new_field.column:
self.execute(self.sql_rename_column % {
"table": self.quote_name(model._meta.db_table),
@@ -347,16 +365,58 @@ def alter_field(self, model, old_field, new_field):
},
[],
))
- # Combine actions together if we can (e.g. postgres)
- if self.connection.features.supports_combined_alters:
- sql, params = tuple(zip(*actions))
- actions = [(", ".join(sql), params)]
- # Apply those actions
- for sql, params in actions:
+ if actions:
+ # Combine actions together if we can (e.g. postgres)
+ if self.connection.features.supports_combined_alters:
+ sql, params = tuple(zip(*actions))
+ actions = [(", ".join(sql), params)]
+ # Apply those actions
+ for sql, params in actions:
+ self.execute(
+ self.sql_alter_column % {
+ "table": self.quote_name(model._meta.db_table),
+ "changes": sql,
+ },
+ params,
+ )
+ # Added a unique?
+ if not old_field.unique and new_field.unique:
self.execute(
- self.sql_alter_column % {
+ self.sql_create_unique % {
"table": self.quote_name(model._meta.db_table),
- "changes": sql,
- },
- params,
+ "name": self._create_index_name(model, [new_field.column], suffix="_uniq"),
+ "columns": self.quote_name(new_field.column),
+ }
)
+
+ def _create_index_name(self, model, column_names, suffix=""):
+ "Generates a unique name for an index/unique constraint."
+ # If there is just one column in the index, use a default algorithm from Django
+ if len(column_names) == 1 and not suffix:
+ return truncate_name(
+ '%s_%s' % (model._meta.db_table, BaseDatabaseCreation._digest(column_names[0])),
+ self.connection.ops.max_name_length()
+ )
+ # Else generate the name for the index by South
+ table_name = model._meta.db_table.replace('"', '').replace('.', '_')
+ index_unique_name = '_%x' % abs(hash((table_name, ','.join(column_names))))
+ # If the index name is too long, truncate it
+ index_name = ('%s_%s%s%s' % (table_name, column_names[0], index_unique_name, suffix)).replace('"', '').replace('.', '_')
+ if len(index_name) > self.connection.features.max_index_name_length:
+ part = ('_%s%s%s' % (column_names[0], index_unique_name, suffix))
+ index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part)
+ return index_name
+
+ def _constraint_names(self, model, column_names, unique=None, primary_key=None):
+ "Returns all constraint names matching the columns and conditions"
+ column_names = set(column_names)
+ constraints = self.connection.introspection.get_constraints(self.connection.cursor(), model._meta.db_table)
+ result = []
+ for name, infodict in constraints.items():
+ if column_names == infodict['columns']:
+ if unique is not None and infodict['unique'] != unique:
+ continue
+ if primary_key is not None and infodict['primary_key'] != unique:
+ continue
+ result.append(name)
+ return result
View
13 tests/modeltests/schema/models.py
@@ -12,10 +12,23 @@ class Meta:
managed = False
+class AuthorWithM2M(models.Model):
+ name = models.CharField(max_length=255)
+
+ class Meta:
+ managed = False
+
+
class Book(models.Model):
author = models.ForeignKey(Author)
title = models.CharField(max_length=100)
pub_date = models.DateTimeField()
+ #tags = models.ManyToManyField("Tag", related_name="books")
class Meta:
managed = False
+
+
+class Tag(models.Model):
+ title = models.CharField(max_length=255)
+ slug = models.SlugField(unique=True)
View
132 tests/modeltests/schema/tests.py
@@ -3,9 +3,10 @@
import datetime
from django.test import TestCase
from django.db import connection, DatabaseError, IntegrityError
-from django.db.models.fields import IntegerField, TextField
+from django.db.models.fields import IntegerField, TextField, CharField, SlugField
+from django.db.models.fields.related import ManyToManyField
from django.db.models.loading import cache
-from .models import Author, Book
+from .models import Author, Book, AuthorWithM2M, Tag
class SchemaTests(TestCase):
@@ -17,7 +18,7 @@ class SchemaTests(TestCase):
as the code it is testing.
"""
- models = [Author, Book]
+ models = [Author, Book, AuthorWithM2M, Tag]
# Utility functions
@@ -39,6 +40,17 @@ def tearDown(self):
# Delete any tables made for our models
cursor = connection.cursor()
for model in self.models:
+ # Remove any M2M tables first
+ for field in model._meta.local_many_to_many:
+ try:
+ cursor.execute("DROP TABLE %s CASCADE" % (
+ connection.ops.quote_name(field.rel.through._meta.db_table),
+ ))
+ except DatabaseError:
+ connection.rollback()
+ else:
+ connection.commit()
+ # Then remove the main tables
try:
cursor.execute("DROP TABLE %s CASCADE" % (
connection.ops.quote_name(model._meta.db_table),
@@ -172,3 +184,117 @@ def test_alter(self):
columns = self.column_classes(Author)
self.assertEqual(columns['name'][0], "TextField")
self.assertEqual(columns['name'][1][6], True)
+
+ def test_rename(self):
+ """
+ Tests simple altering of fields
+ """
+ # Create the table
+ editor = connection.schema_editor()
+ editor.start()
+ editor.create_model(Author)
+ editor.commit()
+ # Ensure the field is right to begin with
+ columns = self.column_classes(Author)
+ self.assertEqual(columns['name'][0], "CharField")
+ self.assertEqual(columns['name'][1][3], 255)
+ self.assertNotIn("display_name", columns)
+ # Alter the name field's name
+ new_field = CharField(max_length=254)
+ new_field.set_attributes_from_name("display_name")
+ editor = connection.schema_editor()
+ editor.start()
+ editor.alter_field(
+ Author,
+ Author._meta.get_field_by_name("name")[0],
+ new_field,
+ )
+ editor.commit()
+ # Ensure the field is right afterwards
+ columns = self.column_classes(Author)
+ self.assertEqual(columns['display_name'][0], "CharField")
+ self.assertEqual(columns['display_name'][1][3], 254)
+ self.assertNotIn("name", columns)
+
+ def test_m2m(self):
+ """
+ Tests adding/removing M2M fields on models
+ """
+ # Create the tables
+ editor = connection.schema_editor()
+ editor.start()
+ editor.create_model(AuthorWithM2M)
+ editor.create_model(Tag)
+ editor.commit()
+ # Create an M2M field
+ new_field = ManyToManyField("schema.Tag", related_name="authors")
+ new_field.contribute_to_class(AuthorWithM2M, "tags")
+ # Ensure there's no m2m table there
+ self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
+ connection.rollback()
+ # Add the field
+ editor = connection.schema_editor()
+ editor.start()
+ editor.create_field(
+ Author,
+ new_field,
+ )
+ editor.commit()
+ # Ensure there is now an m2m table there
+ columns = self.column_classes(new_field.rel.through)
+ self.assertEqual(columns['tag_id'][0], "IntegerField")
+ # Remove the M2M table again
+ editor = connection.schema_editor()
+ editor.start()
+ editor.delete_field(
+ Author,
+ new_field,
+ )
+ editor.commit()
+ # Ensure there's no m2m table there
+ self.assertRaises(DatabaseError, self.column_classes, new_field.rel.through)
+ connection.rollback()
+
+ def test_unique(self):
+ """
+ Tests removing and adding unique constraints to a single column.
+ """
+ # Create the table
+ editor = connection.schema_editor()
+ editor.start()
+ editor.create_model(Tag)
+ editor.commit()
+ # Ensure the field is unique to begin with
+ Tag.objects.create(title="foo", slug="foo")
+ self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
+ connection.rollback()
+ # Alter the slug field to be non-unique
+ new_field = SlugField(unique=False)
+ new_field.set_attributes_from_name("slug")
+ editor = connection.schema_editor()
+ editor.start()
+ editor.alter_field(
+ Tag,
+ Tag._meta.get_field_by_name("slug")[0],
+ new_field,
+ )
+ editor.commit()
+ # Ensure the field is no longer unique
+ Tag.objects.create(title="foo", slug="foo")
+ Tag.objects.create(title="bar", slug="foo")
+ connection.rollback()
+ # Alter the slug field to be non-unique
+ new_new_field = SlugField(unique=True)
+ new_new_field.set_attributes_from_name("slug")
+ editor = connection.schema_editor()
+ editor.start()
+ editor.alter_field(
+ Tag,
+ new_field,
+ new_new_field,
+ )
+ editor.commit()
+ # Ensure the field is unique again
+ Tag.objects.create(title="foo", slug="foo")
+ self.assertRaises(IntegrityError, Tag.objects.create, title="bar", slug="foo")
+ connection.rollback()

0 comments on commit b139315

Please sign in to comment.
Something went wrong with that request. Please try again.