Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Merge branch 'only_timeout_work' into new_version
  • Loading branch information
freels committed Sep 8, 2010
2 parents 15241db + 48e3a97 commit 51526e7
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 17 deletions.
42 changes: 42 additions & 0 deletions src/main/scala/com/twitter/querulous/ConnectionDestroying.scala
@@ -0,0 +1,42 @@
package com.twitter.querulous.query

import java.sql.Connection
import org.apache.commons.dbcp.{DelegatingConnection => DBCPConnection}
import com.mysql.jdbc.{ConnectionImpl => MySQLConnection}


// Emergency connection destruction toolkit

trait ConnectionDestroying {
def destroyConnection(conn: Connection) {
if ( !conn.isClosed )
conn match {
case c: DBCPConnection =>
destroyDbcpWrappedConnection(c)
case c: MySQLConnection =>
destroyMysqlConnection(c)
case _ => error("Unsupported driver type, cannot reliably timeout.")
}
}

def destroyDbcpWrappedConnection(conn: DBCPConnection) {
val inner = conn.getInnermostDelegate

if ( inner != null ) {
destroyConnection(inner)
} else {
// this should never happen if we use our own ApachePoolingDatabase to get connections.
error("Could not get access to the delegate connection. Make sure the dbcp connection pool allows access to underlying connections.")
}

// "close" the wrapper so that it updates its internal bookkeeping, just do it
try { conn.close } catch { case _ => }
}

def destroyMysqlConnection(conn: MySQLConnection) {
val abort = Class.forName("com.mysql.jdbc.ConnectionImpl").getDeclaredMethod("abortInternal")
abort.setAccessible(true)

abort.invoke(conn)
}
}
12 changes: 8 additions & 4 deletions src/main/scala/com/twitter/querulous/Timeout.scala
Expand Up @@ -7,12 +7,12 @@ import com.twitter.xrayspecs.Duration
class TimeoutException extends Exception class TimeoutException extends Exception


object Timeout { object Timeout {
val timer = new Timer("Timer thread", true) val defaultTimer = new Timer("Timer thread", true)


def apply[T](timeout: Duration)(f: => T)(onTimeout: => Unit): T = { def apply[T](timer: Timer, timeout: Duration)(f: => T)(onTimeout: => Unit): T = {
@volatile var cancelled = false @volatile var cancelled = false
val task = if (timeout.inMillis > 0) val task = if (timeout.inMillis > 0)
Some(schedule(timeout, { cancelled = true; onTimeout })) Some(schedule(timer, timeout, { cancelled = true; onTimeout }))
else None else None


try { try {
Expand All @@ -26,7 +26,11 @@ object Timeout {
} }
} }


private def schedule(timeout: Duration, f: => Unit) = { def apply[T](timeout: Duration)(f: => T)(onTimeout: => Unit): T = {
apply(defaultTimer, timeout)(f)(onTimeout)
}

private def schedule(timer: Timer, timeout: Duration, f: => Unit) = {
val task = new TimerTask() { val task = new TimerTask() {
override def run() { f } override def run() { f }
} }
Expand Down
Expand Up @@ -66,6 +66,7 @@ class ApachePoolingDatabase(
false, false,
true) true)
private val poolingDataSource = new PoolingDataSource(connectionPool) private val poolingDataSource = new PoolingDataSource(connectionPool)
poolingDataSource.setAccessToUnderlyingConnectionAllowed(true)


def close(connection: Connection) { def close(connection: Connection) {
try { try {
Expand Down
44 changes: 37 additions & 7 deletions src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala
Expand Up @@ -2,47 +2,60 @@ package com.twitter.querulous.query


import java.sql.{SQLException, Connection} import java.sql.{SQLException, Connection}
import com.twitter.xrayspecs.Duration import com.twitter.xrayspecs.Duration
import com.twitter.xrayspecs.TimeConversions._




class SqlQueryTimeoutException(val timeout: Duration) extends SQLException("Query timeout: " + timeout.inMillis + " msec") class SqlQueryTimeoutException(val timeout: Duration) extends SQLException("Query timeout: " + timeout.inMillis + " msec")


/** /**
* A {@code QueryFactory} that creates {@link Query}s that execute subject to a {@code timeout}. An * A {@code QueryFactory} that creates {@link Query}s that execute subject to a {@code timeout}. An
* attempt to {@link Query#cancel} the query is made if the timeout expires. * attempt to {@link Query#cancel} the query is made if the timeout expires.
* *
* <p>Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working * <p>Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working
* and executing promptly for the JDBC driver in use. * and executing promptly for the JDBC driver in use.
*/ */
class TimingOutQueryFactory(queryFactory: QueryFactory, timeout: Duration) extends QueryFactory { class TimingOutQueryFactory(queryFactory: QueryFactory, timeout: Duration, cancelTimeout: Duration) extends QueryFactory {
def this(queryFactory: QueryFactory, timeout: Duration) = this(queryFactory, timeout, 0.millis)

def apply(connection: Connection, query: String, params: Any*) = { def apply(connection: Connection, query: String, params: Any*) = {
new TimingOutQuery(queryFactory(connection, query, params: _*), timeout) new TimingOutQuery(queryFactory(connection, query, params: _*), connection, timeout, cancelTimeout)
} }
} }


/** /**
* A {@code QueryFactory} that creates {@link Query}s that execute subject to the {@code timeouts} * A {@code QueryFactory} that creates {@link Query}s that execute subject to the {@code timeouts}
* specified for individual queries. An attempt to {@link Query#cancel} a query is made if the * specified for individual queries. An attempt to {@link Query#cancel} a query is made if the
* timeout expires. * timeout expires.
* *
* <p>Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working * <p>Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working
* and executing promptly for the JDBC driver in use. * and executing promptly for the JDBC driver in use.
*/ */
class PerQueryTimingOutQueryFactory(queryFactory: QueryFactory, timeouts: Map[String, Duration]) class PerQueryTimingOutQueryFactory(queryFactory: QueryFactory, timeouts: Map[String, Duration], cancelTimeout: Duration)
extends QueryFactory { extends QueryFactory {


def this(queryFactory: QueryFactory, timeouts: Map[String, Duration]) = this(queryFactory, timeouts, 0.millis)

def apply(connection: Connection, query: String, params: Any*) = { def apply(connection: Connection, query: String, params: Any*) = {
new TimingOutQuery(queryFactory(connection, query, params: _*), timeouts(query)) new TimingOutQuery(queryFactory(connection, query, params: _*), connection, timeouts(query), cancelTimeout)
} }
} }


private object QueryCancellation {
val cancelTimer = new java.util.Timer("global query cancellation timer", true)
}

/** /**
* A {@code Query} that executes subject to the {@code timeout} specified. An attempt to * A {@code Query} that executes subject to the {@code timeout} specified. An attempt to
* {@link #cancel} the query is made if the timeout expires before the query completes. * {@link #cancel} the query is made if the timeout expires before the query completes.
* *
* <p>Note that the query timing out promptly is based upon {@link java.sql.Statement#cancel} * <p>Note that the query timing out promptly is based upon {@link java.sql.Statement#cancel}
* working and executing promptly for the JDBC driver in use. * working and executing promptly for the JDBC driver in use.
*/ */
class TimingOutQuery(query: Query, timeout: Duration) extends QueryProxy(query) { class TimingOutQuery(query: Query, connection: Connection, timeout: Duration, cancelTimeout: Duration)
extends QueryProxy(query) with ConnectionDestroying {

import QueryCancellation._

override def delegate[A](f: => A) = { override def delegate[A](f: => A) = {
try { try {
Timeout(timeout) { Timeout(timeout) {
Expand All @@ -55,4 +68,21 @@ class TimingOutQuery(query: Query, timeout: Duration) extends QueryProxy(query)
throw new SqlQueryTimeoutException(timeout) throw new SqlQueryTimeoutException(timeout)
} }
} }

override def cancel() {
val cancelThread = new Thread("query cancellation") {
override def run() {
try {
Timeout(cancelTimer, cancelTimeout) {
// start by trying the nice way
query.cancel()
} {
// if the cancel times out, destroy the underlying connection
destroyConnection(connection)
}
} catch { case e: TimeoutException => }
}
}
cancelThread.start()
}
} }
Expand Up @@ -5,6 +5,7 @@ import scala.collection.Map
import scala.util.matching.Regex import scala.util.matching.Regex
import scala.collection.Map import scala.collection.Map
import com.twitter.xrayspecs.Duration import com.twitter.xrayspecs.Duration
import com.twitter.xrayspecs.TimeConversions._
import net.lag.extensions._ import net.lag.extensions._




Expand All @@ -22,13 +23,25 @@ object TimingOutStatsCollectingQueryFactory {
} }


class TimingOutStatsCollectingQueryFactory(queryFactory: QueryFactory, class TimingOutStatsCollectingQueryFactory(queryFactory: QueryFactory,
queryInfo: Map[String, (String, Duration)], queryInfo: Map[String, (String, Duration)],
defaultTimeout: Duration, stats: StatsCollector) defaultTimeout: Duration, cancelTimeout: Duration, stats: StatsCollector)
extends QueryFactory { extends QueryFactory {

def this(queryFactory: QueryFactory, queryInfo: Map[String, (String, Duration)], defaultTimeout: Duration, stats: StatsCollector) =
this(queryFactory, queryInfo, defaultTimeout, 0.millis, stats)

def apply(connection: Connection, query: String, params: Any*) = { def apply(connection: Connection, query: String, params: Any*) = {
val simplifiedQueryString = TimingOutStatsCollectingQueryFactory.simplifiedQuery(query) val simplifiedQueryString = TimingOutStatsCollectingQueryFactory.simplifiedQuery(query)
val (name, timeout) = queryInfo.getOrElse(simplifiedQueryString, ("default", defaultTimeout)) val (name, timeout) = queryInfo.getOrElse(simplifiedQueryString, ("default", defaultTimeout))
new TimingOutStatsCollectingQuery(new TimingOutQuery(queryFactory(connection, query, params: _*), timeout), name, stats)
new TimingOutStatsCollectingQuery(
new TimingOutQuery(
queryFactory(connection, query, params: _*),
connection,
timeout,
cancelTimeout),
name,
stats)
} }
} }


Expand Down
@@ -1,6 +1,7 @@
package com.twitter.querulous.unit package com.twitter.querulous.unit


import java.sql.ResultSet import java.sql.ResultSet
import net.lag.configgy.Configgy
import org.specs.Specification import org.specs.Specification
import org.specs.mock.{JMocker, ClassMocker} import org.specs.mock.{JMocker, ClassMocker}
import com.twitter.querulous.test.FakeQuery import com.twitter.querulous.test.FakeQuery
Expand All @@ -12,7 +13,10 @@ import java.util.concurrent.{CountDownLatch, TimeUnit}


class TimingOutQuerySpec extends Specification with JMocker with ClassMocker { class TimingOutQuerySpec extends Specification with JMocker with ClassMocker {
"TimingOutQuery" should { "TimingOutQuery" should {
val config = Configgy.config.configMap("db")
val connection = TestEvaluator.testDatabaseFactory(List("localhost"), config("username"), config("password")).open()
val timeout = 1.second val timeout = 1.second
val cancelTimeout = 0.millis
val resultSet = mock[ResultSet] val resultSet = mock[ResultSet]


"timeout" in { "timeout" in {
Expand All @@ -25,7 +29,7 @@ class TimingOutQuerySpec extends Specification with JMocker with ClassMocker {
super.select(f) super.select(f)
} }
} }
val timingOutQuery = new TimingOutQuery(query, timeout) val timingOutQuery = new TimingOutQuery(query, connection, timeout, cancelTimeout)


timingOutQuery.select { r => 1 } must throwA[SqlQueryTimeoutException] timingOutQuery.select { r => 1 } must throwA[SqlQueryTimeoutException]
latch.getCount mustEqual 0 latch.getCount mustEqual 0
Expand All @@ -36,7 +40,7 @@ class TimingOutQuerySpec extends Specification with JMocker with ClassMocker {
val query = new FakeQuery(List(resultSet)) { val query = new FakeQuery(List(resultSet)) {
override def cancel() = { latch.countDown() } override def cancel() = { latch.countDown() }
} }
val timingOutQuery = new TimingOutQuery(query, timeout) val timingOutQuery = new TimingOutQuery(query, connection, timeout, cancelTimeout)


timingOutQuery.select { r => 1 } timingOutQuery.select { r => 1 }
latch.getCount mustEqual 1 latch.getCount mustEqual 1
Expand Down

0 comments on commit 51526e7

Please sign in to comment.