Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge branch 'only_timeout_work' into new_version

  • Loading branch information...
commit 51526e72ecdd0c2d8c86f1aa46fe9298680e8743 2 parents 15241db + 48e3a97
Matt Freels freels authored
42 src/main/scala/com/twitter/querulous/ConnectionDestroying.scala
View
@@ -0,0 +1,42 @@
+package com.twitter.querulous.query
+
+import java.sql.Connection
+import org.apache.commons.dbcp.{DelegatingConnection => DBCPConnection}
+import com.mysql.jdbc.{ConnectionImpl => MySQLConnection}
+
+
+// Emergency connection destruction toolkit
+
+trait ConnectionDestroying {
+ def destroyConnection(conn: Connection) {
+ if ( !conn.isClosed )
+ conn match {
+ case c: DBCPConnection =>
+ destroyDbcpWrappedConnection(c)
+ case c: MySQLConnection =>
+ destroyMysqlConnection(c)
+ case _ => error("Unsupported driver type, cannot reliably timeout.")
+ }
+ }
+
+ def destroyDbcpWrappedConnection(conn: DBCPConnection) {
+ val inner = conn.getInnermostDelegate
+
+ if ( inner != null ) {
+ destroyConnection(inner)
+ } else {
+ // this should never happen if we use our own ApachePoolingDatabase to get connections.
+ error("Could not get access to the delegate connection. Make sure the dbcp connection pool allows access to underlying connections.")
+ }
+
+ // "close" the wrapper so that it updates its internal bookkeeping, just do it
+ try { conn.close } catch { case _ => }
+ }
+
+ def destroyMysqlConnection(conn: MySQLConnection) {
+ val abort = Class.forName("com.mysql.jdbc.ConnectionImpl").getDeclaredMethod("abortInternal")
+ abort.setAccessible(true)
+
+ abort.invoke(conn)
+ }
+}
12 src/main/scala/com/twitter/querulous/Timeout.scala
View
@@ -7,12 +7,12 @@ import com.twitter.xrayspecs.Duration
class TimeoutException extends Exception
object Timeout {
- val timer = new Timer("Timer thread", true)
+ val defaultTimer = new Timer("Timer thread", true)
- def apply[T](timeout: Duration)(f: => T)(onTimeout: => Unit): T = {
+ def apply[T](timer: Timer, timeout: Duration)(f: => T)(onTimeout: => Unit): T = {
@volatile var cancelled = false
val task = if (timeout.inMillis > 0)
- Some(schedule(timeout, { cancelled = true; onTimeout }))
+ Some(schedule(timer, timeout, { cancelled = true; onTimeout }))
else None
try {
@@ -26,7 +26,11 @@ object Timeout {
}
}
- private def schedule(timeout: Duration, f: => Unit) = {
+ def apply[T](timeout: Duration)(f: => T)(onTimeout: => Unit): T = {
+ apply(defaultTimer, timeout)(f)(onTimeout)
+ }
+
+ private def schedule(timer: Timer, timeout: Duration, f: => Unit) = {
val task = new TimerTask() {
override def run() { f }
}
1  src/main/scala/com/twitter/querulous/database/ApachePoolingDatabase.scala
View
@@ -66,6 +66,7 @@ class ApachePoolingDatabase(
false,
true)
private val poolingDataSource = new PoolingDataSource(connectionPool)
+ poolingDataSource.setAccessToUnderlyingConnectionAllowed(true)
def close(connection: Connection) {
try {
44 src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala
View
@@ -2,6 +2,7 @@ package com.twitter.querulous.query
import java.sql.{SQLException, Connection}
import com.twitter.xrayspecs.Duration
+import com.twitter.xrayspecs.TimeConversions._
class SqlQueryTimeoutException(val timeout: Duration) extends SQLException("Query timeout: " + timeout.inMillis + " msec")
@@ -9,13 +10,15 @@ class SqlQueryTimeoutException(val timeout: Duration) extends SQLException("Quer
/**
* A {@code QueryFactory} that creates {@link Query}s that execute subject to a {@code timeout}. An
* attempt to {@link Query#cancel} the query is made if the timeout expires.
- *
+ *
* <p>Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working
* and executing promptly for the JDBC driver in use.
*/
-class TimingOutQueryFactory(queryFactory: QueryFactory, timeout: Duration) extends QueryFactory {
+class TimingOutQueryFactory(queryFactory: QueryFactory, timeout: Duration, cancelTimeout: Duration) extends QueryFactory {
+ def this(queryFactory: QueryFactory, timeout: Duration) = this(queryFactory, timeout, 0.millis)
+
def apply(connection: Connection, query: String, params: Any*) = {
- new TimingOutQuery(queryFactory(connection, query, params: _*), timeout)
+ new TimingOutQuery(queryFactory(connection, query, params: _*), connection, timeout, cancelTimeout)
}
}
@@ -23,18 +26,24 @@ class TimingOutQueryFactory(queryFactory: QueryFactory, timeout: Duration) exten
* A {@code QueryFactory} that creates {@link Query}s that execute subject to the {@code timeouts}
* specified for individual queries. An attempt to {@link Query#cancel} a query is made if the
* timeout expires.
- *
+ *
* <p>Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working
* and executing promptly for the JDBC driver in use.
*/
-class PerQueryTimingOutQueryFactory(queryFactory: QueryFactory, timeouts: Map[String, Duration])
+class PerQueryTimingOutQueryFactory(queryFactory: QueryFactory, timeouts: Map[String, Duration], cancelTimeout: Duration)
extends QueryFactory {
+ def this(queryFactory: QueryFactory, timeouts: Map[String, Duration]) = this(queryFactory, timeouts, 0.millis)
+
def apply(connection: Connection, query: String, params: Any*) = {
- new TimingOutQuery(queryFactory(connection, query, params: _*), timeouts(query))
+ new TimingOutQuery(queryFactory(connection, query, params: _*), connection, timeouts(query), cancelTimeout)
}
}
+private object QueryCancellation {
+ val cancelTimer = new java.util.Timer("global query cancellation timer", true)
+}
+
/**
* A {@code Query} that executes subject to the {@code timeout} specified. An attempt to
* {@link #cancel} the query is made if the timeout expires before the query completes.
@@ -42,7 +51,11 @@ class PerQueryTimingOutQueryFactory(queryFactory: QueryFactory, timeouts: Map[St
* <p>Note that the query timing out promptly is based upon {@link java.sql.Statement#cancel}
* working and executing promptly for the JDBC driver in use.
*/
-class TimingOutQuery(query: Query, timeout: Duration) extends QueryProxy(query) {
+class TimingOutQuery(query: Query, connection: Connection, timeout: Duration, cancelTimeout: Duration)
+ extends QueryProxy(query) with ConnectionDestroying {
+
+ import QueryCancellation._
+
override def delegate[A](f: => A) = {
try {
Timeout(timeout) {
@@ -55,4 +68,21 @@ class TimingOutQuery(query: Query, timeout: Duration) extends QueryProxy(query)
throw new SqlQueryTimeoutException(timeout)
}
}
+
+ override def cancel() {
+ val cancelThread = new Thread("query cancellation") {
+ override def run() {
+ try {
+ Timeout(cancelTimer, cancelTimeout) {
+ // start by trying the nice way
+ query.cancel()
+ } {
+ // if the cancel times out, destroy the underlying connection
+ destroyConnection(connection)
+ }
+ } catch { case e: TimeoutException => }
+ }
+ }
+ cancelThread.start()
+ }
}
21 src/main/scala/com/twitter/querulous/query/TimingOutStatsCollectingQuery.scala
View
@@ -5,6 +5,7 @@ import scala.collection.Map
import scala.util.matching.Regex
import scala.collection.Map
import com.twitter.xrayspecs.Duration
+import com.twitter.xrayspecs.TimeConversions._
import net.lag.extensions._
@@ -22,13 +23,25 @@ object TimingOutStatsCollectingQueryFactory {
}
class TimingOutStatsCollectingQueryFactory(queryFactory: QueryFactory,
- queryInfo: Map[String, (String, Duration)],
- defaultTimeout: Duration, stats: StatsCollector)
- extends QueryFactory {
+ queryInfo: Map[String, (String, Duration)],
+ defaultTimeout: Duration, cancelTimeout: Duration, stats: StatsCollector)
+ extends QueryFactory {
+
+ def this(queryFactory: QueryFactory, queryInfo: Map[String, (String, Duration)], defaultTimeout: Duration, stats: StatsCollector) =
+ this(queryFactory, queryInfo, defaultTimeout, 0.millis, stats)
+
def apply(connection: Connection, query: String, params: Any*) = {
val simplifiedQueryString = TimingOutStatsCollectingQueryFactory.simplifiedQuery(query)
val (name, timeout) = queryInfo.getOrElse(simplifiedQueryString, ("default", defaultTimeout))
- new TimingOutStatsCollectingQuery(new TimingOutQuery(queryFactory(connection, query, params: _*), timeout), name, stats)
+
+ new TimingOutStatsCollectingQuery(
+ new TimingOutQuery(
+ queryFactory(connection, query, params: _*),
+ connection,
+ timeout,
+ cancelTimeout),
+ name,
+ stats)
}
}
8 src/test/scala/com/twitter/querulous/unit/TimingOutQuerySpec.scala
View
@@ -1,6 +1,7 @@
package com.twitter.querulous.unit
import java.sql.ResultSet
+import net.lag.configgy.Configgy
import org.specs.Specification
import org.specs.mock.{JMocker, ClassMocker}
import com.twitter.querulous.test.FakeQuery
@@ -12,7 +13,10 @@ import java.util.concurrent.{CountDownLatch, TimeUnit}
class TimingOutQuerySpec extends Specification with JMocker with ClassMocker {
"TimingOutQuery" should {
+ val config = Configgy.config.configMap("db")
+ val connection = TestEvaluator.testDatabaseFactory(List("localhost"), config("username"), config("password")).open()
val timeout = 1.second
+ val cancelTimeout = 0.millis
val resultSet = mock[ResultSet]
"timeout" in {
@@ -25,7 +29,7 @@ class TimingOutQuerySpec extends Specification with JMocker with ClassMocker {
super.select(f)
}
}
- val timingOutQuery = new TimingOutQuery(query, timeout)
+ val timingOutQuery = new TimingOutQuery(query, connection, timeout, cancelTimeout)
timingOutQuery.select { r => 1 } must throwA[SqlQueryTimeoutException]
latch.getCount mustEqual 0
@@ -36,7 +40,7 @@ class TimingOutQuerySpec extends Specification with JMocker with ClassMocker {
val query = new FakeQuery(List(resultSet)) {
override def cancel() = { latch.countDown() }
}
- val timingOutQuery = new TimingOutQuery(query, timeout)
+ val timingOutQuery = new TimingOutQuery(query, connection, timeout, cancelTimeout)
timingOutQuery.select { r => 1 }
latch.getCount mustEqual 1
Please sign in to comment.
Something went wrong with that request. Please try again.