diff --git a/kafka-rest/src/test/java/io/confluent/kafkarest/integration/ClusterTestHarness.java b/kafka-rest/src/test/java/io/confluent/kafkarest/integration/ClusterTestHarness.java index 230321d2c5..acb0fd1373 100644 --- a/kafka-rest/src/test/java/io/confluent/kafkarest/integration/ClusterTestHarness.java +++ b/kafka-rest/src/test/java/io/confluent/kafkarest/integration/ClusterTestHarness.java @@ -22,7 +22,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -import io.confluent.kafka.schemaregistry.avro.AvroCompatibilityLevel; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.confluent.kafka.schemaregistry.CompatibilityLevel; import io.confluent.kafka.schemaregistry.rest.SchemaRegistryConfig; import io.confluent.kafka.schemaregistry.rest.SchemaRegistryRestApplication; import io.confluent.kafka.serializers.KafkaAvroSerializer; @@ -35,9 +36,11 @@ import java.net.ServerSocket; import java.net.URI; import java.net.URISyntaxException; +import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -66,6 +69,10 @@ import org.apache.kafka.clients.admin.ListTopicsResult; import org.apache.kafka.clients.admin.NewPartitionReassignment; import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.producer.KafkaProducer; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.clients.producer.ProducerRecord; @@ -75,6 +82,7 @@ import org.apache.kafka.common.config.ConfigResource; import org.apache.kafka.common.security.auth.SecurityProtocol; import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; import org.eclipse.jetty.server.Server; import org.glassfish.jersey.apache.connector.ApacheConnectorProvider; import org.glassfish.jersey.client.ClientConfig; @@ -139,7 +147,7 @@ public static int choosePort() { protected String plaintextBrokerList = null; // Schema registry config - protected String schemaRegCompatibility = AvroCompatibilityLevel.NONE.name; + protected String schemaRegCompatibility = CompatibilityLevel.NONE.name; protected Properties schemaRegProperties = null; protected String schemaRegConnect = null; protected SchemaRegistryRestApplication schemaRegApp = null; @@ -201,12 +209,13 @@ private void setupMethod() throws Exception { if (withSchemaRegistry) { int schemaRegPort = choosePort(); schemaRegProperties.put( - SchemaRegistryConfig.PORT_CONFIG, ((Integer) schemaRegPort).toString()); - schemaRegProperties.put(SchemaRegistryConfig.KAFKASTORE_CONNECTION_URL_CONFIG, zkConnect); + SchemaRegistryConfig.LISTENERS_CONFIG, + String.format("http://127.0.0.1:%d", schemaRegPort)); schemaRegProperties.put( SchemaRegistryConfig.KAFKASTORE_TOPIC_CONFIG, SchemaRegistryConfig.DEFAULT_KAFKASTORE_TOPIC); - schemaRegProperties.put(SchemaRegistryConfig.COMPATIBILITY_CONFIG, schemaRegCompatibility); + schemaRegProperties.put( + SchemaRegistryConfig.SCHEMA_COMPATIBILITY_CONFIG, schemaRegCompatibility); String broker = SecurityProtocol.PLAINTEXT.name + "://" @@ -237,6 +246,8 @@ private void setupMethod() throws Exception { // Reduce the metadata fetch timeout so requests for topics that don't exist timeout much // faster than the default restProperties.put("producer." + ProducerConfig.MAX_BLOCK_MS_CONFIG, "5000"); + restProperties.put( + "producer." + ProducerConfig.MAX_REQUEST_SIZE_CONFIG, String.valueOf((2 << 20) * 10)); restConfig = new KafkaRestConfig(restProperties); @@ -340,6 +351,7 @@ protected Properties getBrokerProperties(int i) { (short) 1, false); props.setProperty("auto.create.topics.enable", "false"); + props.setProperty("message.max.bytes", String.valueOf((2 << 20) * 10)); // We *must* override this to use the port we allocated (Kafka currently allocates one port // that it always uses for ZK props.setProperty("zookeeper.connect", this.zkConnect); @@ -717,6 +729,63 @@ private void doProduce( }); } + protected final ConsumerRecord getMessage( + String topic, + int partition, + long offset, + Deserializer keyDeserializer, + Deserializer valueDeserializer) { + + Properties props = new Properties(); + props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList); + + KafkaConsumer consumer = new KafkaConsumer<>(props, keyDeserializer, valueDeserializer); + TopicPartition tp = new TopicPartition(topic, partition); + consumer.assign(Collections.singleton(tp)); + consumer.seek(tp, offset); + + ConsumerRecords records = consumer.poll(Duration.ofSeconds(60)); + consumer.close(); + + return records.isEmpty() ? null : records.records(tp).get(0); + } + + protected final ConsumerRecords getMessages( + String topic, + Deserializer keyDeserializer, + Deserializer valueDeserializer, + int messageCount) { + + List> accumulator = new ArrayList<>(messageCount); + int numMessages = 0; + + Properties props = new Properties(); + props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList); + + KafkaConsumer consumer = new KafkaConsumer<>(props, keyDeserializer, valueDeserializer); + TopicPartition tp = new TopicPartition(topic, 0); + consumer.assign(Collections.singleton(tp)); + consumer.seekToBeginning(Collections.singleton(tp)); + + ConsumerRecords records; + while (numMessages < messageCount) { + records = consumer.poll(Duration.ofSeconds(60)); + Iterator> it = records.iterator(); + while (it.hasNext() && (numMessages < messageCount)) { + ConsumerRecord rec = it.next(); + accumulator.add(rec); + numMessages++; + } + } + consumer.close(); + + return new ConsumerRecords<>(Collections.singletonMap(tp, accumulator)); + } + + protected ObjectMapper getObjectMapper() { + return restApp.getJsonMapper(); + } + protected Map> createAssignment( List replicaIds, int numReplicas) { Map> replicaAssignments = new HashMap<>(); diff --git a/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v2/SchemaProduceConsumeTest.java b/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v2/SchemaProduceConsumeTest.java index fe02b515ef..b9481a4c45 100644 --- a/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v2/SchemaProduceConsumeTest.java +++ b/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v2/SchemaProduceConsumeTest.java @@ -31,7 +31,7 @@ import io.confluent.kafkarest.entities.v2.SchemaConsumerRecord; import io.confluent.kafkarest.entities.v2.SchemaTopicProduceRequest; import io.confluent.kafkarest.entities.v2.SchemaTopicProduceRequest.SchemaTopicProduceRecord; -import io.confluent.kafkarest.testing.DefaultKafkaRestTestEnvironment; +import io.confluent.kafkarest.integration.ClusterTestHarness; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -44,12 +44,11 @@ import javax.ws.rs.core.Response.Status; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @Tag("IntegrationTest") -public abstract class SchemaProduceConsumeTest { +public abstract class SchemaProduceConsumeTest extends ClusterTestHarness { private static final String TOPIC = "topic-1"; @@ -57,9 +56,6 @@ public abstract class SchemaProduceConsumeTest { private static final Logger log = LoggerFactory.getLogger(SchemaProduceConsumeTest.class); - @RegisterExtension - public final DefaultKafkaRestTestEnvironment testEnv = new DefaultKafkaRestTestEnvironment(); - protected abstract EmbeddedFormat getFormat(); protected abstract String getContentType(); @@ -70,18 +66,16 @@ public abstract class SchemaProduceConsumeTest { protected abstract List getProduceRecords(); + public SchemaProduceConsumeTest() { + super(/* numBrokers= */ 1, /* withSchemaRegistry= */ true); + } + @Test public void produceThenConsume_returnsExactlyProduced() throws Exception { - testEnv - .kafkaCluster() - .createTopic(TOPIC, /* numPartitions= */ 1, /* replicationFactor= */ (short) 3); + createTopic(TOPIC, /* numPartitions= */ 1, /* replicationFactor= */ (short) 1); Response createConsumerInstanceResponse = - testEnv - .kafkaRest() - .target() - .path(String.format("/consumers/%s", CONSUMER_GROUP)) - .request() + request(String.format("/consumers/%s", CONSUMER_GROUP)) .post( Entity.entity( new CreateConsumerInstanceRequest( @@ -100,14 +94,10 @@ public void produceThenConsume_returnsExactlyProduced() throws Exception { createConsumerInstanceResponse.readEntity(CreateConsumerInstanceResponse.class); Response subscribeResponse = - testEnv - .kafkaRest() - .target() - .path( + request( String.format( "/consumers/%s/instances/%s/subscription", CONSUMER_GROUP, createConsumerInstance.getInstanceId())) - .request() .post( Entity.entity( new ConsumerSubscriptionRecord(singletonList(TOPIC), null), @@ -116,14 +106,10 @@ public void produceThenConsume_returnsExactlyProduced() throws Exception { assertEquals(Status.NO_CONTENT.getStatusCode(), subscribeResponse.getStatus()); // Needs to consume empty once before producing. - testEnv - .kafkaRest() - .target() - .path( + request( String.format( "/consumers/%s/instances/%s/records", CONSUMER_GROUP, createConsumerInstance.getInstanceId())) - .request() .accept(getContentType()) .get(); @@ -136,11 +122,7 @@ public void produceThenConsume_returnsExactlyProduced() throws Exception { null); Response genericResponse = - testEnv - .kafkaRest() - .target() - .path(String.format("/topics/%s", TOPIC)) - .request() + request(String.format("/topics/%s", TOPIC)) .post(Entity.entity(produceRequest, getContentType())); ProduceResponse produceResponse; @@ -162,14 +144,10 @@ public void produceThenConsume_returnsExactlyProduced() throws Exception { assertEquals(Status.OK, produceResponse.getRequestStatus()); Response readRecordsResponse = - testEnv - .kafkaRest() - .target() - .path( + request( String.format( "/consumers/%s/instances/%s/records", CONSUMER_GROUP, createConsumerInstance.getInstanceId())) - .request() .accept(getContentType()) .get(); diff --git a/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v3/ProduceActionNoSchemaIntegrationTest.java b/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v3/ProduceActionNoSchemaIntegrationTest.java new file mode 100644 index 0000000000..cfe777a1a1 --- /dev/null +++ b/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v3/ProduceActionNoSchemaIntegrationTest.java @@ -0,0 +1,1039 @@ +/* + * Copyright 2023 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.kafkarest.integration.v3; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.MappingIterator; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.BinaryNode; +import com.fasterxml.jackson.databind.node.IntNode; +import com.fasterxml.jackson.databind.node.NullNode; +import com.fasterxml.jackson.databind.node.TextNode; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import io.confluent.kafka.serializers.KafkaJsonDeserializer; +import io.confluent.kafkarest.entities.EmbeddedFormat; +import io.confluent.kafkarest.entities.v3.ProduceRequest; +import io.confluent.kafkarest.entities.v3.ProduceRequest.ProduceRequestData; +import io.confluent.kafkarest.entities.v3.ProduceRequest.ProduceRequestHeader; +import io.confluent.kafkarest.entities.v3.ProduceResponse; +import io.confluent.kafkarest.exceptions.v3.ErrorResponse; +import io.confluent.kafkarest.integration.ClusterTestHarness; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import javax.ws.rs.ProcessingException; +import javax.ws.rs.client.Entity; +import javax.ws.rs.core.GenericType; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.Response.Status; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ProduceActionNoSchemaIntegrationTest extends ClusterTestHarness { + + private static final String TOPIC_NAME = "topic-1"; + private static final int NUM_PARTITIONS = 3; + + public ProduceActionNoSchemaIntegrationTest() { + super(/* numBrokers= */ 1, /* withSchemaRegistry= */ false); + } + + @BeforeEach + @Override + public void setUp() throws Exception { + super.setUp(); + + createTopic(TOPIC_NAME, NUM_PARTITIONS, (short) 1); + } + + @Test + public void produceBinary() throws Exception { + String clusterId = getClusterId(); + ByteString key = ByteString.copyFromUtf8("foo"); + ByteString value = ByteString.copyFromUtf8("bar"); + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(key.toByteArray())) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(value.toByteArray())) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + assertEquals(key, ByteString.copyFrom(produced.key())); + assertEquals(value, ByteString.copyFrom(produced.value())); + } + + @Test + public void produceBinaryWithNullData() throws Exception { + String clusterId = getClusterId(); + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(NullNode.getInstance()) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(NullNode.getInstance()) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + assertNull(produced.key()); + assertNull(produced.value()); + } + + @Test + public void produceBinaryWithInvalidData_throwsBadRequest() throws Exception { + String clusterId = getClusterId(); + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(IntNode.valueOf(1)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(TextNode.valueOf("fooba")) // invalid base64 string + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ErrorResponse actual = response.readEntity(ErrorResponse.class); + assertEquals(400, actual.getErrorCode()); + } + + @Test + public void produceString() throws Exception { + String clusterId = getClusterId(); + String key = "foo"; + String value = "bar"; + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(key)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(value)) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new StringDeserializer(), + new StringDeserializer()); + assertEquals(key, produced.key()); + assertEquals(value, produced.value()); + } + + @Test + public void produceStringWithEmptyData() throws Exception { + String clusterId = getClusterId(); + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf("")) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf("")) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new StringDeserializer(), + new StringDeserializer()); + assertTrue(produced.key().isEmpty()); + assertTrue(produced.value().isEmpty()); + } + + @Test + public void produceStringWithNullData() throws Exception { + String clusterId = getClusterId(); + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(NullNode.getInstance()) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(NullNode.getInstance()) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new StringDeserializer(), + new StringDeserializer()); + assertNull(produced.key()); + assertNull(produced.value()); + } + + @Test + public void produceWithInvalidData_throwsBadRequest() throws Exception { + String clusterId = getClusterId(); + String request = "{ \"records\": {\"subject\": \"foobar\" } }"; + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + ErrorResponse actual = response.readEntity(ErrorResponse.class); + assertEquals(400, actual.getErrorCode()); + assertEquals( + "Unrecognized field \"records\" " + + "(class io.confluent.kafkarest.entities.v3.AutoValue_ProduceRequest$Builder), " + + "not marked as ignorable (6 known properties: \"value\", \"originalSize\", " + + "\"partitionId\", \"headers\", \"key\", \"timestamp\"])", + actual.getMessage()); + } + + @Test + public void produceJson() throws Exception { + String clusterId = getClusterId(); + String key = "foo"; + String value = "bar"; + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(TextNode.valueOf(key)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(TextNode.valueOf(value)) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + KafkaJsonDeserializer deserializer = new KafkaJsonDeserializer<>(); + deserializer.configure(emptyMap(), /* isKey= */ false); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, actual.getPartitionId(), actual.getOffset(), deserializer, deserializer); + assertEquals(key, produced.key()); + assertEquals(value, produced.value()); + } + + @Test + public void produceJsonWithNullData() throws Exception { + String clusterId = getClusterId(); + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(NullNode.getInstance()) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(NullNode.getInstance()) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + KafkaJsonDeserializer deserializer = new KafkaJsonDeserializer<>(); + deserializer.configure(emptyMap(), /* isKey= */ false); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, actual.getPartitionId(), actual.getOffset(), deserializer, deserializer); + assertNull(produced.key()); + assertNull(produced.value()); + } + + @Test + public void produceBinaryWithPartitionId() throws Exception { + String clusterId = getClusterId(); + int partitionId = 1; + ByteString key = ByteString.copyFromUtf8("foo"); + ByteString value = ByteString.copyFromUtf8("bar"); + ProduceRequest request = + ProduceRequest.builder() + .setPartitionId(partitionId) + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(key.toByteArray())) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(value.toByteArray())) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + partitionId, + actual.getOffset(), + new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + assertEquals(key, ByteString.copyFrom(produced.key())); + assertEquals(value, ByteString.copyFrom(produced.value())); + } + + @Test + public void produceBinaryWithTimestamp() throws Exception { + String clusterId = getClusterId(); + Instant timestamp = Instant.ofEpochMilli(1000); + ByteString key = ByteString.copyFromUtf8("foo"); + ByteString value = ByteString.copyFromUtf8("bar"); + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(key.toByteArray())) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(value.toByteArray())) + .build()) + .setTimestamp(timestamp) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + assertEquals(key, ByteString.copyFrom(produced.key())); + assertEquals(value, ByteString.copyFrom(produced.value())); + assertEquals(timestamp, Instant.ofEpochMilli(produced.timestamp())); + } + + @Test + public void produceBinaryWithHeaders() throws Exception { + String clusterId = getClusterId(); + ByteString key = ByteString.copyFromUtf8("foo"); + ByteString value = ByteString.copyFromUtf8("bar"); + ProduceRequest request = + ProduceRequest.builder() + .setHeaders( + Arrays.asList( + ProduceRequestHeader.create("header-1", ByteString.copyFromUtf8("value-1")), + ProduceRequestHeader.create("header-1", ByteString.copyFromUtf8("value-2")), + ProduceRequestHeader.create("header-2", ByteString.copyFromUtf8("value-3")))) + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(key.toByteArray())) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(value.toByteArray())) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + assertEquals(key, ByteString.copyFrom(produced.key())); + assertEquals(value, ByteString.copyFrom(produced.value())); + assertEquals( + Arrays.asList( + new RecordHeader("header-1", ByteString.copyFromUtf8("value-1").toByteArray()), + new RecordHeader("header-1", ByteString.copyFromUtf8("value-2").toByteArray())), + ImmutableList.copyOf(produced.headers().headers("header-1"))); + assertEquals( + singletonList( + new RecordHeader("header-2", ByteString.copyFromUtf8("value-3").toByteArray())), + ImmutableList.copyOf(produced.headers().headers("header-2"))); + } + + @Test + public void produceBinaryKeyOnly() throws Exception { + String clusterId = getClusterId(); + ByteString key = ByteString.copyFromUtf8("foo"); + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(key.toByteArray())) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + assertEquals(key, ByteString.copyFrom(produced.key())); + assertNull(produced.value()); + } + + @Test + public void produceBinaryValueOnly() throws Exception { + String clusterId = getClusterId(); + ByteString value = ByteString.copyFromUtf8("bar"); + ProduceRequest request = + ProduceRequest.builder() + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(value.toByteArray())) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + assertNull(produced.key()); + assertEquals(value, ByteString.copyFrom(produced.value())); + } + + @Test + public void produceStringWithPartitionId() throws Exception { + String clusterId = getClusterId(); + int partitionId = 1; + String key = "foo"; + String value = "bar"; + ProduceRequest request = + ProduceRequest.builder() + .setPartitionId(partitionId) + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(key)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(value)) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + partitionId, + actual.getOffset(), + new StringDeserializer(), + new StringDeserializer()); + assertEquals(key, produced.key()); + assertEquals(value, produced.value()); + } + + @Test + public void produceStringWithTimestamp() throws Exception { + String clusterId = getClusterId(); + Instant timestamp = Instant.ofEpochMilli(1000); + String key = "foo"; + String value = "bar"; + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(key)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(value)) + .build()) + .setTimestamp(timestamp) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new StringDeserializer(), + new StringDeserializer()); + assertEquals(key, produced.key()); + assertEquals(value, produced.value()); + assertEquals(timestamp, Instant.ofEpochMilli(produced.timestamp())); + } + + @Test + public void produceStringWithHeaders() throws Exception { + String clusterId = getClusterId(); + String key = "foo"; + String value = "bar"; + ProduceRequest request = + ProduceRequest.builder() + .setHeaders( + Arrays.asList( + ProduceRequestHeader.create("header-1", ByteString.copyFromUtf8("value-1")), + ProduceRequestHeader.create("header-1", ByteString.copyFromUtf8("value-2")), + ProduceRequestHeader.create("header-2", ByteString.copyFromUtf8("value-3")))) + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(key)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(value)) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new StringDeserializer(), + new StringDeserializer()); + assertEquals(key, produced.key()); + assertEquals(value, produced.value()); + assertEquals( + Arrays.asList( + new RecordHeader("header-1", ByteString.copyFromUtf8("value-1").toByteArray()), + new RecordHeader("header-1", ByteString.copyFromUtf8("value-2").toByteArray())), + ImmutableList.copyOf(produced.headers().headers("header-1"))); + assertEquals( + singletonList( + new RecordHeader("header-2", ByteString.copyFromUtf8("value-3").toByteArray())), + ImmutableList.copyOf(produced.headers().headers("header-2"))); + } + + @Test + public void produceStringKeyOnly() throws Exception { + String clusterId = getClusterId(); + String key = "foo"; + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(key)) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new StringDeserializer(), + new ByteArrayDeserializer()); + assertEquals(key, produced.key()); + assertNull(produced.value()); + } + + @Test + public void produceStringValueOnly() throws Exception { + String clusterId = getClusterId(); + String value = "bar"; + ProduceRequest request = + ProduceRequest.builder() + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf(value)) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new StringDeserializer(), + new StringDeserializer()); + assertNull(produced.key()); + assertEquals(value, produced.value()); + } + + @Test + public void produceNothing() throws Exception { + String clusterId = getClusterId(); + ProduceRequest request = ProduceRequest.builder().setOriginalSize(0L).build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + assertNull(produced.key()); + assertNull(produced.value()); + } + + @Test + public void produceJsonBatch() throws Exception { + String clusterId = getClusterId(); + ArrayList requests = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + requests.add( + ProduceRequest.builder() + .setPartitionId(0) + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(TextNode.valueOf("key-" + i)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(TextNode.valueOf("value-" + i)) + .build()) + .setOriginalSize(0L) + .build()); + } + + StringBuilder batch = new StringBuilder(); + ObjectMapper objectMapper = getObjectMapper(); + for (ProduceRequest produceRequest : requests) { + batch.append(objectMapper.writeValueAsString(produceRequest)); + } + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(batch.toString(), MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + List actual = readProduceResponses(response); + KafkaJsonDeserializer deserializer = new KafkaJsonDeserializer<>(); + deserializer.configure(emptyMap(), /* isKey= */ false); + + ConsumerRecords producedRecords = + getMessages(TOPIC_NAME, deserializer, deserializer, 100); + + Iterator> it = producedRecords.iterator(); + assertEquals(100, producedRecords.count()); + + for (int i = 0; i < 100; i++) { + ConsumerRecord record = it.next(); + assertEquals(actual.get(i).getPartitionId(), record.partition()); + assertEquals(actual.get(i).getOffset(), record.offset()); + assertEquals( + requests + .get(i) + .getKey() + .map(ProduceRequestData::getData) + .map(JsonNode::asText) + .orElse(null), + record.key()); + assertEquals( + requests + .get(i) + .getValue() + .map(ProduceRequestData::getData) + .map(JsonNode::asText) + .orElse(null), + record.value()); + } + } + + @Test + public void produceStringBatch() throws Exception { + String clusterId = getClusterId(); + ArrayList requests = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + requests.add( + ProduceRequest.builder() + .setPartitionId(0) + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf("key-" + i)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.STRING) + .setData(TextNode.valueOf("value-" + i)) + .build()) + .setOriginalSize(0L) + .build()); + } + + StringBuilder batch = new StringBuilder(); + ObjectMapper objectMapper = getObjectMapper(); + for (ProduceRequest produceRequest : requests) { + batch.append(objectMapper.writeValueAsString(produceRequest)); + } + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(batch.toString(), MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + List actual = readProduceResponses(response); + StringDeserializer deserializer = new StringDeserializer(); + + ConsumerRecords producedRecords = + getMessages(TOPIC_NAME, deserializer, deserializer, 100); + + Iterator> it = producedRecords.iterator(); + assertEquals(100, producedRecords.count()); + + for (int i = 0; i < 100; i++) { + ConsumerRecord record = it.next(); + assertEquals(actual.get(i).getPartitionId(), record.partition()); + assertEquals(actual.get(i).getOffset(), record.offset()); + assertEquals( + requests + .get(i) + .getKey() + .map(ProduceRequestData::getData) + .map(JsonNode::textValue) + .orElse(null), + record.key()); + assertEquals( + requests + .get(i) + .getValue() + .map(ProduceRequestData::getData) + .map(JsonNode::textValue) + .orElse(null), + record.value()); + } + } + + @Test + public void produceBinaryBatchWithInvalidData_throwsMultipleBadRequests() throws Exception { + String clusterId = getClusterId(); + ArrayList requests = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + requests.add( + ProduceRequest.builder() + .setPartitionId(0) + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(IntNode.valueOf(2 * i)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(IntNode.valueOf(2 * i + 1)) + .build()) + .setOriginalSize(0L) + .build()); + } + + StringBuilder batch = new StringBuilder(); + ObjectMapper objectMapper = getObjectMapper(); + for (ProduceRequest produceRequest : requests) { + batch.append(objectMapper.writeValueAsString(produceRequest)); + } + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(batch.toString(), MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + List actual = readErrorResponses(response); + for (int i = 0; i < 100; i++) { + assertEquals(400, actual.get(i).getErrorCode()); + } + } + + @Test + public void produceBinaryWithLargerSizeMessage() throws Exception { + String clusterId = getClusterId(); + ByteString key = ByteString.copyFromUtf8("foo"); + // Kafka server and producer is configured to accept messages upto 20971520 Bytes (20MB) but + // KafkaProducer calculates produced bytes including key, value, headers size and additional + // record overhead bytes hence producing message of 20971420 bytes. + int valueSize = ((2 << 20) * 10) - 100; + byte[] value = generateBinaryData(valueSize); + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(key.toByteArray())) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.BINARY) + .setData(BinaryNode.valueOf(value)) + .build()) + .setOriginalSize(0L) + .build(); + + Response response = + request("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + ProduceResponse actual = readProduceResponse(response); + assertTrue(actual.getValue().isPresent()); + assertEquals(valueSize, actual.getValue().get().getSize()); + + ConsumerRecord produced = + getMessage( + TOPIC_NAME, + actual.getPartitionId(), + actual.getOffset(), + new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + assertEquals(key, ByteString.copyFrom(produced.key())); + assertEquals(valueSize, produced.serializedValueSize()); + assertEquals(Arrays.toString(value), Arrays.toString(produced.value())); + } + + private static ProduceResponse readProduceResponse(Response response) { + response.bufferEntity(); + try { + return response.readEntity(ProduceResponse.class); + } catch (ProcessingException e) { + throw new RuntimeException(response.readEntity(ErrorResponse.class).toString(), e); + } + } + + private static ImmutableList readProduceResponses(Response response) { + return ImmutableList.copyOf( + response.readEntity(new GenericType>() {})); + } + + private static ImmutableList readErrorResponses(Response response) { + return ImmutableList.copyOf( + response.readEntity(new GenericType>() {})); + } + + private static byte[] generateBinaryData(int messageSize) { + byte[] data = new byte[messageSize]; + Arrays.fill(data, (byte) 1); + return data; + } +} diff --git a/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v3/ProduceActionRateLimitIntegrationTest.java b/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v3/ProduceActionRateLimitIntegrationTest.java new file mode 100644 index 0000000000..42cb2182ad --- /dev/null +++ b/kafka-rest/src/test/java/io/confluent/kafkarest/integration/v3/ProduceActionRateLimitIntegrationTest.java @@ -0,0 +1,212 @@ +/* + * Copyright 2023 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.kafkarest.integration.v3; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.fasterxml.jackson.databind.MappingIterator; +import com.fasterxml.jackson.databind.node.TextNode; +import com.google.common.collect.ImmutableList; +import io.confluent.kafkarest.KafkaRestConfig; +import io.confluent.kafkarest.entities.EmbeddedFormat; +import io.confluent.kafkarest.entities.v3.ProduceRequest; +import io.confluent.kafkarest.entities.v3.ProduceRequest.ProduceRequestData; +import io.confluent.kafkarest.exceptions.v3.ErrorResponse; +import io.confluent.kafkarest.testing.DefaultKafkaRestTestEnvironment; +import java.util.List; +import java.util.Properties; +import javax.ws.rs.client.Entity; +import javax.ws.rs.core.GenericType; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.Response.Status; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.extension.RegisterExtension; + +@Tag("IntegrationTest") +public class ProduceActionRateLimitIntegrationTest { + + private static final String TOPIC_NAME = "topic-1"; + + @RegisterExtension + public final DefaultKafkaRestTestEnvironment testEnv = new DefaultKafkaRestTestEnvironment(false); + + @BeforeEach + public void setUp(TestInfo testInfo) throws Exception { + Properties restConfigs = new Properties(); + // Adding custom KafkaRestConfigs for individual test-cases/test-methods below. + if (testInfo.getDisplayName().contains("CallerIsRateLimited")) { + restConfigs.put(KafkaRestConfig.RATE_LIMIT_ENABLE_CONFIG, "true"); + restConfigs.put(KafkaRestConfig.PRODUCE_RATE_LIMIT_ENABLED, "true"); + restConfigs.put(KafkaRestConfig.RATE_LIMIT_BACKEND_CONFIG, "resilience4j"); + // The happy-path testing, i.e. rest calls below threshold succeed are already covered by the + // other existing tests. The 4 tests below, 1 per rate-limit config, set a very low rate-limit + // of "1", to deterministically make sure limits apply and rest-calls see 429s. + if (testInfo + .getDisplayName() + .contains("test_whenGlobalByteLimitReached_thenCallerIsRateLimited")) { + + restConfigs.put(KafkaRestConfig.PRODUCE_MAX_BYTES_GLOBAL_PER_SECOND, "1"); + } + if (testInfo + .getDisplayName() + .contains("test_whenClusterByteLimitReached_thenCallerIsRateLimited")) { + + restConfigs.put(KafkaRestConfig.PRODUCE_MAX_BYTES_PER_SECOND, "1"); + } + if (testInfo + .getDisplayName() + .contains("test_whenGlobalRequestCountLimitReached_thenCallerIsRateLimited")) { + + restConfigs.put(KafkaRestConfig.PRODUCE_MAX_REQUESTS_GLOBAL_PER_SECOND, "1"); + } + if (testInfo + .getDisplayName() + .contains("test_whenClusterRequestCountLimitReached_thenCallerIsRateLimited")) { + + restConfigs.put(KafkaRestConfig.PRODUCE_MAX_REQUESTS_PER_SECOND, "1"); + } + } + testEnv.kafkaRest().startApp(restConfigs); + + testEnv.kafkaCluster().createTopic(TOPIC_NAME, 3, (short) 1); + } + + @AfterEach + public void tearDown() { + testEnv.kafkaRest().closeApp(); + } + + private void doByteLimitReachedTest() throws Exception { + String clusterId = testEnv.kafkaCluster().getClusterId(); + String key = "foo"; + String value = "bar"; + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(TextNode.valueOf(key)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(TextNode.valueOf(value)) + .build()) + // 0 value here is meaningless and only set as originalSize is mandatory for AutoValue. + // Value set here is ignored any-ways, as "true" originalSize is calculated & set, + // when the JSON request is de-serialized into a ProduceRecord object on the + // server-side. + .setOriginalSize(0L) + .build(); + + Response response = + testEnv + .kafkaRest() + .target() + .path("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .request() + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response.getStatus()); + + List actual = readErrorResponses(response); + assertEquals(actual.size(), 1); + // Check request was rate-limited, so return http error-code is 429. + // NOTE - Byte rate-limit is set as 1 in setup() making sure 1st request itself fails. + assertEquals(actual.get(0).getErrorCode(), 429); + } + + @Test + @DisplayName("test_whenGlobalByteLimitReached_thenCallerIsRateLimited") + public void test_whenGlobalByteLimitReached_thenCallerIsRateLimited() throws Exception { + doByteLimitReachedTest(); + } + + @Test + @DisplayName("test_whenClusterByteLimitReached_thenCallerIsRateLimited") + public void test_whenClusterByteLimitReached_thenCallerIsRateLimited() throws Exception { + doByteLimitReachedTest(); + } + + private void doCountLimitTest() throws Exception { + String clusterId = testEnv.kafkaCluster().getClusterId(); + String key = "foo"; + String value = "bar"; + ProduceRequest request = + ProduceRequest.builder() + .setKey( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(TextNode.valueOf(key)) + .build()) + .setValue( + ProduceRequestData.builder() + .setFormat(EmbeddedFormat.JSON) + .setData(TextNode.valueOf(value)) + .build()) + .setOriginalSize(0L) + .build(); + + Response response1 = + testEnv + .kafkaRest() + .target() + .path("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .request() + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + Response response2 = + testEnv + .kafkaRest() + .target() + .path("/v3/clusters/" + clusterId + "/topics/" + TOPIC_NAME + "/records") + .request() + .accept(MediaType.APPLICATION_JSON) + .post(Entity.entity(request, MediaType.APPLICATION_JSON)); + assertEquals(Status.OK.getStatusCode(), response1.getStatus()); + + assertEquals(Status.OK.getStatusCode(), response2.getStatus()); + List actual = readErrorResponses(response2); + assertEquals(actual.size(), 1); + // Check request was rate-limited, so return http error-code is 429. + // NOTE - Count rate-limit is set as 1 in setup() making sure 2nd request fails + // deterministically. + assertEquals(actual.get(0).getErrorCode(), 429); + } + + @Test + @DisplayName("test_whenGlobalRequestCountLimitReached_thenCallerIsRateLimited") + public void test_whenGlobalRequestCountLimitReached_thenCallerIsRateLimited() throws Exception { + doCountLimitTest(); + } + + @Test + @DisplayName("test_whenClusterRequestCountLimitReached_thenCallerIsRateLimited") + public void test_whenClusterRequestCountLimitReached_thenCallerIsRateLimited() throws Exception { + doCountLimitTest(); + } + + private static ImmutableList readErrorResponses(Response response) { + return ImmutableList.copyOf( + response.readEntity(new GenericType>() {})); + } +}