Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Fixed #17258 -- Moved `threading.local` from `DatabaseWrapper` to the…

… `django.db.connections` dictionary. This allows connections to be explicitly shared between multiple threads and is particularly useful for enabling the sharing of in-memory SQLite connections. Many thanks to Anssi Kääriäinen for the excellent suggestions and feedback, and to Alex Gaynor for the reviews. Refs #2879.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@17205 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 503f261cdee962f4cc26dbfc3d4a70165757822e 1 parent 7a3561b
julien authored
View
16 django/db/__init__.py
@@ -22,9 +22,21 @@
# we manually create the dictionary from the settings, passing only the
# settings that the database backends care about. Note that TIME_ZONE is used
# by the PostgreSQL backends.
-# we load all these up for backwards compatibility, you should use
+# We load all these up for backwards compatibility, you should use
# connections['default'] instead.
-connection = connections[DEFAULT_DB_ALIAS]
+class DefaultConnectionProxy(object):
+ """
+ Proxy for accessing the default DatabaseWrapper object's attributes. If you
+ need to access the DatabaseWrapper object itself, use
+ connections[DEFAULT_DB_ALIAS] instead.
+ """
+ def __getattr__(self, item):
+ return getattr(connections[DEFAULT_DB_ALIAS], item)
+
+ def __setattr__(self, name, value):
+ return setattr(connections[DEFAULT_DB_ALIAS], name, value)
+
+connection = DefaultConnectionProxy()
backend = load_backend(connection.settings_dict['ENGINE'])
# Register an event that closes the database connection
View
33 django/db/backends/__init__.py
@@ -1,8 +1,9 @@
+from django.db.utils import DatabaseError
+
try:
import thread
except ImportError:
import dummy_thread as thread
-from threading import local
from contextlib import contextmanager
from django.conf import settings
@@ -13,14 +14,15 @@
from django.utils.timezone import is_aware
-class BaseDatabaseWrapper(local):
+class BaseDatabaseWrapper(object):
"""
Represents a database connection.
"""
ops = None
vendor = 'unknown'
- def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
+ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS,
+ allow_thread_sharing=False):
# `settings_dict` should be a dictionary containing keys such as
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
# to disambiguate it from Django settings modules.
@@ -34,6 +36,8 @@ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
self.transaction_state = []
self.savepoint_state = 0
self._dirty = None
+ self._thread_ident = thread.get_ident()
+ self.allow_thread_sharing = allow_thread_sharing
def __eq__(self, other):
return self.alias == other.alias
@@ -116,6 +120,21 @@ def leave_transaction_management(self):
"pending COMMIT/ROLLBACK")
self._dirty = False
+ def validate_thread_sharing(self):
+ """
+ Validates that the connection isn't accessed by another thread than the
+ one which originally created it, unless the connection was explicitly
+ authorized to be shared between threads (via the `allow_thread_sharing`
+ property). Raises an exception if the validation fails.
+ """
+ if (not self.allow_thread_sharing
+ and self._thread_ident != thread.get_ident()):
+ raise DatabaseError("DatabaseWrapper objects created in a "
+ "thread can only be used in that same thread. The object"
+ "with alias '%s' was created in thread id %s and this is "
+ "thread id %s."
+ % (self.alias, self._thread_ident, thread.get_ident()))
+
def is_dirty(self):
"""
Returns True if the current transaction requires a commit for changes to
@@ -179,6 +198,7 @@ def commit_unless_managed(self):
"""
Commits changes if the system is not in managed transaction mode.
"""
+ self.validate_thread_sharing()
if not self.is_managed():
self._commit()
self.clean_savepoints()
@@ -189,6 +209,7 @@ def rollback_unless_managed(self):
"""
Rolls back changes if the system is not in managed transaction mode.
"""
+ self.validate_thread_sharing()
if not self.is_managed():
self._rollback()
else:
@@ -198,6 +219,7 @@ def commit(self):
"""
Does the commit itself and resets the dirty flag.
"""
+ self.validate_thread_sharing()
self._commit()
self.set_clean()
@@ -205,6 +227,7 @@ def rollback(self):
"""
This function does the rollback itself and resets the dirty flag.
"""
+ self.validate_thread_sharing()
self._rollback()
self.set_clean()
@@ -228,6 +251,7 @@ def savepoint_rollback(self, sid):
Rolls back the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
+ self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_rollback(sid)
@@ -236,6 +260,7 @@ def savepoint_commit(self, sid):
Commits the most recent savepoint (if one exists). Does nothing if
savepoints are not supported.
"""
+ self.validate_thread_sharing()
if self.savepoint_state:
self._savepoint_commit(sid)
@@ -269,11 +294,13 @@ def check_constraints(self, table_names=None):
pass
def close(self):
+ self.validate_thread_sharing()
if self.connection is not None:
self.connection.close()
self.connection = None
def cursor(self):
+ self.validate_thread_sharing()
if (self.use_debug_cursor or
(self.use_debug_cursor is None and settings.DEBUG)):
cursor = self.make_debug_cursor(self._cursor())
View
17 django/db/backends/sqlite3/base.py
@@ -7,10 +7,10 @@
import datetime
import decimal
+import warnings
import re
import sys
-from django.conf import settings
from django.db import utils
from django.db.backends import *
from django.db.backends.signals import connection_created
@@ -241,6 +241,21 @@ def _cursor(self):
'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
}
kwargs.update(settings_dict['OPTIONS'])
+ # Always allow the underlying SQLite connection to be shareable
+ # between multiple threads. The safe-guarding will be handled at a
+ # higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
+ # property. This is necessary as the shareability is disabled by
+ # default in pysqlite and it cannot be changed once a connection is
+ # opened.
+ if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
+ warnings.warn(
+ 'The `check_same_thread` option was provided and set to '
+ 'True. It will be overriden with False. Use the '
+ '`DatabaseWrapper.allow_thread_sharing` property instead '
+ 'for controlling thread shareability.',
+ RuntimeWarning
+ )
+ kwargs.update({'check_same_thread': False})
self.connection = Database.connect(**kwargs)
# Register extract, date_trunc, and regexp functions.
self.connection.create_function("django_extract", 2, _sqlite_extract)
View
12 django/db/utils.py
@@ -1,4 +1,5 @@
import os
+from threading import local
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
@@ -50,7 +51,7 @@ class ConnectionDoesNotExist(Exception):
class ConnectionHandler(object):
def __init__(self, databases):
self.databases = databases
- self._connections = {}
+ self._connections = local()
def ensure_defaults(self, alias):
"""
@@ -73,16 +74,19 @@ def ensure_defaults(self, alias):
conn.setdefault(setting, None)
def __getitem__(self, alias):
- if alias in self._connections:
- return self._connections[alias]
+ if hasattr(self._connections, alias):
+ return getattr(self._connections, alias)
self.ensure_defaults(alias)
db = self.databases[alias]
backend = load_backend(db['ENGINE'])
conn = backend.DatabaseWrapper(db, alias)
- self._connections[alias] = conn
+ setattr(self._connections, alias, conn)
return conn
+ def __setitem__(self, key, value):
+ setattr(self._connections, key, value)
+
def __iter__(self):
return iter(self.databases)
View
26 docs/releases/1.4.txt
@@ -673,6 +673,32 @@ datetimes are now stored without time zone information in SQLite. When
:setting:`USE_TZ` is ``False``, if you attempt to save an aware datetime
object, Django raises an exception.
+Database connection's thread-locality
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+``DatabaseWrapper`` objects (i.e. the connection objects referenced by
+``django.db.connection`` and ``django.db.connections["some_alias"]``) used to
+be thread-local. They are now global objects in order to be potentially shared
+between multiple threads. While the individual connection objects are now
+global, the ``django.db.connections`` dictionary referencing those objects is
+still thread-local. Therefore if you just use the ORM or
+``DatabaseWrapper.cursor()`` then the behavior is still the same as before.
+Note, however, that ``django.db.connection`` does not directly reference the
+default ``DatabaseWrapper`` object any more and is now a proxy to access that
+object's attributes. If you need to access the actual ``DatabaseWrapper``
+object, use ``django.db.connections[DEFAULT_DB_ALIAS]`` instead.
+
+As part of this change, all underlying SQLite connections are now enabled for
+potential thread-sharing (by passing the ``check_same_thread=False`` attribute
+to pysqlite). ``DatabaseWrapper`` however preserves the previous behavior by
+disabling thread-sharing by default, so this does not affect any existing
+code that purely relies on the ORM or on ``DatabaseWrapper.cursor()``.
+
+Finally, while it is now possible to pass connections between threads, Django
+does not make any effort to synchronize access to the underlying backend.
+Concurrency behavior is defined by the underlying backend implementation.
+Check their documentation for details.
+
`COMMENTS_BANNED_USERS_GROUP` setting
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
View
94 tests/regressiontests/backends/tests.py
@@ -3,6 +3,7 @@
from __future__ import with_statement, absolute_import
import datetime
+import threading
from django.conf import settings
from django.core.management.color import no_style
@@ -283,7 +284,7 @@ def receiver(sender, connection, **kwargs):
connection_created.connect(receiver)
connection.close()
cursor = connection.cursor()
- self.assertTrue(data["connection"] is connection)
+ self.assertTrue(data["connection"].connection is connection.connection)
connection_created.disconnect(receiver)
data.clear()
@@ -446,3 +447,94 @@ def test_check_constraints(self):
connection.check_constraints()
finally:
transaction.rollback()
+
+
+class ThreadTests(TestCase):
+
+ def test_default_connection_thread_local(self):
+ """
+ Ensure that the default connection (i.e. django.db.connection) is
+ different for each thread.
+ Refs #17258.
+ """
+ connections_set = set()
+ connection.cursor()
+ connections_set.add(connection.connection)
+ def runner():
+ from django.db import connection
+ connection.cursor()
+ connections_set.add(connection.connection)
+ for x in xrange(2):
+ t = threading.Thread(target=runner)
+ t.start()
+ t.join()
+ self.assertEquals(len(connections_set), 3)
+ # Finish by closing the connections opened by the other threads (the
+ # connection opened in the main thread will automatically be closed on
+ # teardown).
+ for conn in connections_set:
+ if conn != connection.connection:
+ conn.close()
+
+ def test_connections_thread_local(self):
+ """
+ Ensure that the connections are different for each thread.
+ Refs #17258.
+ """
+ connections_set = set()
+ for conn in connections.all():
+ connections_set.add(conn)
+ def runner():
+ from django.db import connections
+ for conn in connections.all():
+ connections_set.add(conn)
+ for x in xrange(2):
+ t = threading.Thread(target=runner)
+ t.start()
+ t.join()
+ self.assertEquals(len(connections_set), 6)
+ # Finish by closing the connections opened by the other threads (the
+ # connection opened in the main thread will automatically be closed on
+ # teardown).
+ for conn in connections_set:
+ if conn != connection:
+ conn.close()
+
+ def test_pass_connection_between_threads(self):
+ """
+ Ensure that a connection can be passed from one thread to the other.
+ Refs #17258.
+ """
+ models.Person.objects.create(first_name="John", last_name="Doe")
+
+ def do_thread():
+ def runner(main_thread_connection):
+ from django.db import connections
+ connections['default'] = main_thread_connection
+ try:
+ models.Person.objects.get(first_name="John", last_name="Doe")
+ except DatabaseError, e:
+ exceptions.append(e)
+ t = threading.Thread(target=runner, args=[connections['default']])
+ t.start()
+ t.join()
+
+ # Without touching allow_thread_sharing, which should be False by default.
+ exceptions = []
+ do_thread()
+ # Forbidden!
+ self.assertTrue(isinstance(exceptions[0], DatabaseError))
+
+ # If explicitly setting allow_thread_sharing to False
+ connections['default'].allow_thread_sharing = False
+ exceptions = []
+ do_thread()
+ # Forbidden!
+ self.assertTrue(isinstance(exceptions[0], DatabaseError))
+
+ # If explicitly setting allow_thread_sharing to True
+ connections['default'].allow_thread_sharing = True
+ exceptions = []
+ do_thread()
+ # All good
+ self.assertEqual(len(exceptions), 0)
Please sign in to comment.
Something went wrong with that request. Please try again.