Skip to content

Commit

Permalink
Use the custom ServerRequestCache that the user configures
Browse files Browse the repository at this point in the history
on for the default authentication entry point and authentication
success handler

Fixes spring-projectsgh-7721

spring-projects#7721

Set RequestCache on the Oauth2LoginSpec default authentication success handler

import static ReflectionTestUtils.getField

Feedback incorporated per

spring-projects#7734 (review)
  • Loading branch information
fhanik committed Dec 18, 2019
1 parent 0d24e2b commit 32a131c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
Expand Up @@ -76,9 +76,11 @@
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
Expand Down Expand Up @@ -984,7 +986,7 @@ public class OAuth2LoginSpec {

private ServerWebExchangeMatcher authenticationMatcher;

private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
private ServerAuthenticationSuccessHandler authenticationSuccessHandler;

private ServerAuthenticationFailureHandler authenticationFailureHandler;

Expand Down Expand Up @@ -1175,24 +1177,37 @@ protected void configure(ServerHttpSecurity http) {
authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));

authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http));
authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler());
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);

MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher(
MediaType.TEXT_HTML);
htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
Map<String, String> urlToText = http.oauth2Login.getLinks();
String authenticationEntryPointRedirectPath;
if (urlToText.size() == 1) {
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next())));
authenticationEntryPointRedirectPath = urlToText.keySet().iterator().next();
} else {
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login")));
authenticationEntryPointRedirectPath = "/login";
}
RedirectServerAuthenticationEntryPoint entryPoint = new RedirectServerAuthenticationEntryPoint(authenticationEntryPointRedirectPath);
entryPoint.setRequestCache(http.requestCache.requestCache);
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, entryPoint));

http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
}

private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) {
if (this.authenticationSuccessHandler == null) {
RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler();
handler.setRequestCache(http.requestCache.requestCache);
this.authenticationSuccessHandler = handler;
}
return this.authenticationSuccessHandler;
}

private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
if (this.authenticationFailureHandler == null) {
this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error");
Expand Down
Expand Up @@ -20,10 +20,12 @@
import static org.mockito.BDDMockito.given;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.test.util.ReflectionTestUtils.getField;

import java.util.Arrays;
import java.util.List;
Expand All @@ -35,16 +37,20 @@
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;

import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
import reactor.core.publisher.Mono;
import reactor.test.publisher.TestPublisher;

Expand All @@ -64,7 +70,6 @@
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
import org.springframework.security.web.server.csrf.CsrfWebFilter;
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.reactive.server.EntityExchangeResult;
import org.springframework.test.web.reactive.server.FluxExchangeResult;
import org.springframework.test.web.reactive.server.WebTestClient;
Expand Down Expand Up @@ -200,7 +205,7 @@ public void csrfServerLogoutHandlerNotAppliedIfCsrfIsntEnabled() {
.isNotPresent();

Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));

assertThat(logoutHandler)
.get()
Expand All @@ -213,17 +218,17 @@ public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() {

assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
.get()
.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
.extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository"))
.isEqualTo(this.csrfTokenRepository);

Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));

assertThat(logoutHandler)
.get()
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
.extracting(delegatingLogoutHandler ->
((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
.map(ServerLogoutHandler::getClass)
.collect(Collectors.toList()))
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
Expand Down Expand Up @@ -479,6 +484,33 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() {
verify(customServerCsrfTokenRepository).loadToken(any());
}

@Test
public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());
ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);

SecurityWebFilterChain securityFilterChain = this.http
.oauth2Login()
.clientRegistrationRepository(clientRegistrationRepository)
.and()
.authorizeExchange().anyExchange().authenticated()
.and()
.requestCache(c -> c.requestCache(requestCache))
.build();

WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build();
client.get().uri("/test").exchange();
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
verify(requestCache).saveRequest(captor.capture());
assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test");


OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get();
Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler");
assertThat(getField(handler, "requestCache")).isSameAs(requestCache);
}

@Test
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
Expand All @@ -503,7 +535,7 @@ public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {

private boolean isX509Filter(WebFilter filter) {
try {
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
Object converter = getField(filter, "authenticationConverter");
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
} catch (IllegalArgumentException e) {
// field doesn't exist
Expand Down

0 comments on commit 32a131c

Please sign in to comment.