Skip to content

Commit

Permalink
Merge a17ad0c into aa48a48
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosbarbero committed Jun 6, 2020
2 parents aa48a48 + a17ad0c commit edf79d8
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@
package com.marcosbarbero.cloud.autoconfigure.zuul.ratelimit.config.repository;

import com.marcosbarbero.cloud.autoconfigure.zuul.ratelimit.config.Rate;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;

import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.Objects;

import static java.util.concurrent.TimeUnit.SECONDS;
Expand Down Expand Up @@ -63,17 +69,19 @@ private Long calcRemaining(Long limit, Duration refreshInterval, long usage, Str
rate.setReset(refreshInterval.toMillis());
Long current = 0L;
try {
Boolean present = redisTemplate.opsForValue().setIfAbsent(key, Long.toString(usage), refreshInterval.getSeconds(), SECONDS);
if (Boolean.FALSE.equals(present)) {
// Key already exists, increment
current = redisTemplate.opsForValue().increment(key, usage);
} else {
current = usage;
}
current = redisTemplate.execute(getScript(), Collections.singletonList(key), Long.toString(usage),
Long.toString(refreshInterval.getSeconds()));
} catch (RuntimeException e) {
String msg = "Failed retrieving rate for " + key + ", will return the current value";
rateLimiterErrorHandler.handleError(msg, e);
}
return Math.max(-1, limit - (current != null ? current : 0L));
return Math.max(-1, limit - (current != null ? current.intValue() : 0));
}

private RedisScript<Long> getScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setLocation(new ClassPathResource("/scripts/ratelimit.lua"));
redisScript.setResultType(Long.class);
return redisScript;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
local current = redis.call('incrby', KEYS[1], ARGV[1])

if tonumber(current) == tonumber(ARGV[1]) then
redis.call('expire', KEYS[1], ARGV[2])
end

return current
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
package com.marcosbarbero.cloud.autoconfigure.zuul.ratelimit.config.repository;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.matches;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.Maps;
import com.marcosbarbero.cloud.autoconfigure.zuul.ratelimit.config.properties.RateLimitProperties.Policy;
import java.time.Duration;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
Expand All @@ -25,6 +12,12 @@
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.ValueOperations;

import java.time.Duration;
import java.util.Map;

import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;

@SuppressWarnings("unchecked")
public class RedisRateLimiterTest extends BaseRateLimiterTest {

Expand All @@ -36,102 +29,76 @@ public class RedisRateLimiterTest extends BaseRateLimiterTest {
@BeforeEach
public void setUp() {
MockitoAnnotations.initMocks(this);
Map<String, BoundValueOperations<String, String>> map = Maps.newHashMap();
Map<String, Long> longMap = Maps.newHashMap();

when(this.redisTemplate.boundValueOps(any())).thenAnswer(invocation -> {
String key = invocation.getArgument(0);
BoundValueOperations<String, String> mock = map.computeIfAbsent(key, k -> Mockito.mock(BoundValueOperations.class));
when(mock.increment(anyLong())).thenAnswer(invocationOnMock -> {
long value = invocationOnMock.getArgument(0);
return longMap.compute(key, (k, v) -> ((v != null) ? v : 0L) + value);
});
return mock;
});
when(this.redisTemplate.opsForValue()).thenAnswer(invocation -> {
ValueOperations<String, String> mock = mock(ValueOperations.class);
when(mock.increment(any(), anyLong())).thenAnswer(invocationOnMock -> {
String key = invocationOnMock.getArgument(0);
long value = invocationOnMock.getArgument(1);
return longMap.compute(key, (k, v) -> ((v != null) ? v : 0L) + value);
});
return mock;
});
doReturn(1L, 2L)
.when(redisTemplate).execute(any(), anyList(), anyString(), anyString());

this.target = new RedisRateLimiter(this.rateLimiterErrorHandler, this.redisTemplate);
}

@Test
@Disabled
public void testConsumeOnlyQuota() {
// disabling in favor of integration tests
}

@Test
@Disabled
public void testConsume() {
// disabling in favor of integration tests
}

@Test
public void testConsumeRemainingLimitException() {
ValueOperations<String, String> ops = mock(ValueOperations.class);
when(ops.setIfAbsent(anyString(), anyString(), anyLong(), any())).thenReturn(false);
doReturn(ops).when(redisTemplate).opsForValue();
doThrow(new RuntimeException()).when(ops).increment(anyString(), anyLong());
doThrow(new RuntimeException()).when(redisTemplate).execute(any(), anyList(), anyString(), anyString());

Policy policy = new Policy();
policy.setLimit(100L);
target.consume(policy, "key", 0L);
verify(redisTemplate.opsForValue()).setIfAbsent(anyString(), anyString(), anyLong(), any());
verify(redisTemplate.opsForValue()).increment(anyString(), anyLong());
verify(rateLimiterErrorHandler).handleError(matches(".* key, .*"), any());
}

@Test
public void testConsumeRemainingQuotaLimitException() {
ValueOperations<String, String> ops = mock(ValueOperations.class);
when(ops.setIfAbsent(anyString(), anyString(), anyLong(), any())).thenReturn(false);
doReturn(ops).when(redisTemplate).opsForValue();
doThrow(new RuntimeException()).when(ops).increment(anyString(), anyLong());
doThrow(new RuntimeException()).when(redisTemplate).execute(any(), anyList(), anyString(), anyString());

Policy policy = new Policy();
policy.setQuota(Duration.ofSeconds(100));
target.consume(policy, "key", 0L);
verify(redisTemplate.opsForValue()).setIfAbsent(anyString(), anyString(), anyLong(), any());
verify(redisTemplate.opsForValue()).increment(anyString(), anyLong());
verify(rateLimiterErrorHandler).handleError(matches(".* key-quota, .*"), any());
}

@Test
public void testConsumeGetExpireException() {
ValueOperations<String, String> ops = mock(ValueOperations.class);
when(ops.setIfAbsent(anyString(), anyString(), anyLong(), any())).thenReturn(false);
doReturn(ops).when(redisTemplate).opsForValue();
doThrow(new RuntimeException()).when(ops).increment(anyString(), anyLong());
doThrow(new RuntimeException()).when(redisTemplate).execute(any(), anyList(), anyString(), anyString());

Policy policy = new Policy();
policy.setLimit(100L);
policy.setQuota(Duration.ofSeconds(50));
target.consume(policy, "key", 0L);
verify(redisTemplate.opsForValue(), times(2)).setIfAbsent(anyString(), anyString(), anyLong(), any());
verify(redisTemplate.opsForValue(), times(2)).increment(anyString(), anyLong());
verify(rateLimiterErrorHandler).handleError(matches(".* key, .*"), any());
verify(rateLimiterErrorHandler).handleError(matches(".* key-quota, .*"), any());
}

@Test
public void testConsumeExpireException() {
ValueOperations<String, String> ops = mock(ValueOperations.class);
doThrow(new RuntimeException()).when(ops).setIfAbsent(anyString(), anyString(), anyLong(), any());
when(ops.increment(anyString(), anyLong())).thenReturn(0L);
doReturn(ops).when(redisTemplate).opsForValue();
doThrow(new RuntimeException()).when(redisTemplate).execute(any(), anyList(), anyString(), anyString());

Policy policy = new Policy();
policy.setLimit(100L);
target.consume(policy, "key", 0L);
verify(redisTemplate.opsForValue()).setIfAbsent(anyString(), anyString(), anyLong(), any());
verify(redisTemplate.opsForValue(), never()).increment(any(), anyLong());
verify(rateLimiterErrorHandler).handleError(matches(".* key, .*"), any());
}

@Test
public void testConsumeSetKey() {
ValueOperations<String, String> ops = mock(ValueOperations.class);
when(ops.setIfAbsent(anyString(), anyString(), anyLong(), any())).thenReturn(true);
doReturn(ops).when(redisTemplate).opsForValue();
doReturn(1L, 2L)
.when(redisTemplate).execute(any(), anyList(), anyString(), anyString());

Policy policy = new Policy();
policy.setLimit(20L);
target.consume(policy, "key", 0L);
verify(redisTemplate.opsForValue()).setIfAbsent(anyString(), anyString(), anyLong(), any());
verify(redisTemplate.opsForValue(), never()).increment(any(), anyLong());

verify(redisTemplate).execute(any(), anyList(), anyString(), anyString());
verify(rateLimiterErrorHandler, never()).handleError(any(), any());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public void testNoRateLimitService() {
}

String exceeded = (String) this.context.get("rateLimitExceeded");
assertFalse(Boolean.valueOf(exceeded), "RateLimit not exceeded");
assertFalse(Boolean.parseBoolean(exceeded), "RateLimit not exceeded");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
package com.marcosbarbero.cloud.autoconfigure.zuul.ratelimit.filters.pre;

import static com.marcosbarbero.cloud.autoconfigure.zuul.ratelimit.support.RateLimitConstants.HEADER_REMAINING;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import com.marcosbarbero.cloud.autoconfigure.zuul.ratelimit.config.repository.RateLimiterErrorHandler;
import com.marcosbarbero.cloud.autoconfigure.zuul.ratelimit.config.repository.RedisRateLimiter;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.ValueOperations;

import java.util.Objects;
import java.util.concurrent.TimeUnit;

import static com.marcosbarbero.cloud.autoconfigure.zuul.ratelimit.support.RateLimitConstants.HEADER_REMAINING;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;

/**
* @author Marcos Barbero
Expand All @@ -41,21 +37,18 @@ public void setUp() {
@Override
@SuppressWarnings("unchecked")
public void testRateLimitExceedCapacity() throws Exception {
ValueOperations ops = mock(ValueOperations.class);
doReturn(ops).when(redisTemplate).opsForValue();
doReturn(3L)
.when(redisTemplate).execute(any(), anyList(), anyString(), anyString());

when(ops.increment(anyString(), anyLong())).thenReturn(3L);
super.testRateLimitExceedCapacity();
}

@Test
@Override
@SuppressWarnings("unchecked")
public void testRateLimit() throws Exception {
ValueOperations ops = mock(ValueOperations.class);
when(ops.increment(anyString(), anyLong())).thenReturn(1L);
doReturn(ops).when(redisTemplate).opsForValue();
when(ops.increment(anyString(), anyLong())).thenReturn(2L);
doReturn(1L, 2L)
.when(redisTemplate).execute(any(), anyList(), anyString(), anyString());


this.request.setRequestURI("/serviceA");
Expand All @@ -73,19 +66,18 @@ public void testRateLimit() throws Exception {

TimeUnit.SECONDS.sleep(2);

when(ops.increment(anyString(), anyLong())).thenReturn(1L);
doReturn(1L)
.when(redisTemplate).execute(any(), anyList(), anyString(), anyString());

this.filter.run();
remaining = this.response.getHeader(HEADER_REMAINING + key);
assertEquals("1", remaining);
}

@Test
public void testShouldReturnCorrectRateRemainingValue() {
String redisKey = "null:serviceA:10.0.0.100:anonymous:GET";
ValueOperations<String, String> ops = mock(ValueOperations.class);
when(redisTemplate.opsForValue()).thenReturn(ops);
when(ops.setIfAbsent(eq(redisKey), eq("1"), anyLong(), any())).thenReturn(true, false);
when(ops.increment(eq(redisKey), anyLong())).thenReturn(2L);
doReturn(1L, 2L)
.when(redisTemplate).execute(any(), anyList(), anyString(), anyString());

this.request.setRequestURI("/serviceA");
this.request.setRemoteAddr("10.0.0.100");
Expand Down

0 comments on commit edf79d8

Please sign in to comment.