Navigation Menu

Skip to content

Commit

Permalink
Refactor all uses of thread locals to be more consistant and sane.
Browse files Browse the repository at this point in the history
git-svn-id: http://code.djangoproject.com/svn/django/trunk@15232 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information
alex committed Jan 17, 2011
1 parent 964cf1b commit fcbf881
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 100 deletions.
24 changes: 10 additions & 14 deletions django/core/urlresolvers.py
Expand Up @@ -8,6 +8,7 @@
""" """


import re import re
from threading import local


from django.http import Http404 from django.http import Http404
from django.conf import settings from django.conf import settings
Expand All @@ -17,18 +18,18 @@
from django.utils.functional import memoize from django.utils.functional import memoize
from django.utils.importlib import import_module from django.utils.importlib import import_module
from django.utils.regex_helper import normalize from django.utils.regex_helper import normalize
from django.utils.thread_support import currentThread


_resolver_cache = {} # Maps URLconf modules to RegexURLResolver instances. _resolver_cache = {} # Maps URLconf modules to RegexURLResolver instances.
_callable_cache = {} # Maps view and url pattern names to their view functions. _callable_cache = {} # Maps view and url pattern names to their view functions.


# SCRIPT_NAME prefixes for each thread are stored here. If there's no entry for # SCRIPT_NAME prefixes for each thread are stored here. If there's no entry for
# the current thread (which is the only one we ever access), it is assumed to # the current thread (which is the only one we ever access), it is assumed to
# be empty. # be empty.
_prefixes = {} _prefixes = local()


# Overridden URLconfs for each thread are stored here. # Overridden URLconfs for each thread are stored here.
_urlconfs = {} _urlconfs = local()



class ResolverMatch(object): class ResolverMatch(object):
def __init__(self, func, args, kwargs, url_name=None, app_name=None, namespaces=None): def __init__(self, func, args, kwargs, url_name=None, app_name=None, namespaces=None):
Expand Down Expand Up @@ -401,35 +402,30 @@ def set_script_prefix(prefix):
""" """
if not prefix.endswith('/'): if not prefix.endswith('/'):
prefix += '/' prefix += '/'
_prefixes[currentThread()] = prefix _prefixes.value = prefix


def get_script_prefix(): def get_script_prefix():
""" """
Returns the currently active script prefix. Useful for client code that Returns the currently active script prefix. Useful for client code that
wishes to construct their own URLs manually (although accessing the request wishes to construct their own URLs manually (although accessing the request
instance is normally going to be a lot cleaner). instance is normally going to be a lot cleaner).
""" """
return _prefixes.get(currentThread(), u'/') return getattr(_prefixes, "value", u'/')


def set_urlconf(urlconf_name): def set_urlconf(urlconf_name):
""" """
Sets the URLconf for the current thread (overriding the default one in Sets the URLconf for the current thread (overriding the default one in
settings). Set to None to revert back to the default. settings). Set to None to revert back to the default.
""" """
thread = currentThread()
if urlconf_name: if urlconf_name:
_urlconfs[thread] = urlconf_name _urlconfs.value = urlconf_name
else: else:
# faster than wrapping in a try/except if hasattr(_urlconfs, "value"):
if thread in _urlconfs: del _urlconfs.value
del _urlconfs[thread]


def get_urlconf(default=None): def get_urlconf(default=None):
""" """
Returns the root URLconf to use for the current thread if it has been Returns the root URLconf to use for the current thread if it has been
changed from the default one. changed from the default one.
""" """
thread = currentThread() return getattr(_urlconfs, "value", default)
if thread in _urlconfs:
return _urlconfs[thread]
return default
5 changes: 5 additions & 0 deletions django/db/backends/__init__.py
Expand Up @@ -25,6 +25,11 @@ def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
self.alias = alias self.alias = alias
self.use_debug_cursor = None self.use_debug_cursor = None


# Transaction related attributes
self.transaction_state = []
self.savepoint_state = 0
self.dirty = None

def __eq__(self, other): def __eq__(self, other):
return self.alias == other.alias return self.alias == other.alias


Expand Down
103 changes: 47 additions & 56 deletions django/db/transaction.py
Expand Up @@ -25,26 +25,14 @@
from django.conf import settings from django.conf import settings
from django.db import connections, DEFAULT_DB_ALIAS from django.db import connections, DEFAULT_DB_ALIAS



class TransactionManagementError(Exception): class TransactionManagementError(Exception):
""" """
This exception is thrown when something bad happens with transaction This exception is thrown when something bad happens with transaction
management. management.
""" """
pass pass


# The states are dictionaries of dictionaries of lists. The key to the outer
# dict is the current thread, and the key to the inner dictionary is the
# connection alias and the list is handled as a stack of values.
state = {}
savepoint_state = {}

# The dirty flag is set by *_unless_managed functions to denote that the
# code under transaction management has changed things to require a
# database commit.
# This is a dictionary mapping thread to a dictionary mapping connection
# alias to a boolean.
dirty = {}

def enter_transaction_management(managed=True, using=None): def enter_transaction_management(managed=True, using=None):
""" """
Enters transaction management for a running thread. It must be balanced with Enters transaction management for a running thread. It must be balanced with
Expand All @@ -58,15 +46,14 @@ def enter_transaction_management(managed=True, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in state and state[thread_ident].get(using): if connection.transaction_state:
state[thread_ident][using].append(state[thread_ident][using][-1]) connection.transaction_state.append(connection.transaction_state[-1])
else: else:
state.setdefault(thread_ident, {}) connection.transaction_state.append(settings.TRANSACTIONS_MANAGED)
state[thread_ident][using] = [settings.TRANSACTIONS_MANAGED]
if thread_ident not in dirty or using not in dirty[thread_ident]: if connection.dirty is None:
dirty.setdefault(thread_ident, {}) connection.dirty = False
dirty[thread_ident][using] = False
connection._enter_transaction_management(managed) connection._enter_transaction_management(managed)


def leave_transaction_management(using=None): def leave_transaction_management(using=None):
Expand All @@ -78,16 +65,18 @@ def leave_transaction_management(using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]

connection._leave_transaction_management(is_managed(using=using)) connection._leave_transaction_management(is_managed(using=using))
thread_ident = thread.get_ident() if connection.transaction_state:
if thread_ident in state and state[thread_ident].get(using): del connection.transaction_state[-1]
del state[thread_ident][using][-1]
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction "
if dirty.get(thread_ident, {}).get(using, False): "management")
if connection.dirty:
rollback(using=using) rollback(using=using)
raise TransactionManagementError("Transaction managed block ended with pending COMMIT/ROLLBACK") raise TransactionManagementError("Transaction managed block ended with "
dirty[thread_ident][using] = False "pending COMMIT/ROLLBACK")
connection.dirty = False


def is_dirty(using=None): def is_dirty(using=None):
""" """
Expand All @@ -96,7 +85,9 @@ def is_dirty(using=None):
""" """
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
return dirty.get(thread.get_ident(), {}).get(using, False) connection = connections[using]

return connection.dirty


def set_dirty(using=None): def set_dirty(using=None):
""" """
Expand All @@ -106,11 +97,13 @@ def set_dirty(using=None):
""" """
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
thread_ident = thread.get_ident() connection = connections[using]
if thread_ident in dirty and using in dirty[thread_ident]:
dirty[thread_ident][using] = True if connection.dirty is not None:
connection.dirty = True
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction "
"management")


def set_clean(using=None): def set_clean(using=None):
""" """
Expand All @@ -120,30 +113,29 @@ def set_clean(using=None):
""" """
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
thread_ident = thread.get_ident() connection = connections[using]
if thread_ident in dirty and using in dirty[thread_ident]:
dirty[thread_ident][using] = False if connection.dirty is not None:
connection.dirty = False
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction management")
clean_savepoints(using=using) clean_savepoints(using=using)


def clean_savepoints(using=None): def clean_savepoints(using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
thread_ident = thread.get_ident() connection = connections[using]
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: connection.savepoint_state = 0
del savepoint_state[thread_ident][using]


def is_managed(using=None): def is_managed(using=None):
""" """
Checks whether the transaction manager is in manual or in auto state. Checks whether the transaction manager is in manual or in auto state.
""" """
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
thread_ident = thread.get_ident() connection = connections[using]
if thread_ident in state and using in state[thread_ident]: if connection.transaction_state:
if state[thread_ident][using]: return connection.transaction_state[-1]
return state[thread_ident][using][-1]
return settings.TRANSACTIONS_MANAGED return settings.TRANSACTIONS_MANAGED


def managed(flag=True, using=None): def managed(flag=True, using=None):
Expand All @@ -156,15 +148,16 @@ def managed(flag=True, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident()
top = state.get(thread_ident, {}).get(using, None) top = connection.transaction_state
if top: if top:
top[-1] = flag top[-1] = flag
if not flag and is_dirty(using=using): if not flag and is_dirty(using=using):
connection._commit() connection._commit()
set_clean(using=using) set_clean(using=using)
else: else:
raise TransactionManagementError("This code isn't under transaction management") raise TransactionManagementError("This code isn't under transaction "
"management")


def commit_unless_managed(using=None): def commit_unless_managed(using=None):
""" """
Expand Down Expand Up @@ -221,13 +214,11 @@ def savepoint(using=None):
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident() thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]:
savepoint_state[thread_ident][using].append(None) connection.savepoint_state += 1
else:
savepoint_state.setdefault(thread_ident, {})
savepoint_state[thread_ident][using] = [None]
tid = str(thread_ident).replace('-', '') tid = str(thread_ident).replace('-', '')
sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident][using])) sid = "s%s_x%d" % (tid, connection.savepoint_state)
connection._savepoint(sid) connection._savepoint(sid)
return sid return sid


Expand All @@ -239,8 +230,8 @@ def savepoint_rollback(sid, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: if connection.savepoint_state:
connection._savepoint_rollback(sid) connection._savepoint_rollback(sid)


def savepoint_commit(sid, using=None): def savepoint_commit(sid, using=None):
Expand All @@ -251,8 +242,8 @@ def savepoint_commit(sid, using=None):
if using is None: if using is None:
using = DEFAULT_DB_ALIAS using = DEFAULT_DB_ALIAS
connection = connections[using] connection = connections[using]
thread_ident = thread.get_ident()
if thread_ident in savepoint_state and using in savepoint_state[thread_ident]: if connection.savepoint_state:
connection._savepoint_commit(sid) connection._savepoint_commit(sid)


############## ##############
Expand Down
12 changes: 0 additions & 12 deletions django/utils/thread_support.py

This file was deleted.

31 changes: 16 additions & 15 deletions django/utils/translation/trans_real.py
Expand Up @@ -7,15 +7,16 @@
import warnings import warnings
import gettext as gettext_module import gettext as gettext_module
from cStringIO import StringIO from cStringIO import StringIO
from threading import local


from django.utils.importlib import import_module from django.utils.importlib import import_module
from django.utils.safestring import mark_safe, SafeData from django.utils.safestring import mark_safe, SafeData
from django.utils.thread_support import currentThread


# Translations are cached in a dictionary for every language+app tuple. # Translations are cached in a dictionary for every language+app tuple.
# The active translations are stored by threadid to make them thread local. # The active translations are stored by threadid to make them thread local.
_translations = {} _translations = {}
_active = {} _active = local()


# The default translation is based on the settings file. # The default translation is based on the settings file.
_default = None _default = None
Expand Down Expand Up @@ -197,28 +198,27 @@ def activate(language):
"Please use the 'nb' translation instead.", "Please use the 'nb' translation instead.",
DeprecationWarning DeprecationWarning
) )
_active[currentThread()] = translation(language) _active.value = translation(language)


def deactivate(): def deactivate():
""" """
Deinstalls the currently active translation object so that further _ calls Deinstalls the currently active translation object so that further _ calls
will resolve against the default translation object, again. will resolve against the default translation object, again.
""" """
global _active if hasattr(_active, "value"):
if currentThread() in _active: del _active.value
del _active[currentThread()]


def deactivate_all(): def deactivate_all():
""" """
Makes the active translation object a NullTranslations() instance. This is Makes the active translation object a NullTranslations() instance. This is
useful when we want delayed translations to appear as the original string useful when we want delayed translations to appear as the original string
for some reason. for some reason.
""" """
_active[currentThread()] = gettext_module.NullTranslations() _active.value = gettext_module.NullTranslations()


def get_language(): def get_language():
"""Returns the currently selected language.""" """Returns the currently selected language."""
t = _active.get(currentThread(), None) t = getattr(_active, "value", None)
if t is not None: if t is not None:
try: try:
return t.to_language() return t.to_language()
Expand Down Expand Up @@ -246,8 +246,9 @@ def catalog():
This can be used if you need to modify the catalog or want to access the This can be used if you need to modify the catalog or want to access the
whole message catalog instead of just translating one string. whole message catalog instead of just translating one string.
""" """
global _default, _active global _default
t = _active.get(currentThread(), None)
t = getattr(_active, "value", None)
if t is not None: if t is not None:
return t return t
if _default is None: if _default is None:
Expand All @@ -262,9 +263,10 @@ def do_translate(message, translation_function):
translation object to use. If no current translation is activated, the translation object to use. If no current translation is activated, the
message will be run through the default translation object. message will be run through the default translation object.
""" """
global _default

eol_message = message.replace('\r\n', '\n').replace('\r', '\n') eol_message = message.replace('\r\n', '\n').replace('\r', '\n')
global _default, _active t = getattr(_active, "value", None)
t = _active.get(currentThread(), None)
if t is not None: if t is not None:
result = getattr(t, translation_function)(eol_message) result = getattr(t, translation_function)(eol_message)
else: else:
Expand Down Expand Up @@ -300,9 +302,9 @@ def gettext_noop(message):
return message return message


def do_ntranslate(singular, plural, number, translation_function): def do_ntranslate(singular, plural, number, translation_function):
global _default, _active global _default


t = _active.get(currentThread(), None) t = getattr(_active, "value", None)
if t is not None: if t is not None:
return getattr(t, translation_function)(singular, plural, number) return getattr(t, translation_function)(singular, plural, number)
if _default is None: if _default is None:
Expand Down Expand Up @@ -587,4 +589,3 @@ def get_partial_date_formats():
if month_day_format == 'MONTH_DAY_FORMAT': if month_day_format == 'MONTH_DAY_FORMAT':
month_day_format = settings.MONTH_DAY_FORMAT month_day_format = settings.MONTH_DAY_FORMAT
return year_month_format, month_day_format return year_month_format, month_day_format

0 comments on commit fcbf881

Please sign in to comment.