Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Add appendPostTransaction method on DB, refs #1102

  • Loading branch information...
commit b6af3b942309de6894bc79adbe9554b5604a8a0a 1 parent 026a9f4
@jeppenejsum jeppenejsum authored
View
56 persistence/db/src/main/scala/net/liftweb/db/DB.scala
@@ -30,7 +30,7 @@ import javax.naming.{Context, InitialContext}
trait DB1
object DB1 {
- implicit def db1ToDb(in: DB1): DB = DB.theDB
+ implicit def db1ToDb(in: DB1): DB = DB.theDB
}
object DB extends DB1 {
@@ -58,16 +58,16 @@ trait DB extends Loggable {
/**
* queryCollector can be used to collect all statements executed in a single request when passed to addLogFunc
- *
+ *
* Use S.queryLog to get the list of (statement, duration) entries or set an analyzer function using
* S.addAnalyzer
*/
@volatile var queryCollector: LogFunc = {
- case (query:DBLog, time) =>
+ case (query:DBLog, time) =>
}
-
-
+
+
/**
* can we get a JDBC connection from JNDI?
*/
@@ -79,7 +79,6 @@ trait DB extends Loggable {
}
}
- // var connectionManager: Box[ConnectionManager] = Empty
private val connectionManagers = new HashMap[ConnectionIdentifier, ConnectionManager]
private val threadLocalConnectionManagers = new ThreadGlobal[Map[ConnectionIdentifier, ConnectionManager]]
@@ -97,7 +96,7 @@ trait DB extends Loggable {
threadLocalConnectionManagers.doWith(newMap)(f)
}
- case class ConnectionHolder(conn: SuperConnection, cnt: Int, postCommit: List[() => Unit])
+ case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit])
private def info: HashMap[ConnectionIdentifier, ConnectionHolder] = {
threadStore.get match {
@@ -122,9 +121,10 @@ trait DB extends Loggable {
private def postCommit_=(lst: List[() => Unit]): Unit = _postCommitFuncs.set(lst)
/**
- * perform this function post-commit. THis is helpful for sending messages to Actors after we know
+ * perform this function after transaction has ended. THis is helpful for sending messages to Actors after we know
* a transaction has committed
*/
+ @deprecated("Use appendPostTransaction")
def performPostCommit(f: => Unit) {
postCommit = (() => f) :: postCommit
}
@@ -219,7 +219,7 @@ trait DB extends Loggable {
throw e
}
}
-
+
} finally {
clearThread(success)
}
@@ -257,19 +257,19 @@ trait DB extends Loggable {
CurrentConnectionSet.is.map(_.use(conn)) openOr 0
private def getConnection(name: ConnectionIdentifier): SuperConnection = {
- logger.trace("Acquiring connection " + name + " On thread " + Thread.currentThread)
+ logger.trace("Acquiring " + name + " On thread " + Thread.currentThread)
var ret = info.get(name) match {
case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil)
case Some(ConnectionHolder(conn, cnt, post)) => ConnectionHolder(conn, cnt + 1, post)
}
info(name) = ret
- logger.trace("Acquired connection " + name + " on thread " + Thread.currentThread +
+ logger.trace("Acquired " + name + " on thread " + Thread.currentThread +
" count " + ret.cnt)
ret.conn
}
private def releaseConnectionNamed(name: ConnectionIdentifier, rollback: Boolean) {
- logger.trace("Request to release connection: " + name + " on thread " + Thread.currentThread)
+ logger.trace("Request to release " + name + " on thread " + Thread.currentThread)
(info.get(name): @unchecked) match {
case Some(ConnectionHolder(c, 1, post)) => {
if (! c.getAutoCommit()) {
@@ -278,25 +278,39 @@ trait DB extends Loggable {
}
tryo(c.releaseFunc())
info -= name
- post.reverse.foreach(f => tryo(f()))
- logger.trace("Released connection " + name + " on thread " + Thread.currentThread)
+ logger.trace("Invoking %d postTransaction functions ".format(post.size))
+ post.reverse.foreach(f => tryo(f(!rollback)))
+ logger.trace("Released " + name + " on thread " + Thread.currentThread)
}
case Some(ConnectionHolder(c, n, post)) =>
- logger.trace("Did not release connection: " + name + " on thread " + Thread.currentThread + " count " + (n - 1))
+ logger.trace("Did not release " + name + " on thread " + Thread.currentThread + " count " + (n - 1))
info(name) = ConnectionHolder(c, n - 1, post)
-
- case _ =>
+ case x =>
// ignore
}
}
/**
- * Append a function to be invoked after the commit has taken place for the given connection identifier
+ * Append a function to be invoked after the transaction has ended for the given connection identifier
*/
+ @deprecated("Use appendPostTransaction")
def appendPostFunc(name: ConnectionIdentifier, func: () => Unit) {
+ appendPostTransaction(name, dontUse => func())
+ }
+
+ /**
+ * Append a function to be invoked after the transaction on the specified connection identifier has ended.
+ * The value passed to the function indicates true for success/commit or false for failure/rollback.
+ *
+ * Note: the function will only be called when automatic transaction management is in effect, either by executing within
+ * the context of a buildLoanWrapper or a DB.use {}
+ */
+ def appendPostTransaction(name: ConnectionIdentifier, func: Boolean => Unit) {
info.get(name) match {
- case Some(ConnectionHolder(c, n, post)) => info(name) = ConnectionHolder(c, n, func :: post)
- case _ =>
+ case Some(ConnectionHolder(c, n, post)) =>
+ info(name) = ConnectionHolder(c, n, func :: post)
+ logger.trace("Appended postTransaction function on %s, new count=%d".format(name, post.size+1))
+ case _ => throw new IllegalStateException("Tried to append postTransaction function on illegal ConnectionIdentifer or outside transaction context")
}
}
@@ -1175,7 +1189,7 @@ trait ProtoDBVendor extends ConnectionManager {
else {
pool.foreach {c => tryo(c.close); poolSize -= 1}
pool = Nil
-
+
if (poolSize > 0) wait(250)
_closeAllConnections_!(cnt + 1)
View
132 persistence/db/src/test/scala/net/liftweb/db/DBSpec.scala
@@ -0,0 +1,132 @@
+package net.liftweb
+package db
+
+import org.specs.Specification
+import org.specs.runner._
+import org.specs.mock.Mockito
+import org.mockito.Matchers._
+
+import net.liftweb.common._
+import net.liftweb.db._
+import net.liftweb.util.ControlHelpers._
+
+import java.sql._
+
+class DBSpec extends Specification with Mockito {
+ trait CommitFunc {
+ def f(success: Boolean): Unit
+ }
+
+ var activeConnection: Connection = _
+
+ def dBVendor = new ProtoDBVendor {
+ def createOne = {
+ val connection = mock[Connection]
+ connection.createStatement returns mock[PreparedStatement]
+ activeConnection = connection
+ Full(connection)
+ }
+ }
+
+ DB.defineConnectionManager(DefaultConnectionIdentifier, dBVendor)
+
+ "eager buildLoanWrapper" should {
+ "call postTransaction functions with true if transaction is committed" in {
+ val m = mock[CommitFunc]
+
+ DB.buildLoanWrapper(true) {
+ DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
+ DB.currentConnection.map{c => DB.exec(c, "stuff") {dummy => }}
+ }
+ there was one(activeConnection).commit
+ there was one(m).f(true)
+ }
+
+ "call postCommit functions with false if transaction is rolledback" in {
+ val m = mock[CommitFunc]
+
+ val lw = DB.buildLoanWrapper(true)
+
+ tryo(lw.apply {
+ DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
+ DB.currentConnection.map{c => DB.exec(c, "stuff") {dummy => }}
+ throw new RuntimeException("oh no")
+ 42
+ })
+ there was one(activeConnection).rollback
+ there was one(m).f(false)
+ }
+ }
+
+ "lazy buildLoanWrapper" should {
+ "call postTransaction functions with true if transaction is committed" in {
+ val m = mock[CommitFunc]
+
+ DB.buildLoanWrapper(false) {
+ DB.use(DefaultConnectionIdentifier) {c =>
+ DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
+ DB.exec(c, "stuff") {
+ dummy =>
+ }
+ }
+ DB.use(DefaultConnectionIdentifier) {c =>
+ DB.exec(c, "more stuff") { dummy => }
+ }
+ }
+ there was one(activeConnection).commit
+ there was one(m).f(true)
+ }
+
+ "call postCommit functions with false if transaction is rolledback" in {
+ val m = mock[CommitFunc]
+
+ val lw = DB.buildLoanWrapper(false)
+
+ tryo(lw.apply {
+ DB.use(DefaultConnectionIdentifier) {c =>
+ DB.exec(c, "more stuff") { dummy => }
+ }
+ DB.use(DefaultConnectionIdentifier) {c =>
+ DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
+ DB.exec(c, "stuff") {dummy => throw new RuntimeException("oh no")}
+ }
+ 42
+ })
+ there was one(activeConnection).rollback
+ there was one(m).f(false)
+ }
+ }
+
+ "DB.use" should {
+ "call postTransaction functions with true if transaction is committed" in {
+ val m = mock[CommitFunc]
+
+ DB.use(DefaultConnectionIdentifier) {c =>
+ DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
+ DB.exec(c, "stuff") {dummy => }
+ }
+
+ there was one(activeConnection).commit
+ there was one(m).f(true)
+ }
+
+ "call postTransaction functions with false if transaction is committed" in {
+ val m = mock[CommitFunc]
+
+ tryo(DB.use(DefaultConnectionIdentifier) {c =>
+ DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
+ DB.exec(c, "stuff") {dummy => throw new RuntimeException("Oh no")}
+ 42
+ })
+
+ there was one(activeConnection).rollback
+ there was one(m).f(false)
+ }
+ }
+
+ "appendPostTransaction" should {
+ "throw if called outside tx context" in {
+ DB.appendPostTransaction(DefaultConnectionIdentifier, d => ()) must throwA[IllegalStateException]
+ }
+ }
+}
View
2  project/build/LiftFrameworkProject.scala
@@ -46,7 +46,7 @@ class LiftFrameworkProject(info: ProjectInfo) extends ParentProject(info) with L
// Persistence projects
// --------------------
- lazy val db = persistenceProject("db")(util)
+ lazy val db = persistenceProject("db",logback,TestScope.mockito_all)(util)
lazy val proto = persistenceProject("proto")(webkit)
lazy val jpa = persistenceProject("jpa", scalajpa, persistence_api)(webkit)
lazy val mapper = persistenceProject("mapper", RuntimeScope.h2database, RuntimeScope.derby)(db, proto)
Please sign in to comment.
Something went wrong with that request. Please try again.