Skip to content

Commit

Permalink
Improve the API for MfaChecker
Browse files Browse the repository at this point in the history
- Don't require unused origin
- add isMfaEnabledForZoneId

[#164817168]
  • Loading branch information
joshuatcasey committed Mar 27, 2019
1 parent 92fc49b commit 58f15bd
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 64 deletions.
Expand Up @@ -51,7 +51,7 @@ public void onApplicationEvent(AbstractUaaAuthenticationEvent event) {
passwordAuthEvent.getUser(), passwordAuthEvent.getUser(),
(Authentication) passwordAuthEvent.getSource() (Authentication) passwordAuthEvent.getSource()
); );
if (!checker.isMfaEnabled(userEvent.getIdentityZone(), userEvent.getUser().getOrigin())) { if (!checker.isMfaEnabled(userEvent.getIdentityZone())) {
publisher.publishEvent(userEvent); publisher.publishEvent(userEvent);
} }
} else if (event instanceof MfaAuthenticationSuccessEvent) { } else if (event instanceof MfaAuthenticationSuccessEvent) {
Expand Down
Expand Up @@ -750,7 +750,7 @@ private void populatePrompts(
} }
map.put(prompt.getName(), details); map.put(prompt.getName(), details);
} }
if (mfaChecker.isMfaEnabled(IdentityZoneHolder.get(), OriginKeys.UAA)) { if (mfaChecker.isMfaEnabled(IdentityZoneHolder.get())) {
Prompt p = new Prompt( Prompt p = new Prompt(
MFA_CODE, MFA_CODE,
"password", "password",
Expand Down Expand Up @@ -830,7 +830,7 @@ private String goToPasswordPage(String email, Model model) {
@ResponseBody @ResponseBody
public AutologinResponse generateAutologinCode(@RequestBody AutologinRequest request, public AutologinResponse generateAutologinCode(@RequestBody AutologinRequest request,
@RequestHeader(value = "Authorization", required = false) String auth) throws Exception { @RequestHeader(value = "Authorization", required = false) String auth) throws Exception {
if (mfaChecker.isMfaEnabled(IdentityZoneHolder.get(), "uaa")) { if (mfaChecker.isMfaEnabled(IdentityZoneHolder.get())) {
throw new BadCredentialsException("MFA is required"); throw new BadCredentialsException("MFA is required");
} }


Expand Down Expand Up @@ -876,7 +876,7 @@ public AutologinResponse generateAutologinCode(@RequestBody AutologinRequest req


@RequestMapping(value = "/autologin", method = GET) @RequestMapping(value = "/autologin", method = GET)
public String performAutologin(HttpSession session) { public String performAutologin(HttpSession session) {
if (mfaChecker.isMfaEnabled(IdentityZoneHolder.get(), "uaa")) { if (mfaChecker.isMfaEnabled(IdentityZoneHolder.get())) {
throw new BadCredentialsException("MFA is required"); throw new BadCredentialsException("MFA is required");
} }
String redirectLocation = "home"; String redirectLocation = "home";
Expand Down
@@ -1,32 +1,24 @@
/*
* Cloud Foundry
* Copyright (c) [2009-2018] Pivotal Software, Inc. All Rights Reserved.
* <p/>
* This product is licensed to you under the Apache License, Version 2.0 (the "License").
* You may not use this product except in compliance with the License.
* <p/>
* This product includes a number of subcomponents with
* separate copyright notices and license terms. Your use of these
* subcomponents is subject to the terms and conditions of the
* subcomponent's license, as noted in the LICENSE file
*/
package org.cloudfoundry.identity.uaa.mfa; package org.cloudfoundry.identity.uaa.mfa;


import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning;
import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneProvisioning;


public class MfaChecker { public class MfaChecker {


private final IdentityProviderProvisioning providerProvisioning; private final IdentityZoneProvisioning identityZoneProvisioning;


public MfaChecker(IdentityProviderProvisioning providerProvisioning) { public MfaChecker(IdentityZoneProvisioning identityZoneProvisioning) {
this.providerProvisioning = providerProvisioning; this.identityZoneProvisioning = identityZoneProvisioning;
} }


public boolean isMfaEnabled(IdentityZone zone, String originKey) { public boolean isMfaEnabled(IdentityZone zone) {
return zone.getConfig().getMfaConfig().isEnabled(); return zone.getConfig().getMfaConfig().isEnabled();
} }


public boolean isMfaEnabledForZoneId(String zoneId) {
return isMfaEnabled(identityZoneProvisioning.retrieve(zoneId));
}

public boolean isRequired(IdentityZone zone, String originKey) { public boolean isRequired(IdentityZone zone, String originKey) {
return zone.getConfig().getMfaConfig().getIdentityProviders().contains(originKey); return zone.getConfig().getMfaConfig().getIdentityProviders().contains(originKey);
} }
Expand Down
Expand Up @@ -68,7 +68,7 @@ protected boolean isMfaRequiredAndMissing() {
return false; return false;
} }
UaaAuthentication uaaAuth = (UaaAuthentication) a; UaaAuthentication uaaAuth = (UaaAuthentication) a;
if (!mfaRequired(uaaAuth.getPrincipal().getOrigin())) { if (!mfaRequired()) {
return false; return false;
} }


Expand All @@ -80,7 +80,7 @@ protected boolean isMfaRequiredAndMissing() {
} }
} }


protected boolean mfaRequired(String origin) { protected boolean mfaRequired() {
return checker.isMfaEnabled(IdentityZoneHolder.get(), origin); return checker.isMfaEnabled(IdentityZoneHolder.get());
} }
} }
Expand Up @@ -166,7 +166,7 @@ protected void sendRedirect(String redirectUrl, HttpServletRequest request, Http
} }


protected boolean mfaRequired(String origin) { protected boolean mfaRequired(String origin) {
return checker.isMfaEnabled(IdentityZoneHolder.get(), origin) && checker.isRequired(IdentityZoneHolder.get(), origin); return checker.isMfaEnabled(IdentityZoneHolder.get()) && checker.isRequired(IdentityZoneHolder.get(), origin);
} }


private boolean logoutInProgress(HttpServletRequest request) { private boolean logoutInProgress(HttpServletRequest request) {
Expand Down
1 change: 0 additions & 1 deletion server/src/main/resources/spring/login-ui.xml
Expand Up @@ -212,7 +212,6 @@
</bean> </bean>


<bean id="mfaChecker" class="org.cloudfoundry.identity.uaa.mfa.MfaChecker"> <bean id="mfaChecker" class="org.cloudfoundry.identity.uaa.mfa.MfaChecker">
<constructor-arg name="providerProvisioning" ref="identityProviderProvisioning"/>
</bean> </bean>


<bean id="mfaUiRequiredFilter" class="org.cloudfoundry.identity.uaa.mfa.MfaUiRequiredFilter"> <bean id="mfaUiRequiredFilter" class="org.cloudfoundry.identity.uaa.mfa.MfaUiRequiredFilter">
Expand Down
Expand Up @@ -23,6 +23,7 @@
import org.cloudfoundry.identity.uaa.scim.ScimUserProvisioning; import org.cloudfoundry.identity.uaa.scim.ScimUserProvisioning;
import org.cloudfoundry.identity.uaa.user.UaaUser; import org.cloudfoundry.identity.uaa.user.UaaUser;
import org.cloudfoundry.identity.uaa.user.UaaUserPrototype; import org.cloudfoundry.identity.uaa.user.UaaUserPrototype;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
Expand Down Expand Up @@ -125,7 +126,7 @@ public void previousLoginIsSetOnTheAuthentication() {


@Test @Test
public void provider_authentication_success_triggers_user_authentication_success() throws Exception { public void provider_authentication_success_triggers_user_authentication_success() throws Exception {
when(checker.isMfaEnabled(any(), any())).thenReturn(false); when(checker.isMfaEnabled(any(IdentityZone.class))).thenReturn(false);
IdentityProviderAuthenticationSuccessEvent event = new IdentityProviderAuthenticationSuccessEvent( IdentityProviderAuthenticationSuccessEvent event = new IdentityProviderAuthenticationSuccessEvent(
user, user,
mockAuth, mockAuth,
Expand All @@ -137,7 +138,7 @@ public void provider_authentication_success_triggers_user_authentication_success


@Test @Test
public void provider_authentication_success_does_not_trigger_user_authentication_success() throws Exception { public void provider_authentication_success_does_not_trigger_user_authentication_success() throws Exception {
when(checker.isMfaEnabled(any(), any())).thenReturn(true); when(checker.isMfaEnabled(any(IdentityZone.class))).thenReturn(true);
IdentityProviderAuthenticationSuccessEvent event = new IdentityProviderAuthenticationSuccessEvent( IdentityProviderAuthenticationSuccessEvent event = new IdentityProviderAuthenticationSuccessEvent(
user, user,
mockAuth, mockAuth,
Expand All @@ -149,7 +150,7 @@ public void provider_authentication_success_does_not_trigger_user_authentication


@Test @Test
public void mfa_authentication_success_triggers_user_authentication_success() throws Exception { public void mfa_authentication_success_triggers_user_authentication_success() throws Exception {
when(checker.isMfaEnabled(any(), any())).thenReturn(true); when(checker.isMfaEnabled(any(IdentityZone.class))).thenReturn(true);
MfaAuthenticationSuccessEvent event = new MfaAuthenticationSuccessEvent( MfaAuthenticationSuccessEvent event = new MfaAuthenticationSuccessEvent(
user, user,
mockAuth, mockAuth,
Expand Down
Expand Up @@ -33,12 +33,7 @@
import org.cloudfoundry.identity.uaa.provider.saml.SamlIdentityProviderConfigurator; import org.cloudfoundry.identity.uaa.provider.saml.SamlIdentityProviderConfigurator;
import org.cloudfoundry.identity.uaa.util.JsonUtils; import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.cloudfoundry.identity.uaa.util.PredicateMatcher; import org.cloudfoundry.identity.uaa.util.PredicateMatcher;
import org.cloudfoundry.identity.uaa.zone.MultitenantClientServices; import org.cloudfoundry.identity.uaa.zone.*;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneConfiguration;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.cloudfoundry.identity.uaa.zone.Links;
import org.cloudfoundry.identity.uaa.zone.MultitenancyFixture;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
Expand Down Expand Up @@ -142,7 +137,7 @@ public void setUpPrincipal() {
oidcMetadataFetcher = mock(OidcMetadataFetcher.class); oidcMetadataFetcher = mock(OidcMetadataFetcher.class);
IdentityZoneHolder.get().setConfig(new IdentityZoneConfiguration()); IdentityZoneHolder.get().setConfig(new IdentityZoneConfiguration());
configurator = new XOAuthProviderConfigurator(identityProviderProvisioning, oidcMetadataFetcher); configurator = new XOAuthProviderConfigurator(identityProviderProvisioning, oidcMetadataFetcher);
mfaChecker = spy(new MfaChecker(mock(IdentityProviderProvisioning.class))); mfaChecker = spy(new MfaChecker(mock(IdentityZoneProvisioning.class)));
model = new ExtendedModelMap(); model = new ExtendedModelMap();
} }


Expand Down Expand Up @@ -925,7 +920,7 @@ public void testLoginWithInvalidMediaType() throws Exception {


@Test @Test
public void testGenerateAutologinCodeFailsWhenMfaRequired() throws Exception { public void testGenerateAutologinCodeFailsWhenMfaRequired() throws Exception {
doReturn(true).when(mfaChecker).isMfaEnabled(any(IdentityZone.class), anyString()); doReturn(true).when(mfaChecker).isMfaEnabled(any(IdentityZone.class));


LoginInfoEndpoint endpoint = getEndpoint(); LoginInfoEndpoint endpoint = getEndpoint();
try { try {
Expand All @@ -938,7 +933,7 @@ public void testGenerateAutologinCodeFailsWhenMfaRequired() throws Exception {


@Test @Test
public void testPerformAutologinFailsWhenMfaRequired() throws Exception { public void testPerformAutologinFailsWhenMfaRequired() throws Exception {
doReturn(true).when(mfaChecker).isMfaEnabled(any(IdentityZone.class), anyString()); doReturn(true).when(mfaChecker).isMfaEnabled(any(IdentityZone.class));
LoginInfoEndpoint endpoint = getEndpoint(); LoginInfoEndpoint endpoint = getEndpoint();
try { try {
endpoint.performAutologin(new MockHttpSession()); endpoint.performAutologin(new MockHttpSession());
Expand Down
@@ -1,56 +1,90 @@
package org.cloudfoundry.identity.uaa.mfa; package org.cloudfoundry.identity.uaa.mfa;


import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning;
import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneProvisioning;
import org.cloudfoundry.identity.uaa.zone.MfaConfig; import org.cloudfoundry.identity.uaa.zone.MfaConfig;
import org.cloudfoundry.identity.uaa.zone.MultitenancyFixture; import org.cloudfoundry.identity.uaa.zone.MultitenancyFixture;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;
import org.springframework.security.oauth2.common.util.RandomValueStringGenerator;


import java.util.Arrays; import java.util.Arrays;
import java.util.stream.Stream;


import static org.cloudfoundry.identity.uaa.constants.OriginKeys.LDAP; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.*;
import static org.cloudfoundry.identity.uaa.constants.OriginKeys.SAML;
import static org.cloudfoundry.identity.uaa.constants.OriginKeys.UAA;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.*;


class MfaCheckerTests { class MfaCheckerTests {


private IdentityZone identityZone; private IdentityZone identityZone;
private MfaChecker mfaChecker; private MfaChecker mfaChecker;
private IdentityProviderProvisioning identityProviderProvisioning; private IdentityZoneProvisioning mockIdentityZoneProvisioning;
private RandomValueStringGenerator randomValueStringGenerator;


@BeforeEach @BeforeEach
void setUp() { void setUp() {
identityProviderProvisioning = mock(IdentityProviderProvisioning.class); randomValueStringGenerator = new RandomValueStringGenerator();
identityZone = MultitenancyFixture.identityZone("id", "domain");
mfaChecker = new MfaChecker(identityProviderProvisioning); identityZone = MultitenancyFixture.identityZone(randomValueStringGenerator.generate(), randomValueStringGenerator.generate());

mockIdentityZoneProvisioning = mock(IdentityZoneProvisioning.class);
when(mockIdentityZoneProvisioning.retrieve(any())).thenReturn(identityZone);

mfaChecker = new MfaChecker(mockIdentityZoneProvisioning);
} }


@Test static class BooleanArgumentsProvider implements ArgumentsProvider {
void isMfaEnabled_WhenEnabled() {
identityZone.getConfig().getMfaConfig().setEnabled(true); @Override
assertTrue(mfaChecker.isMfaEnabled(identityZone, UAA)); public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(
Arguments.of(true),
Arguments.of(false)
);
}
} }


@Test @ParameterizedTest
void isMfaEnabled_WhenDisabled() { @ArgumentsSource(BooleanArgumentsProvider.class)
identityZone.getConfig().getMfaConfig().setEnabled(false); void isMfaEnabled(final boolean isMfaEnabled) {
assertFalse(mfaChecker.isMfaEnabled(identityZone, UAA)); identityZone.getConfig().getMfaConfig().setEnabled(isMfaEnabled);
assertEquals(isMfaEnabled, mfaChecker.isMfaEnabled(identityZone));
}

@ParameterizedTest
@ArgumentsSource(BooleanArgumentsProvider.class)
void isMfaEnabledForZoneId(final boolean isMfaEnabled) {
final String zoneId = randomValueStringGenerator.generate();
identityZone.getConfig().getMfaConfig().setEnabled(isMfaEnabled);
assertEquals(isMfaEnabled, mfaChecker.isMfaEnabledForZoneId(zoneId));

verify(mockIdentityZoneProvisioning).retrieve(zoneId);
} }


@Test @Test
void mfaIsRequiredWhenCorrectOriginsAreConfigured() { void mfaIsRequiredWhenCorrectOriginsAreConfigured() {
final String randomIdp = randomValueStringGenerator.generate();
identityZone.getConfig().getMfaConfig().setIdentityProviders( identityZone.getConfig().getMfaConfig().setIdentityProviders(
Lists.newArrayList("uaa", "ldap")); Lists.newArrayList("uaa", "george", randomIdp));

assertThat(mfaChecker.isRequired(identityZone, "uaa"), is(true));
assertThat(mfaChecker.isRequired(identityZone, "george"), is(true));
assertThat(mfaChecker.isRequired(identityZone, randomIdp), is(true));


assertThat(mfaChecker.isRequired(identityZone, UAA), is(true));
assertThat(mfaChecker.isRequired(identityZone, "other"), is(false)); assertThat(mfaChecker.isRequired(identityZone, "other"), is(false));
assertThat(mfaChecker.isRequired(identityZone, null), is(false));
assertThat(mfaChecker.isRequired(identityZone, ""), is(false));
assertThat(mfaChecker.isRequired(identityZone, randomValueStringGenerator.generate()), is(false));
} }


@Test @Test
Expand Down
Expand Up @@ -69,7 +69,7 @@ public void setup() throws Exception {
mfaChecker, mfaChecker,
entryPoint entryPoint
); );
when(mfaChecker.isMfaEnabled(any(IdentityZone.class), anyString())).thenReturn(true); when(mfaChecker.isMfaEnabled(any(IdentityZone.class))).thenReturn(true);
request = new MockHttpServletRequest(); request = new MockHttpServletRequest();
response = new MockHttpServletResponse(); response = new MockHttpServletResponse();
SecurityContextHolder.getContext().setAuthentication(authentication); SecurityContextHolder.getContext().setAuthentication(authentication);
Expand Down Expand Up @@ -116,7 +116,7 @@ public void mfa_present() throws Exception {


@Test @Test
public void mfa_not_enabled() throws Exception { public void mfa_not_enabled() throws Exception {
when(mfaChecker.isMfaEnabled(any(IdentityZone.class), anyString())).thenReturn(false); when(mfaChecker.isMfaEnabled(any(IdentityZone.class))).thenReturn(false);
assertFalse(filter.isMfaRequiredAndMissing()); assertFalse(filter.isMfaRequiredAndMissing());
} }


Expand Down
Expand Up @@ -14,6 +14,7 @@
import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;


import org.cloudfoundry.identity.uaa.zone.IdentityZoneProvisioning;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -63,21 +64,19 @@ class MfaUiRequiredFilterTests {
private HttpServletResponse response; private HttpServletResponse response;
private FilterChain chain; private FilterChain chain;
private MfaUiRequiredFilter filter; private MfaUiRequiredFilter filter;
private IdentityProviderProvisioning providerProvisioning;
private AntPathRequestMatcher logoutMatcher; private AntPathRequestMatcher logoutMatcher;
private IdentityZone mfaEnabledZone; private IdentityZone mfaEnabledZone;


@BeforeEach @BeforeEach
void setup() throws Exception { void setup() throws Exception {
providerProvisioning = mock(IdentityProviderProvisioning.class);
requestCache = mock(RequestCache.class); requestCache = mock(RequestCache.class);
logoutMatcher = new AntPathRequestMatcher("/logout.do"); logoutMatcher = new AntPathRequestMatcher("/logout.do");
filter = new MfaUiRequiredFilter("/login/mfa/**", filter = new MfaUiRequiredFilter("/login/mfa/**",
"/login/mfa/register", "/login/mfa/register",
requestCache, requestCache,
"/login/mfa/completed", "/login/mfa/completed",
logoutMatcher, logoutMatcher,
new MfaChecker(providerProvisioning)); new MfaChecker(mock(IdentityZoneProvisioning.class)));
spyFilter = spy(filter); spyFilter = spy(filter);
request = new MockHttpServletRequest(); request = new MockHttpServletRequest();
usernameAuthentication = new UsernamePasswordAuthenticationToken("fake-principal","fake-credentials"); usernameAuthentication = new UsernamePasswordAuthenticationToken("fake-principal","fake-credentials");
Expand Down

0 comments on commit 58f15bd

Please sign in to comment.