Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Added a ManyToManyField(db_constraint=False) option, this allows not …

…creating constraints on the intermediary models.
  • Loading branch information...
commit bbbd698c7a4dd19e6394660bece7e6e907b0a824 1 parent 4cccb85
Alex Gaynor authored
View
17 django/db/models/fields/related.py
@@ -955,7 +955,9 @@ def __init__(self, to, field_name, related_name=None, limit_choices_to=None,
class ManyToManyRel(object):
def __init__(self, to, related_name=None, limit_choices_to=None,
- symmetrical=True, through=None):
+ symmetrical=True, through=None, db_constraint=True):
+ if through and not db_constraint:
+ raise ValueError("Can't supply a through model and db_constraint=False")
self.to = to
self.related_name = related_name
if limit_choices_to is None:
@@ -964,6 +966,7 @@ def __init__(self, to, related_name=None, limit_choices_to=None,
self.symmetrical = symmetrical
self.multiple = True
self.through = through
+ self.db_constraint = db_constraint
def is_hidden(self):
"Should the related object be hidden?"
@@ -1196,15 +1199,15 @@ def set_managed(field, model, cls):
return type(name, (models.Model,), {
'Meta': meta,
'__module__': klass.__module__,
- from_: models.ForeignKey(klass, related_name='%s+' % name, db_tablespace=field.db_tablespace),
- to: models.ForeignKey(to_model, related_name='%s+' % name, db_tablespace=field.db_tablespace)
+ from_: models.ForeignKey(klass, related_name='%s+' % name, db_tablespace=field.db_tablespace, db_constraint=field.rel.db_constraint),
+ to: models.ForeignKey(to_model, related_name='%s+' % name, db_tablespace=field.db_tablespace, db_constraint=field.rel.db_constraint)
})
class ManyToManyField(RelatedField, Field):
description = _("Many-to-many relationship")
- def __init__(self, to, **kwargs):
+ def __init__(self, to, db_constraint=True, **kwargs):
try:
assert not to._meta.abstract, "%s cannot define a relation with abstract class %s" % (self.__class__.__name__, to._meta.object_name)
except AttributeError: # to._meta doesn't exist, so it must be RECURSIVE_RELATIONSHIP_CONSTANT
@@ -1219,13 +1222,15 @@ def __init__(self, to, **kwargs):
related_name=kwargs.pop('related_name', None),
limit_choices_to=kwargs.pop('limit_choices_to', None),
symmetrical=kwargs.pop('symmetrical', to == RECURSIVE_RELATIONSHIP_CONSTANT),
- through=kwargs.pop('through', None))
+ through=kwargs.pop('through', None),
+ db_constraint=db_constraint,
+ )
self.db_table = kwargs.pop('db_table', None)
if kwargs['rel'].through is not None:
assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used."
- Field.__init__(self, **kwargs)
+ super(ManyToManyField, self).__init__(**kwargs)
msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.')
self.help_text = string_concat(self.help_text, ' ', msg)
View
14 docs/ref/models/fields.txt
@@ -1227,6 +1227,20 @@ that control how the relationship functions.
the table for the model defining the relationship and the name of the field
itself.
+.. attribute:: ManyToManyField.db_constraint
+
+ Controls whether or not constraints should be created in the database for
+ the foreign keys in the intermediary table. The default is ``True``, and
+ that's almost certainly what you want; setting this to ``False`` can be
+ very bad for data integrity. That said, here are some scenarios where you
+ might want to do this:
+
+ * You have legacy data that is not valid.
+ * You're sharding your database.
+
+ It is an error to pass both ``db_constraint`` and ``through``.
+
+
.. _ref-onetoone:
``OneToOneField``
View
4 docs/releases/1.6.txt
@@ -113,8 +113,8 @@ Minor features
* The ``MemcachedCache`` cache backend now uses the latest :mod:`pickle`
protocol available.
-* Added the :attr:`django.db.models.ForeignKey.db_constraint`
- option.
+* Added the :attr:`django.db.models.ForeignKey.db_constraint` and
+ :attr:`django.db.models.ManyToManyField.db_constraint` options.
* The jQuery library embedded in the admin has been upgraded to version 1.9.1.
View
5 tests/backends/models.py
@@ -90,7 +90,10 @@ def __str__(self):
@python_2_unicode_compatible
class Object(models.Model):
- pass
+ related_objects = models.ManyToManyField("self", db_constraint=False, symmetrical=False)
+
+ def __str__(self):
+ return str(self.id)
@python_2_unicode_compatible
View
47 tests/backends/tests.py
@@ -12,13 +12,12 @@
from django.db.backends.signals import connection_created
from django.db.backends.postgresql_psycopg2 import version as pg_version
from django.db.models import Sum, Avg, Variance, StdDev
-from django.db.utils import ConnectionHandler, DatabaseError
+from django.db.utils import ConnectionHandler
from django.test import (TestCase, skipUnlessDBFeature, skipIfDBFeature,
TransactionTestCase)
from django.test.utils import override_settings, str_prefix
-from django.utils import six
+from django.utils import six, unittest
from django.utils.six.moves import xrange
-from django.utils import unittest
from . import models
@@ -52,7 +51,7 @@ def test_dbms_session(self):
convert_unicode = backend.convert_unicode
cursor = connection.cursor()
cursor.callproc(convert_unicode('DBMS_SESSION.SET_IDENTIFIER'),
- [convert_unicode('_django_testing!'),])
+ [convert_unicode('_django_testing!')])
@unittest.skipUnless(connection.vendor == 'oracle',
"No need to check Oracle cursor semantics")
@@ -72,7 +71,7 @@ def test_long_string(self):
c = connection.cursor()
c.execute('CREATE TABLE ltext ("TEXT" NCLOB)')
long_str = ''.join([six.text_type(x) for x in xrange(4000)])
- c.execute('INSERT INTO ltext VALUES (%s)',[long_str])
+ c.execute('INSERT INTO ltext VALUES (%s)', [long_str])
c.execute('SELECT text FROM ltext')
row = c.fetchone()
self.assertEqual(long_str, row[0].read())
@@ -99,6 +98,7 @@ def test_order_of_nls_parameters(self):
c.execute(query)
self.assertEqual(c.fetchone()[0], 1)
+
class MySQLTests(TestCase):
@unittest.skipUnless(connection.vendor == 'mysql',
"Test valid only for MySQL")
@@ -117,7 +117,7 @@ def test_autoincrement(self):
found_reset = False
for sql in statements:
found_reset = found_reset or 'ALTER TABLE' in sql
- if connection.mysql_version < (5,0,13):
+ if connection.mysql_version < (5, 0, 13):
self.assertTrue(found_reset)
else:
self.assertFalse(found_reset)
@@ -182,6 +182,7 @@ def test_no_interpolation_on_sqlite(self):
self.assertEqual(connection.queries[-1]['sql'],
str_prefix("QUERY = %(_)s\"SELECT strftime('%%Y', 'now');\" - PARAMS = ()"))
+
class ParameterHandlingTest(TestCase):
def test_bad_parameter_count(self):
"An executemany call with too many/not enough parameters will raise an exception (Refs #12612)"
@@ -191,8 +192,9 @@ def test_bad_parameter_count(self):
connection.ops.quote_name('root'),
connection.ops.quote_name('square')
))
- self.assertRaises(Exception, cursor.executemany, query, [(1,2,3),])
- self.assertRaises(Exception, cursor.executemany, query, [(1,),])
+ self.assertRaises(Exception, cursor.executemany, query, [(1, 2, 3)])
+ self.assertRaises(Exception, cursor.executemany, query, [(1,)])
+
# Unfortunately, the following tests would be a good test to run on all
# backends, but it breaks MySQL hard. Until #13711 is fixed, it can't be run
@@ -240,6 +242,7 @@ def test_sequence_name_length_limits_flush(self):
for statement in connection.ops.sql_flush(no_style(), tables, sequences):
cursor.execute(statement)
+
class SequenceResetTest(TestCase):
def test_generic_relation(self):
"Sequence names are correct when resetting generic relations (Ref #13941)"
@@ -257,6 +260,7 @@ def test_generic_relation(self):
obj = models.Post.objects.create(name='New post', text='goodbye world')
self.assertTrue(obj.pk > 10)
+
class PostgresVersionTest(TestCase):
def assert_parses(self, version_string, version):
self.assertEqual(pg_version._parse_version(version_string), version)
@@ -291,6 +295,7 @@ def cursor(self):
conn = OlderConnectionMock()
self.assertEqual(pg_version.get_version(conn), 80300)
+
class PostgresNewConnectionTest(TestCase):
"""
#17062: PostgreSQL shouldn't roll back SET TIME ZONE, even if the first
@@ -338,17 +343,18 @@ class ConnectionCreatedSignalTest(TestCase):
@skipUnlessDBFeature('test_db_allows_multiple_connections')
def test_signal(self):
data = {}
+
def receiver(sender, connection, **kwargs):
data["connection"] = connection
connection_created.connect(receiver)
connection.close()
- cursor = connection.cursor()
+ connection.cursor()
self.assertTrue(data["connection"].connection is connection.connection)
connection_created.disconnect(receiver)
data.clear()
- cursor = connection.cursor()
+ connection.cursor()
self.assertTrue(data == {})
@@ -443,7 +449,7 @@ def test_unicode_password(self):
old_password = connection.settings_dict['PASSWORD']
connection.settings_dict['PASSWORD'] = "françois"
try:
- cursor = connection.cursor()
+ connection.cursor()
except DatabaseError:
# As password is probably wrong, a database exception is expected
pass
@@ -470,6 +476,7 @@ def test_duplicate_table_error(self):
with self.assertRaises(DatabaseError):
cursor.execute(query)
+
# We don't make these tests conditional because that means we would need to
# check and differentiate between:
# * MySQL+InnoDB, MySQL+MYISAM (something we currently can't do).
@@ -477,7 +484,6 @@ def test_duplicate_table_error(self):
# on or not, something that would be controlled by runtime support and user
# preference.
# verify if its type is django.database.db.IntegrityError.
-
class FkConstraintsTests(TransactionTestCase):
def setUp(self):
@@ -581,6 +587,7 @@ def test_default_connection_thread_local(self):
connections_dict = {}
connection.cursor()
connections_dict[id(connection)] = connection
+
def runner():
# Passing django.db.connection between threads doesn't work while
# connections[DEFAULT_DB_ALIAS] does.
@@ -602,7 +609,7 @@ def runner():
# Finish by closing the connections opened by the other threads (the
# connection opened in the main thread will automatically be closed on
# teardown).
- for conn in connections_dict.values() :
+ for conn in connections_dict.values():
if conn is not connection:
conn.close()
@@ -616,6 +623,7 @@ def test_connections_thread_local(self):
connections_dict = {}
for conn in connections.all():
connections_dict[id(conn)] = conn
+
def runner():
from django.db import connections
for conn in connections.all():
@@ -682,6 +690,7 @@ def test_closing_non_shared_connections(self):
"""
# First, without explicitly enabling the connection for sharing.
exceptions = set()
+
def runner1():
def runner2(other_thread_connection):
try:
@@ -699,6 +708,7 @@ def runner2(other_thread_connection):
# Then, with explicitly enabling the connection for sharing.
exceptions = set()
+
def runner1():
def runner2(other_thread_connection):
try:
@@ -746,3 +756,14 @@ def test_can_reference_non_existant(self):
with self.assertRaises(models.Object.DoesNotExist):
ref.obj
+
+ def test_many_to_many(self):
+ obj = models.Object.objects.create()
+ obj.related_objects.create()
+ self.assertEqual(models.Object.objects.count(), 2)
+ self.assertEqual(obj.related_objects.count(), 1)
+
+ intermediary_model = models.Object._meta.get_field_by_name("related_objects")[0].rel.through
+ intermediary_model.objects.create(from_object_id=obj.id, to_object_id=12345)
+ self.assertEqual(obj.related_objects.count(), 1)
+ self.assertEqual(intermediary_model.objects.count(), 2)
Please sign in to comment.
Something went wrong with that request. Please try again.