diff --git a/akka/src/main/scala/cakesolutions/kafka/akka/KafkaConsumerActor.scala b/akka/src/main/scala/cakesolutions/kafka/akka/KafkaConsumerActor.scala index 713dda2..a130675 100644 --- a/akka/src/main/scala/cakesolutions/kafka/akka/KafkaConsumerActor.scala +++ b/akka/src/main/scala/cakesolutions/kafka/akka/KafkaConsumerActor.scala @@ -97,8 +97,18 @@ object KafkaConsumerActor { * The client should ensure that received records are confirmed with 'commit = true' to ensure kafka tracks the commit point. * * @param topics the topics to subscribe to start consuming from + * @param assignedListener Optionally provide a callback when partitions are assigned. Can be used if any initialisation is + * required prior to receiving messages for the partition, such as to populate a cache. Default implementation + * is to do nothing. + * @param revokedListener Optionally provide a callback when partitions are revoked. Can be used if any cleanup is + * required after a partition assignment is revoked. Default implementation + * is to do nothing. */ - final case class AutoPartition(topics: Iterable[String]) extends Subscribe + final case class AutoPartition( + topics: Iterable[String] = List(), + assignedListener: List[TopicPartition] => Unit = _ => (), + revokedListener: List[TopicPartition] => Unit = _ => () + ) extends Subscribe /** * Subscribe to topics in auto assigned partition mode with client managed offset commit positions for each partition. @@ -658,9 +668,9 @@ private final class KafkaConsumerActorImpl[K: TypeTag, V: TypeTag]( } private def subscribe(s: Subscribe): Unit = s match { - case Subscribe.AutoPartition(topics) => + case Subscribe.AutoPartition(topics, assignedListener, revokedListener) => log.info(s"Subscribing in auto partition assignment mode to topics [{}].", topics.mkString(",")) - trackPartitions = new TrackPartitionsCommitMode(consumer, context.self) + trackPartitions = new TrackPartitionsCommitMode(consumer, context.self, assignedListener, revokedListener) consumer.subscribe(topics.toList.asJava, trackPartitions) case Subscribe.AutoPartitionWithManualOffset(topics, assignedListener, revokedListener) => diff --git a/akka/src/main/scala/cakesolutions/kafka/akka/TrackPartitions.scala b/akka/src/main/scala/cakesolutions/kafka/akka/TrackPartitions.scala index 4f1e9f8..3d308a1 100644 --- a/akka/src/main/scala/cakesolutions/kafka/akka/TrackPartitions.scala +++ b/akka/src/main/scala/cakesolutions/kafka/akka/TrackPartitions.scala @@ -13,6 +13,9 @@ sealed trait TrackPartitions extends ConsumerRebalanceListener { def isRevoked: Boolean def reset(): Unit + + def offsetsToTopicPartitions(offsets: Map[TopicPartition, Long]): List[TopicPartition] = + offsets.map { case (tp, _) => tp }.toList } /** @@ -25,8 +28,10 @@ sealed trait TrackPartitions extends ConsumerRebalanceListener { * @param consumer The client driver * @param consumerActor Tha KafkaConsumerActor to notify of partition change events */ -private final class TrackPartitionsCommitMode(consumer: KafkaConsumer[_, _], consumerActor: ActorRef) - extends TrackPartitions { +private final class TrackPartitionsCommitMode( + consumer: KafkaConsumer[_, _], consumerActor: ActorRef, + assignedListener: List[TopicPartition] => Unit, + revokedListener: List[TopicPartition] => Unit) extends TrackPartitions { private val log = LoggerFactory.getLogger(getClass) @@ -38,6 +43,8 @@ private final class TrackPartitionsCommitMode(consumer: KafkaConsumer[_, _], con _revoked = true + revokedListener(partitions.asScala.toList) + // If partitions have been revoked, keep a record of our current position within them. if (!partitions.isEmpty) { _offsets = partitions.asScala.map(partition => partition -> consumer.position(partition)).toMap @@ -55,6 +62,7 @@ private final class TrackPartitionsCommitMode(consumer: KafkaConsumer[_, _], con val allExisting = _offsets.forall { case (partition, _) => partitions.contains(partition) } if (allExisting) { + assignedListener(partitions.asScala.toList) for { partition <- partitions.asScala offset <- _offsets.get(partition) @@ -66,6 +74,9 @@ private final class TrackPartitionsCommitMode(consumer: KafkaConsumer[_, _], con } else { consumerActor ! KafkaConsumerActor.RevokeReset + + // Invoke client callback to notify revocation of all existing partitions. + revokedListener(offsetsToTopicPartitions(_offsets)) } } @@ -113,9 +124,6 @@ private final class TrackPartitionsManualOffset( log.debug("onPartitionsAssigned: " + partitions.toString) - def offsetsToTopicPartitions(offsets: Map[TopicPartition, Long]): List[TopicPartition] = - offsets.map { case (tp, _) => tp }.toList - def assign(partitions: List[TopicPartition]) = { val offsets = assignedListener(partitions) for { diff --git a/testkit/src/main/scala/cakesolutions.kafka/testkit/KafkaServer.scala b/testkit/src/main/scala/cakesolutions.kafka/testkit/KafkaServer.scala index f194255..0749587 100644 --- a/testkit/src/main/scala/cakesolutions.kafka/testkit/KafkaServer.scala +++ b/testkit/src/main/scala/cakesolutions.kafka/testkit/KafkaServer.scala @@ -158,7 +158,7 @@ final class KafkaServer( val collected = ArrayBuffer.empty[(Option[Key], Value)] val start = System.currentTimeMillis() - while (total <= expectedNumOfRecords && System.currentTimeMillis() < start + timeout) { + while (total < expectedNumOfRecords && System.currentTimeMillis() < start + timeout) { val records = consumer.poll(100) val kvs = records.asScala.map(r => (Option(r.key()), r.value())) collected ++= kvs