Skip to content

Commit

Permalink
feat: add rate-limiting to ksql command topic (#8809)
Browse files Browse the repository at this point in the history
* feat: add rate-limiting to ksql command topic

* make rate limit quota set by config

* make rate limiting return an error

* test + review

* move rate limiting to commandstore

* move rate limiting to distributing executor

* move config to rest config

* rohan's review
  • Loading branch information
lct45 committed Mar 2, 2022
1 parent 1cb7846 commit b2f1540
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,12 @@ public class KsqlRestConfig extends AbstractConfig {
+ "returns a 421 mis-directed response. (NOTE: this check should not be enabled if "
+ "ksqlDB servers have mutual TLS enabled)";

public static final String KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG =
KSQL_CONFIG_PREFIX + "server.command.topic.rate.limit";
public static final double KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG_DEFAULT = Double.MAX_VALUE;
private static final String KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG_DEFAULT_DOC =
"Sets the number of statements that can be executed against the command topic per second";

private static final ConfigDef CONFIG_DEF;

static {
Expand Down Expand Up @@ -749,7 +755,13 @@ public class KsqlRestConfig extends AbstractConfig {
KSQL_SERVER_SNI_CHECK_ENABLE_DEFAULT,
Importance.LOW,
KSQL_SERVER_SNI_CHECK_ENABLE_DOC
);
).define(
KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG,
Type.DOUBLE,
KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG_DEFAULT,
Importance.LOW,
KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG_DEFAULT_DOC
);
}

public KsqlRestConfig(final Map<?, ?> props) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package io.confluent.ksql.rest.server.computation;

import com.google.common.util.concurrent.RateLimiter;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.metastore.MetaStore;
Expand All @@ -31,7 +32,9 @@
import io.confluent.ksql.rest.entity.CommandStatusEntity;
import io.confluent.ksql.rest.entity.KsqlWarning;
import io.confluent.ksql.rest.entity.WarningEntity;
import io.confluent.ksql.rest.server.KsqlRestConfig;
import io.confluent.ksql.rest.server.execution.StatementExecutorResponse;
import io.confluent.ksql.rest.server.resources.KsqlRestException;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.security.KsqlSecurityContext;
import io.confluent.ksql.services.ServiceContext;
Expand All @@ -46,6 +49,7 @@
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import org.apache.kafka.clients.producer.Producer;
Expand All @@ -60,7 +64,9 @@
* duration for the command to be executed remotely if configured with a
* {@code distributedCmdResponseTimeout}.
*/
// CHECKSTYLE_RULES.OFF: ClassDataAbstractionCoupling
public class DistributingExecutor {
// CHECKSTYLE_RULES.ON: ClassDataAbstractionCoupling
private final CommandQueue commandQueue;
private final Duration distributedCmdResponseTimeout;
private final BiFunction<KsqlExecutionContext, ServiceContext, Injector> injectorFactory;
Expand All @@ -70,6 +76,7 @@ public class DistributingExecutor {
private final ReservedInternalTopics internalTopics;
private final Errors errorHandler;
private final Supplier<String> commandRunnerWarning;
private final RateLimiter rateLimiter;

@SuppressFBWarnings(value = "EI_EXPOSE_REP2")
public DistributingExecutor(
Expand Down Expand Up @@ -98,6 +105,9 @@ public DistributingExecutor(
this.errorHandler = Objects.requireNonNull(errorHandler, "errorHandler");
this.commandRunnerWarning =
Objects.requireNonNull(commandRunnerWarning, "commandRunnerWarning");
final KsqlRestConfig restConfig = new KsqlRestConfig(ksqlConfig.originals());
this.rateLimiter =
RateLimiter.create(restConfig.getDouble(KsqlRestConfig.KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG));
}

// CHECKSTYLE_RULES.OFF: CyclomaticComplexity
Expand Down Expand Up @@ -196,6 +206,13 @@ public StatementExecutorResponse execute(
statement.getStatementText()), e);
}

if (!rateLimiter.tryAcquire(1, TimeUnit.SECONDS)) {
throw new KsqlRestException(
Errors.tooManyRequests(
"DDL/DML rate is crossing the configured rate limit of statements/second"
));
}

CommandId commandId = null;
try {
transactionalProducer.beginTransaction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyLong;
Expand All @@ -30,7 +28,6 @@
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -39,7 +36,6 @@

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.engine.KsqlEngine;
import io.confluent.ksql.metrics.MetricCollectors;
import io.confluent.ksql.rest.Errors;
import io.confluent.ksql.rest.server.resources.IncompatibleKsqlCommandVersionException;
import io.confluent.ksql.rest.server.state.ServerState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
Expand Down Expand Up @@ -59,8 +61,11 @@
import io.confluent.ksql.rest.entity.CommandStatus;
import io.confluent.ksql.rest.entity.CommandStatus.Status;
import io.confluent.ksql.rest.entity.CommandStatusEntity;
import io.confluent.ksql.rest.entity.KsqlErrorMessage;
import io.confluent.ksql.rest.entity.WarningEntity;
import io.confluent.ksql.rest.server.KsqlRestConfig;
import io.confluent.ksql.rest.server.execution.StatementExecutorResponse;
import io.confluent.ksql.rest.server.resources.KsqlRestException;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.security.KsqlAuthorizationValidator;
import io.confluent.ksql.security.KsqlSecurityContext;
Expand Down Expand Up @@ -496,4 +501,33 @@ CommonCreateConfigs.VALUE_FORMAT_PROPERTY, new StringLiteral("json")
assertThat("Should be present", response.getEntity().isPresent());
assertThat(((WarningEntity) response.getEntity().get()).getMessage(), containsString(""));
}

@Test
public void shouldThrowIfRateLimitHit() {
// Given:
final DistributingExecutor rateLimitedDistributor = new DistributingExecutor(
new KsqlConfig(ImmutableMap.of(KsqlRestConfig.KSQL_COMMAND_TOPIC_RATE_LIMIT_CONFIG, 0.5)),
queue,
DURATION_10_MS,
(ec, sc) -> InjectorChain.of(schemaInjector, topicInjector),
Optional.of(authorizationValidator),
validatedCommandFactory,
errorHandler,
commandRunnerWarning
);

// When:
rateLimitedDistributor.execute(CONFIGURED_STATEMENT, executionContext, securityContext);


// Then:
final KsqlRestException e = assertThrows(
KsqlRestException.class,
() -> rateLimitedDistributor.execute(CONFIGURED_STATEMENT, executionContext, securityContext)
);

assertEquals(e.getResponse().getStatus(), 429);
final KsqlErrorMessage errorMessage = (KsqlErrorMessage) e.getResponse().getEntity();
assertTrue(errorMessage.getMessage().contains("DDL/DML rate is crossing the configured rate limit of statements/second"));
}
}
10 changes: 10 additions & 0 deletions ksqldb-rest-model/src/main/java/io/confluent/ksql/rest/Errors.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static io.netty.handler.codec.http.HttpResponseStatus.NOT_FOUND;
import static io.netty.handler.codec.http.HttpResponseStatus.PRECONDITION_REQUIRED;
import static io.netty.handler.codec.http.HttpResponseStatus.SERVICE_UNAVAILABLE;
import static io.netty.handler.codec.http.HttpResponseStatus.TOO_MANY_REQUESTS;
import static io.netty.handler.codec.http.HttpResponseStatus.UNAUTHORIZED;

import io.confluent.ksql.rest.entity.KsqlEntityList;
Expand Down Expand Up @@ -72,6 +73,8 @@ public final class Errors {
public static final int ERROR_CODE_SERVER_ERROR =
toErrorCode(INTERNAL_SERVER_ERROR.code());

public static final int ERROR_CODE_TOO_MANY_REQUESTS = toErrorCode(TOO_MANY_REQUESTS.code());

private final ErrorMessages errorMessages;

public static int toStatusCode(final int errorCode) {
Expand Down Expand Up @@ -208,6 +211,13 @@ public static EndpointResponse serverNotReady(final KsqlErrorMessage error) {
.build();
}

public static EndpointResponse tooManyRequests(final String msg) {
return EndpointResponse.create()
.status(TOO_MANY_REQUESTS.code())
.entity(new KsqlErrorMessage(ERROR_CODE_TOO_MANY_REQUESTS, msg))
.build();
}

public Errors(final ErrorMessages errorMessages) {
this.errorMessages = Objects.requireNonNull(errorMessages, "errorMessages");
}
Expand Down

0 comments on commit b2f1540

Please sign in to comment.