Skip to content

Commit

Permalink
Add logic to allow login page to render when an idp timeout occurs
Browse files Browse the repository at this point in the history
- Refactor services to use a trusting/nontrusting resttemplate that has timeout configured

[#157525402] https://www.pivotaltracker.com/story/show/157525402

Signed-off-by: Bruce Ricard <bruce.ricard@gmail.com>
  • Loading branch information
DennisDenuto authored and bruce-ricard committed May 17, 2018
1 parent 8197f29 commit 9a82248
Show file tree
Hide file tree
Showing 24 changed files with 899 additions and 785 deletions.
Expand Up @@ -229,6 +229,7 @@ public static class Builder {
private int singleSignOnServiceIndex; private int singleSignOnServiceIndex;
private boolean metadataTrustCheck; private boolean metadataTrustCheck;
private boolean enableIdpInitiatedSso = false; private boolean enableIdpInitiatedSso = false;
private boolean skipSslValidation = true;


private Builder(){} private Builder(){}


Expand All @@ -243,6 +244,7 @@ public SamlServiceProviderDefinition build() {
def.setSingleSignOnServiceIndex(singleSignOnServiceIndex); def.setSingleSignOnServiceIndex(singleSignOnServiceIndex);
def.setMetadataTrustCheck(metadataTrustCheck); def.setMetadataTrustCheck(metadataTrustCheck);
def.setEnableIdpInitiatedSso(enableIdpInitiatedSso); def.setEnableIdpInitiatedSso(enableIdpInitiatedSso);
def.setSkipSslValidation(skipSslValidation);
return def; return def;
} }


Expand All @@ -256,6 +258,11 @@ public Builder setNameID(String nameID) {
return this; return this;
} }


public Builder setSkipSSLValidation(boolean skipSslValidation) {
this.skipSslValidation = skipSslValidation;
return this;
}

public Builder setSingleSignOnServiceIndex(int singleSignOnServiceIndex) { public Builder setSingleSignOnServiceIndex(int singleSignOnServiceIndex) {
this.singleSignOnServiceIndex = singleSignOnServiceIndex; this.singleSignOnServiceIndex = singleSignOnServiceIndex;
return this; return this;
Expand Down
Expand Up @@ -65,7 +65,7 @@ public byte[] getUrlContent(String uri, final RestTemplate template) {
return metadata; return metadata;
} catch (RestClientException x) { } catch (RestClientException x) {
logger.warn("Unable to fetch metadata for "+uri, x); logger.warn("Unable to fetch metadata for "+uri, x);
return null; throw x;
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }
Expand Down
@@ -1,20 +1,23 @@
package org.cloudfoundry.identity.uaa.impl.config; package org.cloudfoundry.identity.uaa.impl.config;


import org.cloudfoundry.identity.uaa.util.UaaHttpRequestUtils; import org.cloudfoundry.identity.uaa.util.UaaHttpRequestUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;


@Configuration @Configuration
public class RestTemplateConfig { public class RestTemplateConfig {
@Value("${rest.template.timeout:10000}")
public int timeout;


@Bean @Bean
public RestTemplate nonTrustingRestTemplate() { public RestTemplate nonTrustingRestTemplate() {
return new RestTemplate(UaaHttpRequestUtils.createRequestFactory(false, 30_000)); return new RestTemplate(UaaHttpRequestUtils.createRequestFactory(false, timeout));
} }


@Bean @Bean
public RestTemplate trustingRestTemplate() { public RestTemplate trustingRestTemplate() {
return new RestTemplate(UaaHttpRequestUtils.createRequestFactory(true, 30_000)); return new RestTemplate(UaaHttpRequestUtils.createRequestFactory(true, timeout));
} }
} }
Expand Up @@ -13,6 +13,10 @@


package org.cloudfoundry.identity.uaa.provider.oauth; package org.cloudfoundry.identity.uaa.provider.oauth;


import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.cloudfoundry.identity.uaa.authentication.UaaAuthentication; import org.cloudfoundry.identity.uaa.authentication.UaaAuthentication;
import org.cloudfoundry.identity.uaa.authentication.manager.ExternalGroupAuthorizationEvent; import org.cloudfoundry.identity.uaa.authentication.manager.ExternalGroupAuthorizationEvent;
import org.cloudfoundry.identity.uaa.authentication.manager.ExternalLoginAuthenticationManager; import org.cloudfoundry.identity.uaa.authentication.manager.ExternalLoginAuthenticationManager;
Expand All @@ -38,15 +42,9 @@
import org.cloudfoundry.identity.uaa.user.UaaUserPrototype; import org.cloudfoundry.identity.uaa.user.UaaUserPrototype;
import org.cloudfoundry.identity.uaa.util.JsonUtils; import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.cloudfoundry.identity.uaa.util.LinkedMaskingMultiValueMap; import org.cloudfoundry.identity.uaa.util.LinkedMaskingMultiValueMap;
import org.cloudfoundry.identity.uaa.util.RestTemplateFactory;
import org.cloudfoundry.identity.uaa.util.TokenValidation; import org.cloudfoundry.identity.uaa.util.TokenValidation;
import org.cloudfoundry.identity.uaa.util.UaaStringUtils; import org.cloudfoundry.identity.uaa.util.UaaStringUtils;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;

import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.http.HttpEntity; import org.springframework.http.HttpEntity;
Expand Down Expand Up @@ -90,6 +88,8 @@
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;


import static java.util.Collections.emptyList;
import static java.util.Optional.ofNullable;
import static org.cloudfoundry.identity.uaa.oauth.jwk.JsonWebKey.KeyType.MAC; import static org.cloudfoundry.identity.uaa.oauth.jwk.JsonWebKey.KeyType.MAC;
import static org.cloudfoundry.identity.uaa.oauth.jwk.JsonWebKey.KeyType.RSA; import static org.cloudfoundry.identity.uaa.oauth.jwk.JsonWebKey.KeyType.RSA;
import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.SUB; import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.SUB;
Expand All @@ -104,23 +104,26 @@
import static org.cloudfoundry.identity.uaa.util.UaaHttpRequestUtils.isAcceptedInvitationAuthentication; import static org.cloudfoundry.identity.uaa.util.UaaHttpRequestUtils.isAcceptedInvitationAuthentication;
import static org.springframework.util.StringUtils.hasText; import static org.springframework.util.StringUtils.hasText;
import static org.springframework.util.StringUtils.isEmpty; import static org.springframework.util.StringUtils.isEmpty;
import static java.util.Collections.emptyList;
import static java.util.Optional.ofNullable;


public class XOAuthAuthenticationManager extends ExternalLoginAuthenticationManager<XOAuthAuthenticationManager.AuthenticationData> { public class XOAuthAuthenticationManager extends ExternalLoginAuthenticationManager<XOAuthAuthenticationManager.AuthenticationData> {


public static Log logger = LogFactory.getLog(XOAuthAuthenticationManager.class); public static Log logger = LogFactory.getLog(XOAuthAuthenticationManager.class);


private final RestTemplateFactory restTemplateFactory; private final RestTemplate trustingRestTemplate;
private final RestTemplate nonTrustingRestTemplate;



private UaaTokenServices tokenServices; private UaaTokenServices tokenServices;


//origin is per thread during execution //origin is per thread during execution
private final ThreadLocal<String> origin = ThreadLocal.withInitial(() -> "unknown"); private final ThreadLocal<String> origin = ThreadLocal.withInitial(() -> "unknown");


public XOAuthAuthenticationManager(IdentityProviderProvisioning providerProvisioning, RestTemplateFactory restTemplateFactory) { public XOAuthAuthenticationManager(IdentityProviderProvisioning providerProvisioning,
RestTemplate trustingRestTemplate,
RestTemplate nonTrustingRestTemplate) {
super(providerProvisioning); super(providerProvisioning);
this.restTemplateFactory = restTemplateFactory; this.trustingRestTemplate = trustingRestTemplate;
this.nonTrustingRestTemplate = nonTrustingRestTemplate;
} }


@Override @Override
Expand Down Expand Up @@ -420,7 +423,11 @@ protected boolean isAddNewShadowUser() {
} }


public RestTemplate getRestTemplate(AbstractXOAuthIdentityProviderDefinition config) { public RestTemplate getRestTemplate(AbstractXOAuthIdentityProviderDefinition config) {
return restTemplateFactory.getRestTemplate(config.isSkipSslValidation()); if (config.isSkipSslValidation()) {
return trustingRestTemplate;
} else {
return nonTrustingRestTemplate;
}
} }


protected String getResponseType(AbstractXOAuthIdentityProviderDefinition config) { protected String getResponseType(AbstractXOAuthIdentityProviderDefinition config) {
Expand Down
Expand Up @@ -24,9 +24,9 @@
import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning; import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning;
import org.cloudfoundry.identity.uaa.provider.OIDCIdentityProviderDefinition; import org.cloudfoundry.identity.uaa.provider.OIDCIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.util.JsonUtils; import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.cloudfoundry.identity.uaa.util.RestTemplateFactory;
import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator; import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.web.client.RestTemplate;


import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
import java.net.MalformedURLException; import java.net.MalformedURLException;
Expand All @@ -36,6 +36,7 @@
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;


import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
Expand All @@ -49,23 +50,34 @@ public class XOAuthProviderConfigurator implements IdentityProviderProvisioning


private final IdentityProviderProvisioning providerProvisioning; private final IdentityProviderProvisioning providerProvisioning;
private final UrlContentCache contentCache; private final UrlContentCache contentCache;
private final RestTemplateFactory restTemplateFactory; private final RestTemplate trustingRestTemplate;
private final RestTemplate nonTrustingRestTemplate;


public XOAuthProviderConfigurator(IdentityProviderProvisioning providerProvisioning, public XOAuthProviderConfigurator(IdentityProviderProvisioning providerProvisioning,
UrlContentCache contentCache, UrlContentCache contentCache,
RestTemplateFactory restTemplateFactory) { RestTemplate trustingRestTemplate,
RestTemplate nonTrustingRestTemplate) {
this.providerProvisioning = providerProvisioning; this.providerProvisioning = providerProvisioning;
this.contentCache = contentCache; this.contentCache = contentCache;
this.restTemplateFactory = restTemplateFactory; this.trustingRestTemplate = trustingRestTemplate;
this.nonTrustingRestTemplate = nonTrustingRestTemplate;
} }


protected OIDCIdentityProviderDefinition overlay(OIDCIdentityProviderDefinition definition) { protected OIDCIdentityProviderDefinition overlay(OIDCIdentityProviderDefinition definition) {
if (definition.getDiscoveryUrl() == null) { if (definition.getDiscoveryUrl() == null) {
return definition; return definition;
} }


byte[] oidcJson = contentCache.getUrlContent(definition.getDiscoveryUrl().toString(), restTemplateFactory.getRestTemplate(definition.isSkipSslValidation())); boolean skipSslValidation = definition.isSkipSslValidation();
Map<String,Object> oidcConfig = JsonUtils.readValue(oidcJson, new TypeReference<Map<String, Object>>() {}); byte[] oidcJson;
if (skipSslValidation) {
oidcJson = contentCache.getUrlContent(definition.getDiscoveryUrl().toString(), trustingRestTemplate);
} else {
oidcJson = contentCache.getUrlContent(definition.getDiscoveryUrl().toString(), nonTrustingRestTemplate);
}

Map<String, Object> oidcConfig = JsonUtils.readValue(oidcJson, new TypeReference<Map<String, Object>>() {
});


OIDCIdentityProviderDefinition overlayedDefinition = null; OIDCIdentityProviderDefinition overlayedDefinition = null;
try { try {
Expand All @@ -80,9 +92,7 @@ protected OIDCIdentityProviderDefinition overlay(OIDCIdentityProviderDefinition
overlayedDefinition.setTokenUrl(ofNullable(overlayedDefinition.getTokenUrl()).orElse(tokenEndpoint)); overlayedDefinition.setTokenUrl(ofNullable(overlayedDefinition.getTokenUrl()).orElse(tokenEndpoint));
overlayedDefinition.setIssuer(ofNullable(overlayedDefinition.getIssuer()).orElse(issuer)); overlayedDefinition.setIssuer(ofNullable(overlayedDefinition.getIssuer()).orElse(issuer));
overlayedDefinition.setTokenKeyUrl(ofNullable(overlayedDefinition.getTokenKeyUrl()).orElse(tokenKeyUrl)); overlayedDefinition.setTokenKeyUrl(ofNullable(overlayedDefinition.getTokenKeyUrl()).orElse(tokenKeyUrl));
} catch (MalformedURLException e) { } catch (MalformedURLException | CloneNotSupportedException e) {
throw new IllegalStateException(e);
} catch (CloneNotSupportedException e) {
throw new IllegalStateException(e); throw new IllegalStateException(e);
} }


Expand All @@ -92,15 +102,15 @@ protected OIDCIdentityProviderDefinition overlay(OIDCIdentityProviderDefinition
public String getCompleteAuthorizationURI(String alias, String baseURL, AbstractXOAuthIdentityProviderDefinition definition) { public String getCompleteAuthorizationURI(String alias, String baseURL, AbstractXOAuthIdentityProviderDefinition definition) {
try { try {
String authUrlBase; String authUrlBase;
if(definition instanceof OIDCIdentityProviderDefinition) { if (definition instanceof OIDCIdentityProviderDefinition) {
authUrlBase = overlay((OIDCIdentityProviderDefinition) definition).getAuthUrl().toString(); authUrlBase = overlay((OIDCIdentityProviderDefinition) definition).getAuthUrl().toString();
} else { } else {
authUrlBase = definition.getAuthUrl().toString(); authUrlBase = definition.getAuthUrl().toString();
} }
String queryAppendDelimiter = authUrlBase.contains("?") ? "&" : "?"; String queryAppendDelimiter = authUrlBase.contains("?") ? "&" : "?";
List<String> query = new ArrayList<>(); List<String> query = new ArrayList<>();
query.add("client_id=" + definition.getRelyingPartyId()); query.add("client_id=" + definition.getRelyingPartyId());
query.add("response_type="+ URLEncoder.encode(definition.getResponseType(), "UTF-8")); query.add("response_type=" + URLEncoder.encode(definition.getResponseType(), "UTF-8"));
query.add("redirect_uri=" + URLEncoder.encode(baseURL + "/login/callback/" + alias, "UTF-8")); query.add("redirect_uri=" + URLEncoder.encode(baseURL + "/login/callback/" + alias, "UTF-8"));
if (definition.getScopes() != null && !definition.getScopes().isEmpty()) { if (definition.getScopes() != null && !definition.getScopes().isEmpty()) {
query.add("scope=" + URLEncoder.encode(String.join(" ", definition.getScopes()), "UTF-8")); query.add("scope=" + URLEncoder.encode(String.join(" ", definition.getScopes()), "UTF-8"));
Expand Down Expand Up @@ -142,10 +152,10 @@ public List<IdentityProvider> retrieveActive(String zoneId) {


public IdentityProvider retrieveByIssuer(String issuer, String zoneId) throws IncorrectResultSizeDataAccessException { public IdentityProvider retrieveByIssuer(String issuer, String zoneId) throws IncorrectResultSizeDataAccessException {
List<IdentityProvider> providers = retrieveAll(true, zoneId) List<IdentityProvider> providers = retrieveAll(true, zoneId)
.stream() .stream()
.filter(p -> OIDC10.equals(p.getType()) && .filter(p -> OIDC10.equals(p.getType()) &&
issuer.equals(((OIDCIdentityProviderDefinition)p.getConfig()).getIssuer())) issuer.equals(((OIDCIdentityProviderDefinition) p.getConfig()).getIssuer()))
.collect(Collectors.toList()); .collect(Collectors.toList());
if (providers.isEmpty()) { if (providers.isEmpty()) {
throw new IncorrectResultSizeDataAccessException(String.format("Active provider with issuer[%s] not found", issuer), 1); throw new IncorrectResultSizeDataAccessException(String.format("Active provider with issuer[%s] not found", issuer), 1);
} else if (providers.size() > 1) { } else if (providers.size() > 1) {
Expand All @@ -160,19 +170,19 @@ public List<IdentityProvider> retrieveAll(boolean activeOnly, String zoneId) {
List<IdentityProvider> providers = providerProvisioning.retrieveAll(activeOnly, zoneId); List<IdentityProvider> providers = providerProvisioning.retrieveAll(activeOnly, zoneId);
List<IdentityProvider> overlayedProviders = new ArrayList<>(); List<IdentityProvider> overlayedProviders = new ArrayList<>();
ofNullable(providers).orElse(emptyList()).stream() ofNullable(providers).orElse(emptyList()).stream()
.filter(p -> types.contains(p.getType())) .filter(p -> types.contains(p.getType()))
.forEach(p -> { .forEach(p -> {
if (p.getType().equals(OIDC10)) { if (p.getType().equals(OIDC10)) {
try { try {
OIDCIdentityProviderDefinition overlayedDefinition = overlay((OIDCIdentityProviderDefinition) p.getConfig()); OIDCIdentityProviderDefinition overlayedDefinition = overlay((OIDCIdentityProviderDefinition) p.getConfig());
p.setConfig(overlayedDefinition); p.setConfig(overlayedDefinition);
} catch (Exception e) { } catch (Exception e) {
log.error("Identity provider excluded from login page due to a problem.", e); log.error("Identity provider excluded from login page due to a problem.", e);
return; return;
} }
} }
overlayedProviders.add(p); overlayedProviders.add(p);
}); });
return overlayedProviders; return overlayedProviders;
} }


Expand Down
Expand Up @@ -65,6 +65,7 @@
import org.springframework.security.saml.trust.httpclient.TLSProtocolConfigurer; import org.springframework.security.saml.trust.httpclient.TLSProtocolConfigurer;
import org.springframework.security.saml.util.SAMLUtil; import org.springframework.security.saml.util.SAMLUtil;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClientException;
import org.w3c.dom.Element; import org.w3c.dom.Element;


import javax.xml.namespace.QName; import javax.xml.namespace.QName;
Expand Down Expand Up @@ -154,7 +155,7 @@ public List<ExtendedMetadataDelegate> getAvailableProviders() {
initializeProviderData(delegate); initializeProviderData(delegate);
initializeProviderFilters(delegate); initializeProviderFilters(delegate);
result.add(delegate); result.add(delegate);
} catch (MetadataProviderException e) { } catch (RestClientException | MetadataProviderException e) {
log.error("Invalid SAML IDP zone[" + zone.getId() + "] alias[" + definition.getIdpEntityAlias() + "]", e); log.error("Invalid SAML IDP zone[" + zone.getId() + "] alias[" + definition.getIdpEntityAlias() + "]", e);
} }
} }
Expand Down
Expand Up @@ -28,6 +28,8 @@
import org.springframework.security.saml.metadata.ExtendedMetadata; import org.springframework.security.saml.metadata.ExtendedMetadata;
import org.springframework.security.saml.metadata.ExtendedMetadataDelegate; import org.springframework.security.saml.metadata.ExtendedMetadataDelegate;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.client.ResourceAccessException;
import org.springframework.web.client.RestClientException;


import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -163,6 +165,8 @@ protected ExtendedMetadataDelegate configureURLMetadata(SamlServiceProvider prov
byte[] metadata; byte[] metadata;
try { try {
metadata = fixedHttpMetaDataProvider.fetchMetadata(def.getMetaDataLocation(), def.isSkipSslValidation()); metadata = fixedHttpMetaDataProvider.fetchMetadata(def.getMetaDataLocation(), def.isSkipSslValidation());
} catch (RestClientException e) {
throw new MetadataProviderException("Unavailable Metadata Provider", e);
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
throw new MetadataProviderException("Invalid metadata URI: " + def.getMetaDataLocation(), e); throw new MetadataProviderException("Invalid metadata URI: " + def.getMetaDataLocation(), e);
} }
Expand Down

This file was deleted.

Expand Up @@ -41,10 +41,6 @@ public abstract class UaaHttpRequestUtils {


private static Log logger = LogFactory.getLog(UaaHttpRequestUtils.class); private static Log logger = LogFactory.getLog(UaaHttpRequestUtils.class);


public static ClientHttpRequestFactory createRequestFactory() {
return createRequestFactory(false, -1);
}

public static ClientHttpRequestFactory createRequestFactory(boolean skipSslValidation, int timeout) { public static ClientHttpRequestFactory createRequestFactory(boolean skipSslValidation, int timeout) {
return createRequestFactory(getClientBuilder(skipSslValidation), timeout); return createRequestFactory(getClientBuilder(skipSslValidation), timeout);
} }
Expand Down

0 comments on commit 9a82248

Please sign in to comment.