Skip to content

Commit

Permalink
fix: Refactors rate limiter within LoggingHandler to be singleton (#6693
Browse files Browse the repository at this point in the history
)

* fix: Refactors rate limiter within LoggingHandler to be singleton
  • Loading branch information
AlanConfluent committed Dec 1, 2020
1 parent 14465a2 commit 72bc27e
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@

package io.confluent.ksql.api.server;

import static io.confluent.ksql.rest.server.KsqlRestConfig.KSQL_LOGGING_SERVER_RATE_LIMITED_REQUEST_PATHS_CONFIG;
import static io.confluent.ksql.rest.server.KsqlRestConfig.KSQL_LOGGING_SERVER_SKIPPED_RESPONSE_CODES_CONFIG;
import static java.util.Objects.requireNonNull;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.RateLimiter;
import io.confluent.ksql.api.auth.ApiUser;
Expand All @@ -31,11 +30,9 @@
import io.vertx.ext.web.impl.Utils;
import java.time.Clock;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -45,28 +42,27 @@ public class LoggingHandler implements Handler<RoutingContext> {
static final String HTTP_HEADER_USER_AGENT = "User-Agent";

private final Set<Integer> skipResponseCodes;
private final Map<String, Double> rateLimitedPaths;
private final Logger logger;
private final Clock clock;
private final Function<Double, RateLimiter> rateLimiterFactory;
private final LoggingRateLimiter loggingRateLimiter;

private final Map<String, RateLimiter> rateLimiters = new ConcurrentHashMap<>();

public LoggingHandler(final Server server) {
this(server, LOG, Clock.systemUTC(), RateLimiter::create);
public LoggingHandler(final Server server, final LoggingRateLimiter loggingRateLimiter) {
this(server, loggingRateLimiter, LOG, Clock.systemUTC());
}

@VisibleForTesting
LoggingHandler(
final Server server,
final LoggingRateLimiter loggingRateLimiter,
final Logger logger,
final Clock clock,
final Function<Double, RateLimiter> rateLimiterFactory) {
final Clock clock) {
requireNonNull(server);
this.loggingRateLimiter = requireNonNull(loggingRateLimiter);
this.skipResponseCodes = getSkipResponseCodes(server.getConfig());
this.rateLimitedPaths = getSkipRequestPaths(server.getConfig());
this.logger = logger;
this.clock = clock;
this.rateLimiterFactory = rateLimiterFactory;
}

@Override
Expand All @@ -76,13 +72,8 @@ public void handle(final RoutingContext routingContext) {
if (skipResponseCodes.contains(routingContext.response().getStatusCode())) {
return;
}
if (rateLimitedPaths.containsKey(routingContext.request().path())) {
final String path = routingContext.request().path();
final double rateLimit = rateLimitedPaths.get(path);
rateLimiters.computeIfAbsent(path, (k) -> rateLimiterFactory.apply(rateLimit));
if (!rateLimiters.get(path).tryAcquire()) {
return;
}
if (!loggingRateLimiter.shouldLog(routingContext.request().path())) {
return;
}
final long contentLength = routingContext.request().response().bytesWritten();
final HttpVersion version = routingContext.request().version();
Expand Down Expand Up @@ -134,14 +125,6 @@ private static Set<Integer> getSkipResponseCodes(final KsqlRestConfig config) {
.map(Integer::parseInt).collect(ImmutableSet.toImmutableSet());
}

private static Map<String, Double> getSkipRequestPaths(final KsqlRestConfig config) {
// Already validated as having double values
return config.getStringAsMap(KSQL_LOGGING_SERVER_RATE_LIMITED_REQUEST_PATHS_CONFIG)
.entrySet().stream()
.collect(ImmutableMap.toImmutableMap(Entry::getKey,
entry -> Double.parseDouble(entry.getValue())));
}

private void doLog(final int status, final String message) {
if (status >= 500) {
logger.error(message);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright 2020 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.api.server;

import static io.confluent.ksql.rest.server.KsqlRestConfig.KSQL_LOGGING_SERVER_RATE_LIMITED_REQUEST_PATHS_CONFIG;
import static java.util.Objects.requireNonNull;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.RateLimiter;
import io.confluent.ksql.rest.server.KsqlRestConfig;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

class LoggingRateLimiter {

private final Map<String, Double> rateLimitedPaths;
private final Function<Double, RateLimiter> rateLimiterFactory;

private final Map<String, RateLimiter> rateLimiters = new ConcurrentHashMap<>();

LoggingRateLimiter(final KsqlRestConfig ksqlRestConfig) {
this(ksqlRestConfig, RateLimiter::create);
}

@VisibleForTesting
LoggingRateLimiter(
final KsqlRestConfig ksqlRestConfig,
final Function<Double, RateLimiter> rateLimiterFactory) {
requireNonNull(ksqlRestConfig);
this.rateLimiterFactory = requireNonNull(rateLimiterFactory);
this.rateLimitedPaths = getRateLimitedRequestPaths(ksqlRestConfig);
}

public boolean shouldLog(final String path) {
if (rateLimitedPaths.containsKey(path)) {
final double rateLimit = rateLimitedPaths.get(path);
rateLimiters.computeIfAbsent(path, (k) -> rateLimiterFactory.apply(rateLimit));
return rateLimiters.get(path).tryAcquire();
}
return true;
}

private static Map<String, Double> getRateLimitedRequestPaths(final KsqlRestConfig config) {
// Already validated as having double values
return config.getStringAsMap(KSQL_LOGGING_SERVER_RATE_LIMITED_REQUEST_PATHS_CONFIG)
.entrySet().stream()
.collect(ImmutableMap.toImmutableMap(Entry::getKey,
entry -> Double.parseDouble(entry.getValue())));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ public synchronized void start() {
}
this.workerExecutor = vertx.createSharedWorkerExecutor("ksql-workers",
config.getInt(KsqlRestConfig.WORKER_POOL_SIZE));
final LoggingRateLimiter loggingRateLimiter = new LoggingRateLimiter(config);
configureTlsCertReload(config);

final List<URI> listenUris = parseListeners(config);
Expand All @@ -131,7 +132,7 @@ public synchronized void start() {
final ServerVerticle serverVerticle = new ServerVerticle(endpoints,
createHttpServerOptions(config, listener.getHost(), listener.getPort(),
listener.getScheme().equalsIgnoreCase("https"), isInternalListener.orElse(false)),
this, isInternalListener, pullQueryMetrics);
this, isInternalListener, pullQueryMetrics, loggingRateLimiter);
vertx.deployVerticle(serverVerticle, vcf);
final int index = i;
final CompletableFuture<String> deployFuture = vcf.thenApply(s -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,21 @@ public class ServerVerticle extends AbstractVerticle {
private HttpServer httpServer;
private final Optional<Boolean> isInternalListener;
private final Optional<PullQueryExecutorMetrics> pullQueryMetrics;
private final LoggingRateLimiter loggingRateLimiter;

public ServerVerticle(
final Endpoints endpoints,
final HttpServerOptions httpServerOptions,
final Server server,
final Optional<Boolean> isInternalListener,
final Optional<PullQueryExecutorMetrics> pullQueryMetrics) {
final Optional<PullQueryExecutorMetrics> pullQueryMetrics,
final LoggingRateLimiter loggingRateLimiter) {
this.endpoints = Objects.requireNonNull(endpoints);
this.httpServerOptions = Objects.requireNonNull(httpServerOptions);
this.server = Objects.requireNonNull(server);
this.isInternalListener = Objects.requireNonNull(isInternalListener);
this.pullQueryMetrics = Objects.requireNonNull(pullQueryMetrics);
this.loggingRateLimiter = Objects.requireNonNull(loggingRateLimiter);
}

@Override
Expand Down Expand Up @@ -112,7 +115,7 @@ int actualPort() {
private Router setupRouter() {
final Router router = Router.router(vertx);

router.route().handler(new LoggingHandler(server));
router.route().handler(new LoggingHandler(server, loggingRateLimiter));

KsqlCorsHandler.setupCorsHandler(server, router);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.RateLimiter;
import io.confluent.ksql.rest.server.KsqlRestConfig;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
Expand All @@ -37,6 +35,8 @@ public class LoggingHandlerTest {
@Mock
private Server server;
@Mock
private LoggingRateLimiter loggingRateLimiter;
@Mock
private Logger logger;
@Mock
private RoutingContext routingContext;
Expand All @@ -50,8 +50,6 @@ public class LoggingHandlerTest {
private SocketAddress socketAddress;
@Mock
private Clock clock;
@Mock
private RateLimiter rateLimiter;
@Captor
private ArgumentCaptor<String> logStringCaptor;
@Captor
Expand All @@ -68,18 +66,17 @@ public void setUp() {
when(request.response()).thenReturn(response);
when(request.remoteAddress()).thenReturn(socketAddress);
when(ksqlRestConfig.getList(any())).thenReturn(ImmutableList.of("401"));
when(ksqlRestConfig.getStringAsMap(any())).thenReturn(ImmutableMap.of("/query", "100"));
when(rateLimiter.tryAcquire()).thenReturn(true);
when(loggingRateLimiter.shouldLog("/query")).thenReturn(true);
when(clock.millis()).thenReturn(1699813434333L);
when(response.bytesWritten()).thenReturn(5678L);
when(request.uri()).thenReturn("/query");
when(request.path()).thenReturn("/query");
when(request.uri()).thenReturn("/query");
when(request.getHeader(HTTP_HEADER_USER_AGENT)).thenReturn("bot");
when(socketAddress.host()).thenReturn("123.111.222.333");
when(request.bytesRead()).thenReturn(3456L);
when(request.version()).thenReturn(HttpVersion.HTTP_1_1);
when(request.method()).thenReturn(HttpMethod.POST);
loggingHandler = new LoggingHandler(server, logger, clock, (rateLimit) -> rateLimiter);
loggingHandler = new LoggingHandler(server, loggingRateLimiter, logger, clock);
}

@Test
Expand Down Expand Up @@ -136,7 +133,7 @@ public void shouldSkipLog() {
public void shouldSkipRateLimited() {
// Given:
when(response.getStatusCode()).thenReturn(200);
when(rateLimiter.tryAcquire()).thenReturn(true, true, false, false);
when(loggingRateLimiter.shouldLog("/query")).thenReturn(true, true, false, false);

// When:
loggingHandler.handle(routingContext);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package io.confluent.ksql.api.server;


import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.RateLimiter;
import io.confluent.ksql.rest.server.KsqlRestConfig;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;

@RunWith(MockitoJUnitRunner.class)
public class LoggingRateLimiterTest {

private static final String PATH = "/query";

@Mock
private RateLimiter rateLimiter;
@Mock
private KsqlRestConfig ksqlRestConfig;

private LoggingRateLimiter loggingRateLimiter;


@Before
public void setUp() {
when(ksqlRestConfig.getStringAsMap(any())).thenReturn(ImmutableMap.of(PATH, "2"));
when(rateLimiter.tryAcquire()).thenReturn(true);
loggingRateLimiter = new LoggingRateLimiter(ksqlRestConfig, (rateLimit) -> rateLimiter);
}

@Test
public void shouldLog() {
// When:
assertThat(loggingRateLimiter.shouldLog(PATH), is(true));

// Then:
verify(rateLimiter).tryAcquire();
}

@Test
public void shouldSkipRateLimited() {
// Given:
when(rateLimiter.tryAcquire()).thenReturn(true, true, false, false);

// When:
assertThat(loggingRateLimiter.shouldLog(PATH), is(true));
assertThat(loggingRateLimiter.shouldLog(PATH), is(true));
assertThat(loggingRateLimiter.shouldLog(PATH), is(false));
assertThat(loggingRateLimiter.shouldLog(PATH), is(false));

// Then:
verify(rateLimiter, times(4)).tryAcquire();
}

@Test
public void shouldLog_notRateLimited() {
// When:
assertThat(loggingRateLimiter.shouldLog("/foo"), is(true));

// Then:
verify(rateLimiter, never()).tryAcquire();
}
}

0 comments on commit 72bc27e

Please sign in to comment.