Skip to content
This repository

Correctly propagate the fact that DB.rollback was called #1298

Merged
merged 1 commit into from almost 2 years ago

2 participants

Jeppe Nejsum Madsen David Pollak
Jeppe Nejsum Madsen
Owner

to any functions registered with appendPostTransaction. Closes #1295

David Pollak dpp merged commit 001d956 into from
David Pollak dpp closed this
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Showing 1 unique commit by 1 author.

Jul 20, 2012
Jeppe Nejsum Madsen jeppenejsum Correctly propagate the fact that DB.rollback was called to any funct…
…ions registered with appendPostTransaction. Closes #1295
83cb3a8
This page is out of date. Refresh to see the latest.
42 persistence/db/src/main/scala/net/liftweb/db/DB.scala
@@ -119,7 +119,7 @@ trait DB extends Loggable {
119 119 threadLocalConnectionManagers.doWith(newMap)(f)
120 120 }
121 121
122   - case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit])
  122 + case class ConnectionHolder(conn: SuperConnection, cnt: Int, postTransaction: List[Boolean => Unit], rolledBack: Boolean)
123 123
124 124 private def info: HashMap[ConnectionIdentifier, ConnectionHolder] = {
125 125 threadStore.get match {
@@ -292,8 +292,8 @@ trait DB extends Loggable {
292 292 private def getConnection(name: ConnectionIdentifier): SuperConnection = {
293 293 logger.trace("Acquiring " + name + " On thread " + Thread.currentThread)
294 294 var ret = info.get(name) match {
295   - case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil)
296   - case Some(ConnectionHolder(conn, cnt, post)) => ConnectionHolder(conn, cnt + 1, post)
  295 + case None => ConnectionHolder(newConnection(name), calcBaseCount(name) + 1, Nil, false)
  296 + case Some(ConnectionHolder(conn, cnt, post, rb)) => ConnectionHolder(conn, cnt + 1, post, rb)
297 297 }
298 298 info(name) = ret
299 299 logger.trace("Acquired " + name + " on thread " + Thread.currentThread +
@@ -302,22 +302,24 @@ trait DB extends Loggable {
302 302 }
303 303
304 304 private def releaseConnectionNamed(name: ConnectionIdentifier, rollback: Boolean) {
305   - logger.trace("Request to release " + name + " on thread " + Thread.currentThread)
  305 + logger.trace("Request to release %s on thread %s, auto rollback=%s".format(name,Thread.currentThread, rollback))
  306 +
306 307 (info.get(name): @unchecked) match {
307   - case Some(ConnectionHolder(c, 1, post)) => {
308   - if (! c.getAutoCommit()) {
  308 + case Some(ConnectionHolder(c, 1, post, manualRollback)) => {
  309 + if (! (c.getAutoCommit() || manualRollback)) {
309 310 if (rollback) tryo{c.rollback}
310 311 else c.commit
311 312 }
312 313 tryo(c.releaseFunc())
313 314 info -= name
314   - logger.trace("Invoking %d postTransaction functions ".format(post.size))
315   - post.reverse.foreach(f => tryo(f(!rollback)))
316   - logger.trace("Released " + name + " on thread " + Thread.currentThread)
  315 + val rolledback = rollback | manualRollback
  316 + logger.trace("Invoking %d postTransaction functions. rollback=%s".format(post.size, rolledback))
  317 + post.reverse.foreach(f => tryo(f(!rolledback)))
  318 + logger.trace("Released %s on thread %s".format(name,Thread.currentThread))
317 319 }
318   - case Some(ConnectionHolder(c, n, post)) =>
  320 + case Some(ConnectionHolder(c, n, post, rb)) =>
319 321 logger.trace("Did not release " + name + " on thread " + Thread.currentThread + " count " + (n - 1))
320   - info(name) = ConnectionHolder(c, n - 1, post)
  322 + info(name) = ConnectionHolder(c, n - 1, post, rb)
321 323 case x =>
322 324 // ignore
323 325 }
@@ -340,8 +342,8 @@ trait DB extends Loggable {
340 342 */
341 343 def appendPostTransaction(name: ConnectionIdentifier, func: Boolean => Unit) {
342 344 info.get(name) match {
343   - case Some(ConnectionHolder(c, n, post)) =>
344   - info(name) = ConnectionHolder(c, n, func :: post)
  345 + case Some(ConnectionHolder(c, n, post, rb)) =>
  346 + info(name) = ConnectionHolder(c, n, func :: post, rb)
345 347 logger.trace("Appended postTransaction function on %s, new count=%d".format(name, post.size+1))
346 348 case _ => throw new IllegalStateException("Tried to append postTransaction function on illegal ConnectionIdentifer or outside transaction context")
347 349 }
@@ -567,7 +569,17 @@ trait DB extends Loggable {
567 569 use(DefaultConnectionIdentifier)(conn => exec(conn, query)(resultSetToAny))
568 570
569 571
570   - def rollback(name: ConnectionIdentifier) = use(name)(conn => conn.rollback)
  572 + def rollback(name: ConnectionIdentifier): Unit = {
  573 + info.get(name) match {
  574 + case Some(ConnectionHolder(c, n, post, _)) =>
  575 + info(name) = ConnectionHolder(c, n, post, true)
  576 + logger.trace("Manual rollback on %s".format(name))
  577 + use(name)(conn => conn.rollback)
  578 + case _ => throw new IllegalStateException("Tried to rollback transaction on illegal ConnectionIdentifer or outside transaction context")
  579 + }
  580 + }
  581 +
  582 + def rollback: Unit = rollback(DefaultConnectionIdentifier)
571 583
572 584 /**
573 585 * Executes { @code statement } and converts the { @code ResultSet } to model
@@ -671,6 +683,8 @@ trait DB extends Loggable {
671 683 /**
672 684 * Executes function { @code f } with the connection named { @code name }. Releases the connection
673 685 * before returning.
  686 + *
  687 + * Only use within a stateful request
674 688 */
675 689 def use[T](name: ConnectionIdentifier)(f: (SuperConnection) => T): T = {
676 690 val conn = getConnection(name)
17 persistence/db/src/test/scala/net/liftweb/db/DBSpec.scala
@@ -150,4 +150,21 @@ class DBSpec extends Specification with Mockito {
150 150 DB.appendPostTransaction {committed => ()} must throwA[IllegalStateException]
151 151 }
152 152 }
  153 +
  154 + "DB.rollback" should {
  155 + "call postTransaction functions with false" in {
  156 + val m = mock[CommitFunc]
  157 + val activeConnection = mock[Connection]
  158 + DB.defineConnectionManager(DefaultConnectionIdentifier, dBVendor(activeConnection))
  159 +
  160 + tryo(DB.use(DefaultConnectionIdentifier) {c =>
  161 + DB.appendPostTransaction(DefaultConnectionIdentifier, m.f _)
  162 + DB.rollback(DefaultConnectionIdentifier)
  163 + 42
  164 + })
  165 +
  166 + there was one(activeConnection).rollback
  167 + there was one(m).f(false)
  168 + }
  169 + }
153 170 }

Tip: You can add notes to lines in a file. Hover to the left of a line to make a note

Something went wrong with that request. Please try again.