Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Correctly propagate the fact that DB.rollback was called to any funct…

…ions registered with appendPostTransaction. Closes #1295
  • Loading branch information...
commit 83cb3a80357aeaef9eb9034d599ca82d39673f2c 1 parent 1a360c9
@jeppenejsum jeppenejsum authored
View
42 persistence/db/src/main/scala/net/liftweb/db/DB.scala
@@ -119,7 +119,7 @@ trait DB extends Loggable {
threadLocalConnectionManagers.doWith(newMap)(f)
}
- case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit])
+ case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit], rolledBack: Boolean)
private def info: HashMap[ConnectionIdentifier, ConnectionHolder] = {
threadStore.get match {
@@ -292,8 +292,8 @@ trait DB extends Loggable {
private def getConnection(name: ConnectionIdentifier): SuperConnection = {
logger.trace("Acquiring " + name + " On thread " + Thread.currentThread)
var ret = info.get(name) match {
- case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil)
- case Some(ConnectionHolder(conn, cnt, post)) => ConnectionHolder(conn, cnt + 1, post)
+ case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil, false)
+ case Some(ConnectionHolder(conn, cnt, post, rb)) => ConnectionHolder(conn, cnt + 1, post, rb)
}
info(name) = ret
logger.trace("Acquired " + name + " on thread " + Thread.currentThread +
@@ -302,22 +302,24 @@ trait DB extends Loggable {
}
private def releaseConnectionNamed(name: ConnectionIdentifier, rollback: Boolean) {
- logger.trace("Request to release " + name + " on thread " + Thread.currentThread)
+ logger.trace("Request to release %s on thread %s, auto rollback=%s".format(name,Thread.currentThread, rollback))
+
(info.get(name): @unchecked) match {
- case Some(ConnectionHolder(c, 1, post)) => {
- if (! c.getAutoCommit()) {
+ case Some(ConnectionHolder(c, 1, post, manualRollback)) => {
+ if (! (c.getAutoCommit() || manualRollback)) {
if (rollback) tryo{c.rollback}
else c.commit
}
tryo(c.releaseFunc())
info -= name
- logger.trace("Invoking %d postTransaction functions ".format(post.size))
- post.reverse.foreach(f => tryo(f(!rollback)))
- logger.trace("Released " + name + " on thread " + Thread.currentThread)
+ val rolledback = rollback | manualRollback
+ logger.trace("Invoking %d postTransaction functions. rollback=%s".format(post.size, rolledback))
+ post.reverse.foreach(f => tryo(f(!rolledback)))
+ logger.trace("Released %s on thread %s".format(name,Thread.currentThread))
}
- case Some(ConnectionHolder(c, n, post)) =>
+ case Some(ConnectionHolder(c, n, post, rb)) =>
logger.trace("Did not release " + name + " on thread " + Thread.currentThread + " count " + (n - 1))
- info(name) = ConnectionHolder(c, n - 1, post)
+ info(name) = ConnectionHolder(c, n - 1, post, rb)
case x =>
// ignore
}
@@ -340,8 +342,8 @@ trait DB extends Loggable {
*/
def appendPostTransaction(name: ConnectionIdentifier, func: Boolean => Unit) {
info.get(name) match {
- case Some(ConnectionHolder(c, n, post)) =>
- info(name) = ConnectionHolder(c, n, func :: post)
+ case Some(ConnectionHolder(c, n, post, rb)) =>
+ info(name) = ConnectionHolder(c, n, func :: post, rb)
logger.trace("Appended postTransaction function on %s, new count=%d".format(name, post.size+1))
case _ => throw new IllegalStateException("Tried to append postTransaction function on illegal ConnectionIdentifer or outside transaction context")
}
@@ -567,7 +569,17 @@ trait DB extends Loggable {
use(DefaultConnectionIdentifier)(conn => exec(conn, query)(resultSetToAny))
- def rollback(name: ConnectionIdentifier) = use(name)(conn => conn.rollback)
+ def rollback(name: ConnectionIdentifier): Unit = {
+ info.get(name) match {
+ case Some(ConnectionHolder(c, n, post, _)) =>
+ info(name) = ConnectionHolder(c, n, post, true)
+ logger.trace("Manual rollback on %s".format(name))
+ use(name)(conn => conn.rollback)
+ case _ => throw new IllegalStateException("Tried to rollback transaction on illegal ConnectionIdentifer or outside transaction context")
+ }
+ }
+
+ def rollback: Unit = rollback(DefaultConnectionIdentifier)
/**
* Executes { @code statement } and converts the { @code ResultSet } to model
@@ -671,6 +683,8 @@ trait DB extends Loggable {
/**
* Executes function { @code f } with the connection named { @code name }. Releases the connection
* before returning.
+ *
+ * Only use within a stateful request
*/
def use[T](name: ConnectionIdentifier)(f: (SuperConnection) => T): T = {
val conn = getConnection(name)
View
17 persistence/db/src/test/scala/net/liftweb/db/DBSpec.scala
@@ -150,4 +150,21 @@ class DBSpec extends Specification with Mockito {
DB.appendPostTransaction {committed => ()} must throwA[IllegalStateException]
}
}
+
+ "DB.rollback" should {
+ "call postTransaction functions with false" in {
+ val m = mock[CommitFunc]
+ val activeConnection = mock[Connection]
+ DB.defineConnectionManager(DefaultConnectionIdentifier, dBVendor(activeConnection))
+
+ tryo(DB.use(DefaultConnectionIdentifier) {c =>
+ DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
+ DB.rollback(DefaultConnectionIdentifier)
+ 42
+ })
+
+ there was one(activeConnection).rollback
+ there was one(m).f(false)
+ }
+ }
}
Please sign in to comment.
Something went wrong with that request. Please try again.