Skip to content

Commit

Permalink
HttpSessionSecurityContextRepository does not persist @transient Auth…
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrandja committed Dec 7, 2021
1 parent f489dd0 commit 4d9f040
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 0 deletions.
Expand Up @@ -19,13 +19,24 @@
import java.util.LinkedHashMap;
import java.util.Map;

import javax.servlet.AsyncContext;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;

import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.Transient;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenIntrospectionAuthenticationProvider;
Expand All @@ -39,6 +50,10 @@
import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;
import org.springframework.security.web.authentication.HttpStatusEntryPoint;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
import org.springframework.security.web.context.HttpRequestResponseHolder;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SaveContextOnUpdateOrErrorResponseWrapper;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
Expand Down Expand Up @@ -212,6 +227,105 @@ public void init(B builder) {
this.tokenRevocationEndpointMatcher)
);
}

// gh-482
initSecurityContextRepository(builder);
}

private void initSecurityContextRepository(B builder) {
// TODO This is a temporary fix and should be removed after upgrading to Spring Security 5.7.0 GA.
//
// See:
// Prevent Save @Transient Authentication with existing HttpSession
// https://github.com/spring-projects/spring-security/pull/9993

final SecurityContextRepository securityContextRepository = builder.getSharedObject(SecurityContextRepository.class);
if (!(securityContextRepository instanceof HttpSessionSecurityContextRepository)) {
return;
}

SecurityContextRepository securityContextRepositoryTransientNotSaved = new SecurityContextRepository() {
// OAuth2ClientAuthenticationToken is @Transient and is accepted by
// OAuth2TokenEndpointFilter, OAuth2TokenIntrospectionEndpointFilter and OAuth2TokenRevocationEndpointFilter
private final RequestMatcher clientAuthenticationRequestMatcher = new OrRequestMatcher(
getRequestMatcher(OAuth2TokenEndpointConfigurer.class),
OAuth2AuthorizationServerConfigurer.this.tokenIntrospectionEndpointMatcher,
OAuth2AuthorizationServerConfigurer.this.tokenRevocationEndpointMatcher);

// JwtAuthenticationToken is @Transient and is accepted by
// OidcUserInfoEndpointFilter and OidcClientRegistrationEndpointFilter
private final RequestMatcher jwtAuthenticationRequestMatcher = getRequestMatcher(OidcConfigurer.class);

@Override
public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) {
final HttpServletRequest unwrappedRequest = requestResponseHolder.getRequest();
final HttpServletResponse unwrappedResponse = requestResponseHolder.getResponse();

SecurityContext securityContext = securityContextRepository.loadContext(requestResponseHolder);

if (this.clientAuthenticationRequestMatcher.matches(unwrappedRequest) ||
this.jwtAuthenticationRequestMatcher.matches(unwrappedRequest)) {

final SaveContextOnUpdateOrErrorResponseWrapper transientAuthenticationResponseWrapper =
new SaveContextOnUpdateOrErrorResponseWrapper(unwrappedResponse, false) {

@Override
protected void saveContext(SecurityContext context) {
// @Transient Authentication should not be saved
if (context.getAuthentication() != null) {
Assert.state(isTransientAuthentication(context.getAuthentication()), "Expected @Transient Authentication");
}
}

};
// Override the default HttpSessionSecurityContextRepository.SaveToSessionResponseWrapper
requestResponseHolder.setResponse(transientAuthenticationResponseWrapper);

final HttpServletRequestWrapper transientAuthenticationRequestWrapper =
new HttpServletRequestWrapper(unwrappedRequest) {

@Override
public AsyncContext startAsync() {
transientAuthenticationResponseWrapper.disableSaveOnResponseCommitted();
return super.startAsync();
}

@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
throws IllegalStateException {
transientAuthenticationResponseWrapper.disableSaveOnResponseCommitted();
return super.startAsync(servletRequest, servletResponse);
}

};
// Override the default HttpSessionSecurityContextRepository.SaveToSessionRequestWrapper
requestResponseHolder.setRequest(transientAuthenticationRequestWrapper);
}

return securityContext;
}

@Override
public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) {
Authentication authentication = context.getAuthentication();
if (authentication == null || isTransientAuthentication(authentication)) {
return;
}
securityContextRepository.saveContext(context, request, response);
}

@Override
public boolean containsContext(HttpServletRequest request) {
return securityContextRepository.containsContext(request);
}

private boolean isTransientAuthentication(Authentication authentication) {
return AnnotationUtils.getAnnotation(authentication.getClass(), Transient.class) != null;
}

};

builder.setSharedObject(SecurityContextRepository.class, securityContextRepositoryTransientNotSaved);
}

@Override
Expand Down
Expand Up @@ -37,9 +37,11 @@
import org.assertj.core.matcher.AssertionMatcher;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
Expand All @@ -56,6 +58,7 @@
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
Expand Down Expand Up @@ -102,6 +105,8 @@
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
import org.springframework.security.web.context.SecurityContextRepository;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
Expand All @@ -116,6 +121,10 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
Expand Down Expand Up @@ -154,6 +163,7 @@ public class OAuth2AuthorizationCodeGrantTests {
private static AuthenticationProvider authorizationRequestAuthenticationProvider;
private static AuthenticationSuccessHandler authorizationResponseHandler;
private static AuthenticationFailureHandler authorizationErrorResponseHandler;
private static SecurityContextRepository securityContextRepository;
private static String consentPage = "/oauth2/consent";

@Rule
Expand Down Expand Up @@ -187,6 +197,7 @@ public static void init() {
authorizationRequestAuthenticationProvider = mock(AuthenticationProvider.class);
authorizationResponseHandler = mock(AuthenticationSuccessHandler.class);
authorizationErrorResponseHandler = mock(AuthenticationFailureHandler.class);
securityContextRepository = spy(new HttpSessionSecurityContextRepository());
db = new EmbeddedDatabaseBuilder()
.generateUniqueName(true)
.setType(EmbeddedDatabaseType.HSQL)
Expand All @@ -197,6 +208,11 @@ public static void init() {
.build();
}

@Before
public void setup() {
reset(securityContextRepository);
}

@After
public void tearDown() {
jdbcOperations.update("truncate table oauth2_authorization");
Expand Down Expand Up @@ -615,6 +631,48 @@ public void requestWhenAuthorizationEndpointCustomizedThenUsed() throws Exceptio
verify(authorizationResponseHandler).onAuthenticationSuccess(any(), any(), eq(authorizationCodeRequestAuthenticationResult));
}

// gh-482
@Test
public void requestWhenClientObtainsAccessTokenThenClientAuthenticationNotPersisted() throws Exception {
this.spring.register(AuthorizationServerConfigurationWithSecurityContextRepository.class).autowire();

RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
this.registeredClientRepository.save(registeredClient);

MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
.params(getAuthorizationRequestParameters(registeredClient))
.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
.with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();

ArgumentCaptor<org.springframework.security.core.context.SecurityContext> securityContextCaptor =
ArgumentCaptor.forClass(org.springframework.security.core.context.SecurityContext.class);
verify(securityContextRepository, times(2)).saveContext(securityContextCaptor.capture(), any(), any());
securityContextCaptor.getAllValues().forEach(securityContext ->
assertThat(securityContext.getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class));
reset(securityContextRepository);

String authorizationCode = extractParameterFromRedirectUri(mvcResult.getResponse().getRedirectedUrl(), "code");
OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);

this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER))
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
.andExpect(status().isOk())
.andExpect(jsonPath("$.access_token").isNotEmpty())
.andExpect(jsonPath("$.token_type").isNotEmpty())
.andExpect(jsonPath("$.expires_in").isNotEmpty())
.andExpect(jsonPath("$.refresh_token").doesNotExist())
.andExpect(jsonPath("$.scope").isNotEmpty());

verify(securityContextRepository, never()).saveContext(any(), any(), any());
}

private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
Expand Down Expand Up @@ -739,6 +797,29 @@ static class ParametersMapper extends JdbcOAuth2AuthorizationService.OAuth2Autho

}

@EnableWebSecurity
static class AuthorizationServerConfigurationWithSecurityContextRepository extends AuthorizationServerConfiguration {
// @formatter:off
@Bean
public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception {
OAuth2AuthorizationServerConfigurer<HttpSecurity> authorizationServerConfigurer =
new OAuth2AuthorizationServerConfigurer<>();
RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher();

http
.requestMatcher(endpointsMatcher)
.authorizeRequests(authorizeRequests ->
authorizeRequests.anyRequest().authenticated()
)
.csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher))
.securityContext(securityContext ->
securityContext.securityContextRepository(securityContextRepository))
.apply(authorizationServerConfigurer);
return http.build();
}
// @formatter:on
}

@EnableWebSecurity
@Import(OAuth2AuthorizationServerConfiguration.class)
static class AuthorizationServerConfigurationWithJwtEncoder extends AuthorizationServerConfiguration {
Expand Down

0 comments on commit 4d9f040

Please sign in to comment.