Skip to content

Commit

Permalink
Add logic to allow login page to render when an idp timeout occurs
Browse files Browse the repository at this point in the history
- Refactor services to use a trusting/nontrusting resttemplate that has timeout configured

[#157525402] https://www.pivotaltracker.com/story/show/157525402

Signed-off-by: Bruce Ricard <bruce.ricard@gmail.com>
  • Loading branch information
DennisDenuto authored and bruce-ricard committed May 17, 2018
1 parent 8197f29 commit 9a82248
Show file tree
Hide file tree
Showing 24 changed files with 899 additions and 785 deletions.
Expand Up @@ -229,6 +229,7 @@ public static class Builder {
private int singleSignOnServiceIndex;
private boolean metadataTrustCheck;
private boolean enableIdpInitiatedSso = false;
private boolean skipSslValidation = true;

private Builder(){}

Expand All @@ -243,6 +244,7 @@ public SamlServiceProviderDefinition build() {
def.setSingleSignOnServiceIndex(singleSignOnServiceIndex);
def.setMetadataTrustCheck(metadataTrustCheck);
def.setEnableIdpInitiatedSso(enableIdpInitiatedSso);
def.setSkipSslValidation(skipSslValidation);
return def;
}

Expand All @@ -256,6 +258,11 @@ public Builder setNameID(String nameID) {
return this;
}

public Builder setSkipSSLValidation(boolean skipSslValidation) {
this.skipSslValidation = skipSslValidation;
return this;
}

public Builder setSingleSignOnServiceIndex(int singleSignOnServiceIndex) {
this.singleSignOnServiceIndex = singleSignOnServiceIndex;
return this;
Expand Down
Expand Up @@ -65,7 +65,7 @@ public byte[] getUrlContent(String uri, final RestTemplate template) {
return metadata;
} catch (RestClientException x) {
logger.warn("Unable to fetch metadata for "+uri, x);
return null;
throw x;
} catch (URISyntaxException e) {
throw new IllegalArgumentException(e);
}
Expand Down
@@ -1,20 +1,23 @@
package org.cloudfoundry.identity.uaa.impl.config;

import org.cloudfoundry.identity.uaa.util.UaaHttpRequestUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;

@Configuration
public class RestTemplateConfig {
@Value("${rest.template.timeout:10000}")
public int timeout;

@Bean
public RestTemplate nonTrustingRestTemplate() {
return new RestTemplate(UaaHttpRequestUtils.createRequestFactory(false, 30_000));
return new RestTemplate(UaaHttpRequestUtils.createRequestFactory(false, timeout));
}

@Bean
public RestTemplate trustingRestTemplate() {
return new RestTemplate(UaaHttpRequestUtils.createRequestFactory(true, 30_000));
return new RestTemplate(UaaHttpRequestUtils.createRequestFactory(true, timeout));
}
}
Expand Up @@ -13,6 +13,10 @@

package org.cloudfoundry.identity.uaa.provider.oauth;

import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.cloudfoundry.identity.uaa.authentication.UaaAuthentication;
import org.cloudfoundry.identity.uaa.authentication.manager.ExternalGroupAuthorizationEvent;
import org.cloudfoundry.identity.uaa.authentication.manager.ExternalLoginAuthenticationManager;
Expand All @@ -38,15 +42,9 @@
import org.cloudfoundry.identity.uaa.user.UaaUserPrototype;
import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.cloudfoundry.identity.uaa.util.LinkedMaskingMultiValueMap;
import org.cloudfoundry.identity.uaa.util.RestTemplateFactory;
import org.cloudfoundry.identity.uaa.util.TokenValidation;
import org.cloudfoundry.identity.uaa.util.UaaStringUtils;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;

import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.http.HttpEntity;
Expand Down Expand Up @@ -90,6 +88,8 @@
import java.util.Set;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
import static java.util.Optional.ofNullable;
import static org.cloudfoundry.identity.uaa.oauth.jwk.JsonWebKey.KeyType.MAC;
import static org.cloudfoundry.identity.uaa.oauth.jwk.JsonWebKey.KeyType.RSA;
import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.SUB;
Expand All @@ -104,23 +104,26 @@
import static org.cloudfoundry.identity.uaa.util.UaaHttpRequestUtils.isAcceptedInvitationAuthentication;
import static org.springframework.util.StringUtils.hasText;
import static org.springframework.util.StringUtils.isEmpty;
import static java.util.Collections.emptyList;
import static java.util.Optional.ofNullable;

public class XOAuthAuthenticationManager extends ExternalLoginAuthenticationManager<XOAuthAuthenticationManager.AuthenticationData> {

public static Log logger = LogFactory.getLog(XOAuthAuthenticationManager.class);

private final RestTemplateFactory restTemplateFactory;
private final RestTemplate trustingRestTemplate;
private final RestTemplate nonTrustingRestTemplate;


private UaaTokenServices tokenServices;

//origin is per thread during execution
private final ThreadLocal<String> origin = ThreadLocal.withInitial(() -> "unknown");

public XOAuthAuthenticationManager(IdentityProviderProvisioning providerProvisioning, RestTemplateFactory restTemplateFactory) {
public XOAuthAuthenticationManager(IdentityProviderProvisioning providerProvisioning,
RestTemplate trustingRestTemplate,
RestTemplate nonTrustingRestTemplate) {
super(providerProvisioning);
this.restTemplateFactory = restTemplateFactory;
this.trustingRestTemplate = trustingRestTemplate;
this.nonTrustingRestTemplate = nonTrustingRestTemplate;
}

@Override
Expand Down Expand Up @@ -420,7 +423,11 @@ protected boolean isAddNewShadowUser() {
}

public RestTemplate getRestTemplate(AbstractXOAuthIdentityProviderDefinition config) {
return restTemplateFactory.getRestTemplate(config.isSkipSslValidation());
if (config.isSkipSslValidation()) {
return trustingRestTemplate;
} else {
return nonTrustingRestTemplate;
}
}

protected String getResponseType(AbstractXOAuthIdentityProviderDefinition config) {
Expand Down
Expand Up @@ -24,9 +24,9 @@
import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning;
import org.cloudfoundry.identity.uaa.provider.OIDCIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.cloudfoundry.identity.uaa.util.RestTemplateFactory;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;
import org.springframework.web.client.RestTemplate;

import java.io.UnsupportedEncodingException;
import java.net.MalformedURLException;
Expand All @@ -36,6 +36,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
Expand All @@ -49,23 +50,34 @@ public class XOAuthProviderConfigurator implements IdentityProviderProvisioning

private final IdentityProviderProvisioning providerProvisioning;
private final UrlContentCache contentCache;
private final RestTemplateFactory restTemplateFactory;
private final RestTemplate trustingRestTemplate;
private final RestTemplate nonTrustingRestTemplate;

public XOAuthProviderConfigurator(IdentityProviderProvisioning providerProvisioning,
UrlContentCache contentCache,
RestTemplateFactory restTemplateFactory) {
RestTemplate trustingRestTemplate,
RestTemplate nonTrustingRestTemplate) {
this.providerProvisioning = providerProvisioning;
this.contentCache = contentCache;
this.restTemplateFactory = restTemplateFactory;
this.trustingRestTemplate = trustingRestTemplate;
this.nonTrustingRestTemplate = nonTrustingRestTemplate;
}

protected OIDCIdentityProviderDefinition overlay(OIDCIdentityProviderDefinition definition) {
if (definition.getDiscoveryUrl() == null) {
return definition;
}

byte[] oidcJson = contentCache.getUrlContent(definition.getDiscoveryUrl().toString(), restTemplateFactory.getRestTemplate(definition.isSkipSslValidation()));
Map<String,Object> oidcConfig = JsonUtils.readValue(oidcJson, new TypeReference<Map<String, Object>>() {});
boolean skipSslValidation = definition.isSkipSslValidation();
byte[] oidcJson;
if (skipSslValidation) {
oidcJson = contentCache.getUrlContent(definition.getDiscoveryUrl().toString(), trustingRestTemplate);
} else {
oidcJson = contentCache.getUrlContent(definition.getDiscoveryUrl().toString(), nonTrustingRestTemplate);
}

Map<String, Object> oidcConfig = JsonUtils.readValue(oidcJson, new TypeReference<Map<String, Object>>() {
});

OIDCIdentityProviderDefinition overlayedDefinition = null;
try {
Expand All @@ -80,9 +92,7 @@ protected OIDCIdentityProviderDefinition overlay(OIDCIdentityProviderDefinition
overlayedDefinition.setTokenUrl(ofNullable(overlayedDefinition.getTokenUrl()).orElse(tokenEndpoint));
overlayedDefinition.setIssuer(ofNullable(overlayedDefinition.getIssuer()).orElse(issuer));
overlayedDefinition.setTokenKeyUrl(ofNullable(overlayedDefinition.getTokenKeyUrl()).orElse(tokenKeyUrl));
} catch (MalformedURLException e) {
throw new IllegalStateException(e);
} catch (CloneNotSupportedException e) {
} catch (MalformedURLException | CloneNotSupportedException e) {
throw new IllegalStateException(e);
}

Expand All @@ -92,15 +102,15 @@ protected OIDCIdentityProviderDefinition overlay(OIDCIdentityProviderDefinition
public String getCompleteAuthorizationURI(String alias, String baseURL, AbstractXOAuthIdentityProviderDefinition definition) {
try {
String authUrlBase;
if(definition instanceof OIDCIdentityProviderDefinition) {
if (definition instanceof OIDCIdentityProviderDefinition) {
authUrlBase = overlay((OIDCIdentityProviderDefinition) definition).getAuthUrl().toString();
} else {
authUrlBase = definition.getAuthUrl().toString();
}
String queryAppendDelimiter = authUrlBase.contains("?") ? "&" : "?";
List<String> query = new ArrayList<>();
query.add("client_id=" + definition.getRelyingPartyId());
query.add("response_type="+ URLEncoder.encode(definition.getResponseType(), "UTF-8"));
query.add("response_type=" + URLEncoder.encode(definition.getResponseType(), "UTF-8"));
query.add("redirect_uri=" + URLEncoder.encode(baseURL + "/login/callback/" + alias, "UTF-8"));
if (definition.getScopes() != null && !definition.getScopes().isEmpty()) {
query.add("scope=" + URLEncoder.encode(String.join(" ", definition.getScopes()), "UTF-8"));
Expand Down Expand Up @@ -142,10 +152,10 @@ public List<IdentityProvider> retrieveActive(String zoneId) {

public IdentityProvider retrieveByIssuer(String issuer, String zoneId) throws IncorrectResultSizeDataAccessException {
List<IdentityProvider> providers = retrieveAll(true, zoneId)
.stream()
.filter(p -> OIDC10.equals(p.getType()) &&
issuer.equals(((OIDCIdentityProviderDefinition)p.getConfig()).getIssuer()))
.collect(Collectors.toList());
.stream()
.filter(p -> OIDC10.equals(p.getType()) &&
issuer.equals(((OIDCIdentityProviderDefinition) p.getConfig()).getIssuer()))
.collect(Collectors.toList());
if (providers.isEmpty()) {
throw new IncorrectResultSizeDataAccessException(String.format("Active provider with issuer[%s] not found", issuer), 1);
} else if (providers.size() > 1) {
Expand All @@ -160,19 +170,19 @@ public List<IdentityProvider> retrieveAll(boolean activeOnly, String zoneId) {
List<IdentityProvider> providers = providerProvisioning.retrieveAll(activeOnly, zoneId);
List<IdentityProvider> overlayedProviders = new ArrayList<>();
ofNullable(providers).orElse(emptyList()).stream()
.filter(p -> types.contains(p.getType()))
.forEach(p -> {
if (p.getType().equals(OIDC10)) {
try {
OIDCIdentityProviderDefinition overlayedDefinition = overlay((OIDCIdentityProviderDefinition) p.getConfig());
p.setConfig(overlayedDefinition);
} catch (Exception e) {
log.error("Identity provider excluded from login page due to a problem.", e);
return;
}
}
overlayedProviders.add(p);
});
.filter(p -> types.contains(p.getType()))
.forEach(p -> {
if (p.getType().equals(OIDC10)) {
try {
OIDCIdentityProviderDefinition overlayedDefinition = overlay((OIDCIdentityProviderDefinition) p.getConfig());
p.setConfig(overlayedDefinition);
} catch (Exception e) {
log.error("Identity provider excluded from login page due to a problem.", e);
return;
}
}
overlayedProviders.add(p);
});
return overlayedProviders;
}

Expand Down
Expand Up @@ -65,6 +65,7 @@
import org.springframework.security.saml.trust.httpclient.TLSProtocolConfigurer;
import org.springframework.security.saml.util.SAMLUtil;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClientException;
import org.w3c.dom.Element;

import javax.xml.namespace.QName;
Expand Down Expand Up @@ -154,7 +155,7 @@ public List<ExtendedMetadataDelegate> getAvailableProviders() {
initializeProviderData(delegate);
initializeProviderFilters(delegate);
result.add(delegate);
} catch (MetadataProviderException e) {
} catch (RestClientException | MetadataProviderException e) {
log.error("Invalid SAML IDP zone[" + zone.getId() + "] alias[" + definition.getIdpEntityAlias() + "]", e);
}
}
Expand Down
Expand Up @@ -28,6 +28,8 @@
import org.springframework.security.saml.metadata.ExtendedMetadata;
import org.springframework.security.saml.metadata.ExtendedMetadataDelegate;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResourceAccessException;
import org.springframework.web.client.RestClientException;

import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -163,6 +165,8 @@ protected ExtendedMetadataDelegate configureURLMetadata(SamlServiceProvider prov
byte[] metadata;
try {
metadata = fixedHttpMetaDataProvider.fetchMetadata(def.getMetaDataLocation(), def.isSkipSslValidation());
} catch (RestClientException e) {
throw new MetadataProviderException("Unavailable Metadata Provider", e);
} catch (URISyntaxException e) {
throw new MetadataProviderException("Invalid metadata URI: " + def.getMetaDataLocation(), e);
}
Expand Down

This file was deleted.

Expand Up @@ -41,10 +41,6 @@ public abstract class UaaHttpRequestUtils {

private static Log logger = LogFactory.getLog(UaaHttpRequestUtils.class);

public static ClientHttpRequestFactory createRequestFactory() {
return createRequestFactory(false, -1);
}

public static ClientHttpRequestFactory createRequestFactory(boolean skipSslValidation, int timeout) {
return createRequestFactory(getClientBuilder(skipSslValidation), timeout);
}
Expand Down

0 comments on commit 9a82248

Please sign in to comment.