Skip to content
This repository
Browse code

Merge pull request #1298 from lift/jnm_1295

Correctly propagate the fact that DB.rollback was called
  • Loading branch information...
commit 001d956d7db37e3bb3a01d3653d12ff3e8010334 2 parents 1a360c9 + 83cb3a8
David Pollak authored July 20, 2012
42  persistence/db/src/main/scala/net/liftweb/db/DB.scala
@@ -119,7 +119,7 @@ trait DB extends Loggable {
119 119
     threadLocalConnectionManagers.doWith(newMap)(f)
120 120
   }
121 121
 
122  
-  case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit])
  122
+  case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit], rolledBack: Boolean)
123 123
 
124 124
   private def info: HashMap[ConnectionIdentifier, ConnectionHolder] = {
125 125
     threadStore.get match {
@@ -292,8 +292,8 @@ trait DB extends Loggable {
292 292
   private def getConnection(name: ConnectionIdentifier): SuperConnection = {
293 293
     logger.trace("Acquiring " + name + " On thread " + Thread.currentThread)
294 294
     var ret = info.get(name) match {
295  
-      case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil)
296  
-      case Some(ConnectionHolder(conn, cnt, post)) => ConnectionHolder(conn, cnt + 1, post)
  295
+      case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil, false)
  296
+      case Some(ConnectionHolder(conn, cnt, post, rb)) => ConnectionHolder(conn, cnt + 1, post, rb)
297 297
     }
298 298
     info(name) = ret
299 299
     logger.trace("Acquired " + name + " on thread " + Thread.currentThread +
@@ -302,22 +302,24 @@ trait DB extends Loggable {
302 302
   }
303 303
 
304 304
   private def releaseConnectionNamed(name: ConnectionIdentifier, rollback: Boolean) {
305  
-    logger.trace("Request to release " + name + " on thread " + Thread.currentThread)
  305
+    logger.trace("Request to release %s on thread %s, auto rollback=%s".format(name,Thread.currentThread, rollback))
  306
+
306 307
     (info.get(name): @unchecked) match {
307  
-      case Some(ConnectionHolder(c, 1, post)) => {
308  
-        if (! c.getAutoCommit()) {
  308
+      case Some(ConnectionHolder(c, 1, post, manualRollback)) => {
  309
+        if (! (c.getAutoCommit() || manualRollback)) {
309 310
           if (rollback) tryo{c.rollback}
310 311
           else c.commit
311 312
         }
312 313
         tryo(c.releaseFunc())
313 314
         info -= name
314  
-        logger.trace("Invoking %d postTransaction functions ".format(post.size))
315  
-        post.reverse.foreach(f => tryo(f(!rollback)))
316  
-        logger.trace("Released " + name + " on thread " + Thread.currentThread)
  315
+        val rolledback = rollback | manualRollback
  316
+        logger.trace("Invoking %d postTransaction functions. rollback=%s".format(post.size, rolledback))
  317
+        post.reverse.foreach(f => tryo(f(!rolledback)))
  318
+        logger.trace("Released %s on thread %s".format(name,Thread.currentThread))
317 319
       }
318  
-      case Some(ConnectionHolder(c, n, post)) =>
  320
+      case Some(ConnectionHolder(c, n, post, rb)) =>
319 321
         logger.trace("Did not release " + name + " on thread " + Thread.currentThread + " count " + (n - 1))
320  
-        info(name) = ConnectionHolder(c, n - 1, post)
  322
+        info(name) = ConnectionHolder(c, n - 1, post, rb)
321 323
       case x =>
322 324
         // ignore
323 325
     }
@@ -340,8 +342,8 @@ trait DB extends Loggable {
340 342
    */
341 343
   def appendPostTransaction(name: ConnectionIdentifier, func: Boolean => Unit) {
342 344
     info.get(name) match {
343  
-      case Some(ConnectionHolder(c, n, post)) =>
344  
-        info(name) = ConnectionHolder(c, n, func :: post)
  345
+      case Some(ConnectionHolder(c, n, post, rb)) =>
  346
+        info(name) = ConnectionHolder(c, n, func :: post, rb)
345 347
         logger.trace("Appended postTransaction function on %s, new count=%d".format(name, post.size+1))
346 348
       case _ => throw new IllegalStateException("Tried to append postTransaction function on illegal ConnectionIdentifer or outside transaction context")
347 349
     }
@@ -567,7 +569,17 @@ trait DB extends Loggable {
567 569
     use(DefaultConnectionIdentifier)(conn => exec(conn, query)(resultSetToAny))
568 570
 
569 571
 
570  
-  def rollback(name: ConnectionIdentifier) = use(name)(conn => conn.rollback)
  572
+  def rollback(name: ConnectionIdentifier): Unit = {
  573
+    info.get(name) match {
  574
+      case Some(ConnectionHolder(c, n, post, _)) =>
  575
+        info(name) = ConnectionHolder(c, n, post, true)
  576
+        logger.trace("Manual rollback on %s".format(name))
  577
+        use(name)(conn => conn.rollback)
  578
+      case _ => throw new IllegalStateException("Tried to rollback transaction on illegal ConnectionIdentifer or outside transaction context")
  579
+    }
  580
+  }
  581
+
  582
+  def rollback: Unit = rollback(DefaultConnectionIdentifier)
571 583
 
572 584
   /**
573 585
    * Executes  { @code statement } and converts the  { @code ResultSet } to model
@@ -671,6 +683,8 @@ trait DB extends Loggable {
671 683
   /**
672 684
    * Executes function  { @code f } with the connection named  { @code name }. Releases the connection
673 685
    * before returning.
  686
+   *
  687
+   * Only use within a stateful request
674 688
    */
675 689
   def use[T](name: ConnectionIdentifier)(f: (SuperConnection) => T): T = {
676 690
     val conn = getConnection(name)
17  persistence/db/src/test/scala/net/liftweb/db/DBSpec.scala
@@ -150,4 +150,21 @@ class DBSpec extends Specification with Mockito {
150 150
       DB.appendPostTransaction {committed => ()}  must throwA[IllegalStateException]
151 151
     }
152 152
   }
  153
+
  154
+  "DB.rollback" should {
  155
+    "call postTransaction functions with false" in {
  156
+      val m = mock[CommitFunc]
  157
+      val activeConnection = mock[Connection]
  158
+      DB.defineConnectionManager(DefaultConnectionIdentifier, dBVendor(activeConnection))
  159
+
  160
+      tryo(DB.use(DefaultConnectionIdentifier) {c =>
  161
+        DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
  162
+        DB.rollback(DefaultConnectionIdentifier)
  163
+        42
  164
+      })
  165
+
  166
+      there was one(activeConnection).rollback
  167
+      there was one(m).f(false)
  168
+    }
  169
+  }
153 170
 }

0 notes on commit 001d956

Please sign in to comment.
Something went wrong with that request. Please try again.