Skip to content

Commit

Permalink
wip: Zoned Login
Browse files Browse the repository at this point in the history
[#187902333]

Signed-off-by: Duane May <duane.may@broadcom.com>
Signed-off-by: Peter Chen <peter-h.chen@broadcom.com>
  • Loading branch information
duanemay authored and peterhaochen47 committed Jul 10, 2024
1 parent 51deef5 commit 369afa6
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,14 @@ public RelyingPartyRegistration findByRegistrationId(String registrationId) {
if (identityProviderDefinition.getIdpEntityAlias().equals(registrationId)) {

IdentityZone zone = retrieveZone();
String zonedSamlEntityID = zone.isUaa() ? samlEntityID : zone.getConfig().getSamlConfig().getEntityID();

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, identityProviderDefinition.getNameID(),
zonedSamlEntityID, identityProviderDefinition.getNameID(),
keyWithCert, identityProviderDefinition.getMetaDataLocation(),
registrationId, zone.getConfig().getSamlConfig().isRequestSigned());
}
}
return buildDefaultRelyingPartyRegistration();
}

private RelyingPartyRegistration buildDefaultRelyingPartyRegistration() {
String samlEntityID, samlServiceUri;
IdentityZone zone = retrieveZone();
if (zone.isUaa()) {
samlEntityID = this.samlEntityID;
samlServiceUri = this.samlEntityID;
}
else if (zone.getConfig() != null && zone.getConfig().getSamlConfig() != null) {

samlEntityID = zone.getConfig().getSamlConfig().getEntityID();
samlServiceUri = zone.getSubdomain() + "." + this.samlEntityID;
}
else {
return null;
}

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, null,
keyWithCert, "dummy-saml-idp-metadata.xml", null,
samlServiceUri, zone.getConfig().getSamlConfig().isRequestSigned());
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package org.cloudfoundry.identity.uaa.provider.saml;

import org.cloudfoundry.identity.uaa.util.KeyWithCert;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.ZoneAware;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;

/**
* A {@link RelyingPartyRegistrationRepository} that always returns a default {@link RelyingPartyRegistrationRepository}.
*/
public class DefaultRelyingPartyRegistrationRepository implements RelyingPartyRegistrationRepository, ZoneAware {
public static final String CLASSPATH_DUMMY_SAML_IDP_METADATA_XML = "classpath:dummy-saml-idp-metadata.xml";

private final KeyWithCert keyWithCert;
private final String samlEntityID;

public DefaultRelyingPartyRegistrationRepository(String samlEntityID,
KeyWithCert keyWithCert) {
this.keyWithCert = keyWithCert;
this.samlEntityID = samlEntityID;
}

/**
* Returns the relying party registration identified by the provided
* {@code registrationId}, or {@code null} if not found.
*
* @param registrationId the registration identifier
* @return the {@link RelyingPartyRegistration} if found, otherwise {@code null}
*/
@Override
public RelyingPartyRegistration findByRegistrationId(String registrationId) {
IdentityZone zone = retrieveZone();

String zonedSamlEntityID;
if (zone.isUaa()) {
zonedSamlEntityID = this.samlEntityID;
} else if (zone.getConfig() != null && zone.getConfig().getSamlConfig() != null) {
zonedSamlEntityID = zone.getConfig().getSamlConfig().getEntityID();
} else {
return null;
}

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
zonedSamlEntityID, null,
keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, registrationId,
zonedSamlEntityID, zone.getConfig().getSamlConfig().isRequestSigned());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti

InMemoryRelyingPartyRegistrationRepository bootstrapRepo = new InMemoryRelyingPartyRegistrationRepository(relyingPartyRegistrations);
ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlEntityID, keyWithCert, samlIdentityProviderConfigurator);
return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo);
DefaultRelyingPartyRegistrationRepository defaultRepo = new DefaultRelyingPartyRegistrationRepository(samlEntityID, keyWithCert);
return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo, defaultRepo);
}

@Autowired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.util.KeyWithCert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand Down Expand Up @@ -34,6 +33,7 @@
class ConfiguratorRelyingPartyRegistrationRepositoryTest {
private static final String ENTITY_ID = "entityId";
private static final String REGISTRATION_ID = "registrationId";
private static final String REGISTRATION_ID_2 = "registrationId2";
private static final String NAME_ID = "name1";

@Mock
Expand Down Expand Up @@ -89,7 +89,6 @@ void findByRegistrationIdWithMultipleInDb() {
}

@Test
@Disabled("Test not valid because ConfiguratorRelyingPartyRegistrationRepository now returns default RelyingPartyRegistration when none found")
void findByRegistrationIdWhenNoneFound() {
SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class);
when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID);
Expand All @@ -104,16 +103,16 @@ void buildsCorrectRegistrationWhenMetadataXmlIsStored() {
when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class));
when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class));
SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class);
when(definition.getIdpEntityAlias()).thenReturn("no_slos");
when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID);
when(definition.getNameID()).thenReturn(NAME_ID);
when(definition.getMetaDataLocation()).thenReturn(metadata);
when(mockConfigurator.getIdentityProviderDefinitions()).thenReturn(List.of(definition));

RelyingPartyRegistration registration = repository.findByRegistrationId("no_slos");
RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID);

assertThat(registration)
// from definition
.returns("no_slos", RelyingPartyRegistration::getRegistrationId)
.returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId)
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat)
// from functions
Expand All @@ -129,15 +128,15 @@ void buildsCorrectRegistrationWhenMetadataLocationIsStored() {
when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class));
when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class));
SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class);
when(definition.getIdpEntityAlias()).thenReturn("no_slos");
when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID_2);
when(definition.getNameID()).thenReturn(NAME_ID);
when(definition.getMetaDataLocation()).thenReturn("no_single_logout_service-metadata.xml");
when(mockConfigurator.getIdentityProviderDefinitions()).thenReturn(List.of(definition));

RelyingPartyRegistration registration = repository.findByRegistrationId("no_slos");
RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID_2);
assertThat(registration)
// from definition
.returns("no_slos", RelyingPartyRegistration::getRegistrationId)
.returns(REGISTRATION_ID_2, RelyingPartyRegistration::getRegistrationId)
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat)
// from functions
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package org.cloudfoundry.identity.uaa.provider.saml;

import org.cloudfoundry.identity.uaa.util.KeyWithCert;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneConfiguration;
import org.cloudfoundry.identity.uaa.zone.SamlConfig;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;

import java.security.PrivateKey;
import java.security.cert.X509Certificate;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class DefaultRelyingPartyRegistrationRepositoryTest {
private static final String ENTITY_ID = "entityId";
private static final String ZONED_ENTITY_ID = "%s.%s".formatted("subdomain",ENTITY_ID);
private static final String REGISTRATION_ID = "registrationId";
private static final String NAME_ID = "name1";

@Mock
private KeyWithCert mockKeyWithCert;

@Mock
private IdentityZone identityZone;

@Mock
private IdentityZoneConfiguration identityZoneConfig;

@Mock
private SamlConfig samlConfig;

private DefaultRelyingPartyRegistrationRepository repository;

@BeforeEach
void setUp() {
repository = spy(new DefaultRelyingPartyRegistrationRepository(ENTITY_ID, mockKeyWithCert));
}

@Test
void findByRegistrationId() {
when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class));
when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class));

RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID);

assertThat(registration)
// from definition
.returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId)
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(null, RelyingPartyRegistration::getNameIdFormat)
// from functions
.returns("{baseUrl}/saml/SSO/alias/entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
// from xml
.extracting(RelyingPartyRegistration::getAssertingPartyDetails)
.returns("exampleEntityId", RelyingPartyRegistration.AssertingPartyDetails::getEntityId);
}

@Test
void findByRegistrationIdForZone() {
when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class));
when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class));
when(repository.retrieveZone()).thenReturn(identityZone);
when(identityZone.isUaa()).thenReturn(false);
when(identityZone.getConfig()).thenReturn(identityZoneConfig);
when(identityZoneConfig.getSamlConfig()).thenReturn(samlConfig);
when(samlConfig.getEntityID()).thenReturn(ZONED_ENTITY_ID);

RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID);

assertThat(registration)
// from definition
.returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId)
.returns(ZONED_ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(null, RelyingPartyRegistration::getNameIdFormat)
// from functions
.returns("{baseUrl}/saml/SSO/alias/subdomain.entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/subdomain.entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
// from xml
.extracting(RelyingPartyRegistration::getAssertingPartyDetails)
.returns("exampleEntityId", RelyingPartyRegistration.AssertingPartyDetails::getEntityId);
}
}

0 comments on commit 369afa6

Please sign in to comment.