Skip to content

Commit

Permalink
Merge e773609 into 6532a8e
Browse files Browse the repository at this point in the history
  • Loading branch information
pkwarren committed Jun 30, 2016
2 parents 6532a8e + e773609 commit 18425f7
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 14 deletions.
Expand Up @@ -4,14 +4,18 @@
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import com.google.common.base.Predicate;
import com.google.common.cache.Cache;
import com.google.common.base.Throwables;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheBuilderSpec;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.CacheStats;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.UncheckedExecutionException;

import java.security.Principal;
import java.util.Optional;
import java.util.concurrent.ExecutionException;

import static com.codahale.metrics.MetricRegistry.name;

Expand All @@ -23,8 +27,7 @@
* @param <P> the type of principals the authenticator returns
*/
public class CachingAuthenticator<C, P extends Principal> implements Authenticator<C, P> {
private final Authenticator<C, P> underlying;
private final Cache<C, Optional<P>> cache;
private final LoadingCache<C, Optional<P>> cache;
private final Meter cacheMisses;
private final Timer gets;

Expand All @@ -51,25 +54,37 @@ public CachingAuthenticator(final MetricRegistry metricRegistry,
public CachingAuthenticator(final MetricRegistry metricRegistry,
final Authenticator<C, P> authenticator,
final CacheBuilder<Object, Object> builder) {
this.underlying = authenticator;
this.cacheMisses = metricRegistry.meter(name(authenticator.getClass(), "cache-misses"));
this.gets = metricRegistry.timer(name(authenticator.getClass(), "gets"));
this.cache = builder.recordStats().build();
this.cache = builder.recordStats().build(new CacheLoader<C, Optional<P>>() {
@Override
public Optional<P> load(C key) throws Exception {
cacheMisses.mark();
final Optional<P> optPrincipal = authenticator.authenticate(key);
if (!optPrincipal.isPresent()) {
// Prevent caching of unknown credentials
throw new InvalidCredentialsException();
}
return optPrincipal;
}
});
}

@Override
public Optional<P> authenticate(C credentials) throws AuthenticationException {
final Timer.Context context = gets.time();
try {
Optional<P> optionalPrincipal = cache.getIfPresent(credentials);
if (optionalPrincipal == null) {
cacheMisses.mark();
optionalPrincipal = underlying.authenticate(credentials);
if (optionalPrincipal.isPresent()) {
cache.put(credentials, optionalPrincipal);
}
return cache.get(credentials);
} catch (ExecutionException e) {
final Throwable cause = e.getCause();
if (cause instanceof InvalidCredentialsException) {
return Optional.empty();
}
return optionalPrincipal;
// Attempt to re-throw as-is
Throwables.propagateIfPossible(cause, AuthenticationException.class);
throw new AuthenticationException(cause);
} catch (UncheckedExecutionException e) {
throw Throwables.propagate(e.getCause());
} finally {
context.stop();
}
Expand Down Expand Up @@ -126,4 +141,10 @@ public long size() {
public CacheStats stats() {
return cache.stats();
}

/**
* Exception thrown by {@link CacheLoader#load(Object)} when the authenticator returns {@link Optional#empty()}.
* This is used to prevent caching of invalid credentials.
*/
private static class InvalidCredentialsException extends Exception {}
}
Expand Up @@ -3,8 +3,11 @@
import com.codahale.metrics.MetricRegistry;
import com.google.common.cache.CacheBuilderSpec;
import com.google.common.collect.ImmutableSet;
import org.hamcrest.CoreMatchers;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.InOrder;

import java.security.Principal;
Expand All @@ -23,6 +26,8 @@ public class CachingAuthenticatorTest {
private final Authenticator<String, Principal> underlying = mock(Authenticator.class);
private final CachingAuthenticator<String, Principal> cached =
new CachingAuthenticator<>(new MetricRegistry(), underlying, CacheBuilderSpec.parse("maximumSize=1"));
@Rule
public ExpectedException expected = ExpectedException.none();

@Before
public void setUp() throws Exception {
Expand Down Expand Up @@ -94,7 +99,7 @@ public void calculatesTheSizeOfTheCache() throws Exception {
@Test
public void calculatesCacheStats() throws Exception {
cached.authenticate("credentials1");
assertThat(cached.stats().loadCount()).isEqualTo(0);
assertThat(cached.stats().loadCount()).isEqualTo(1);
assertThat(cached.size()).isEqualTo(1);
}

Expand All @@ -105,4 +110,20 @@ public void shouldNotCacheAbsentPrincipals() throws Exception {
verify(underlying).authenticate("credentials");
assertThat(cached.size()).isEqualTo(0);
}

@Test
public void shouldPropagateAuthenticationException() throws AuthenticationException {
final AuthenticationException e = new AuthenticationException("Auth failed");
when(underlying.authenticate(anyString())).thenThrow(e);
expected.expect(CoreMatchers.sameInstance(e));
cached.authenticate("credentials");
}

@Test
public void shouldPropagateRuntimeException() throws AuthenticationException {
final RuntimeException e = new NullPointerException();
when(underlying.authenticate(anyString())).thenThrow(e);
expected.expect(CoreMatchers.sameInstance(e));
cached.authenticate("credentials");
}
}

0 comments on commit 18425f7

Please sign in to comment.