diff --git a/build.sbt b/build.sbt index bd9bb98..ed587a2 100644 --- a/build.sbt +++ b/build.sbt @@ -20,13 +20,13 @@ lazy val commonSettings = Seq( ) lazy val commonLibrarySettings = libraryDependencies ++= Seq( - "org.scalatest" %% "scalatest" % "3.0.1", - "org.apache.kafka" %% "kafka" % kafkaVersion exclude (slf4jLog4jOrg, slf4jLog4jArtifact), - "org.apache.zookeeper" % "zookeeper" % "3.4.8" exclude (slf4jLog4jOrg, slf4jLog4jArtifact), - "org.apache.avro" % "avro" % "1.8.1" exclude (slf4jLog4jOrg, slf4jLog4jArtifact), - "com.typesafe.akka" %% "akka-actor" % akkaVersion % Test, - "com.typesafe.akka" %% "akka-testkit" % akkaVersion % Test - ) + "org.scalatest" %% "scalatest" % "3.0.1", + "org.apache.kafka" %% "kafka" % kafkaVersion exclude(slf4jLog4jOrg, slf4jLog4jArtifact), + "org.apache.zookeeper" % "zookeeper" % "3.4.8" exclude(slf4jLog4jOrg, slf4jLog4jArtifact), + "org.apache.avro" % "avro" % "1.8.1" exclude(slf4jLog4jOrg, slf4jLog4jArtifact), + "com.typesafe.akka" %% "akka-actor" % akkaVersion % Test, + "com.typesafe.akka" %% "akka-testkit" % akkaVersion % Test +) lazy val publishSettings = Seq( licenses += ("MIT", url("http://opensource.org/licenses/MIT")), @@ -69,6 +69,7 @@ lazy val embeddedKafka = (project in file("embedded-kafka")) .settings(publishSettings: _*) .settings(commonSettings: _*) .settings(commonLibrarySettings) + .settings(libraryDependencies += "org.mockito" % "mockito-core" % "2.7.14" % Test) .settings(releaseSettings: _*) lazy val kafkaStreams = (project in file("kafka-streams")) @@ -78,6 +79,6 @@ lazy val kafkaStreams = (project in file("kafka-streams")) .settings(commonLibrarySettings) .settings(releaseSettings: _*) .settings(libraryDependencies ++= Seq( - "org.apache.kafka" % "kafka-streams" % kafkaVersion exclude (slf4jLog4jOrg, slf4jLog4jArtifact) + "org.apache.kafka" % "kafka-streams" % kafkaVersion exclude(slf4jLog4jOrg, slf4jLog4jArtifact) )) .dependsOn(embeddedKafka) diff --git a/embedded-kafka/src/main/scala/net/manub/embeddedkafka/ConsumerExtensions.scala b/embedded-kafka/src/main/scala/net/manub/embeddedkafka/ConsumerExtensions.scala index a8f6d84..742b873 100644 --- a/embedded-kafka/src/main/scala/net/manub/embeddedkafka/ConsumerExtensions.scala +++ b/embedded-kafka/src/main/scala/net/manub/embeddedkafka/ConsumerExtensions.scala @@ -6,9 +6,9 @@ import org.apache.log4j.Logger import scala.util.Try -/** Method extensions for Kafka's [[KafkaConsumer]] API allowing easy testing.*/ +/** Method extensions for Kafka's [[KafkaConsumer]] API allowing easy testing. */ object ConsumerExtensions { - val MaximumAttempts = 3 + implicit class ConsumerOps[K, V](val consumer: KafkaConsumer[K, V]) { private val logger = Logger.getLogger(classOf[ConsumerOps[K, V]]) @@ -18,14 +18,16 @@ object ConsumerExtensions { * to consume batches from the given topic, until it reaches the number of desired messages or * return otherwise. * - * @param topic the topic from which to consume messages + * @param topic the topic from which to consume messages + * @param maximumAttempts the maximum number of attempts to try and get the batch (defaults to 3) + * @param poll the amount of time, in milliseconds, to wait in the buffer for any messages to be available (defaults to 2000) * @return the stream of consumed messages that you can do `.take(n: Int).toList` * to evaluate the requested number of messages. */ - def consumeLazily(topic: String): Stream[(K, V)] = { - val attempts = 1 to MaximumAttempts + def consumeLazily(topic: String, maximumAttempts: Int = 3, poll: Long = 2000): Stream[(K, V)] = { + val attempts = 1 to maximumAttempts attempts.toStream.flatMap { attempt => - val batch: Seq[(K, V)] = getNextBatch(topic) + val batch: Seq[(K, V)] = getNextBatch(topic, poll) logger.debug(s"----> Batch $attempt ($topic) | ${batch.mkString("|")}") batch } @@ -34,18 +36,20 @@ object ConsumerExtensions { /** Get the next batch of messages from Kafka. * * @param topic the topic to consume + * @param poll the amount of time, in milliseconds, to wait in the buffer for any messages to be available * @return the next batch of messages */ - def getNextBatch(topic: String): Seq[(K, V)] = + private def getNextBatch(topic: String, poll: Long): Seq[(K, V)] = Try { - import scala.collection.JavaConversions._ - consumer.subscribe(List(topic)) + import scala.collection.JavaConverters._ + consumer.subscribe(List(topic).asJava) consumer.partitionsFor(topic) - val records = consumer.poll(2000) + val records = consumer.poll(poll) // use toList to force eager evaluation. toSeq is lazy - records.iterator().toList.map(r => r.key -> r.value) + records.iterator().asScala.toList.map(r => r.key -> r.value) }.recover { case ex: KafkaException => throw new KafkaUnavailableException(ex) }.get } + } diff --git a/embedded-kafka/src/test/java/net/manub/embeddedkafka/ConsumerOpsSpec.scala b/embedded-kafka/src/test/java/net/manub/embeddedkafka/ConsumerOpsSpec.scala new file mode 100644 index 0000000..4d5ba43 --- /dev/null +++ b/embedded-kafka/src/test/java/net/manub/embeddedkafka/ConsumerOpsSpec.scala @@ -0,0 +1,59 @@ +package net.manub.embeddedkafka + +import net.manub.embeddedkafka.ConsumerExtensions._ +import org.apache.kafka.clients.consumer.{ConsumerRecord, ConsumerRecords, KafkaConsumer} +import org.apache.kafka.common.TopicPartition +import org.mockito.Mockito.{times, verify, when} +import org.scalatest.mockito.MockitoSugar + +import scala.collection.JavaConverters._ + +class ConsumerOpsSpec extends EmbeddedKafkaSpecSupport with MockitoSugar { + + "ConsumeLazily " should { + "retry to get messages with the configured maximum number of attempts when poll fails" in { + val consumer = mock[KafkaConsumer[String, String]] + val consumerRecords = + new ConsumerRecords[String, String](Map.empty[TopicPartition, java.util.List[ConsumerRecord[String, String]]].asJava) + + val pollTimeout = 1 + when(consumer.poll(pollTimeout)).thenReturn(consumerRecords) + + val maximumAttempts = 2 + consumer.consumeLazily("topic", maximumAttempts, pollTimeout) + + verify(consumer, times(maximumAttempts)).poll(pollTimeout) + } + + "not retry to get messages with the configured maximum number of attempts when poll succeeds" in { + val consumer = mock[KafkaConsumer[String, String]] + val consumerRecord = mock[ConsumerRecord[String, String]] + val consumerRecords = new ConsumerRecords[String, String]( + Map[TopicPartition, java.util.List[ConsumerRecord[String, String]]](new TopicPartition("topic", 1) -> List(consumerRecord).asJava).asJava + ) + + val pollTimeout = 1 + when(consumer.poll(pollTimeout)).thenReturn(consumerRecords) + + val maximumAttempts = 2 + consumer.consumeLazily("topic", maximumAttempts, pollTimeout) + + verify(consumer).poll(pollTimeout) + } + + "poll to get messages with the configured poll timeout" in { + val consumer = mock[KafkaConsumer[String, String]] + val consumerRecords = + new ConsumerRecords[String, String](Map.empty[TopicPartition, java.util.List[ConsumerRecord[String, String]]].asJava) + + val pollTimeout = 10 + when(consumer.poll(pollTimeout)).thenReturn(consumerRecords) + + val maximumAttempts = 1 + consumer.consumeLazily("topic", maximumAttempts, pollTimeout) + + verify(consumer).poll(pollTimeout) + } + } + +}