Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

SECOAUTH-366: add optional filter for token endpoint

  • Loading branch information...
commit 6004abc975e17fd68150e5a5c748b7490d2552a8 1 parent 7c05904
Dave Syer dsyer authored
28 ...g-security-oauth2/src/main/java/org/springframework/security/oauth2/provider/endpoint/TokenEndpoint.java
View
@@ -34,6 +34,7 @@
import org.springframework.security.oauth2.common.util.OAuth2Utils;
import org.springframework.security.oauth2.provider.ClientRegistrationException;
import org.springframework.security.oauth2.provider.DefaultAuthorizationRequest;
+import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RequestMapping;
@@ -70,13 +71,11 @@
"There is no client authentication. Try adding an appropriate authentication filter.");
}
- Authentication client = (Authentication) principal;
- if (!client.isAuthenticated()) {
- throw new InsufficientAuthenticationException("The client is not authenticated.");
- }
HashMap<String, String> request = new HashMap<String, String>(parameters);
- String clientId = client.getName();
- request.put("client_id", clientId);
+ String clientId = getClientId(principal);
+ if (clientId != null) {
+ request.put("client_id", clientId);
+ }
if (!StringUtils.hasText(grantType)) {
throw new InvalidRequestException("Missing grant type");
@@ -107,6 +106,23 @@
}
+ /**
+ * @param principal the currently authentication principal
+ * @return a client id if there is one in the principal
+ */
+ protected String getClientId(Principal principal) {
+ Authentication client = (Authentication) principal;
+ if (!client.isAuthenticated()) {
+ throw new InsufficientAuthenticationException("The client is not authenticated.");
+ }
+ String clientId = client.getName();
+ if (client instanceof OAuth2Authentication) {
+ // Might be a client and user combined authentication
+ clientId = ((OAuth2Authentication) client).getAuthorizationRequest().getClientId();
+ }
+ return clientId;
+ }
+
@ExceptionHandler(ClientRegistrationException.class)
public ResponseEntity<OAuth2Exception> handleClientRegistrationException(Exception e) throws Exception {
logger.info("Handling error: " + e.getClass().getSimpleName() + ", " + e.getMessage());
217 ...c/main/java/org/springframework/security/oauth2/provider/endpoint/TokenEndpointAuthenticationFilter.java
View
@@ -0,0 +1,217 @@
+/*
+ * Copyright 2012-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.provider.endpoint;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+import javax.servlet.Filter;
+import javax.servlet.FilterChain;
+import javax.servlet.FilterConfig;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.security.authentication.AuthenticationDetailsSource;
+import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.BadCredentialsException;
+import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.common.util.OAuth2Utils;
+import org.springframework.security.oauth2.provider.DefaultAuthorizationRequest;
+import org.springframework.security.oauth2.provider.OAuth2Authentication;
+import org.springframework.security.oauth2.provider.error.OAuth2AuthenticationEntryPoint;
+import org.springframework.security.web.AuthenticationEntryPoint;
+import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.authentication.www.BasicAuthenticationFilter;
+
+/**
+ * <p>
+ * An optional authentication filter for the {@link TokenEndpoint}. It sits downstream of another filter (usually
+ * {@link BasicAuthenticationFilter}) for the client, and creates an {@link OAuth2Authentication} for the Spring
+ * {@link SecurityContext} if the request also contains user credentials, e.g. as typically would be the case in a
+ * password grant. This filter is only required if the TokenEndpoint (or one of it's dependencies) needs to know about
+ * the authenticated user. In a vanilla password grant this <b>isn't</b> normally necessary because the token granter
+ * will also authenticate the user.
+ * </p>
+ *
+ * <p>
+ * If this filter is used the Spring Security context will contain an OAuth2Authentication encapsulating (as the
+ * authorization request) the form parameters coming into the filter and the client id from the already authenticated
+ * client authentication, and the authenticated user token extracted from the request and validated using the
+ * authentication manager.
+ * </p>
+ *
+ * @author Dave Syer
+ *
+ */
+public class TokenEndpointAuthenticationFilter implements Filter {
+
+ private static final Log logger = LogFactory.getLog(TokenEndpointAuthenticationFilter.class);
+
+ private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
+
+ private AuthenticationEntryPoint authenticationEntryPoint = new OAuth2AuthenticationEntryPoint();
+
+ private final AuthenticationManager authenticationManager;
+
+ /**
+ * @param authenticationManager an AuthenticationManager for the incoming request
+ */
+ public TokenEndpointAuthenticationFilter(AuthenticationManager authenticationManager) {
+ super();
+ this.authenticationManager = authenticationManager;
+ }
+
+ /**
+ * An authentication entry point that can handle unsuccessful authentication. Defaults to an
+ * {@link OAuth2AuthenticationEntryPoint}.
+ *
+ * @param authenticationEntryPoint the authenticationEntryPoint to set
+ */
+ public void setAuthenticationEntryPoint(AuthenticationEntryPoint authenticationEntryPoint) {
+ this.authenticationEntryPoint = authenticationEntryPoint;
+ }
+
+ /**
+ * A source of authentication details for requests that result in authentication.
+ *
+ * @param authenticationDetailsSource the authenticationDetailsSource to set
+ */
+ public void setAuthenticationDetailsSource(
+ AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
+ this.authenticationDetailsSource = authenticationDetailsSource;
+ }
+
+ public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException,
+ ServletException {
+
+ final boolean debug = logger.isDebugEnabled();
+ final HttpServletRequest request = (HttpServletRequest) req;
+ final HttpServletResponse response = (HttpServletResponse) res;
+
+ try {
+ Authentication credentials = extractCredentials(request);
+
+ if (credentials != null) {
+
+ if (debug) {
+ logger.debug("Authentication credentials found for '" + credentials.getName() + "'");
+ }
+
+ Authentication authResult = authenticationManager.authenticate(credentials);
+
+ if (debug) {
+ logger.debug("Authentication success: " + authResult.getName());
+ }
+
+ Authentication clientAuth = SecurityContextHolder.getContext().getAuthentication();
+ if (clientAuth == null) {
+ throw new BadCredentialsException(
+ "No client authentication found. Remember to put a filter upstream of the TokenEndpointAuthenticationFilter.");
+ }
+ DefaultAuthorizationRequest authorizationRequest = new DefaultAuthorizationRequest(
+ clientAuth.getName(), getScope(request));
+ authorizationRequest.setAuthorizationParameters(getSingleValueMap(request));
+ if (clientAuth.isAuthenticated()) {
+ // Ensure the OAuth2Authentication is authenticated
+ authorizationRequest.setApproved(true);
+ }
+
+ SecurityContextHolder.getContext().setAuthentication(
+ new OAuth2Authentication(authorizationRequest, authResult));
+
+ onSuccessfulAuthentication(request, response, authResult);
+
+ }
+
+ }
+ catch (AuthenticationException failed) {
+ SecurityContextHolder.clearContext();
+
+ if (debug) {
+ logger.debug("Authentication request for failed: " + failed);
+ }
+
+ onUnsuccessfulAuthentication(request, response, failed);
+
+ authenticationEntryPoint.commence(request, response, failed);
+
+ return;
+ }
+
+ chain.doFilter(request, response);
+ }
+
+ private Map<String, String> getSingleValueMap(HttpServletRequest request) {
+ Map<String, String> map = new HashMap<String, String>();
+ @SuppressWarnings("unchecked")
+ Map<String, String[]> parameters = request.getParameterMap();
+ for (String key : parameters.keySet()) {
+ String[] values = parameters.get(key);
+ map.put(key, values != null && values.length > 0 ? values[0] : null);
+ }
+ return map;
+ }
+
+ private Collection<String> getScope(HttpServletRequest request) {
+ return OAuth2Utils.parseParameterList(request.getParameter("scope"));
+ }
+
+ protected void onSuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
+ Authentication authResult) throws IOException {
+ }
+
+ protected void onUnsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
+ AuthenticationException failed) throws IOException {
+ }
+
+ /**
+ * If the incoming request contains user credentials in headers or parameters then extract them here into an
+ * Authentication token that can be validated later. This implementation only recognises password grant requests and
+ * extracts the username and password.
+ *
+ * @param request the incoming request, possibly with user credentials
+ * @return an authentication for validation (or null if there is no further authentication)
+ */
+ protected Authentication extractCredentials(HttpServletRequest request) {
+ String grantType = request.getParameter("grant_type");
+ if (grantType != null && grantType.equals("password")) {
+ UsernamePasswordAuthenticationToken result = new UsernamePasswordAuthenticationToken(
+ request.getParameter("username"), request.getParameter("password"));
+ result.setDetails(authenticationDetailsSource.buildDetails(request));
+ return result;
+ }
+ return null;
+ }
+
+ public void init(FilterConfig filterConfig) throws ServletException {
+ }
+
+ public void destroy() {
+ }
+
+}
99 ...st/java/org/springframework/security/oauth2/provider/endpoint/TestTokenEndpointAuthenticationFilter.java
View
@@ -0,0 +1,99 @@
+/*
+ * Copyright 2012-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.provider.endpoint;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+import org.springframework.mock.web.MockFilterChain;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.provider.OAuth2Authentication;
+
+/**
+ * @author Dave Syer
+ *
+ */
+public class TestTokenEndpointAuthenticationFilter {
+
+ private MockHttpServletRequest request = new MockHttpServletRequest();
+
+ private MockHttpServletResponse response = new MockHttpServletResponse();
+
+ private MockFilterChain chain = new MockFilterChain();
+
+ private AuthenticationManager authenticationManager = Mockito.mock(AuthenticationManager.class);
+
+ @Before
+ public void init() {
+ SecurityContextHolder.clearContext();
+ SecurityContextHolder.getContext().setAuthentication(
+ new UsernamePasswordAuthenticationToken("client", "secret", AuthorityUtils
+ .commaSeparatedStringToAuthorityList("ROLE_CLIENT")));
+ }
+
+ @After
+ public void close() {
+ SecurityContextHolder.clearContext();
+ }
+
+ @Test
+ public void testPasswordGrant() throws Exception {
+ request.setParameter("grant_type", "password");
+ Mockito.when(authenticationManager.authenticate(Mockito.<Authentication> any())).thenReturn(
+ new UsernamePasswordAuthenticationToken("foo", "bar", AuthorityUtils
+ .commaSeparatedStringToAuthorityList("ROLE_USER")));
+ TokenEndpointAuthenticationFilter filter = new TokenEndpointAuthenticationFilter(authenticationManager);
+ filter.doFilter(request, response, chain);
+ Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+ assertTrue(authentication instanceof OAuth2Authentication);
+ assertTrue(authentication.isAuthenticated());
+ }
+
+ @Test
+ public void testPasswordGrantWithUnAuthenticatedClient() throws Exception {
+ SecurityContextHolder.getContext().setAuthentication(
+ new UsernamePasswordAuthenticationToken("client", "secret"));
+ request.setParameter("grant_type", "password");
+ Mockito.when(authenticationManager.authenticate(Mockito.<Authentication> any())).thenReturn(
+ new UsernamePasswordAuthenticationToken("foo", "bar", AuthorityUtils
+ .commaSeparatedStringToAuthorityList("ROLE_USER")));
+ TokenEndpointAuthenticationFilter filter = new TokenEndpointAuthenticationFilter(authenticationManager);
+ filter.doFilter(request, response, chain);
+ Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+ assertTrue(authentication instanceof OAuth2Authentication);
+ assertFalse(authentication.isAuthenticated());
+ }
+
+ @Test
+ public void testNoGrantType() throws Exception {
+ TokenEndpointAuthenticationFilter filter = new TokenEndpointAuthenticationFilter(authenticationManager);
+ filter.doFilter(request, response, chain);
+ // Just the client
+ assertTrue(SecurityContextHolder.getContext().getAuthentication() instanceof UsernamePasswordAuthenticationToken);
+ }
+
+}
Please sign in to comment.
Something went wrong with that request. Please try again.