Skip to content

Commit

Permalink
Use a RepeatedTaskQueue in SqsJobConsumer (#1071)
Browse files Browse the repository at this point in the history
* Use a RepeatedTaskQueue in SqsJobConsumer

This fixes #1069, which was originally a bug in RepeatedTaskQueue back
in the day. We should prefer using RepeatedTaskQueue instead of hand
rolling threads to prevent repeating the same mistakes. In addition, job
consumers will backoff if there is no work or run into errors.

Now every parallel job consumer runs as a RepeatedTask and each Task
consumes N messages and dispatches them to separate executor service for
running.
  • Loading branch information
Ryan Hall committed Jun 21, 2019
1 parent c3f502b commit 18d41ac
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 95 deletions.
Expand Up @@ -18,8 +18,7 @@ class AwsSqsJobHandlerModule<T : JobHandler> private constructor(
) : KAbstractModule() {
override fun configure() {
newMapBinder<QueueName, JobHandler>().addBinding(queueName).to(handler.java)
install(ServiceModule<AwsSqsJobHandlerSubscriptionService>()
.dependsOn<SqsJobConsumer>())
install(ServiceModule<AwsSqsJobHandlerSubscriptionService>())
}

companion object {
Expand Down Expand Up @@ -56,4 +55,4 @@ internal class AwsSqsJobHandlerSubscriptionService @Inject constructor(
}

override fun shutDown() {}
}
}
@@ -1,6 +1,7 @@
package misk.jobqueue.sqs

import misk.config.Config
import misk.tasks.RepeatedTaskQueueConfig

/**
* [AwsSqsJobQueueConfig] is the configuration for job queueing backed by Amazon's
Expand All @@ -22,6 +23,8 @@ class AwsSqsJobQueueConfig(
/**
* Number of jobs that can be processed concurrently.
*/
val consumer_thread_pool_size: Int = 4
val consumer_thread_pool_size: Int = 4,

val task_queue: RepeatedTaskQueueConfig = RepeatedTaskQueueConfig()
) : Config

21 changes: 19 additions & 2 deletions misk-aws/src/main/kotlin/misk/jobqueue/sqs/AwsSqsJobQueueModule.kt
Expand Up @@ -3,17 +3,23 @@ package misk.jobqueue.sqs
import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.services.sqs.AmazonSQS
import com.amazonaws.services.sqs.AmazonSQSClientBuilder
import com.google.common.util.concurrent.ThreadFactoryBuilder
import com.google.inject.Provider
import com.google.inject.Provides
import com.google.inject.Singleton
import misk.ServiceModule
import misk.cloud.aws.AwsRegion
import misk.concurrent.ExecutorServiceModule
import misk.inject.KAbstractModule
import misk.inject.keyOf
import misk.jobqueue.JobConsumer
import misk.jobqueue.JobQueue
import misk.jobqueue.QueueName
import misk.jobqueue.TransactionalJobQueue
import misk.tasks.RepeatedTaskQueue
import java.time.Clock
import java.time.Duration
import java.util.concurrent.Executors
import javax.inject.Inject

/** [AwsSqsJobQueueModule] installs job queue support provided by SQS. */
Expand All @@ -30,7 +36,7 @@ class AwsSqsJobQueueModule(
bind<JobQueue>().to<SqsJobQueue>()
bind<TransactionalJobQueue>().to<SqsTransactionalJobQueue>()

install(ServiceModule<SqsJobConsumer>())
install(ServiceModule(keyOf<RepeatedTaskQueue>(ForSqsConsumer::class)))

install(ExecutorServiceModule.withFixedThreadPool(
ForSqsConsumer::class,
Expand Down Expand Up @@ -62,6 +68,17 @@ class AwsSqsJobQueueModule(
.build()
}

@Provides @ForSqsConsumer @Singleton
fun consumerRepeatedTaskQueue(clock: Clock, config : AwsSqsJobQueueConfig): RepeatedTaskQueue {
return RepeatedTaskQueue(
"sqs-consumer-poller",
clock,
Executors.newCachedThreadPool(ThreadFactoryBuilder()
.setNameFormat("sqs-consumer-%d")
.build()),
config.task_queue)
}

private class AmazonSQSProvider(val region: AwsRegion) : Provider<AmazonSQS> {
@Inject lateinit var credentials: AWSCredentialsProvider

Expand All @@ -70,4 +87,4 @@ class AwsSqsJobQueueModule(
.withRegion(region.name)
.build()
}
}
}
3 changes: 2 additions & 1 deletion misk-aws/src/main/kotlin/misk/jobqueue/sqs/ForSqsConsumer.kt
Expand Up @@ -3,4 +3,5 @@ package misk.jobqueue.sqs
import javax.inject.Qualifier

@Qualifier
internal annotation class ForSqsConsumer
@Target(AnnotationTarget.FIELD, AnnotationTarget.FUNCTION, AnnotationTarget.VALUE_PARAMETER)
internal annotation class ForSqsConsumer
124 changes: 63 additions & 61 deletions misk-aws/src/main/kotlin/misk/jobqueue/sqs/SqsJobConsumer.kt
@@ -1,7 +1,6 @@
package misk.jobqueue.sqs

import com.amazonaws.services.sqs.model.ReceiveMessageRequest
import com.google.common.util.concurrent.AbstractIdleService
import com.google.common.util.concurrent.ServiceManager
import io.opentracing.Tracer
import io.opentracing.tag.StringTag
Expand All @@ -10,21 +9,24 @@ import misk.jobqueue.JobConsumer
import misk.jobqueue.JobHandler
import misk.jobqueue.QueueName
import misk.logging.getLogger
import misk.tasks.RepeatedTaskQueue
import misk.tasks.Status
import misk.time.timed
import misk.tracing.traceWithSpan
import java.time.Duration
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutionException
import java.util.concurrent.ExecutorService
import java.util.concurrent.atomic.AtomicBoolean
import javax.inject.Inject
import javax.inject.Provider
import javax.inject.Singleton
import kotlin.concurrent.thread

@Singleton
internal class SqsJobConsumer @Inject internal constructor(
private val config: AwsSqsJobQueueConfig,
private val queues: QueueResolver,
@ForSqsConsumer private val dispatchThreadPool: ExecutorService,
@ForSqsConsumer private val taskQueue: RepeatedTaskQueue,
private val tracer: Tracer,
private val metrics: SqsMetrics,
/**
Expand All @@ -33,16 +35,10 @@ internal class SqsJobConsumer @Inject internal constructor(
* jobs. We use a provider here to avoid a dependency cycle.
*/
private val serviceManagerProvider: Provider<ServiceManager>
) : AbstractIdleService(), JobConsumer {
private val subscriptions = ConcurrentHashMap<QueueName, JobConsumer.Subscription>()
) : JobConsumer {
private val subscriptions = ConcurrentHashMap<QueueName, QueueReceiver>()

override fun startUp() {}

override fun shutDown() {
subscriptions.values.forEach { it.close() }
}

override fun subscribe(queueName: QueueName, handler: JobHandler): JobConsumer.Subscription {
override fun subscribe(queueName: QueueName, handler: JobHandler) {
val receiver = QueueReceiver(queueName, handler)
if (subscriptions.putIfAbsent(queueName, receiver) != null) {
throw IllegalStateException("already subscribed to queue ${queueName.value}")
Expand All @@ -53,70 +49,76 @@ internal class SqsJobConsumer @Inject internal constructor(
}

for (i in (0 until config.concurrent_receivers_per_queue)) {
thread(name = "sqs-receiver-${queueName.value}-$i", start = true) {
log.info { "launching receiver $i for ${queueName.value}" }
receiver.run()
taskQueue.scheduleWithBackoff(Duration.ZERO) {
// Don't call handlers until all services are ready, otherwise handlers will crash because the
// services they might need (databases, etc.) won't be ready.
serviceManagerProvider.get().awaitHealthy()
receiver.runOnce()
}
}
}

return receiver
internal fun getReceiver(queueName: QueueName): QueueReceiver {
return subscriptions[queueName]!!
}

private inner class QueueReceiver(
internal inner class QueueReceiver(
private val queueName: QueueName,
private val handler: JobHandler
) : Runnable, JobConsumer.Subscription {
) {
private val queue = queues[queueName]
private val running = AtomicBoolean(true)

override fun run() {
// Don't call handlers until all services are ready, otherwise handlers will crash because the
// services they might need (databases, etc.) won't be ready.
serviceManagerProvider.get().awaitHealthy()

while (running.get()) {
val messages = queue.call { client ->
client.receiveMessage(ReceiveMessageRequest()
.withAttributeNames("All")
.withMessageAttributeNames("All")
.withQueueUrl(queue.url)
.withMaxNumberOfMessages(10))
.messages
}

messages.map { SqsJob(queueName, queues, metrics, it) }.forEach { message ->
dispatchThreadPool.submit {
metrics.jobsReceived.labels(queueName.value).inc()

tracer.traceWithSpan("handle-job-${queueName.value}") { span ->
// If the incoming job has an original trace id, set that as a tag on the new span.
// We don't turn that into the parent of the current span because that would
// incorrectly include the execution time of the job in the execution time of the
// action that triggered the job
message.attributes[SqsJob.ORIGINAL_TRACE_ID_ATTR]?.let {
ORIGINAL_TRACE_ID_TAG.set(span, it)
}

// Run the handler and record timing
try {
val (duration, _) = timed { handler.handleJob(message) }
metrics.handlerDispatchTime.record(duration.toMillis().toDouble(), queueName.value)
} catch (th: Throwable) {
log.error(th) { "error handling job from ${queueName.value}" }
metrics.handlerFailures.labels(queueName.value).inc()
Tags.ERROR.set(span, true)
}
fun runOnce(): Status {
val messages = queue.call { client ->
client.receiveMessage(ReceiveMessageRequest()
.withAttributeNames("All")
.withMessageAttributeNames("All")
.withQueueUrl(queue.url)
.withMaxNumberOfMessages(10))
.messages
}

if (messages.size == 0) {
return Status.NO_WORK
}

val futures = messages.map { SqsJob(queueName, queues, metrics, it) }.map { message ->
dispatchThreadPool.submit {
metrics.jobsReceived.labels(queueName.value).inc()

tracer.traceWithSpan("handle-job-${queueName.value}") { span ->
// If the incoming job has an original trace id, set that as a tag on the new span.
// We don't turn that into the parent of the current span because that would
// incorrectly include the execution time of the job in the execution time of the
// action that triggered the job q
message.attributes[SqsJob.ORIGINAL_TRACE_ID_ATTR]?.let {
ORIGINAL_TRACE_ID_TAG.set(span, it)
}

// Run the handler and record timing
try {
val (duration, _) = timed { handler.handleJob(message) }
metrics.handlerDispatchTime.record(duration.toMillis().toDouble(), queueName.value)
} catch (th: Throwable) {
log.error(th) { "error handling job from ${queueName.value}" }
metrics.handlerFailures.labels(queueName.value).inc()
Tags.ERROR.set(span, true)
throw th
}
}
}
}
}

override fun close() {
if (!subscriptions.remove(queueName, this)) return
for (future in futures) {
try {
future.get()
} catch (e: ExecutionException) {
// the exception was already logged when the dispatched task failed above
return Status.FAILED
}
}

log.info { "closing subscription to queue ${queueName.value}" }
running.set(false)
return Status.OK
}
}

Expand Down
31 changes: 29 additions & 2 deletions misk-aws/src/test/kotlin/misk/jobqueue/sqs/SqsJobQueueTest.kt
Expand Up @@ -7,6 +7,8 @@ import misk.jobqueue.JobConsumer
import misk.jobqueue.JobQueue
import misk.jobqueue.QueueName
import misk.jobqueue.subscribe
import misk.tasks.RepeatedTaskQueue
import misk.tasks.Status
import misk.testing.MiskExternalDependency
import misk.testing.MiskTest
import misk.testing.MiskTestModule
Expand All @@ -28,6 +30,7 @@ internal class SqsJobQueueTest {
@Inject private lateinit var queue: JobQueue
@Inject private lateinit var consumer: JobConsumer
@Inject private lateinit var sqsMetrics: SqsMetrics
@Inject @ForSqsConsumer lateinit var taskQueue : RepeatedTaskQueue

private lateinit var queueName: QueueName
private lateinit var deadLetterQueueName: QueueName
Expand Down Expand Up @@ -173,13 +176,13 @@ internal class SqsJobQueueTest {

@Test fun stopsDeliveryAfterClose() {
val handledJobs = CopyOnWriteArrayList<Job>()
val subscription = consumer.subscribe(queueName) {
consumer.subscribe(queueName) {
handledJobs.add(it)
it.acknowledge()
}

// Close the subscription and wait for any currently outstanding long-polls to complete
subscription.close()
turnOffTaskQueue()
Thread.sleep(1001)

// Send 10 jobs, then wait again for the long-poll to complete make sure none of them are delivered
Expand Down Expand Up @@ -219,4 +222,28 @@ internal class SqsJobQueueTest {
assertThat(sqsMetrics.jobsDeadLettered.labels(queueName.value).get()).isEqualTo(0.0)
assertThat(sqsMetrics.handlerFailures.labels(queueName.value).get()).isEqualTo(2.0)
}

@Test fun waitsForDispatchedTasksToFail() {
turnOffTaskQueue()
consumer.subscribe(queueName) {
throw IllegalStateException("boom!")
}
queue.enqueue(queueName, "fail away")
val receiver = (consumer as SqsJobConsumer).getReceiver(queueName)
assertThat(receiver.runOnce()).isEqualTo(Status.FAILED)
}

@Test fun noWork() {
turnOffTaskQueue()
consumer.subscribe(queueName) {
throw IllegalStateException("boom!")
}
val receiver = (consumer as SqsJobConsumer).getReceiver(queueName)
assertThat(receiver.runOnce()).isEqualTo(Status.NO_WORK)
}

private fun turnOffTaskQueue() {
taskQueue.stopAsync()
taskQueue.awaitTerminated()
}
}
Expand Up @@ -7,9 +7,10 @@ import misk.MiskTestingServiceModule
import misk.cloud.aws.AwsEnvironmentModule
import misk.cloud.aws.FakeAwsEnvironmentModule
import misk.inject.KAbstractModule
import misk.tasks.RepeatedTaskQueueConfig
import misk.testing.MockTracingBackendModule

class SqsJobQueueTestModule(
class SqsJobQueueTestModule (
private val credentials: AWSCredentialsProvider,
private val client: AmazonSQS
) : KAbstractModule() {
Expand All @@ -20,7 +21,8 @@ class SqsJobQueueTestModule(
install(FakeAwsEnvironmentModule())
install(
Modules.override(
AwsSqsJobQueueModule(AwsSqsJobQueueConfig()))
AwsSqsJobQueueModule(
AwsSqsJobQueueConfig(task_queue = RepeatedTaskQueueConfig(default_jitter_ms = 0))))
.with(SqsTestModule(credentials, client))
)
}
Expand All @@ -34,4 +36,4 @@ class SqsTestModule(
bind<AWSCredentialsProvider>().toInstance(credentials)
bind<AmazonSQS>().toInstance(client)
}
}
}

0 comments on commit 18d41ac

Please sign in to comment.