From 369afa624a6d216e970cd20d7a475d0a151c9943 Mon Sep 17 00:00:00 2001 From: Duane May Date: Wed, 10 Jul 2024 17:24:53 -0400 Subject: [PATCH] wip: Zoned Login [#187902333] Signed-off-by: Duane May Signed-off-by: Peter Chen --- ...torRelyingPartyRegistrationRepository.java | 28 +----- ...ultRelyingPartyRegistrationRepository.java | 49 ++++++++++ ...yingPartyRegistrationRepositoryConfig.java | 3 +- ...elyingPartyRegistrationRepositoryTest.java | 15 ++- ...elyingPartyRegistrationRepositoryTest.java | 92 +++++++++++++++++++ 5 files changed, 154 insertions(+), 33 deletions(-) create mode 100644 server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepository.java create mode 100644 server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepositoryTest.java diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepository.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepository.java index 1cfa860c1b6..780e28fabe8 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepository.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepository.java @@ -43,34 +43,14 @@ public RelyingPartyRegistration findByRegistrationId(String registrationId) { if (identityProviderDefinition.getIdpEntityAlias().equals(registrationId)) { IdentityZone zone = retrieveZone(); + String zonedSamlEntityID = zone.isUaa() ? samlEntityID : zone.getConfig().getSamlConfig().getEntityID(); + return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration( - samlEntityID, identityProviderDefinition.getNameID(), + zonedSamlEntityID, identityProviderDefinition.getNameID(), keyWithCert, identityProviderDefinition.getMetaDataLocation(), registrationId, zone.getConfig().getSamlConfig().isRequestSigned()); } } - return buildDefaultRelyingPartyRegistration(); - } - - private RelyingPartyRegistration buildDefaultRelyingPartyRegistration() { - String samlEntityID, samlServiceUri; - IdentityZone zone = retrieveZone(); - if (zone.isUaa()) { - samlEntityID = this.samlEntityID; - samlServiceUri = this.samlEntityID; - } - else if (zone.getConfig() != null && zone.getConfig().getSamlConfig() != null) { - - samlEntityID = zone.getConfig().getSamlConfig().getEntityID(); - samlServiceUri = zone.getSubdomain() + "." + this.samlEntityID; - } - else { - return null; - } - - return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration( - samlEntityID, null, - keyWithCert, "dummy-saml-idp-metadata.xml", null, - samlServiceUri, zone.getConfig().getSamlConfig().isRequestSigned()); + return null; } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepository.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepository.java new file mode 100644 index 00000000000..6cdffebd058 --- /dev/null +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepository.java @@ -0,0 +1,49 @@ +package org.cloudfoundry.identity.uaa.provider.saml; + +import org.cloudfoundry.identity.uaa.util.KeyWithCert; +import org.cloudfoundry.identity.uaa.zone.IdentityZone; +import org.cloudfoundry.identity.uaa.zone.ZoneAware; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; + +/** + * A {@link RelyingPartyRegistrationRepository} that always returns a default {@link RelyingPartyRegistrationRepository}. + */ +public class DefaultRelyingPartyRegistrationRepository implements RelyingPartyRegistrationRepository, ZoneAware { + public static final String CLASSPATH_DUMMY_SAML_IDP_METADATA_XML = "classpath:dummy-saml-idp-metadata.xml"; + + private final KeyWithCert keyWithCert; + private final String samlEntityID; + + public DefaultRelyingPartyRegistrationRepository(String samlEntityID, + KeyWithCert keyWithCert) { + this.keyWithCert = keyWithCert; + this.samlEntityID = samlEntityID; + } + + /** + * Returns the relying party registration identified by the provided + * {@code registrationId}, or {@code null} if not found. + * + * @param registrationId the registration identifier + * @return the {@link RelyingPartyRegistration} if found, otherwise {@code null} + */ + @Override + public RelyingPartyRegistration findByRegistrationId(String registrationId) { + IdentityZone zone = retrieveZone(); + + String zonedSamlEntityID; + if (zone.isUaa()) { + zonedSamlEntityID = this.samlEntityID; + } else if (zone.getConfig() != null && zone.getConfig().getSamlConfig() != null) { + zonedSamlEntityID = zone.getConfig().getSamlConfig().getEntityID(); + } else { + return null; + } + + return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration( + zonedSamlEntityID, null, + keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, registrationId, + zonedSamlEntityID, zone.getConfig().getSamlConfig().isRequestSigned()); + } +} diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfig.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfig.java index ceae85a8e04..0f746292c31 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfig.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfig.java @@ -80,7 +80,8 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti InMemoryRelyingPartyRegistrationRepository bootstrapRepo = new InMemoryRelyingPartyRegistrationRepository(relyingPartyRegistrations); ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlEntityID, keyWithCert, samlIdentityProviderConfigurator); - return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo); + DefaultRelyingPartyRegistrationRepository defaultRepo = new DefaultRelyingPartyRegistrationRepository(samlEntityID, keyWithCert); + return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo, defaultRepo); } @Autowired diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepositoryTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepositoryTest.java index a14f3a6ce35..702f9c085ca 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepositoryTest.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepositoryTest.java @@ -3,7 +3,6 @@ import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition; import org.cloudfoundry.identity.uaa.util.KeyWithCert; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -34,6 +33,7 @@ class ConfiguratorRelyingPartyRegistrationRepositoryTest { private static final String ENTITY_ID = "entityId"; private static final String REGISTRATION_ID = "registrationId"; + private static final String REGISTRATION_ID_2 = "registrationId2"; private static final String NAME_ID = "name1"; @Mock @@ -89,7 +89,6 @@ void findByRegistrationIdWithMultipleInDb() { } @Test - @Disabled("Test not valid because ConfiguratorRelyingPartyRegistrationRepository now returns default RelyingPartyRegistration when none found") void findByRegistrationIdWhenNoneFound() { SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class); when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID); @@ -104,16 +103,16 @@ void buildsCorrectRegistrationWhenMetadataXmlIsStored() { when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class); - when(definition.getIdpEntityAlias()).thenReturn("no_slos"); + when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID); when(definition.getNameID()).thenReturn(NAME_ID); when(definition.getMetaDataLocation()).thenReturn(metadata); when(mockConfigurator.getIdentityProviderDefinitions()).thenReturn(List.of(definition)); - RelyingPartyRegistration registration = repository.findByRegistrationId("no_slos"); + RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID); assertThat(registration) // from definition - .returns("no_slos", RelyingPartyRegistration::getRegistrationId) + .returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId) .returns(ENTITY_ID, RelyingPartyRegistration::getEntityId) .returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat) // from functions @@ -129,15 +128,15 @@ void buildsCorrectRegistrationWhenMetadataLocationIsStored() { when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class); - when(definition.getIdpEntityAlias()).thenReturn("no_slos"); + when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID_2); when(definition.getNameID()).thenReturn(NAME_ID); when(definition.getMetaDataLocation()).thenReturn("no_single_logout_service-metadata.xml"); when(mockConfigurator.getIdentityProviderDefinitions()).thenReturn(List.of(definition)); - RelyingPartyRegistration registration = repository.findByRegistrationId("no_slos"); + RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID_2); assertThat(registration) // from definition - .returns("no_slos", RelyingPartyRegistration::getRegistrationId) + .returns(REGISTRATION_ID_2, RelyingPartyRegistration::getRegistrationId) .returns(ENTITY_ID, RelyingPartyRegistration::getEntityId) .returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat) // from functions diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepositoryTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepositoryTest.java new file mode 100644 index 00000000000..7f27a1c53fa --- /dev/null +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepositoryTest.java @@ -0,0 +1,92 @@ +package org.cloudfoundry.identity.uaa.provider.saml; + +import org.cloudfoundry.identity.uaa.util.KeyWithCert; +import org.cloudfoundry.identity.uaa.zone.IdentityZone; +import org.cloudfoundry.identity.uaa.zone.IdentityZoneConfiguration; +import org.cloudfoundry.identity.uaa.zone.SamlConfig; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; + +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class DefaultRelyingPartyRegistrationRepositoryTest { + private static final String ENTITY_ID = "entityId"; + private static final String ZONED_ENTITY_ID = "%s.%s".formatted("subdomain",ENTITY_ID); + private static final String REGISTRATION_ID = "registrationId"; + private static final String NAME_ID = "name1"; + + @Mock + private KeyWithCert mockKeyWithCert; + + @Mock + private IdentityZone identityZone; + + @Mock + private IdentityZoneConfiguration identityZoneConfig; + + @Mock + private SamlConfig samlConfig; + + private DefaultRelyingPartyRegistrationRepository repository; + + @BeforeEach + void setUp() { + repository = spy(new DefaultRelyingPartyRegistrationRepository(ENTITY_ID, mockKeyWithCert)); + } + + @Test + void findByRegistrationId() { + when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); + when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); + + RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(registration) + // from definition + .returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId) + .returns(ENTITY_ID, RelyingPartyRegistration::getEntityId) + .returns(null, RelyingPartyRegistration::getNameIdFormat) + // from functions + .returns("{baseUrl}/saml/SSO/alias/entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation) + .returns("{baseUrl}/saml/SingleLogout/alias/entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation) + // from xml + .extracting(RelyingPartyRegistration::getAssertingPartyDetails) + .returns("exampleEntityId", RelyingPartyRegistration.AssertingPartyDetails::getEntityId); + } + + @Test + void findByRegistrationIdForZone() { + when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); + when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); + when(repository.retrieveZone()).thenReturn(identityZone); + when(identityZone.isUaa()).thenReturn(false); + when(identityZone.getConfig()).thenReturn(identityZoneConfig); + when(identityZoneConfig.getSamlConfig()).thenReturn(samlConfig); + when(samlConfig.getEntityID()).thenReturn(ZONED_ENTITY_ID); + + RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(registration) + // from definition + .returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId) + .returns(ZONED_ENTITY_ID, RelyingPartyRegistration::getEntityId) + .returns(null, RelyingPartyRegistration::getNameIdFormat) + // from functions + .returns("{baseUrl}/saml/SSO/alias/subdomain.entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation) + .returns("{baseUrl}/saml/SingleLogout/alias/subdomain.entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation) + // from xml + .extracting(RelyingPartyRegistration::getAssertingPartyDetails) + .returns("exampleEntityId", RelyingPartyRegistration.AssertingPartyDetails::getEntityId); + } +} \ No newline at end of file