Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Added savepoint support to the transaction code.

This is a no-op for most databases. Only necessary on PostgreSQL so that we can
do things which will possibly intentionally raise an IntegrityError and not
have to rollback the entire transaction. Not supported for PostgreSQL versions
prior to 8.0, so should be used sparingly in internal Django code.


git-svn-id: http://code.djangoproject.com/svn/django/trunk@8314 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 220993bcc52e78520d8e6cc8aa022608eac10b2a 1 parent e73bf2b
@malcolmt malcolmt authored
View
51 django/db/backends/__init__.py
@@ -9,7 +9,7 @@
except NameError:
# Python 2.3 compat
from sets import Set as set
-
+
from django.db.backends import util
from django.utils import datetime_safe
@@ -31,6 +31,21 @@ def _rollback(self):
if self.connection is not None:
return self.connection.rollback()
+ def _savepoint(self, sid):
+ if not self.features.uses_savepoints:
+ return
+ self.connection.cursor().execute(self.ops.savepoint_create_sql(sid))
+
+ def _savepoint_rollback(self, sid):
+ if not self.features.uses_savepoints:
+ return
+ self.connection.cursor().execute(self.ops.savepoint_rollback_sql(sid))
+
+ def _savepoint_commit(self, sid):
+ if not self.features.uses_savepoints:
+ return
+ self.connection.cursor().execute(self.ops.savepoint_commit_sql(sid))
+
def close(self):
if self.connection is not None:
self.connection.close()
@@ -55,6 +70,7 @@ class BaseDatabaseFeatures(object):
update_can_self_select = True
interprets_empty_strings_as_nulls = False
can_use_chunked_reads = True
+ uses_savepoints = False
class BaseDatabaseOperations(object):
"""
@@ -226,6 +242,26 @@ def regex_lookup(self, lookup_type):
"""
raise NotImplementedError
+ def savepoint_create_sql(self, sid):
+ """
+ Returns the SQL for starting a new savepoint. Only required if the
+ "uses_savepoints" feature is True. The "sid" parameter is a string
+ for the savepoint id.
+ """
+ raise NotImplementedError
+
+ def savepoint_commit_sql(self, sid):
+ """
+ Returns the SQL for committing the given savepoint.
+ """
+ raise NotImplementedError
+
+ def savepoint_rollback_sql(self, sid):
+ """
+ Returns the SQL for rolling back the given savepoint.
+ """
+ raise NotImplementedError
+
def sql_flush(self, style, tables, sequences):
"""
Returns a list of SQL statements required to remove all data from
@@ -259,7 +295,7 @@ def sql_for_tablespace(self, tablespace, inline=False):
a tablespace. Returns '' if the backend doesn't use tablespaces.
"""
return ''
-
+
def prep_for_like_query(self, x):
"""Prepares a value for use in a LIKE query."""
from django.utils.encoding import smart_unicode
@@ -336,11 +372,11 @@ def __init__(self, connection):
def table_name_converter(self, name):
"""Apply a conversion to the name for the purposes of comparison.
-
+
The default table name converter is for case sensitive comparison.
"""
return name
-
+
def table_names(self):
"Returns a list of names of all tables that exist in the database."
cursor = self.connection.cursor()
@@ -371,10 +407,10 @@ def installed_models(self, tables):
for app in models.get_apps():
for model in models.get_models(app):
all_models.append(model)
- return set([m for m in all_models
+ return set([m for m in all_models
if self.table_name_converter(m._meta.db_table) in map(self.table_name_converter, tables)
])
-
+
def sequence_list(self):
"Returns a list of information about all DB sequences for all models in all apps."
from django.db import models
@@ -393,8 +429,7 @@ def sequence_list(self):
sequence_list.append({'table': f.m2m_db_table(), 'column': None})
return sequence_list
-
-
+
class BaseDatabaseClient(object):
"""
This class encapsualtes all backend-specific methods for opening a
View
7 django/db/backends/postgresql/base.py
@@ -63,6 +63,9 @@ def __getattr__(self, attr):
def __iter__(self):
return iter(self.cursor)
+class DatabaseFeatures(BaseDatabaseFeatures):
+ uses_savepoints = True
+
class DatabaseWrapper(BaseDatabaseWrapper):
operators = {
'exact': '= %s',
@@ -83,8 +86,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
-
- self.features = BaseDatabaseFeatures()
+
+ self.features = DatabaseFeatures()
self.ops = DatabaseOperations()
self.client = DatabaseClient()
self.creation = DatabaseCreation(self)
View
10 django/db/backends/postgresql/operations.py
@@ -124,3 +124,13 @@ def sequence_reset_sql(self, style, model_list):
style.SQL_KEYWORD('FROM'),
style.SQL_TABLE(qn(f.m2m_db_table()))))
return output
+
+ def savepoint_create_sql(self, sid):
+ return "SAVEPOINT %s" % sid
+
+ def savepoint_commit_sql(self, sid):
+ return "RELEASE SAVEPOINT %s" % sid
+
+ def savepoint_rollback_sql(self, sid):
+ return "ROLLBACK TO SAVEPOINT %s" % sid
+
View
1  django/db/backends/postgresql_psycopg2/base.py
@@ -26,6 +26,7 @@
class DatabaseFeatures(BaseDatabaseFeatures):
needs_datetime_string_cast = False
+ uses_savepoints = True
class DatabaseOperations(PostgresqlDatabaseOperations):
def last_executed_query(self, cursor, sql, params):
View
35 django/db/transaction.py
@@ -19,7 +19,7 @@
try:
from functools import wraps
except ImportError:
- from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
+ from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
from django.db import connection
from django.conf import settings
@@ -30,9 +30,10 @@ class TransactionManagementError(Exception):
"""
pass
-# The state is a dictionary of lists. The key to the dict is the current
+# The states are dictionaries of lists. The key to the dict is the current
# thread and the list is handled as a stack of values.
state = {}
+savepoint_state = {}
# The dirty flag is set by *_unless_managed functions to denote that the
# code under transaction management has changed things to require a
@@ -164,6 +165,36 @@ def rollback():
connection._rollback()
set_clean()
+def savepoint():
+ """
+ Creates a savepoint (if supported and required by the backend) inside the
+ current transaction. Returns an identifier for the savepoint that will be
+ used for the subsequent rollback or commit.
+ """
+ thread_ident = thread.get_ident()
+ if thread_ident in savepoint_state:
+ savepoint_state[thread_ident].append(None)
+ else:
+ savepoint_state[thread_ident] = [None]
+ tid = str(thread_ident).replace('-', '')
+ sid = "s%s_x%d" % (tid, len(savepoint_state[thread_ident]))
+ connection._savepoint(sid)
+ return sid
+
+def savepoint_rollback(sid):
+ """
+ Rolls back the most recent savepoint (if one exists). Does nothing if
+ savepoints are not supported.
+ """
+ connection._savepoint_rollback(sid)
+
+def savepoint_commit(sid):
+ """
+ Commits the most recent savepoint (if one exists). Does nothing if
+ savepoints are not supported.
+ """
+ connection._savepoint_commit(sid)
+
##############
# DECORATORS #
##############
View
8 tests/modeltests/force_insert_update/models.py
@@ -2,7 +2,7 @@
Tests for forcing insert and update queries (instead of Django's normal
automatic behaviour).
"""
-from django.db import models
+from django.db import models, transaction
class Counter(models.Model):
name = models.CharField(max_length = 10)
@@ -40,15 +40,13 @@ class WithCustomPK(models.Model):
>>> c1.save(force_insert=True)
# Won't work because we can't insert a pk of the same value.
+>>> sid = transaction.savepoint()
>>> c.value = 5
>>> c.save(force_insert=True)
Traceback (most recent call last):
...
IntegrityError: ...
-
-# Work around transaction failure cleaning up for PostgreSQL.
->>> from django.db import connection
->>> connection.close()
+>>> transaction.savepoint_rollback(sid)
# Trying to update should still fail, even with manual primary keys, if the
# data isn't in the database already.
View
8 tests/modeltests/one_to_one/models.py
@@ -6,7 +6,7 @@
In this example, a ``Place`` optionally can be a ``Restaurant``.
"""
-from django.db import models, connection
+from django.db import models, transaction
class Place(models.Model):
name = models.CharField(max_length=50)
@@ -178,13 +178,11 @@ def __unicode__(self):
# This will fail because each one-to-one field must be unique (and link2=o1 was
# used for x1, above).
+>>> sid = transaction.savepoint()
>>> MultiModel(link1=p2, link2=o1, name="x1").save()
Traceback (most recent call last):
...
IntegrityError: ...
+>>> transaction.savepoint_rollback(sid)
-# Because the unittests all use a single connection, we need to force a
-# reconnect here to ensure the connection is clean (after the previous
-# IntegrityError).
->>> connection.close()
"""}
Please sign in to comment.
Something went wrong with that request. Please try again.