Skip to content

Commit

Permalink
CloudSolrServers: warmup newly active servers before propagating the …
Browse files Browse the repository at this point in the history
…change.

CloudSolrServers can be configured with a WarmupQueries instance
that takes queries for the warmup and the number of times that warmup
queries should be run.

Closes #6
  • Loading branch information
magro committed May 5, 2015
1 parent f9261d2 commit f97c402
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 57 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ Solr Cloud is supported with the following properties / restrictions:
* No Collection Aliases supported (should be implemented next)
* Can use a default collection, if this is not provided, per request the `SolrQuery` must specify
the collection via the "collection" parameter.
* New solr servers or solr servers that changed their state from inactive (e.g. down) to active can be tested
with warmup queries before they're used for load balancing queries, for this a `WarmupQueries` instance
can be set.
* Querying solr is possible when ZooKeeper is temporarily not available
* Construction of `CloudSolrServers` is possible while ZooKeeper is not available. When ZK becomes
available `CloudSolrServers` will be connected to ZK. As interval for trying to connect the
Expand Down Expand Up @@ -122,7 +125,8 @@ val servers = new CloudSolrServers(zkHost = "host1:2181,host2:2181",
zkClientTimeout = 15 seconds,
zkConnectTimeout = 10 seconds,
clusterStateUpdateInterval = 1 second,
defaultCollection = Some("collection1"))
defaultCollection = Some("collection1"),
warmupQueries: WarmupQueries("collection1" => Seq(new SolrQuery("*:*")), count = 10))
val solr = AsyncSolrClient.Builder(RoundRobinLB(servers)).build
```

Expand Down
23 changes: 20 additions & 3 deletions src/main/scala/io/ino/solrs/AsyncSolrClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@ object AsyncSolrClient {

protected def createMetrics: Metrics = NoopMetrics

// the load balancer might need to access this instance, extracted as protected method to be overridable from tests
protected def setOnLoadBalancer(solr: AsyncSolrClient): Unit = {
// the load balancer and others might need to access this instance, extracted as protected method to be overridable from tests
protected def setOnAsyncSolrClientAwares(solr: AsyncSolrClient): Unit = {
// the solr servers should be able to probe servers before the load balancer gets the handle...
// it's also set here (instead of letting the LoadBalancer pass the solrs instance to SolrServers),
// so that a LoadBalancer subclass cannot not forget to invoke super
loadBalancer.solrServers.setAsyncSolrClient(solr)
loadBalancer.setAsyncSolrClient(solr)
}

Expand All @@ -104,7 +108,7 @@ object AsyncSolrClient {
serverStateObservation,
retryPolicy
)
setOnLoadBalancer(res)
setOnAsyncSolrClientAwares(res)
res
}
}
Expand Down Expand Up @@ -358,6 +362,19 @@ class AsyncSolrClient private (val loadBalancer: LoadBalancer,

}

trait AsyncSolrClientAware {

/**
* On creation of AsyncSolrClient this method is invoked with the created instance if the
* concrete component is "supported", right now this are SolrServers and LoadBalancer.
* Subclasses can override this method to get access to the solr client.
*/
def setAsyncSolrClient(solr: AsyncSolrClient): Unit = {
// empty default
}

}

/**
* Subclass of SolrException that allows us to capture an arbitrary HTTP
* status code that may have been returned by the remote server or a
Expand Down
10 changes: 1 addition & 9 deletions src/main/scala/io/ino/solrs/LoadBalancer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import scala.concurrent.Future
import scala.concurrent.duration._
import scala.language.postfixOps

trait LoadBalancer extends RequestInterceptor {
trait LoadBalancer extends RequestInterceptor with AsyncSolrClientAware {

val solrServers: SolrServers

Expand All @@ -35,14 +35,6 @@ trait LoadBalancer extends RequestInterceptor {
f(solrServer, q)
}

/**
* On creation of AsyncSolrClient this method is invoked with the created instance.
* Subclasses can override this method to get access to the solr client.
*/
def setAsyncSolrClient(solr: AsyncSolrClient): Unit = {
// empty default
}

def shutdown(): Unit = {
// empty default
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/io/ino/solrs/SolrServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class SolrServer(val baseUrl: String) {
@volatile
var status: ServerStatus = Enabled

def isEnabled = status == Enabled

override def toString(): String = s"SolrServer($baseUrl, $status)"

override def equals(other: Any): Boolean = other match {
Expand Down
96 changes: 80 additions & 16 deletions src/main/scala/io/ino/solrs/SolrServers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@ package io.ino.solrs

import java.util.concurrent.{Executors, ScheduledExecutorService, ThreadFactory}

import com.ning.http.client.{Response, AsyncCompletionHandler, AsyncHttpClient}
import org.apache.solr.client.solrj.{SolrServerException, SolrQuery}
import com.ning.http.client.{AsyncCompletionHandler, AsyncHttpClient, Response}
import io.ino.solrs.CloudSolrServers.WarmupQueries
import org.apache.solr.client.solrj.response.QueryResponse
import org.apache.solr.client.solrj.{SolrQuery, SolrServerException}
import org.apache.solr.common.cloud._
import org.slf4j.LoggerFactory

import scala.concurrent.Future
import scala.collection.immutable.Iterable
import scala.concurrent.{Promise, Future}
import scala.util.control.NonFatal
import scala.util.{Failure, Success, Try}

/**
* Provides the list of solr servers.
*/
trait SolrServers {
trait SolrServers extends AsyncSolrClientAware {
/**
* The currently known solr servers.
*/
Expand All @@ -32,9 +36,10 @@ object StaticSolrServers {
def apply(baseUrls: IndexedSeq[String]): StaticSolrServers = new StaticSolrServers(baseUrls.map(SolrServer(_)))
}

import java.util.concurrent.TimeUnit

import scala.concurrent.duration._
import scala.languageFeature.postfixOps
import java.util.concurrent.TimeUnit

private object ZkClusterStateUpdateTF {
private val tg = new ThreadGroup("solrs-CloudSolrServersUpdate")
Expand All @@ -45,7 +50,7 @@ private class ZkClusterStateUpdateTF extends ThreadFactory {
override def newThread(r: Runnable): Thread = {
val td: Thread = new Thread(tg, r, "solrs-CloudSolrServersUpdateThread-" + tg.activeCount() + 1)
td.setDaemon(true)
return td
td
}
}

Expand All @@ -66,18 +71,24 @@ class CloudSolrServers(zkHost: String,
zkClientTimeout: Duration = 15 seconds, // default from Solr Core, see also SOLR-5221
zkConnectTimeout: Duration = 10 seconds, // default from solrj CloudSolrServer
clusterStateUpdateInterval: Duration = 1 second,
defaultCollection: Option[String] = None) extends SolrServers {
defaultCollection: Option[String] = None,
warmupQueries: Option[WarmupQueries] = None) extends SolrServers {

private val logger = LoggerFactory.getLogger(getClass())
private val logger = LoggerFactory.getLogger(getClass)

private var maybeZk: Option[ZkStateReader] = None

private var lastClusterState: ClusterState = _
private var collectionToServers = Map.empty[String, IndexedSeq[SolrServer]]
private var collectionToServers = Map.empty[String, IndexedSeq[SolrServer]].withDefaultValue(IndexedSeq.empty)

private var scheduledExecutor: ScheduledExecutorService = Executors.newScheduledThreadPool(1, new ZkClusterStateUpdateTF)

createZkStateReader()
private var asyncSolrClient: AsyncSolrClient = _

override def setAsyncSolrClient(client: AsyncSolrClient): Unit = {
asyncSolrClient = client
createZkStateReader()
}

private def createZkStateReader(): Unit = {
// Setup ZkStateReader, schedule retry if ZK is unavailable
Expand Down Expand Up @@ -120,6 +131,8 @@ class CloudSolrServers(zkHost: String,
* Updates the server list when the ZkStateReader clusterState changed
*/
private def updateFromClusterState(zkStateReader: ZkStateReader): Unit = {
import scala.concurrent.ExecutionContext.Implicits.global

val clusterState = zkStateReader.getClusterState
if(clusterState != lastClusterState) {

Expand All @@ -137,15 +150,55 @@ class CloudSolrServers(zkHost: String,

lastClusterState = clusterState

if(newCollectionToServers != collectionToServers) {
// expect the new collection to servers map just for better readability on usage side...
def set(newCollectionToServers: Map[String, IndexedSeq[SolrServer]]): Unit = {
collectionToServers = newCollectionToServers
if(logger.isDebugEnabled) logger.debug(s"Updated server map: $collectionToServers from ClusterState $clusterState")
else logger.info(s"Updated server map: $collectionToServers")
if (logger.isDebugEnabled) logger.debug (s"Updated server map: $collectionToServers from ClusterState $clusterState")
else logger.info (s"Updated server map: $collectionToServers")
}

if(newCollectionToServers != collectionToServers) warmupQueries match {
case Some(warmup) => warmupNewServers(newCollectionToServers, warmup)
.onComplete(_ => set(newCollectionToServers))
case None => set(newCollectionToServers)
}

} catch {
case NonFatal(e) =>
logger.error(s"Could not process cluster state, server list might get outdated. Cluster state: ${clusterState}", e)
logger.error(s"Could not process cluster state, server list might get outdated. Cluster state: $clusterState", e)
}
}
}

protected def warmupNewServers(newCollectionToServers: Map[String, IndexedSeq[SolrServer]],
warmup: WarmupQueries): Future[Iterable[Try[QueryResponse]]] = {
import scala.concurrent.ExecutionContext.Implicits.global

val perCollectionResponses = newCollectionToServers.flatMap { case (collection, solrServers) =>
val existingServers = collectionToServers(collection)
// SolrServer.equals checks both baseUrl and status, therefore we can just use contains
val newActiveServers = solrServers.filter(s => s.isEnabled && !existingServers.contains(s))
newActiveServers.map(warmupNewServer(collection, _, warmup.queriesByCollection(collection), warmup.count))
}

Future.sequence(perCollectionResponses).map(_.flatten)
}

protected def warmupNewServer(collection: String, s: SolrServer, queries: Seq[SolrQuery], count: Int): Future[Seq[Try[QueryResponse]]] = {
import scala.concurrent.ExecutionContext.Implicits.global
// queries shall be run in parallel, one round after the other
(1 to count).foldLeft(Future.successful(Seq.empty[Try[QueryResponse]])) { (res, round) =>
res.flatMap { _ =>
val warmupResponses = queries.map(q =>
asyncSolrClient.doQuery(s, q)
.map(Success(_))
.recover {
case NonFatal(e) =>
logger.warn(s"Warmup query $q failed", e)
Failure(e)
}
)
Future.sequence(warmupResponses)
}
}
}
Expand Down Expand Up @@ -189,7 +242,7 @@ class CloudSolrServers(zkHost: String,

class ReloadingSolrServers(url: String, extractor: Array[Byte] => IndexedSeq[SolrServer], httpClient: AsyncHttpClient) extends SolrServers {

private val logger = LoggerFactory.getLogger(getClass())
private val logger = LoggerFactory.getLogger(getClass)

private var solrServers = IndexedSeq.empty[SolrServer]

Expand Down Expand Up @@ -218,7 +271,7 @@ class ReloadingSolrServers(url: String, extractor: Array[Byte] => IndexedSeq[Sol
}

protected def loadUrl(): Future[Array[Byte]] = {
val promise = scala.concurrent.promise[Array[Byte]]
val promise = Promise[Array[Byte]]()
httpClient.prepareGet(url).execute(new AsyncCompletionHandler[Response]() {
override def onCompleted(response: Response): Response = {
promise.success(response.getResponseBodyAsBytes)
Expand All @@ -233,4 +286,15 @@ class ReloadingSolrServers(url: String, extractor: Array[Byte] => IndexedSeq[Sol
}


}

object CloudSolrServers {

/**
* Specifies how newly added servers / servers that changed from down to active are put under load.
* @param queriesByCollection a function that returns warmup queries for a given collection.
* @param count the number of times that the queries shall be run.s
*/
case class WarmupQueries(queriesByCollection: String => Seq[SolrQuery], count: Int)

}
48 changes: 48 additions & 0 deletions src/test/scala/io/ino/solrs/AsyncSolrClientMocks.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.ino.solrs

import io.ino.time.Clock.MutableClock
import org.apache.solr.client.solrj.response.QueryResponse
import org.mockito.Matchers._
import org.mockito.Mockito._

import scala.concurrent._
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Success, Try}

object AsyncSolrClientMocks {

def mockDoQuery(mock: AsyncSolrClient,
solrServer: => SolrServer = any[SolrServer](),
responseDelay: Duration = 1 milli)
(implicit clock: MutableClock): AsyncSolrClient = {
// for spies doReturn should be used...
doReturn(delayedResponse(responseDelay.toMillis)).when(mock).doQuery(solrServer, any())
mock
}

def mockDoQuery(mock: AsyncSolrClient,
futureResponse: Future[QueryResponse]): AsyncSolrClient = {
// for spies doReturn should be used...
doReturn(futureResponse).when(mock).doQuery(any[SolrServer](), any())
mock
}

def delayedResponse(delay: Long)(implicit clock: MutableClock): Future[QueryResponse] = {
val response = new QueryResponse()
new Future[QueryResponse] {
override def onComplete[U](func: (Try[QueryResponse]) => U)(implicit executor: ExecutionContext): Unit = {
clock.advance(delay)
func(Success(response))
}
override def isCompleted: Boolean = true
override def value: Option[Try[QueryResponse]] = Some(Success(response))
@throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): QueryResponse = response
@throws(classOf[InterruptedException])
@throws(classOf[TimeoutException])
override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = this
}
}

}
Loading

0 comments on commit f97c402

Please sign in to comment.