diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java index a8c2a801e4ce..186ad9b80a19 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java @@ -16,6 +16,7 @@ */ package org.apache.kafka.common.requests; +import org.apache.kafka.common.Node; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.message.ProduceResponseData; import org.apache.kafka.common.message.ProduceResponseData.LeaderIdAndEpoch; @@ -72,7 +73,7 @@ public ProduceResponse(ProduceResponseData produceResponseData) { */ @Deprecated public ProduceResponse(Map responses) { - this(responses, DEFAULT_THROTTLE_TIME); + this(responses, DEFAULT_THROTTLE_TIME, Collections.emptyList()); } /** @@ -83,10 +84,23 @@ public ProduceResponse(Map responses) { */ @Deprecated public ProduceResponse(Map responses, int throttleTimeMs) { - this(toData(responses, throttleTimeMs)); + this(toData(responses, throttleTimeMs, Collections.emptyList())); } - private static ProduceResponseData toData(Map responses, int throttleTimeMs) { + /** + * Constructor for the latest version + * This is deprecated in favor of using the ProduceResponseData constructor, KafkaApis should switch to that + * in KAFKA-10730 + * @param responses Produced data grouped by topic-partition + * @param throttleTimeMs Time in milliseconds the response was throttled + * @param nodeEndpoints List of node endpoints + */ + @Deprecated + public ProduceResponse(Map responses, int throttleTimeMs, List nodeEndpoints) { + this(toData(responses, throttleTimeMs, nodeEndpoints)); + } + + private static ProduceResponseData toData(Map responses, int throttleTimeMs, List nodeEndpoints) { ProduceResponseData data = new ProduceResponseData().setThrottleTimeMs(throttleTimeMs); responses.forEach((tp, response) -> { ProduceResponseData.TopicProduceResponse tpr = data.responses().find(tp.topic()); @@ -110,6 +124,12 @@ private static ProduceResponseData toData(Map .setBatchIndexErrorMessage(e.message)) .collect(Collectors.toList()))); }); + nodeEndpoints.forEach(endpoint -> data.nodeEndpoints() + .add(new ProduceResponseData.NodeEndpoint() + .setNodeId(endpoint.id()) + .setHost(endpoint.host()) + .setPort(endpoint.port()) + .setRack(endpoint.rack()))); return data; } diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index bd0959d40d2f..3f5a435d1471 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -562,6 +562,23 @@ class KafkaApis(val requestChannel: RequestChannel, } } + case class LeaderNode(leaderId: Int, leaderEpoch: Int, node: Option[Node]) + + private def getCurrentLeader(tp: TopicPartition, ln: ListenerName): LeaderNode = { + val partitionInfoOrError = replicaManager.getPartitionOrError(tp) + val (leaderId, leaderEpoch) = partitionInfoOrError match { + case Right(x) => + (x.leaderReplicaIdOpt.getOrElse(-1), x.getLeaderEpoch) + case Left(x) => + debug(s"Unable to retrieve local leaderId and Epoch with error $x, falling back to metadata cache") + metadataCache.getPartitionInfo(tp.topic, tp.partition) match { + case Some(pinfo) => (pinfo.leader(), pinfo.leaderEpoch()) + case None => (-1, -1) + } + } + LeaderNode(leaderId, leaderEpoch, metadataCache.getAliveBrokerNode(leaderId, ln)) + } + /** * Handle a produce request */ @@ -614,6 +631,7 @@ class KafkaApis(val requestChannel: RequestChannel, val mergedResponseStatus = responseStatus ++ unauthorizedTopicResponses ++ nonExistingTopicResponses ++ invalidRequestResponses var errorInResponse = false + val nodeEndpoints = new mutable.HashMap[Int, Node] mergedResponseStatus.forKeyValue { (topicPartition, status) => if (status.error != Errors.NONE) { errorInResponse = true @@ -622,6 +640,20 @@ class KafkaApis(val requestChannel: RequestChannel, request.header.clientId, topicPartition, status.error.exceptionName)) + + if (request.header.apiVersion >= 10) { + status.error match { + case Errors.NOT_LEADER_OR_FOLLOWER => + val leaderNode = getCurrentLeader(topicPartition, request.context.listenerName) + leaderNode.node.foreach { node => + nodeEndpoints.put(node.id(), node) + } + status.currentLeader + .setLeaderId(leaderNode.leaderId) + .setLeaderEpoch(leaderNode.leaderEpoch) + case _ => + } + } } } @@ -665,7 +697,7 @@ class KafkaApis(val requestChannel: RequestChannel, requestHelper.sendNoOpResponseExemptThrottle(request) } } else { - requestChannel.sendResponse(request, new ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs), None) + requestChannel.sendResponse(request, new ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs, nodeEndpoints.values.toList.asJava), None) } } @@ -843,6 +875,7 @@ class KafkaApis(val requestChannel: RequestChannel, .setRecords(unconvertedRecords) .setPreferredReadReplica(partitionData.preferredReadReplica) .setDivergingEpoch(partitionData.divergingEpoch) + .setCurrentLeader(partitionData.currentLeader()) } } } @@ -851,6 +884,7 @@ class KafkaApis(val requestChannel: RequestChannel, def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { val partitions = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] val reassigningPartitions = mutable.Set[TopicIdPartition]() + val nodeEndpoints = new mutable.HashMap[Int, Node] responsePartitionData.foreach { case (tp, data) => val abortedTransactions = data.abortedTransactions.orElse(null) val lastStableOffset: Long = data.lastStableOffset.orElse(FetchResponse.INVALID_LAST_STABLE_OFFSET) @@ -864,6 +898,21 @@ class KafkaApis(val requestChannel: RequestChannel, .setAbortedTransactions(abortedTransactions) .setRecords(data.records) .setPreferredReadReplica(data.preferredReadReplica.orElse(FetchResponse.INVALID_PREFERRED_REPLICA_ID)) + + if (versionId >= 16) { + data.error match { + case Errors.NOT_LEADER_OR_FOLLOWER | Errors.FENCED_LEADER_EPOCH => + val leaderNode = getCurrentLeader(tp.topicPartition(), request.context.listenerName) + leaderNode.node.foreach { node => + nodeEndpoints.put(node.id(), node) + } + partitionData.currentLeader() + .setLeaderId(leaderNode.leaderId) + .setLeaderEpoch(leaderNode.leaderEpoch) + case _ => + } + } + data.divergingEpoch.ifPresent(partitionData.setDivergingEpoch(_)) partitions.put(tp, partitionData) } @@ -887,7 +936,7 @@ class KafkaApis(val requestChannel: RequestChannel, // Prepare fetch response from converted data val response = - FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, unconvertedFetchResponse.sessionId, convertedData) + FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, unconvertedFetchResponse.sessionId, convertedData, nodeEndpoints.values.toList.asJava) // record the bytes out metrics only when the response is being sent response.data.responses.forEach { topicResponse => topicResponse.partitions.forEach { data => diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index 17abdb0472d2..41e67e61f5e5 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -24,9 +24,10 @@ import java.util.Arrays.asList import java.util.concurrent.{CompletableFuture, TimeUnit} import java.util.{Collections, Optional, OptionalInt, OptionalLong, Properties} import kafka.api.LeaderAndIsr -import kafka.cluster.Broker +import kafka.cluster.{Broker, Partition} import kafka.controller.{ControllerContext, KafkaController} import kafka.coordinator.transaction.{InitProducerIdResult, TransactionCoordinator} +import kafka.log.UnifiedLog import kafka.metrics.ClientMetricsTestUtils import kafka.network.{RequestChannel, RequestMetrics} import kafka.server.QuotaFactory.QuotaManagers @@ -98,7 +99,7 @@ import org.apache.kafka.coordinator.group.GroupCoordinator import org.apache.kafka.server.common.{Features, MetadataVersion} import org.apache.kafka.server.common.MetadataVersion.{IBP_0_10_2_IV0, IBP_2_2_IV1} import org.apache.kafka.server.util.MockTime -import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchParams, FetchPartitionData} +import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchParams, FetchPartitionData, LogConfig} class KafkaApisTest { private val requestChannel: RequestChannel = mock(classOf[RequestChannel]) @@ -2475,6 +2476,204 @@ class KafkaApisTest { } } + @Test + def testProduceResponseContainsNewLeaderOnNotLeaderOrFollower(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 2, numBrokers = 3) + + for (version <- 10 to ApiKeys.PRODUCE.latestVersion) { + + reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + val responseCallback: ArgumentCaptor[Map[TopicPartition, PartitionResponse] => Unit] = ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit]) + + val tp = new TopicPartition(topic, 0) + val partition = mock(classOf[Partition]) + val newLeaderId = 2 + val newLeaderEpoch = 5 + + val produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName(tp.topic).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(tp.partition) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("test".getBytes)))))) + .iterator)) + .setAcks(1.toShort) + .setTimeoutMs(5000)) + .build(version.toShort) + val request = buildRequest(produceRequest) + + when(replicaManager.appendRecords(anyLong, + anyShort, + ArgumentMatchers.eq(false), + ArgumentMatchers.eq(AppendOrigin.CLIENT), + any(), + responseCallback.capture(), + any(), + any(), + any(), + any(), + any()) + ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new PartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER)))) + + when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => Right(partition)) + when(partition.leaderReplicaIdOpt).thenAnswer(_ => Some(newLeaderId)) + when(partition.getLeaderEpoch).thenAnswer(_ => newLeaderEpoch) + + when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](), + any[Long])).thenReturn(0) + when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs( + any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0) + + createKafkaApis().handleProduceRequest(request, RequestLocal.withThreadConfinedCaching) + + val response = verifyNoThrottling[ProduceResponse](request) + + assertEquals(1, response.data.responses.size) + val topicProduceResponse = response.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size) + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, Errors.forCode(partitionProduceResponse.errorCode)) + assertEquals(newLeaderId, partitionProduceResponse.currentLeader.leaderId()) + assertEquals(newLeaderEpoch, partitionProduceResponse.currentLeader.leaderEpoch()) + assertEquals(1, response.data.nodeEndpoints.size) + val node = response.data.nodeEndpoints.asScala.head + assertEquals(2, node.nodeId) + assertEquals("broker2", node.host) + } + } + + @Test + def testProduceResponseReplicaManagerLookupErrorOnNotLeaderOrFollower(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 2, numBrokers = 3) + + for (version <- 10 to ApiKeys.PRODUCE.latestVersion) { + + reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + val responseCallback: ArgumentCaptor[Map[TopicPartition, PartitionResponse] => Unit] = ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit]) + + val tp = new TopicPartition(topic, 0) + + val produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName(tp.topic).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(tp.partition) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("test".getBytes)))))) + .iterator)) + .setAcks(1.toShort) + .setTimeoutMs(5000)) + .build(version.toShort) + val request = buildRequest(produceRequest) + + when(replicaManager.appendRecords(anyLong, + anyShort, + ArgumentMatchers.eq(false), + ArgumentMatchers.eq(AppendOrigin.CLIENT), + any(), + responseCallback.capture(), + any(), + any(), + any(), + any(), + any()) + ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new PartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER)))) + + when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => Left(Errors.UNKNOWN_TOPIC_OR_PARTITION)) + + when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](), + any[Long])).thenReturn(0) + when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs( + any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0) + + createKafkaApis().handleProduceRequest(request, RequestLocal.withThreadConfinedCaching) + + val response = verifyNoThrottling[ProduceResponse](request) + + assertEquals(1, response.data.responses.size) + val topicProduceResponse = response.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size) + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, Errors.forCode(partitionProduceResponse.errorCode)) + // LeaderId and epoch should be the same values inserted into the metadata cache + assertEquals(0, partitionProduceResponse.currentLeader.leaderId()) + assertEquals(1, partitionProduceResponse.currentLeader.leaderEpoch()) + assertEquals(1, response.data.nodeEndpoints.size) + val node = response.data.nodeEndpoints.asScala.head + assertEquals(0, node.nodeId) + assertEquals("broker0", node.host) + } + } + + @Test + def testProduceResponseMetadataLookupErrorOnNotLeaderOrFollower(): Unit = { + val topic = "topic" + metadataCache = mock(classOf[ZkMetadataCache]) + + for (version <- 10 to ApiKeys.PRODUCE.latestVersion) { + + reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + val responseCallback: ArgumentCaptor[Map[TopicPartition, PartitionResponse] => Unit] = ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit]) + + val tp = new TopicPartition(topic, 0) + + val produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName(tp.topic).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(tp.partition) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("test".getBytes)))))) + .iterator)) + .setAcks(1.toShort) + .setTimeoutMs(5000)) + .build(version.toShort) + val request = buildRequest(produceRequest) + + when(replicaManager.appendRecords(anyLong, + anyShort, + ArgumentMatchers.eq(false), + ArgumentMatchers.eq(AppendOrigin.CLIENT), + any(), + responseCallback.capture(), + any(), + any(), + any(), + any(), + any()) + ).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new PartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER)))) + + when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => Left(Errors.UNKNOWN_TOPIC_OR_PARTITION)) + + when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](), + any[Long])).thenReturn(0) + when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs( + any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0) + when(metadataCache.contains(tp)).thenAnswer(_ => true) + when(metadataCache.getPartitionInfo(tp.topic(), tp.partition())).thenAnswer(_ => Option.empty) + when(metadataCache.getAliveBrokerNode(any(), any())).thenReturn(Option.empty) + + createKafkaApis().handleProduceRequest(request, RequestLocal.withThreadConfinedCaching) + + val response = verifyNoThrottling[ProduceResponse](request) + + assertEquals(1, response.data.responses.size) + val topicProduceResponse = response.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size) + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, Errors.forCode(partitionProduceResponse.errorCode)) + assertEquals(-1, partitionProduceResponse.currentLeader.leaderId()) + assertEquals(-1, partitionProduceResponse.currentLeader.leaderEpoch()) + assertEquals(0, response.data.nodeEndpoints.size) + } + } + @Test def testTransactionalParametersSetCorrectly(): Unit = { val topic = "topic" @@ -3786,6 +3985,73 @@ class KafkaApisTest { assertEquals(MemoryRecords.EMPTY, FetchResponse.recordsOrFail(partitionData)) } + @Test + def testFetchResponseContainsNewLeaderOnNotLeaderOrFollower(): Unit = { + val topicId = Uuid.randomUuid() + val tidp = new TopicIdPartition(topicId, new TopicPartition("foo", 0)) + val tp = tidp.topicPartition + addTopicToMetadataCache(tp.topic, numPartitions = 1, numBrokers = 3, topicId) + + when(replicaManager.getLogConfig(ArgumentMatchers.eq(tp))).thenReturn(Some(LogConfig.fromProps( + Collections.emptyMap(), + new Properties() + ))) + + val partition = mock(classOf[Partition]) + val newLeaderId = 2 + val newLeaderEpoch = 5 + + when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => Right(partition)) + when(partition.leaderReplicaIdOpt).thenAnswer(_ => Some(newLeaderId)) + when(partition.getLeaderEpoch).thenAnswer(_ => newLeaderEpoch) + + when(replicaManager.fetchMessages( + any[FetchParams], + any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]], + any[ReplicaQuota], + any[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]() + )).thenAnswer(invocation => { + val callback = invocation.getArgument(3).asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] + callback(Seq(tidp -> new FetchPartitionData(Errors.NOT_LEADER_OR_FOLLOWER, UnifiedLog.UnknownOffset, UnifiedLog.UnknownOffset, MemoryRecords.EMPTY, + Optional.empty(), OptionalLong.empty(), Optional.empty(), OptionalInt.empty(), false))) + }) + + val fetchData = Map(tidp -> new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000, + Optional.empty())).asJava + val fetchDataBuilder = Map(tp -> new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000, + Optional.empty())).asJava + val fetchMetadata = new JFetchMetadata(0, 0) + val fetchContext = new FullFetchContext(time, new FetchSessionCache(1000, 100), + fetchMetadata, fetchData, false, false) + when(fetchManager.newContext( + any[Short], + any[JFetchMetadata], + any[Boolean], + any[util.Map[TopicIdPartition, FetchRequest.PartitionData]], + any[util.List[TopicIdPartition]], + any[util.Map[Uuid, String]])).thenReturn(fetchContext) + + when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs( + any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0) + + val fetchRequest = new FetchRequest.Builder(16, 16, -1, -1, 100, 0, fetchDataBuilder) + .build() + val request = buildRequest(fetchRequest) + + createKafkaApis().handleFetchRequest(request) + + val response = verifyNoThrottling[FetchResponse](request) + val responseData = response.responseData(metadataCache.topicIdsToNames(), 16) + + val partitionData = responseData.get(tp) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER.code, partitionData.errorCode) + assertEquals(newLeaderId, partitionData.currentLeader.leaderId()) + assertEquals(newLeaderEpoch, partitionData.currentLeader.leaderEpoch()) + val node = response.data.nodeEndpoints.asScala.head + assertEquals(2, node.nodeId) + assertEquals("broker2", node.host) + } + @ParameterizedTest @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP) def testHandleJoinGroupRequest(version: Short): Unit = {