Skip to content

Commit

Permalink
feat(rate-limit): Reactive implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
brasseld committed Sep 26, 2019
1 parent b6c2cc2 commit 7ec8ee5
Show file tree
Hide file tree
Showing 27 changed files with 1,144 additions and 640 deletions.
24 changes: 24 additions & 0 deletions gravitee-gateway-services-ratelimit/pom.xml
Expand Up @@ -79,6 +79,30 @@
</exclusion>
</exclusions>
</dependency>

<!-- Logging -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
</dependency>

<!-- Test dependencies -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Expand Up @@ -15,72 +15,167 @@
*/
package io.gravitee.gateway.services.ratelimit;

import io.gravitee.gateway.services.ratelimit.util.KeySplitter;
import io.gravitee.repository.ratelimit.api.RateLimitRepository;
import io.gravitee.repository.ratelimit.model.RateLimit;
import io.reactivex.*;
import io.reactivex.disposables.Disposable;
import io.reactivex.functions.BiFunction;
import io.reactivex.functions.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Iterator;
import java.util.concurrent.BlockingQueue;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

/**
* @author David BRASSELY (david.brassely at graviteesource.com)
* @author GraviteeSource Team
*/
public class AsyncRateLimitRepository implements RateLimitRepository {
public class AsyncRateLimitRepository implements RateLimitRepository<RateLimit> {

private RateLimitRepository localCacheRateLimitRepository;
private RateLimitRepository aggregateCacheRateLimitRepository;
private BlockingQueue<RateLimit> rateLimitsQueue;
private final Logger logger = LoggerFactory.getLogger(AsyncRateLimitRepository.class);

@Override
public RateLimit get(String rateLimitKey) {
// Get data from local cache
RateLimit cachedRateLimit = localCacheRateLimitRepository.get(rateLimitKey);
cachedRateLimit = (cachedRateLimit != null) ? cachedRateLimit : new RateLimit(rateLimitKey);

// Aggregate counter with data from aggregate cache
// Split the key to remove the gateway_id and get only the needed part
String [] parts = KeySplitter.split(rateLimitKey);
RateLimit aggregateRateLimit = aggregateCacheRateLimitRepository.get(parts[1]);
if (aggregateRateLimit != null) {
cachedRateLimit.setCounter(cachedRateLimit.getCounter() + aggregateRateLimit.getCounter());
AggregateRateLimit extendRateLimit = new AggregateRateLimit(cachedRateLimit);
extendRateLimit.setAggregateCounter(aggregateRateLimit.getCounter());
return extendRateLimit;
}
private LocalRateLimitRepository localCacheRateLimitRepository;
private RateLimitRepository<RateLimit> remoteCacheRateLimitRepository;

private final Set<String> keys = new HashSet<>();

private final BaseSchedulerProvider schedulerProvider;

// Get a map of lock for each rate-limit key to ensure data consistency during merge
private final Map<String, Semaphore> locks = new ConcurrentHashMap<>();

public AsyncRateLimitRepository(BaseSchedulerProvider schedulerProvider) {
this.schedulerProvider = schedulerProvider;
}

return cachedRateLimit;
public void initialize() {
Disposable subscribe = Observable
.timer(2000, TimeUnit.MILLISECONDS)
.repeat()
.subscribe(tick -> merge());

//TODO: dispose subscribe when service is stopped
}

@Override
public void save(RateLimit rateLimit) {
if (rateLimit instanceof AggregateRateLimit) {
AggregateRateLimit aggregateRateLimit = (AggregateRateLimit) rateLimit;
aggregateRateLimit.setCounter(aggregateRateLimit.getCounter() - aggregateRateLimit.getAggregateCounter());
aggregateRateLimit.setAggregateCounter(0L);
public Single<RateLimit> incrementAndGet(String key, long weight, Supplier<RateLimit> supplier) {
return
isLocked(key)
.subscribeOn(schedulerProvider.computation())
.andThen(
Single.defer(() -> localCacheRateLimitRepository
.incrementAndGet(key, weight, () -> new LocalRateLimit(supplier.get()))
.map(localRateLimit -> {
keys.add(localRateLimit.getKey());
return localRateLimit;
}))
);
}

void merge() {
if (!keys.isEmpty()) {
keys.forEach(new java.util.function.Consumer<String>() {
@Override
public void accept(String key) {
lock(key)
// By default, delay signal are done through the computation scheduler
// .observeOn(Schedulers.computation())
.andThen(localCacheRateLimitRepository.get(key)
// Remote rate is incremented by the local counter value
// If the remote does not contains existing value, use the local counter
.flatMapSingle((Function<LocalRateLimit, SingleSource<RateLimit>>) localRateLimit ->
remoteCacheRateLimitRepository.incrementAndGet(key, localRateLimit.getLocal(), () -> localRateLimit))
.zipWith(
localCacheRateLimitRepository.get(key).toSingle(),
new BiFunction<RateLimit, LocalRateLimit, LocalRateLimit>() {
@Override
public LocalRateLimit apply(RateLimit rateLimit, LocalRateLimit localRateLimit) throws Exception {
// Set the counter with the latest value from the repository
localRateLimit.setCounter(rateLimit.getCounter());

// Re-init the local counter
localRateLimit.setLocal(0L);

return localRateLimit;
}
})
// And save the new counter value into the local cache
.flatMap((Function<LocalRateLimit, SingleSource<LocalRateLimit>>) rateLimit ->
localCacheRateLimitRepository.save(rateLimit))
.doAfterTerminate(() -> unlock(key))
.doOnError(throwable -> logger.error("An unexpected error occurs while refreshing asynchronous rate-limit", throwable)))
.subscribe();
}
});

// Clear keys
keys.clear();
}
}

// Push data in local cache
localCacheRateLimitRepository.save(rateLimit);
private Completable isLocked(String key) {
return Completable.create(emitter -> {
Semaphore sem = locks.get(key);

// Push data in queue to store rate-limit asynchronously
rateLimitsQueue.offer(rateLimit);
if (sem == null) {
emitter.onComplete();
} else {
// Wait until unlocked
boolean acquired = false;
while(!acquired) {
acquired = sem.tryAcquire();
}

// Once we get access, release
sem.release();
}

emitter.onComplete();
});
}

private Completable lock(String key) {
return Completable.create(emitter -> {
Semaphore sem = locks.computeIfAbsent(key, key1 -> new Semaphore(1));

boolean acquired = false;
while(!acquired) {
acquired = sem.tryAcquire();
}

emitter.onComplete();
});
}

private void unlock(String key) {
Semaphore lock = this.locks.get(key);
if (lock != null) {
lock.release();
}
}

@Override
public Iterator<RateLimit> findAsyncAfter(long timestamp) {
public Maybe<RateLimit> get(String key) {
throw new IllegalStateException();
}

public void setLocalCacheRateLimitRepository(RateLimitRepository localCacheRateLimitRepository) {
this.localCacheRateLimitRepository = localCacheRateLimitRepository;
@Override
public Single<RateLimit> save(RateLimit rateLimit) {
throw new IllegalStateException();
}

public void setAggregateCacheRateLimitRepository(RateLimitRepository aggregateCacheRateLimitRepository) {
this.aggregateCacheRateLimitRepository = aggregateCacheRateLimitRepository;
public void setLocalCacheRateLimitRepository(LocalRateLimitRepository localCacheRateLimitRepository) {
this.localCacheRateLimitRepository = localCacheRateLimitRepository;
}

public void setRateLimitsQueue(BlockingQueue<RateLimit> rateLimitsQueue) {
this.rateLimitsQueue = rateLimitsQueue;
public void setRemoteCacheRateLimitRepository(RateLimitRepository<RateLimit> remoteCacheRateLimitRepository) {
this.remoteCacheRateLimitRepository = remoteCacheRateLimitRepository;
}
}
Expand Up @@ -16,7 +16,7 @@
package io.gravitee.gateway.services.ratelimit;

import io.gravitee.common.service.AbstractService;
import io.gravitee.common.util.BlockingArrayQueue;
import io.gravitee.gateway.services.ratelimit.rx.SchedulerProvider;
import io.gravitee.repository.ratelimit.api.RateLimitRepository;
import io.gravitee.repository.ratelimit.api.RateLimitService;
import io.gravitee.repository.ratelimit.model.RateLimit;
Expand All @@ -29,8 +29,6 @@
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ConfigurableApplicationContext;

import java.util.concurrent.*;

/**
* @author David BRASSELY (david.brassely at graviteesource.com)
* @author GraviteeSource Team
Expand All @@ -42,24 +40,10 @@ public class AsyncRateLimitService extends AbstractService {
@Value("${services.ratelimit.enabled:true}")
private boolean enabled;

@Value("${services.ratelimit.async.polling:10000}")
private int polling;

@Value("${services.ratelimit.async.queue:10000}")
private int queueCapacity;

@Autowired
@Qualifier("local")
private Cache localCache;

@Autowired
@Qualifier("aggregate")
private Cache aggregateCache;

private ScheduledExecutorService rateLimitPollerExecutor;

private ExecutorService rateLimitUpdaterExecutor;

@Override
protected void doStart() throws Exception {
super.doStart();
Expand All @@ -68,50 +52,26 @@ protected void doStart() throws Exception {
DefaultListableBeanFactory parentBeanFactory = (DefaultListableBeanFactory) ((ConfigurableApplicationContext) applicationContext.getParent()).getBeanFactory();

// Retrieve the current rate-limit repository implementation
RateLimitRepository rateLimitRepository = parentBeanFactory.getBean(RateLimitRepository.class);
RateLimitRepository<RateLimit> rateLimitRepository = parentBeanFactory.getBean(RateLimitRepository.class);
LOGGER.debug("Rate-limit repository implementation is {}", rateLimitRepository.getClass().getName());

if (enabled) {
// Prepare caches
RateLimitRepository aggregateCacheRateLimitRepository = new CachedRateLimitRepository(aggregateCache);
RateLimitRepository localCacheRateLimitRepository = new CachedRateLimitRepository(localCache);

// Prepare queue to flush data into the final repository implementation
BlockingQueue<RateLimit> rateLimitsQueue = new BlockingArrayQueue<>(queueCapacity);
// Prepare local cache
LocalRateLimitRepository localCacheRateLimitRepository = new LocalRateLimitRepository();

LOGGER.debug("Register rate-limit repository asynchronous implementation {}",
AsyncRateLimitRepository.class.getName());
AsyncRateLimitRepository asyncRateLimitRepository = new AsyncRateLimitRepository();
AsyncRateLimitRepository asyncRateLimitRepository = new AsyncRateLimitRepository(new SchedulerProvider());
beanFactory.autowireBean(asyncRateLimitRepository);
asyncRateLimitRepository.setLocalCacheRateLimitRepository(localCacheRateLimitRepository);
asyncRateLimitRepository.setAggregateCacheRateLimitRepository(aggregateCacheRateLimitRepository);
asyncRateLimitRepository.setRateLimitsQueue(rateLimitsQueue);
asyncRateLimitRepository.setRemoteCacheRateLimitRepository(rateLimitRepository);
asyncRateLimitRepository.initialize();

LOGGER.info("Register the rate-limit service bridge for synchronous and asynchronous mode");
DefaultRateLimitService rateLimitService = new DefaultRateLimitService();
rateLimitService.setRateLimitRepository(rateLimitRepository);
rateLimitService.setAsyncRateLimitRepository(asyncRateLimitRepository);
parentBeanFactory.registerSingleton(RateLimitService.class.getName(), rateLimitService);

// Prepare and start rate-limit poller
rateLimitPollerExecutor = Executors.newSingleThreadScheduledExecutor(r -> new Thread(r, "rate-limit-poller"));
RateLimitPoller rateLimitPoller = new RateLimitPoller();
beanFactory.autowireBean(rateLimitPoller);
rateLimitPoller.setRateLimitRepository(rateLimitRepository);
rateLimitPoller.setAggregateCacheRateLimitRepository(aggregateCacheRateLimitRepository);

LOGGER.info("Schedule rate-limit poller at fixed rate: {} {}", polling, TimeUnit.MILLISECONDS);
rateLimitPollerExecutor.scheduleAtFixedRate(
rateLimitPoller, 0L, polling, TimeUnit.MILLISECONDS);

// Prepare and start rate-limit updater
rateLimitUpdaterExecutor = Executors.newSingleThreadExecutor(r -> new Thread(r, "rate-limit-updater"));
RateLimitUpdater rateLimitUpdater = new RateLimitUpdater(rateLimitsQueue);
beanFactory.autowireBean(rateLimitUpdater);
rateLimitUpdater.setRateLimitRepository(rateLimitRepository);

LOGGER.info("Start rate-limit updater");
rateLimitUpdaterExecutor.submit(rateLimitUpdater);
} else {
// By disabling async and cached rate limiting, only the strict mode is allowed
LOGGER.info("Register the rate-limit service bridge for strict mode only");
Expand All @@ -126,22 +86,6 @@ protected void doStart() throws Exception {
protected void doStop() throws Exception {
if (enabled) {
super.doStop();

if (rateLimitPollerExecutor != null) {
try {
rateLimitPollerExecutor.shutdownNow();
} catch (Exception ex) {
LOGGER.error("Unexpected error when shutdown rate-limit poller", ex);
}
}

if (rateLimitUpdaterExecutor != null) {
try {
rateLimitUpdaterExecutor.shutdownNow();
} catch (Exception ex) {
LOGGER.error("Unexpected error when shutdown rate-limit updater", ex);
}
}
}
}

Expand Down
@@ -0,0 +1,25 @@
/**
* Copyright (C) 2015 The Gravitee team (http://gravitee.io)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.gravitee.gateway.services.ratelimit;

import io.reactivex.Scheduler;

public interface BaseSchedulerProvider {

Scheduler io();

Scheduler computation();
}

0 comments on commit 7ec8ee5

Please sign in to comment.