Skip to content

Commit

Permalink
BREAKING Change - by default always validate SSL certificates
Browse files Browse the repository at this point in the history
Remove separate proxy usage
Remove socket factory class - this is no longer used
[#132273335] https://www.pivotaltracker.com/story/show/132273335
  • Loading branch information
fhanik committed Oct 14, 2016
1 parent 7d979bc commit 81dfdee
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 339 deletions.
Expand Up @@ -33,9 +33,6 @@
@JsonIgnoreProperties(ignoreUnknown = true)
public class SamlIdentityProviderDefinition extends ExternalIdentityProviderDefinition {

public static final String DEFAULT_HTTP_SOCKET_FACTORY = "org.apache.commons.httpclient.protocol.DefaultProtocolSocketFactory";
public static final String DEFAULT_HTTPS_SOCKET_FACTORY = "org.apache.commons.httpclient.contrib.ssl.EasySSLProtocolSocketFactory";

public enum MetadataLocation {
URL,
DATA,
Expand All @@ -54,7 +51,6 @@ public enum ExternalGroupMappingMode {
private int assertionConsumerIndex;
private boolean metadataTrustCheck;
private boolean showSamlLink;
private String socketFactoryClassName;
private String linkText;
private String iconUrl;
private ExternalGroupMappingMode groupMappingMode = ExternalGroupMappingMode.EXPLICITLY_MAPPED;
Expand Down Expand Up @@ -187,34 +183,11 @@ public void setGroupMappingMode(ExternalGroupMappingMode asScopes) {
}

public String getSocketFactoryClassName() {
if (socketFactoryClassName!=null && socketFactoryClassName.trim().length()>0) {
return socketFactoryClassName;
}
if (getMetaDataLocation()==null || getMetaDataLocation().trim().length()==0) {
throw new IllegalStateException("Invalid meta data URL[" + getMetaDataLocation() + "] cannot determine socket factory.");
}
if (getMetaDataLocation().startsWith("https")) {
return DEFAULT_HTTPS_SOCKET_FACTORY;
} else {
return DEFAULT_HTTP_SOCKET_FACTORY;
}
return null;
}

public SamlIdentityProviderDefinition setSocketFactoryClassName(String socketFactoryClassName) {
if (socketFactoryClassName!=null && socketFactoryClassName.trim().length()>0) {
try {
Class.forName(
socketFactoryClassName,
true,
Thread.currentThread().getContextClassLoader()
);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException(e);
} catch (ClassCastException e) {
throw new IllegalArgumentException(e);
}
}
this.socketFactoryClassName = socketFactoryClassName;
//no op
return this;
}

Expand Down Expand Up @@ -283,7 +256,8 @@ public String toString() {
", assertionConsumerIndex=" + assertionConsumerIndex +
", metadataTrustCheck=" + metadataTrustCheck +
", showSamlLink=" + showSamlLink +
", socketFactoryClassName='" + socketFactoryClassName + '\'' +
", socketFactoryClassName='deprected-not used'" +
", skipSslValidation=" + skipSslValidation +
", linkText='" + linkText + '\'' +
", iconUrl='" + iconUrl + '\'' +
", zoneId='" + zoneId + '\'' +
Expand Down
Expand Up @@ -12,19 +12,17 @@
*******************************************************************************/
package org.cloudfoundry.identity.uaa.provider.saml.idp;

import java.io.IOException;
import java.io.StringReader;
import java.net.MalformedURLException;
import java.net.URL;
import com.fasterxml.jackson.annotation.JsonIgnore;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;

import org.xml.sax.InputSource;
import org.xml.sax.SAXException;

import com.fasterxml.jackson.annotation.JsonIgnore;
import java.io.IOException;
import java.io.StringReader;
import java.net.MalformedURLException;
import java.net.URL;

public class SamlServiceProviderDefinition {

Expand All @@ -38,24 +36,28 @@ public enum MetadataLocation {
private String nameID;
private int singleSignOnServiceIndex;
private boolean metadataTrustCheck;
private boolean skipSslValidation = false;

public SamlServiceProviderDefinition clone() {
return new SamlServiceProviderDefinition(metaDataLocation,
nameID,
singleSignOnServiceIndex,
metadataTrustCheck);
nameID,
singleSignOnServiceIndex,
metadataTrustCheck,
skipSslValidation);
}

public SamlServiceProviderDefinition() {}

public SamlServiceProviderDefinition(String metaDataLocation,
String nameID,
int singleSignOnServiceIndex,
boolean metadataTrustCheck) {
boolean metadataTrustCheck,
boolean skipSslValidation) {
this.metaDataLocation = metaDataLocation;
this.nameID = nameID;
this.singleSignOnServiceIndex = singleSignOnServiceIndex;
this.metadataTrustCheck = metadataTrustCheck;
this.skipSslValidation = skipSslValidation;
}

@JsonIgnore
Expand Down Expand Up @@ -144,6 +146,14 @@ public int hashCode() {
return result;
}

public boolean isSkipSslValidation() {
return skipSslValidation;
}

public void setSkipSslValidation(boolean skipSslValidation) {
this.skipSslValidation = skipSslValidation;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
Expand Down
Expand Up @@ -120,6 +120,15 @@ public void setIdentityProviders(Map<String, Map<String, Object>> providers) {
String groupMappingMode = (String)((Map)entry.getValue()).get("groupMappingMode");
String providerDescription = (String)((Map)entry.getValue()).get(PROVIDER_DESCRIPTION);
Boolean addShadowUserOnLogin = (Boolean)((Map)entry.getValue()).get("addShadowUserOnLogin");
Boolean skipSslValidation = (Boolean)((Map)entry.getValue()).get("skipSslValidation");
if (skipSslValidation==null) {
if (socketFactoryClassName != null) {
skipSslValidation = false;
} else {
skipSslValidation = true;
}
}

List<String> emailDomain = (List<String>) saml.get(EMAIL_DOMAIN_ATTR);
List<String> externalGroupsWhitelist = (List<String>) saml.get(EXTERNAL_GROUPS_WHITELIST);
Map<String, Object> attributeMappings = (Map<String, Object>) saml.get(ATTRIBUTE_MAPPINGS);
Expand Down Expand Up @@ -148,6 +157,7 @@ public void setIdentityProviders(Map<String, Map<String, Object>> providers) {
def.setAttributeMappings(attributeMappings);
def.setZoneId(hasText(zoneId) ? zoneId : IdentityZone.getUaa().getId());
def.setAddShadowUserOnLogin(addShadowUserOnLogin==null?true:addShadowUserOnLogin);
def.setSkipSslValidation(skipSslValidation);
toBeFetchedProviders.add(def);
}
}
Expand Down
Expand Up @@ -16,17 +16,17 @@
import com.google.common.base.Ticker;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import org.apache.commons.httpclient.HostConfiguration;
import org.apache.commons.httpclient.HttpClient;
import org.apache.commons.httpclient.SimpleHttpConnectionManager;
import org.apache.commons.httpclient.params.HttpClientParams;
import org.apache.commons.httpclient.protocol.ProtocolSocketFactory;
import org.opensaml.saml2.metadata.provider.HTTPMetadataProvider;
import org.opensaml.saml2.metadata.provider.MetadataProviderException;
import org.springframework.web.client.RestTemplate;

import java.net.URISyntaxException;
import java.util.Timer;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

/**
* This class works around the problem described in <a href="http://issues.apache.org/jira/browse/HTTPCLIENT-646">http://issues.apache.org/jira/browse/HTTPCLIENT-646</a> when a socket factory is set
Expand All @@ -41,133 +41,91 @@
*/
public class FixedHttpMetaDataProvider extends HTTPMetadataProvider {

private static final byte[] CLASS_DEF = new byte[0];
private static AtomicLong expirationTimeMillis = new AtomicLong(10*60*1000); //10 minutes refresh on the URL fetch

/**
* Track if we have a custom socket factory
*/
private boolean socketFactorySet = false;
private long lastFetchTime = 0;
private static long expirationTimeMillis = 10*60*1000; //10 minutes refresh on the URL fetch
private static Ticker ticker = new Ticker() {
@Override
public long read() {
return System.nanoTime();
}
};
private RestTemplate template;

protected static Cache<String, byte[]> metadataCache = buildCache();
protected static volatile Cache<String, CacheEntry> metadataCache = buildCache();

protected static Cache<String, byte[]> buildCache() {
protected static Cache<String, CacheEntry> buildCache() {
return CacheBuilder
.newBuilder()
.expireAfterWrite(expirationTimeMillis, TimeUnit.MILLISECONDS)
.expireAfterWrite(expirationTimeMillis.get(), TimeUnit.MILLISECONDS)
.maximumSize(20000)
.ticker(ticker)
.build();
}

public static FixedHttpMetaDataProvider buildProvider(Timer backgroundTaskTimer, HttpClientParams params, String metadataURL) throws MetadataProviderException {
public static FixedHttpMetaDataProvider buildProvider(Timer backgroundTaskTimer,
HttpClientParams params,
String metadataURL,
RestTemplate template) throws MetadataProviderException {
SimpleHttpConnectionManager connectionManager = new SimpleHttpConnectionManager(true);
connectionManager.getParams().setDefaults(params);
HttpClient client = new HttpClient(connectionManager);
configureProxyIfNeeded(client, metadataURL);
return new FixedHttpMetaDataProvider(backgroundTaskTimer, client, metadataURL);
return new FixedHttpMetaDataProvider(backgroundTaskTimer, client, metadataURL, template);
}

private FixedHttpMetaDataProvider(Timer backgroundTaskTimer, HttpClient client, String metadataURL) throws MetadataProviderException {
private FixedHttpMetaDataProvider(Timer backgroundTaskTimer,
HttpClient client,
String metadataURL,
RestTemplate template) throws MetadataProviderException {
super(backgroundTaskTimer, client, metadataURL);
this.template = template;
}

public static void configureProxyIfNeeded(HttpClient client, String metadataURL) {
if (System.getProperty("http.proxyHost")!=null && System.getProperty("http.proxyPort")!=null && metadataURL.toLowerCase().startsWith("http://")) {
setProxy(client, "http");
} else if (System.getProperty("https.proxyHost")!=null && System.getProperty("https.proxyPort")!=null && metadataURL.toLowerCase().startsWith("https://")) {
setProxy(client, "https");
}
}

protected static void setProxy(HttpClient client, String prefix) {
try {
String host = System.getProperty(prefix + ".proxyHost");
int port = Integer.parseInt(System.getProperty(prefix + ".proxyPort"));
HostConfiguration configuration = client.getHostConfiguration();
configuration.setProxy(host, port);
} catch (NumberFormatException e) {
throw new IllegalStateException("Invalid proxy port configured:"+System.getProperty(prefix + ".proxyPort"));
}
}


@Override
public byte[] fetchMetadata() throws MetadataProviderException {
byte[] metadata = metadataCache.getIfPresent(getMetadataURI());
if (metadata==null || (System.currentTimeMillis()-lastFetchTime)>getExpirationTimeMillis()) {
metadata = super.fetchMetadata();
lastFetchTime = System.currentTimeMillis();
metadataCache.put(getMetadataURI(), metadata);
CacheEntry entry = metadataCache.getIfPresent(getMetadataURI());
byte[] metadata = entry != null ? entry.getData() : null;
if (metadata==null || (System.currentTimeMillis()-entry.getTimeEntered())>getExpirationTimeMillis()) {
metadata = template.getForObject(getMetadataURI(), CLASS_DEF.getClass());
metadataCache.put(getMetadataURI(), new CacheEntry(System.currentTimeMillis(), metadata));
}
return metadata;
}

/**
* {@inheritDoc}
*/
@Override
public void setSocketFactory(ProtocolSocketFactory newSocketFactory) {
// TODO Auto-generated method stub
super.setSocketFactory(newSocketFactory);
if (newSocketFactory != null) {
socketFactorySet = true;
} else {
socketFactorySet = false;
}

public static long getExpirationTimeMillis() {
return expirationTimeMillis.get();
}

/**
* If a custom socket factory has been set, only
* return a relative URL so that the custom factory is retained.
* This works around
* https://issues.apache.org/jira/browse/HTTPCLIENT-646 {@inheritDoc}
*/
@Override
public String getMetadataURI() {
if (isSocketFactorySet()) {
java.net.URI uri;
try {
uri = new java.net.URI(super.getMetadataURI());
String result = uri.getPath();
if (uri.getQuery() != null && uri.getQuery().trim().length() > 0) {
result = result + "?" + uri.getQuery();
}
return result;
} catch (URISyntaxException e) {
// this can never happen, satisfy compiler
throw new IllegalArgumentException(e);
}
} else {
return super.getMetadataURI();
public static void setExpirationTimeMillis(long expirationTimeMillis) {
if (FixedHttpMetaDataProvider.expirationTimeMillis.getAndSet(expirationTimeMillis) != expirationTimeMillis) {
metadataCache = buildCache();
}
}

public boolean isSocketFactorySet() {
return socketFactorySet;
public static void setTicker(Ticker ticker) {
if (ticker != FixedHttpMetaDataProvider.ticker) {
FixedHttpMetaDataProvider.ticker = ticker;
metadataCache = buildCache();
}
}

public long getExpirationTimeMillis() {
return expirationTimeMillis;
}
static class CacheEntry {
private final long timeEntered;
private final byte[] data;

public void setExpirationTimeMillis(long expirationTimeMillis) {
this.expirationTimeMillis = expirationTimeMillis;
metadataCache = buildCache();
}
public CacheEntry(long timeEntered, byte[] data) {
this.timeEntered = timeEntered;
this.data = data;
}

public Ticker getTicker() {
return ticker;
}
public long getTimeEntered() {
return timeEntered;
}

public byte[] getData() {
return data;
}

public void setTicker(Ticker ticker) {
FixedHttpMetaDataProvider.ticker = ticker;
metadataCache = buildCache();
}
}

0 comments on commit 81dfdee

Please sign in to comment.