Skip to content

Commit

Permalink
wip: zoned metadata fixes and zoned login
Browse files Browse the repository at this point in the history
- unit tests passing
  • Loading branch information
peterhaochen47 committed Jul 12, 2024
1 parent 748f5f2 commit b58a599
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ public class ConfiguratorRelyingPartyRegistrationRepository
private final KeyWithCert keyWithCert;
private final String samlEntityID;

public ConfiguratorRelyingPartyRegistrationRepository(@Qualifier("samlEntityID") String samlEntityID,
private final String samlEntityIDAlias; // TODO consider renaming this to indicate UAA wide

public ConfiguratorRelyingPartyRegistrationRepository(String samlEntityID,
String samlEntityIDAlias,
KeyWithCert keyWithCert,
SamlIdentityProviderConfigurator configurator) {
Assert.notNull(configurator, "configurator cannot be null");
this.configurator = configurator;
this.keyWithCert = keyWithCert;
this.samlEntityID = samlEntityID;
this.samlEntityIDAlias = samlEntityIDAlias;
}

/**
Expand All @@ -43,34 +47,22 @@ public RelyingPartyRegistration findByRegistrationId(String registrationId) {
if (identityProviderDefinition.getIdpEntityAlias().equals(registrationId)) {

IdentityZone zone = retrieveZone();
String zonedSamlEntityID = zone.isUaa() ? samlEntityID : zone.getConfig().getSamlConfig().getEntityID();

// TODO code repetition?
String zonedSamlEntityIDAlias;
if (zone.isUaa()) { // default zone
zonedSamlEntityIDAlias = samlEntityIDAlias;
} else { // non-default zone
zonedSamlEntityIDAlias = "%s.%s".formatted(zone.getSubdomain(), samlEntityIDAlias);
}

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, identityProviderDefinition.getNameID(),
zonedSamlEntityID, identityProviderDefinition.getNameID(),
keyWithCert, identityProviderDefinition.getMetaDataLocation(),
registrationId, zone.getConfig().getSamlConfig().isRequestSigned());
registrationId, zonedSamlEntityIDAlias, zone.getConfig().getSamlConfig().isRequestSigned());
}
}
return buildDefaultRelyingPartyRegistration();
}

private RelyingPartyRegistration buildDefaultRelyingPartyRegistration() {
String samlEntityID, samlServiceUri;
IdentityZone zone = retrieveZone();
if (zone.isUaa()) {
samlEntityID = this.samlEntityID;
samlServiceUri = this.samlEntityID;
}
else if (zone.getConfig() != null && zone.getConfig().getSamlConfig() != null) {

samlEntityID = zone.getConfig().getSamlConfig().getEntityID();
samlServiceUri = zone.getSubdomain() + "." + this.samlEntityID;
}
else {
return null;
}

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, null,
keyWithCert, "dummy-saml-idp-metadata.xml", null,
samlServiceUri, zone.getConfig().getSamlConfig().isRequestSigned());
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package org.cloudfoundry.identity.uaa.provider.saml;

import org.cloudfoundry.identity.uaa.util.KeyWithCert;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.ZoneAware;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;

/**
* A {@link RelyingPartyRegistrationRepository} that always returns a default {@link RelyingPartyRegistrationRepository}.
*/
public class DefaultRelyingPartyRegistrationRepository implements RelyingPartyRegistrationRepository, ZoneAware {
public static final String CLASSPATH_DUMMY_SAML_IDP_METADATA_XML = "classpath:dummy-saml-idp-metadata.xml";

private final KeyWithCert keyWithCert;
private final String samlEntityID;

private final String samlEntityIDAlias; // TODO consider renaming this to indicate UAA wide

public DefaultRelyingPartyRegistrationRepository(String samlEntityID,
String samlEntityIDAlias,
KeyWithCert keyWithCert) {
this.keyWithCert = keyWithCert;
this.samlEntityID = samlEntityID;
this.samlEntityIDAlias = samlEntityIDAlias;
}

/**
* Returns the relying party registration identified by the provided
* {@code registrationId}, or {@code null} if not found.
*
* @param registrationId the registration identifier
* @return the {@link RelyingPartyRegistration} if found, otherwise {@code null}
*/
@Override
public RelyingPartyRegistration findByRegistrationId(String registrationId) {
IdentityZone zone = retrieveZone();

String zonedSamlEntityID;
if (!zone.isUaa() && zone.getConfig() != null && zone.getConfig().getSamlConfig() != null && zone.getConfig().getSamlConfig().getEntityID() != null) {
zonedSamlEntityID = zone.getConfig().getSamlConfig().getEntityID();
} else {
zonedSamlEntityID = this.samlEntityID;
}

// TODO is this repeating code?
String zonedSamlEntityIDAlias;
if (zone.isUaa()) { // default zone
zonedSamlEntityIDAlias = samlEntityIDAlias;
} else { // non-default zone
zonedSamlEntityIDAlias = "%s.%s".formatted(zone.getSubdomain(), samlEntityIDAlias);
}

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
zonedSamlEntityID, null,
keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, registrationId,
zonedSamlEntityIDAlias, zone.getConfig().getSamlConfig().isRequestSigned());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,10 @@ private RelyingPartyRegistrationBuilder() {
throw new java.lang.UnsupportedOperationException("This is a utility class and cannot be instantiated");
}

public static RelyingPartyRegistration buildRelyingPartyRegistration(
String samlEntityID, String samlSpNameId,
KeyWithCert keyWithCert,
String metadataLocation, String rpRegstrationId, boolean requestSigned) {
return buildRelyingPartyRegistration(samlEntityID, samlSpNameId,
keyWithCert, metadataLocation, rpRegstrationId,
samlEntityID, requestSigned);
}

public static RelyingPartyRegistration buildRelyingPartyRegistration(
String samlEntityID, String samlSpNameId,
KeyWithCert keyWithCert, String metadataLocation,
String rpRegstrationId, String samlServiceUri, boolean requestSigned) {
String rpRegstrationId, String samlSpAlias, boolean requestSigned) {
SamlIdentityProviderDefinition.MetadataLocation type = SamlIdentityProviderDefinition.getType(metadataLocation);

RelyingPartyRegistration.Builder builder;
Expand All @@ -51,14 +42,17 @@ public static RelyingPartyRegistration buildRelyingPartyRegistration(
builder = RelyingPartyRegistrations.fromMetadataLocation(metadataLocation);
}

// fallback to entityId if alias is not provided TODO has the falling back already happened?
samlSpAlias = samlSpAlias == null ? samlEntityID : samlSpAlias;

builder.entityId(samlEntityID);
if (samlSpNameId != null) builder.nameIdFormat(samlSpNameId);
if (rpRegstrationId != null) builder.registrationId(rpRegstrationId);
return builder
.assertionConsumerServiceLocation(assertionConsumerServiceLocationFunction.apply(samlServiceUri))
.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlServiceUri))
.singleLogoutServiceLocation(singleLogoutServiceLocationFunction.apply(samlServiceUri))
.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlServiceUri))
.assertionConsumerServiceLocation(assertionConsumerServiceLocationFunction.apply(samlSpAlias))
.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlSpAlias))
.singleLogoutServiceLocation(singleLogoutServiceLocationFunction.apply(samlSpAlias))
.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlSpAlias))
// Accept both POST and REDIRECT bindings
.singleLogoutServiceBindings(c -> {
c.add(Saml2MessageBinding.REDIRECT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ public class SamlConfigProps {

private String activeKeyId;

private String entityIDAlias;

private Map<String, SamlKey> keys;

private Boolean wantAssertionSigned = true;

private Boolean signRequest = true;

public SamlKey getActiveSamlKey() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.saml.SamlKey;
import org.cloudfoundry.identity.uaa.util.KeyWithCert;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
Expand Down Expand Up @@ -53,6 +54,8 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti

List<RelyingPartyRegistration> relyingPartyRegistrations = new ArrayList<>();

String uaaWideSamlEntityIDAlias = samlConfigProps.getEntityIDAlias() != null ? samlConfigProps.getEntityIDAlias() : samlEntityID;

@SuppressWarnings("java:S125")
// Spring Security requires at least one relyingPartyRegistration before SAML SP metadata generation;
// and each relyingPartyRegistration needs to contain the SAML IDP metadata.
Expand All @@ -65,7 +68,7 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti
// even when there are no SAML IDPs configured.
// See relevant issue: https://github.com/spring-projects/spring-security/issues/11369
RelyingPartyRegistration defaultRelyingPartyRegistration = RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, samlSpNameID, keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, DEFAULT_REGISTRATION_ID, samlConfigProps.getSignRequest());
samlEntityID, samlSpNameID, keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, DEFAULT_REGISTRATION_ID, uaaWideSamlEntityIDAlias, samlConfigProps.getSignRequest());
relyingPartyRegistrations.add(defaultRelyingPartyRegistration);

for (SamlIdentityProviderDefinition samlIdentityProviderDefinition : bootstrapSamlIdentityProviderData.getIdentityProviderDefinitions()) {
Expand All @@ -74,13 +77,15 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti
samlEntityID, samlSpNameID, keyWithCert,
samlIdentityProviderDefinition.getMetaDataLocation(),
samlIdentityProviderDefinition.getIdpEntityAlias(),
uaaWideSamlEntityIDAlias,
samlConfigProps.getSignRequest())
);
}

InMemoryRelyingPartyRegistrationRepository bootstrapRepo = new InMemoryRelyingPartyRegistrationRepository(relyingPartyRegistrations);
ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlEntityID, keyWithCert, samlIdentityProviderConfigurator);
return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo);
ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlEntityID, uaaWideSamlEntityIDAlias, keyWithCert, samlIdentityProviderConfigurator);
DefaultRelyingPartyRegistrationRepository defaultRepo = new DefaultRelyingPartyRegistrationRepository(samlEntityID, uaaWideSamlEntityIDAlias, keyWithCert);
return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo, defaultRepo);
}

@Autowired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.util.KeyWithCert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand Down Expand Up @@ -33,7 +32,9 @@
@ExtendWith(MockitoExtension.class)
class ConfiguratorRelyingPartyRegistrationRepositoryTest {
private static final String ENTITY_ID = "entityId";
private static final String ENTITY_ID_ALIAS = "entityIdAlias";
private static final String REGISTRATION_ID = "registrationId";
private static final String REGISTRATION_ID_2 = "registrationId2";
private static final String NAME_ID = "name1";

@Mock
Expand All @@ -46,14 +47,14 @@ class ConfiguratorRelyingPartyRegistrationRepositoryTest {

@BeforeEach
void setUp() {
repository = new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, mockKeyWithCert,
repository = new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, ENTITY_ID_ALIAS, mockKeyWithCert,
mockConfigurator);
}

@Test
void constructorWithNullConfiguratorThrows() {
assertThatThrownBy(() -> new ConfiguratorRelyingPartyRegistrationRepository(
ENTITY_ID, mockKeyWithCert, null)
ENTITY_ID, ENTITY_ID_ALIAS, mockKeyWithCert, null)
).isInstanceOf(IllegalArgumentException.class);
}

Expand Down Expand Up @@ -81,15 +82,14 @@ void findByRegistrationIdWithMultipleInDb() {
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat)
// from functions
.returns("{baseUrl}/saml/SSO/alias/entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
.returns("{baseUrl}/saml/SSO/alias/entityIdAlias", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityIdAlias", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
// from xml
.extracting(RelyingPartyRegistration::getAssertingPartyDetails)
.returns("https://idp-saml.ua3.int/simplesaml/saml2/idp/metadata.php", RelyingPartyRegistration.AssertingPartyDetails::getEntityId);
}

@Test
@Disabled("Test not valid because ConfiguratorRelyingPartyRegistrationRepository now returns default RelyingPartyRegistration when none found")
void findByRegistrationIdWhenNoneFound() {
SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class);
when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID);
Expand All @@ -104,21 +104,21 @@ void buildsCorrectRegistrationWhenMetadataXmlIsStored() {
when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class));
when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class));
SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class);
when(definition.getIdpEntityAlias()).thenReturn("no_slos");
when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID);
when(definition.getNameID()).thenReturn(NAME_ID);
when(definition.getMetaDataLocation()).thenReturn(metadata);
when(mockConfigurator.getIdentityProviderDefinitions()).thenReturn(List.of(definition));

RelyingPartyRegistration registration = repository.findByRegistrationId("no_slos");
RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID);

assertThat(registration)
// from definition
.returns("no_slos", RelyingPartyRegistration::getRegistrationId)
.returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId)
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat)
// from functions
.returns("{baseUrl}/saml/SSO/alias/entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
.returns("{baseUrl}/saml/SSO/alias/entityIdAlias", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityIdAlias", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
// from xml
.extracting(RelyingPartyRegistration::getAssertingPartyDetails)
.returns("http://uaa-acceptance.cf-app.com/saml-idp", RelyingPartyRegistration.AssertingPartyDetails::getEntityId);
Expand All @@ -129,20 +129,20 @@ void buildsCorrectRegistrationWhenMetadataLocationIsStored() {
when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class));
when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class));
SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class);
when(definition.getIdpEntityAlias()).thenReturn("no_slos");
when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID_2);
when(definition.getNameID()).thenReturn(NAME_ID);
when(definition.getMetaDataLocation()).thenReturn("no_single_logout_service-metadata.xml");
when(mockConfigurator.getIdentityProviderDefinitions()).thenReturn(List.of(definition));

RelyingPartyRegistration registration = repository.findByRegistrationId("no_slos");
RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID_2);
assertThat(registration)
// from definition
.returns("no_slos", RelyingPartyRegistration::getRegistrationId)
.returns(REGISTRATION_ID_2, RelyingPartyRegistration::getRegistrationId)
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat)
// from functions
.returns("{baseUrl}/saml/SSO/alias/entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
.returns("{baseUrl}/saml/SSO/alias/entityIdAlias", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityIdAlias", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
// from xml
.extracting(RelyingPartyRegistration::getAssertingPartyDetails)
.returns("http://uaa-acceptance.cf-app.com/saml-idp", RelyingPartyRegistration.AssertingPartyDetails::getEntityId);
Expand Down
Loading

0 comments on commit b58a599

Please sign in to comment.