Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/main/scala/com/twitter/finagle/postgres/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,27 @@ class Client(factory: ServiceFactory[PgRequest, PgResponse], id:String) {
private[this] lazy val customTypes = CustomOIDProxy.serviceOIDMap(id)
private[this] val logger = Logger(getClass.getName)

/*
* Execute some actions inside of a transaction using a single connection
*/
def inTransaction[T](fn: Client => Future[T]) = for {
service <- factory()
constFactory = ServiceFactory.const(service)
transactionalClient = new Client(constFactory, Random.alphanumeric.take(28).mkString)
_ <- transactionalClient.query("BEGIN")
result <- fn(transactionalClient).rescue {
case err => for {
_ <- transactionalClient.query("ROLLBACK")
_ <- constFactory.close()
_ <- service.close()
_ <- Future.exception(err)
} yield null.asInstanceOf[T]
}
_ <- transactionalClient.query("COMMIT")
_ <- constFactory.close()
_ <- service.close()
} yield result

/*
* Issue an arbitrary SQL query and get the response.
*/
Expand Down Expand Up @@ -201,7 +222,7 @@ class Client(factory: ServiceFactory[PgRequest, PgResponse], id:String) {
}
}

private[this] def genName() = "fin-pg-" + counter.incrementAndGet
private[this] def genName() = s"fin-pg-$id-" + counter.incrementAndGet
}

/*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.twitter.finagle.postgres.integration

import com.twitter.finagle.postgres.{Client, Spec}
import com.twitter.util.{Await, Future}

class TransactionSpec extends Spec {
for {
hostPort <- sys.env.get("PG_HOST_PORT")
user <- sys.env.get("PG_USER")
password = sys.env.get("PG_PASSWORD")
dbname <- sys.env.get("PG_DBNAME")
useSsl = sys.env.getOrElse("USE_PG_SSL", "0") == "1"
} yield {

val client = Client(hostPort, user, password, dbname, useSsl)
Await.result(client.query(
"""
|DROP TABLE IF EXISTS transaction_test;
|CREATE TABLE transaction_test(id integer primary key);
""".stripMargin))

"A postgres transaction" should {

"commit if the transaction future is successful" in {
Await.result {
client.inTransaction {
c => for {
_ <- c.prepareAndExecute("DELETE FROM transaction_test")
_ <- c.prepareAndExecute("INSERT INTO transaction_test VALUES(1)")
_ <- c.prepareAndExecute("INSERT INTO transaction_test VALUES(2)")
} yield ()
}
}
val count = Await.result(client.prepareAndQuery("SELECT COUNT(*)::int4 AS count FROM transaction_test WHERE id IN (1,2)") {
row => row.get[Int]("count")
}.map(_.head))
assert(count == 2)
}

"rollback the transaction if the transaction future fails" in {
val failed = client.inTransaction {
c => for {
_ <- c.prepareAndExecute("DELETE FROM transaction_test")
_ <- c.prepareAndExecute("INSERT INTO transaction_test VALUES(3)")
_ <- c.prepareAndExecute("INSERT INTO transaction_test VALUES(4)")
_ <- Future.exception(new Exception("Roll it back!"))
_ <- c.prepareAndExecute("INSERT INTO transaction_test VALUES(5)")
} yield ()
}.liftToTry

val failedResult = Await.result(failed)

val inTable = Await.result(client.prepareAndQuery("SELECT * FROM transaction_test") {
row => row.get[Int]("id")
}).toList.sorted
assert(inTable == List(1, 2))
}

}
}
}