Skip to content

Commit

Permalink
[GEOS-8824] Push up common behavior of access token converter
Browse files Browse the repository at this point in the history
  • Loading branch information
aaime committed Jul 5, 2018
1 parent 70a21da commit 6decbe9
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 200 deletions.
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/ */
package org.geoserver.security.oauth2.services; package org.geoserver.security.oauth2.services;


import org.geoserver.security.oauth2.GeoServerAccessTokenConverter;

/** /**
* Access Token Converter for GeoNode token details. * Access Token Converter for GeoNode token details.
* *
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -4,23 +4,8 @@
*/ */
package org.geoserver.security.oauth2.services; package org.geoserver.security.oauth2.services;


import java.io.IOException;
import java.util.Map; import java.util.Map;
import org.geoserver.security.oauth2.GeoServerOAuthRemoteTokenServices; import org.geoserver.security.oauth2.GeoServerOAuthRemoteTokenServices;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.RestTemplate;


/** /**
* Remote Token Services for GeoNode token details. * Remote Token Services for GeoNode token details.
Expand All @@ -30,72 +15,12 @@
public class GeoNodeTokenServices extends GeoServerOAuthRemoteTokenServices { public class GeoNodeTokenServices extends GeoServerOAuthRemoteTokenServices {


public GeoNodeTokenServices() { public GeoNodeTokenServices() {
tokenConverter = new GeoNodeAccessTokenConverter(); super(new GeoNodeAccessTokenConverter());
restTemplate = new RestTemplate();
((RestTemplate) restTemplate)
.setErrorHandler(
new DefaultResponseErrorHandler() {
@Override
// Ignore 400
public void handleError(ClientHttpResponse response)
throws IOException {
if (response.getRawStatusCode() != 400) {
super.handleError(response);
}
}
});
} }


@Override protected void transformNonStandardValuesToStandardValues(Map<String, Object> map) {
public OAuth2Authentication loadAuthentication(String accessToken)
throws AuthenticationException, InvalidTokenException {
Map<String, Object> checkTokenResponse = checkToken(accessToken);

if (checkTokenResponse.containsKey("error")) {
logger.debug("check_token returned error: " + checkTokenResponse.get("error"));
throw new InvalidTokenException(accessToken);
}

transformNonStandardValuesToStandardValues(checkTokenResponse);

Assert.state(
checkTokenResponse.containsKey("client_id"),
"Client id must be present in response from auth server");
return tokenConverter.extractAuthentication(checkTokenResponse);
}

private Map<String, Object> checkToken(String accessToken) {
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
formData.add("token", accessToken);
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", getAuthorizationHeader(accessToken));
String accessTokenUrl =
new StringBuilder(checkTokenEndpointUrl)
.append("?access_token=")
.append(accessToken)
.toString();
return postForMap(accessTokenUrl, formData, headers);
}

private void transformNonStandardValuesToStandardValues(Map<String, Object> map) {
LOGGER.debug("Original map = " + map); LOGGER.debug("Original map = " + map);
map.put("user_name", map.get("issued_to")); // GeoNode sends 'client_id' as 'issued_to' map.put("user_name", map.get("issued_to")); // GeoNode sends 'client_id' as 'issued_to'
LOGGER.debug("Transformed = " + map); LOGGER.debug("Transformed = " + map);
} }

private String getAuthorizationHeader(String accessToken) {
return "Bearer " + accessToken;
}

private Map<String, Object> postForMap(
String path, MultiValueMap<String, String> formData, HttpHeaders headers) {
if (headers.getContentType() == null) {
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
}
ParameterizedTypeReference<Map<String, Object>> map =
new ParameterizedTypeReference<Map<String, Object>>() {};
return restTemplate
.exchange(path, HttpMethod.POST, new HttpEntity<>(formData, headers), map)
.getBody();
}
} }
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/ */
package org.geoserver.security.oauth2.services; package org.geoserver.security.oauth2.services;


import org.geoserver.security.oauth2.GeoServerAccessTokenConverter;

/** /**
* Access Token Converter for GitHub token details. * Access Token Converter for GitHub token details.
* *
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -4,23 +4,16 @@
*/ */
package org.geoserver.security.oauth2.services; package org.geoserver.security.oauth2.services;


import java.io.IOException;
import java.util.Map;
import org.geoserver.security.oauth2.GeoServerOAuthRemoteTokenServices; import org.geoserver.security.oauth2.GeoServerOAuthRemoteTokenServices;
import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity; import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException; import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.RestTemplate; import java.util.Map;


/** /**
* Remote Token Services for GitHub token details. * Remote Token Services for GitHub token details.
Expand All @@ -30,67 +23,17 @@
public class GitHubTokenServices extends GeoServerOAuthRemoteTokenServices { public class GitHubTokenServices extends GeoServerOAuthRemoteTokenServices {


public GitHubTokenServices() { public GitHubTokenServices() {
tokenConverter = new GitHubAccessTokenConverter(); super(new GitHubAccessTokenConverter());
restTemplate = new RestTemplate();
((RestTemplate) restTemplate)
.setErrorHandler(
new DefaultResponseErrorHandler() {
@Override
// Ignore 400
public void handleError(ClientHttpResponse response)
throws IOException {
if (response.getRawStatusCode() != 400) {
super.handleError(response);
}
}
});
}

@Override
public OAuth2Authentication loadAuthentication(String accessToken)
throws AuthenticationException, InvalidTokenException {
Map<String, Object> checkTokenResponse = checkToken(accessToken);

if (checkTokenResponse.containsKey("message")
&& checkTokenResponse.get("message").toString().startsWith("Problems")) {
logger.debug("check_token returned error: " + checkTokenResponse.get("message"));
throw new InvalidTokenException(accessToken);
}

transformNonStandardValuesToStandardValues(checkTokenResponse);

Assert.state(
checkTokenResponse.containsKey("client_id"),
"Client id must be present in response from auth server");
return tokenConverter.extractAuthentication(checkTokenResponse);
}

private Map<String, Object> checkToken(String accessToken) {
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
formData.add("token", accessToken);
HttpHeaders headers = new HttpHeaders();
// headers.set("Authorization", getAuthorizationHeader(clientId, clientSecret));
headers.set("Authorization", getAuthorizationHeader(accessToken));
String accessTokenUrl =
new StringBuilder(checkTokenEndpointUrl)
.append("?access_token=")
.append(accessToken)
.toString();
return postForMap(accessTokenUrl, formData, headers);
} }


private void transformNonStandardValuesToStandardValues(Map<String, Object> map) { protected void transformNonStandardValuesToStandardValues(Map<String, Object> map) {
LOGGER.debug("Original map = " + map); LOGGER.debug("Original map = " + map);
map.put("client_id", clientId); // GitHub does not send 'client_id' map.put("client_id", clientId); // GitHub does not send 'client_id'
map.put("user_name", map.get("login")); // GitHub sends 'user_name' as 'login' map.put("user_name", map.get("login")); // GitHub sends 'user_name' as 'login'
LOGGER.debug("Transformed = " + map); LOGGER.debug("Transformed = " + map);
} }


private String getAuthorizationHeader(String accessToken) { protected Map<String, Object> postForMap(
return "Bearer " + accessToken;
}

private Map<String, Object> postForMap(
String path, MultiValueMap<String, String> formData, HttpHeaders headers) { String path, MultiValueMap<String, String> formData, HttpHeaders headers) {
if (headers.getContentType() == null) { if (headers.getContentType() == null) {
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
Expand All @@ -101,4 +44,13 @@ private Map<String, Object> postForMap(
.exchange(path, HttpMethod.GET, new HttpEntity<>(formData, headers), map) .exchange(path, HttpMethod.GET, new HttpEntity<>(formData, headers), map)
.getBody(); .getBody();
} }

@Override
protected void verifyTokenResponse(String accessToken, Map<String, Object> checkTokenResponse) {
if (checkTokenResponse.containsKey("message")
&& checkTokenResponse.get("message").toString().startsWith("Problems")) {
logger.debug("check_token returned error: " + checkTokenResponse.get("message"));
throw new InvalidTokenException(accessToken);
}
}
} }
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -4,25 +4,13 @@
*/ */
package org.geoserver.security.oauth2.services; package org.geoserver.security.oauth2.services;


import java.io.IOException;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
import java.util.Map; import java.util.Map;
import org.geoserver.security.oauth2.GeoServerOAuthRemoteTokenServices; import org.geoserver.security.oauth2.GeoServerOAuthRemoteTokenServices;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.crypto.codec.Base64; import org.springframework.security.crypto.codec.Base64;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.RestTemplate;


/** /**
* Remote Token Services for Google token details. * Remote Token Services for Google token details.
Expand All @@ -32,45 +20,14 @@
public class GoogleTokenServices extends GeoServerOAuthRemoteTokenServices { public class GoogleTokenServices extends GeoServerOAuthRemoteTokenServices {


public GoogleTokenServices() { public GoogleTokenServices() {
tokenConverter = new GoogleAccessTokenConverter(); super(new GoogleAccessTokenConverter());
restTemplate = new RestTemplate();
((RestTemplate) restTemplate)
.setErrorHandler(
new DefaultResponseErrorHandler() {
@Override
// Ignore 400
public void handleError(ClientHttpResponse response)
throws IOException {
if (response.getRawStatusCode() != 400) {
super.handleError(response);
}
}
});
} }


@Override protected Map<String, Object> checkToken(String accessToken) {
public OAuth2Authentication loadAuthentication(String accessToken)
throws AuthenticationException, InvalidTokenException {
Map<String, Object> checkTokenResponse = checkToken(accessToken);

if (checkTokenResponse.containsKey("error")) {
logger.debug("check_token returned error: " + checkTokenResponse.get("error"));
throw new InvalidTokenException(accessToken);
}

transformNonStandardValuesToStandardValues(checkTokenResponse);

Assert.state(
checkTokenResponse.containsKey("client_id"),
"Client id must be present in response from auth server");
return tokenConverter.extractAuthentication(checkTokenResponse);
}

private Map<String, Object> checkToken(String accessToken) {
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>(); MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
formData.add("token", accessToken); formData.add("token", accessToken);
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", getAuthorizationHeader(clientId, clientSecret)); headers.set("Authorization", getAuthorizationHeader(accessToken));
String accessTokenUrl = String accessTokenUrl =
new StringBuilder(checkTokenEndpointUrl) new StringBuilder(checkTokenEndpointUrl)
.append("?access_token=") .append("?access_token=")
Expand All @@ -79,31 +36,19 @@ private Map<String, Object> checkToken(String accessToken) {
return postForMap(accessTokenUrl, formData, headers); return postForMap(accessTokenUrl, formData, headers);
} }


private void transformNonStandardValuesToStandardValues(Map<String, Object> map) { protected void transformNonStandardValuesToStandardValues(Map<String, Object> map) {
LOGGER.debug("Original map = " + map); LOGGER.debug("Original map = " + map);
map.put("client_id", map.get("issued_to")); // Google sends 'client_id' as 'issued_to' map.put("client_id", map.get("issued_to")); // Google sends 'client_id' as 'issued_to'
map.put("user_name", map.get("user_id")); // Google sends 'user_name' as 'user_id' map.put("user_name", map.get("user_id")); // Google sends 'user_name' as 'user_id'
LOGGER.debug("Transformed = " + map); LOGGER.debug("Transformed = " + map);
} }


private String getAuthorizationHeader(String clientId, String clientSecret) { protected String getAuthorizationHeader(String accessToken) {
String creds = String.format("%s:%s", clientId, clientSecret); String creds = String.format("%s:%s", clientId, clientSecret);
try { try {
return "Basic " + new String(Base64.encode(creds.getBytes("UTF-8"))); return "Basic " + new String(Base64.encode(creds.getBytes("UTF-8")));
} catch (UnsupportedEncodingException e) { } catch (UnsupportedEncodingException e) {
throw new IllegalStateException("Could not convert String"); throw new IllegalStateException("Could not convert String");
} }
} }

private Map<String, Object> postForMap(
String path, MultiValueMap<String, String> formData, HttpHeaders headers) {
if (headers.getContentType() == null) {
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
}
ParameterizedTypeReference<Map<String, Object>> map =
new ParameterizedTypeReference<Map<String, Object>>() {};
return restTemplate
.exchange(path, HttpMethod.POST, new HttpEntity<>(formData, headers), map)
.getBody();
}
} }
Original file line number Original file line Diff line number Diff line change
@@ -1,8 +1,15 @@
/*
* (c) 2018 Open Source Geospatial Foundation - all rights reserved
* This code is licensed under the GPL 2.0 license, available at the root
* application directory.
*
*/

/* (c) 2016 Open Source Geospatial Foundation - all rights reserved /* (c) 2016 Open Source Geospatial Foundation - all rights reserved
* This code is licensed under the GPL 2.0 license, available at the root * This code is licensed under the GPL 2.0 license, available at the root
* application directory. * application directory.
*/ */
package org.geoserver.security.oauth2.services; package org.geoserver.security.oauth2;


import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
Expand Down

0 comments on commit 6decbe9

Please sign in to comment.