Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Merge pull request #1298 from lift/jnm_1295

Correctly propagate the fact that DB.rollback was called
  • Loading branch information...
commit 001d956d7db37e3bb3a01d3653d12ff3e8010334 2 parents 1a360c9 + 83cb3a8
@dpp dpp authored
View
42 persistence/db/src/main/scala/net/liftweb/db/DB.scala
@@ -119,7 +119,7 @@ trait DB extends Loggable {
threadLocalConnectionManagers.doWith(newMap)(f)
}
- case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit])
+ case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit], rolledBack: Boolean)
private def info: HashMap[ConnectionIdentifier, ConnectionHolder] = {
threadStore.get match {
@@ -292,8 +292,8 @@ trait DB extends Loggable {
private def getConnection(name: ConnectionIdentifier): SuperConnection = {
logger.trace("Acquiring " + name + " On thread " + Thread.currentThread)
var ret = info.get(name) match {
- case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil)
- case Some(ConnectionHolder(conn, cnt, post)) => ConnectionHolder(conn, cnt + 1, post)
+ case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil, false)
+ case Some(ConnectionHolder(conn, cnt, post, rb)) => ConnectionHolder(conn, cnt + 1, post, rb)
}
info(name) = ret
logger.trace("Acquired " + name + " on thread " + Thread.currentThread +
@@ -302,22 +302,24 @@ trait DB extends Loggable {
}
private def releaseConnectionNamed(name: ConnectionIdentifier, rollback: Boolean) {
- logger.trace("Request to release " + name + " on thread " + Thread.currentThread)
+ logger.trace("Request to release %s on thread %s, auto rollback=%s".format(name,Thread.currentThread, rollback))
+
(info.get(name): @unchecked) match {
- case Some(ConnectionHolder(c, 1, post)) => {
- if (! c.getAutoCommit()) {
+ case Some(ConnectionHolder(c, 1, post, manualRollback)) => {
+ if (! (c.getAutoCommit() || manualRollback)) {
if (rollback) tryo{c.rollback}
else c.commit
}
tryo(c.releaseFunc())
info -= name
- logger.trace("Invoking %d postTransaction functions ".format(post.size))
- post.reverse.foreach(f => tryo(f(!rollback)))
- logger.trace("Released " + name + " on thread " + Thread.currentThread)
+ val rolledback = rollback | manualRollback
+ logger.trace("Invoking %d postTransaction functions. rollback=%s".format(post.size, rolledback))
+ post.reverse.foreach(f => tryo(f(!rolledback)))
+ logger.trace("Released %s on thread %s".format(name,Thread.currentThread))
}
- case Some(ConnectionHolder(c, n, post)) =>
+ case Some(ConnectionHolder(c, n, post, rb)) =>
logger.trace("Did not release " + name + " on thread " + Thread.currentThread + " count " + (n - 1))
- info(name) = ConnectionHolder(c, n - 1, post)
+ info(name) = ConnectionHolder(c, n - 1, post, rb)
case x =>
// ignore
}
@@ -340,8 +342,8 @@ trait DB extends Loggable {
*/
def appendPostTransaction(name: ConnectionIdentifier, func: Boolean => Unit) {
info.get(name) match {
- case Some(ConnectionHolder(c, n, post)) =>
- info(name) = ConnectionHolder(c, n, func :: post)
+ case Some(ConnectionHolder(c, n, post, rb)) =>
+ info(name) = ConnectionHolder(c, n, func :: post, rb)
logger.trace("Appended postTransaction function on %s, new count=%d".format(name, post.size+1))
case _ => throw new IllegalStateException("Tried to append postTransaction function on illegal ConnectionIdentifer or outside transaction context")
}
@@ -567,7 +569,17 @@ trait DB extends Loggable {
use(DefaultConnectionIdentifier)(conn => exec(conn, query)(resultSetToAny))
- def rollback(name: ConnectionIdentifier) = use(name)(conn => conn.rollback)
+ def rollback(name: ConnectionIdentifier): Unit = {
+ info.get(name) match {
+ case Some(ConnectionHolder(c, n, post, _)) =>
+ info(name) = ConnectionHolder(c, n, post, true)
+ logger.trace("Manual rollback on %s".format(name))
+ use(name)(conn => conn.rollback)
+ case _ => throw new IllegalStateException("Tried to rollback transaction on illegal ConnectionIdentifer or outside transaction context")
+ }
+ }
+
+ def rollback: Unit = rollback(DefaultConnectionIdentifier)
/**
* Executes { @code statement } and converts the { @code ResultSet } to model
@@ -671,6 +683,8 @@ trait DB extends Loggable {
/**
* Executes function { @code f } with the connection named { @code name }. Releases the connection
* before returning.
+ *
+ * Only use within a stateful request
*/
def use[T](name: ConnectionIdentifier)(f: (SuperConnection) => T): T = {
val conn = getConnection(name)
View
17 persistence/db/src/test/scala/net/liftweb/db/DBSpec.scala
@@ -150,4 +150,21 @@ class DBSpec extends Specification with Mockito {
DB.appendPostTransaction {committed => ()} must throwA[IllegalStateException]
}
}
+
+ "DB.rollback" should {
+ "call postTransaction functions with false" in {
+ val m = mock[CommitFunc]
+ val activeConnection = mock[Connection]
+ DB.defineConnectionManager(DefaultConnectionIdentifier, dBVendor(activeConnection))
+
+ tryo(DB.use(DefaultConnectionIdentifier) {c =>
+ DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
+ DB.rollback(DefaultConnectionIdentifier)
+ 42
+ })
+
+ there was one(activeConnection).rollback
+ there was one(m).f(false)
+ }
+ }
}
Please sign in to comment.
Something went wrong with that request. Please try again.