Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP

Loading…

Correctly propagate the fact that DB.rollback was called #1298

Merged
merged 1 commit into from

2 participants

@jeppenejsum
Owner

to any functions registered with appendPostTransaction. Closes #1295

@dpp dpp merged commit 001d956 into master
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Commits on Jul 20, 2012
  1. @jeppenejsum

    Correctly propagate the fact that DB.rollback was called to any funct…

    jeppenejsum authored
    …ions registered with appendPostTransaction. Closes #1295
This page is out of date. Refresh to see the latest.
View
42 persistence/db/src/main/scala/net/liftweb/db/DB.scala
@@ -119,7 +119,7 @@ trait DB extends Loggable {
threadLocalConnectionManagers.doWith(newMap)(f)
}
- case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit])
+ case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit], rolledBack: Boolean)
private def info: HashMap[ConnectionIdentifier, ConnectionHolder] = {
threadStore.get match {
@@ -292,8 +292,8 @@ trait DB extends Loggable {
private def getConnection(name: ConnectionIdentifier): SuperConnection = {
logger.trace("Acquiring " + name + " On thread " + Thread.currentThread)
var ret = info.get(name) match {
- case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil)
- case Some(ConnectionHolder(conn, cnt, post)) => ConnectionHolder(conn, cnt + 1, post)
+ case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil, false)
+ case Some(ConnectionHolder(conn, cnt, post, rb)) => ConnectionHolder(conn, cnt + 1, post, rb)
}
info(name) = ret
logger.trace("Acquired " + name + " on thread " + Thread.currentThread +
@@ -302,22 +302,24 @@ trait DB extends Loggable {
}
private def releaseConnectionNamed(name: ConnectionIdentifier, rollback: Boolean) {
- logger.trace("Request to release " + name + " on thread " + Thread.currentThread)
+ logger.trace("Request to release %s on thread %s, auto rollback=%s".format(name,Thread.currentThread, rollback))
+
(info.get(name): @unchecked) match {
- case Some(ConnectionHolder(c, 1, post)) => {
- if (! c.getAutoCommit()) {
+ case Some(ConnectionHolder(c, 1, post, manualRollback)) => {
+ if (! (c.getAutoCommit() || manualRollback)) {
if (rollback) tryo{c.rollback}
else c.commit
}
tryo(c.releaseFunc())
info -= name
- logger.trace("Invoking %d postTransaction functions ".format(post.size))
- post.reverse.foreach(f => tryo(f(!rollback)))
- logger.trace("Released " + name + " on thread " + Thread.currentThread)
+ val rolledback = rollback | manualRollback
+ logger.trace("Invoking %d postTransaction functions. rollback=%s".format(post.size, rolledback))
+ post.reverse.foreach(f => tryo(f(!rolledback)))
+ logger.trace("Released %s on thread %s".format(name,Thread.currentThread))
}
- case Some(ConnectionHolder(c, n, post)) =>
+ case Some(ConnectionHolder(c, n, post, rb)) =>
logger.trace("Did not release " + name + " on thread " + Thread.currentThread + " count " + (n - 1))
- info(name) = ConnectionHolder(c, n - 1, post)
+ info(name) = ConnectionHolder(c, n - 1, post, rb)
case x =>
// ignore
}
@@ -340,8 +342,8 @@ trait DB extends Loggable {
*/
def appendPostTransaction(name: ConnectionIdentifier, func: Boolean => Unit) {
info.get(name) match {
- case Some(ConnectionHolder(c, n, post)) =>
- info(name) = ConnectionHolder(c, n, func :: post)
+ case Some(ConnectionHolder(c, n, post, rb)) =>
+ info(name) = ConnectionHolder(c, n, func :: post, rb)
logger.trace("Appended postTransaction function on %s, new count=%d".format(name, post.size+1))
case _ => throw new IllegalStateException("Tried to append postTransaction function on illegal ConnectionIdentifer or outside transaction context")
}
@@ -567,7 +569,17 @@ trait DB extends Loggable {
use(DefaultConnectionIdentifier)(conn => exec(conn, query)(resultSetToAny))
- def rollback(name: ConnectionIdentifier) = use(name)(conn => conn.rollback)
+ def rollback(name: ConnectionIdentifier): Unit = {
+ info.get(name) match {
+ case Some(ConnectionHolder(c, n, post, _)) =>
+ info(name) = ConnectionHolder(c, n, post, true)
+ logger.trace("Manual rollback on %s".format(name))
+ use(name)(conn => conn.rollback)
+ case _ => throw new IllegalStateException("Tried to rollback transaction on illegal ConnectionIdentifer or outside transaction context")
+ }
+ }
+
+ def rollback: Unit = rollback(DefaultConnectionIdentifier)
/**
* Executes { @code statement } and converts the { @code ResultSet } to model
@@ -671,6 +683,8 @@ trait DB extends Loggable {
/**
* Executes function { @code f } with the connection named { @code name }. Releases the connection
* before returning.
+ *
+ * Only use within a stateful request
*/
def use[T](name: ConnectionIdentifier)(f: (SuperConnection) => T): T = {
val conn = getConnection(name)
View
17 persistence/db/src/test/scala/net/liftweb/db/DBSpec.scala
@@ -150,4 +150,21 @@ class DBSpec extends Specification with Mockito {
DB.appendPostTransaction {committed => ()} must throwA[IllegalStateException]
}
}
+
+ "DB.rollback" should {
+ "call postTransaction functions with false" in {
+ val m = mock[CommitFunc]
+ val activeConnection = mock[Connection]
+ DB.defineConnectionManager(DefaultConnectionIdentifier, dBVendor(activeConnection))
+
+ tryo(DB.use(DefaultConnectionIdentifier) {c =>
+ DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
+ DB.rollback(DefaultConnectionIdentifier)
+ 42
+ })
+
+ there was one(activeConnection).rollback
+ there was one(m).f(false)
+ }
+ }
}
Something went wrong with that request. Please try again.