Skip to content

Commit

Permalink
Filter identity providers by client in the login page
Browse files Browse the repository at this point in the history
[84601994] https://www.pivotaltracker.com/story/show/84601994

Signed-off-by: Bokuk Seo <bkseo74@gmail.com>
Signed-off-by: Madhura Bhave <mbhave@pivotal.io>
  • Loading branch information
Chris Dutra committed Mar 6, 2015
1 parent 1b19f24 commit 0e45ff5
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 33 deletions.
Expand Up @@ -13,6 +13,7 @@
package org.cloudfoundry.identity.uaa.authentication.login; package org.cloudfoundry.identity.uaa.authentication.login;


import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;


import java.io.IOException; import java.io.IOException;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
Expand All @@ -24,9 +25,7 @@
import java.security.Principal; import java.security.Principal;
import java.sql.Timestamp; import java.sql.Timestamp;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
Expand All @@ -36,8 +35,6 @@
import java.util.Map; import java.util.Map;
import java.util.Properties; import java.util.Properties;


import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE;

import org.cloudfoundry.identity.uaa.authentication.AuthzAuthenticationRequest; import org.cloudfoundry.identity.uaa.authentication.AuthzAuthenticationRequest;
import org.cloudfoundry.identity.uaa.authentication.Origin; import org.cloudfoundry.identity.uaa.authentication.Origin;
import org.cloudfoundry.identity.uaa.authentication.UaaAuthentication; import org.cloudfoundry.identity.uaa.authentication.UaaAuthentication;
Expand All @@ -48,28 +45,24 @@
import org.cloudfoundry.identity.uaa.login.AutologinRequest; import org.cloudfoundry.identity.uaa.login.AutologinRequest;
import org.cloudfoundry.identity.uaa.login.AutologinResponse; import org.cloudfoundry.identity.uaa.login.AutologinResponse;
import org.cloudfoundry.identity.uaa.login.PasscodeInformation; import org.cloudfoundry.identity.uaa.login.PasscodeInformation;
import org.cloudfoundry.identity.uaa.login.SamlUserDetails;
import org.cloudfoundry.identity.uaa.login.saml.IdentityProviderConfigurator; import org.cloudfoundry.identity.uaa.login.saml.IdentityProviderConfigurator;
import org.cloudfoundry.identity.uaa.login.saml.IdentityProviderDefinition; import org.cloudfoundry.identity.uaa.login.saml.IdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.login.saml.LoginSamlAuthenticationToken; import org.cloudfoundry.identity.uaa.login.saml.LoginSamlAuthenticationToken;
import org.cloudfoundry.identity.uaa.user.UaaAuthority; import org.cloudfoundry.identity.uaa.user.UaaAuthority;
import org.cloudfoundry.identity.uaa.util.UaaStringUtils; import org.cloudfoundry.identity.uaa.util.UaaStringUtils;
import org.cloudfoundry.identity.uaa.util.UaaUrlUtils; import org.cloudfoundry.identity.uaa.util.UaaUrlUtils;
import org.codehaus.jackson.map.JsonMappingException;
import org.codehaus.jackson.map.ObjectMapper; import org.codehaus.jackson.map.ObjectMapper;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.springframework.core.io.support.PropertiesLoaderUtils; import org.springframework.core.io.support.PropertiesLoaderUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.crypto.codec.Base64; import org.springframework.security.crypto.codec.Base64;
import org.springframework.security.providers.ExpiringUsernameAuthenticationToken; import org.springframework.security.oauth2.provider.ClientDetails;
import org.springframework.security.oauth2.provider.ClientDetailsService;
import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.ui.Model; import org.springframework.ui.Model;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -105,6 +98,7 @@ public class LoginInfoEndpoint {
private AuthenticationManager authenticationManager; private AuthenticationManager authenticationManager;


private ExpiringCodeStore expiringCodeStore; private ExpiringCodeStore expiringCodeStore;
private ClientDetailsService clientDetailsService;


public void setExpiringCodeStore(ExpiringCodeStore expiringCodeStore) { public void setExpiringCodeStore(ExpiringCodeStore expiringCodeStore) {
this.expiringCodeStore = expiringCodeStore; this.expiringCodeStore = expiringCodeStore;
Expand Down Expand Up @@ -154,16 +148,6 @@ public LoginInfoEndpoint() {
} }
} }


protected List<IdentityProviderDefinition> filterIdpsForZone() {
List<IdentityProviderDefinition> result = new LinkedList<>();
for (IdentityProviderDefinition def : idpDefinitions.getIdentityProviderDefinitions()) {
if (IdentityZoneHolder.get().getId().equals(def.getZoneId())) {
result.add(def);
}
}
return result;
}

private List<Prompt> prompts = Arrays.asList(new Prompt("username", "text", "Email"), new Prompt("password", private List<Prompt> prompts = Arrays.asList(new Prompt("username", "text", "Email"), new Prompt("password",
"password", "Password")); "password", "Password"));


Expand Down Expand Up @@ -191,19 +175,24 @@ public String infoForHtml(Model model, Principal principal) {
} }


@RequestMapping(value = {"/login" }, headers = "Accept=text/html, */*") @RequestMapping(value = {"/login" }, headers = "Accept=text/html, */*")
public String loginForHtml(Model model, Principal principal) { public String loginForHtml(Model model, Principal principal, HttpServletRequest request) {
return login(model, principal, Arrays.asList("passcode"), false); return login(model, principal, Arrays.asList("passcode"), false, request);
}

private String login(Model model, Principal principal, List<String> excludedPrompts, boolean nonHtml) {
return login(model, principal, excludedPrompts, nonHtml, null);
} }


public String login(Model model, Principal principal, List<String> excludedPrompts, boolean nonHtml) { private String login(Model model, Principal principal, List<String> excludedPrompts, boolean nonHtml, HttpServletRequest request) {
populatePrompts(model, excludedPrompts, nonHtml); populatePrompts(model, excludedPrompts, nonHtml);
setCommitInfo(model); setCommitInfo(model);
model.addAttribute("zone_name", IdentityZoneHolder.get().getName()); model.addAttribute("zone_name", IdentityZoneHolder.get().getName());
model.addAttribute("links", getLinksInfo()); model.addAttribute("links", getLinksInfo());


// Entity ID to start the discovery // Entity ID to start the discovery
model.addAttribute("entityID", UaaUrlUtils.getSubdomain() + entityID); model.addAttribute("entityID", UaaUrlUtils.getSubdomain() + entityID);
List<IdentityProviderDefinition> idps = filterIdpsForZone();
List<IdentityProviderDefinition> idps = getIdentityProviderDefinitions(request != null ? request.getSession(false) : null);
model.addAttribute("idpDefinitions", idps); model.addAttribute("idpDefinitions", idps);
for (IdentityProviderDefinition idp : idps) { for (IdentityProviderDefinition idp : idps) {
if(idp.isShowSamlLink()) { if(idp.isShowSamlLink()) {
Expand Down Expand Up @@ -233,6 +222,21 @@ public String login(Model model, Principal principal, List<String> excludedPromp
return "home"; return "home";
} }


private List<IdentityProviderDefinition> getIdentityProviderDefinitions(HttpSession session) {
List<IdentityProviderDefinition> idps = idpDefinitions.getIdentityProviderDefinitionsForZone(IdentityZoneHolder.get());
SavedRequest savedRequest;
if (session != null && (savedRequest = (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST")) != null) {
String redirectUrl = savedRequest.getRedirectUrl();
String[] client_ids = savedRequest.getParameterValues("client_id");
if (redirectUrl != null && redirectUrl.contains("/oauth/authorize") && client_ids != null && client_ids.length != 0) {
ClientDetails clientDetails = clientDetailsService.loadClientByClientId(client_ids[0]);
List<String> allowedIdps = (List<String>) clientDetails.getAdditionalInformation().get("allowedproviders");
idps = idpDefinitions.getIdentityProviderDefinitionsForClient(allowedIdps, IdentityZoneHolder.get(), false);
}
}
return idps;
}

private void setCommitInfo(Model model) { private void setCommitInfo(Model model) {
model.addAttribute("commit_id", gitProperties.getProperty("git.commit.id.abbrev", "UNKNOWN")); model.addAttribute("commit_id", gitProperties.getProperty("git.commit.id.abbrev", "UNKNOWN"));
model.addAttribute( model.addAttribute(
Expand Down Expand Up @@ -424,6 +428,10 @@ protected String extractPath(HttpServletRequest request) {
return path; return path;
} }


public void setClientDetailsService(ClientDetailsService clientDetailsService) {
this.clientDetailsService = clientDetailsService;
}

@ResponseStatus(value = HttpStatus.FORBIDDEN, reason = "Unknown authentication token type, unable to derive user ID.") @ResponseStatus(value = HttpStatus.FORBIDDEN, reason = "Unknown authentication token type, unable to derive user ID.")
public static final class UnknownPrincipalException extends RuntimeException {} public static final class UnknownPrincipalException extends RuntimeException {}


Expand Down
Expand Up @@ -23,7 +23,6 @@
import org.opensaml.saml2.metadata.provider.MetadataProviderException; import org.opensaml.saml2.metadata.provider.MetadataProviderException;
import org.opensaml.xml.parse.BasicParserPool; import org.opensaml.xml.parse.BasicParserPool;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessor;
import org.springframework.security.saml.metadata.ExtendedMetadata; import org.springframework.security.saml.metadata.ExtendedMetadata;
import org.springframework.security.saml.metadata.ExtendedMetadataDelegate; import org.springframework.security.saml.metadata.ExtendedMetadataDelegate;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -56,6 +55,31 @@ public class IdentityProviderConfigurator implements InitializingBean {
public List<IdentityProviderDefinition> getIdentityProviderDefinitions() { public List<IdentityProviderDefinition> getIdentityProviderDefinitions() {
return Collections.unmodifiableList(identityProviders); return Collections.unmodifiableList(identityProviders);
} }

public List<IdentityProviderDefinition> getIdentityProviderDefinitionsForZone(IdentityZone zone) {
List<IdentityProviderDefinition> result = new LinkedList<>();
for (IdentityProviderDefinition def : getIdentityProviderDefinitions()) {
if (zone.getId().equals(def.getZoneId())) {
result.add(def);
}
}
return result;
}

public List<IdentityProviderDefinition> getIdentityProviderDefinitionsForClient(List<String> allowedIdps, IdentityZone zone, boolean allowEmptyDefaultIdpList) {
List<IdentityProviderDefinition> idpsInTheZone = getIdentityProviderDefinitionsForZone(zone);
if (allowedIdps != null && !allowedIdps.isEmpty()) {
List<IdentityProviderDefinition> result = new LinkedList<>();
for (IdentityProviderDefinition def : idpsInTheZone) {
if (allowedIdps.contains(def.getIdpEntityAlias())) {
result.add(def);
}
}
return result;
}
return allowEmptyDefaultIdpList ? Collections.<IdentityProviderDefinition>emptyList() : idpsInTheZone;
}

protected List<IdentityProviderDefinition> parseIdentityProviderDefinitions() { protected List<IdentityProviderDefinition> parseIdentityProviderDefinitions() {
List<IdentityProviderDefinition> providerDefinitions = new LinkedList<>(identityProviders); List<IdentityProviderDefinition> providerDefinitions = new LinkedList<>(identityProviders);
if (getLegacyIdpMetaData()!=null) { if (getLegacyIdpMetaData()!=null) {
Expand Down
Expand Up @@ -9,26 +9,34 @@
import org.cloudfoundry.identity.uaa.login.saml.IdentityProviderConfigurator; import org.cloudfoundry.identity.uaa.login.saml.IdentityProviderConfigurator;
import org.cloudfoundry.identity.uaa.login.saml.IdentityProviderDefinition; import org.cloudfoundry.identity.uaa.login.saml.IdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.login.saml.LoginSamlAuthenticationToken; import org.cloudfoundry.identity.uaa.login.saml.LoginSamlAuthenticationToken;
import org.cloudfoundry.identity.uaa.oauth.RemoteUserAuthentication;
import org.cloudfoundry.identity.uaa.zone.IdentityProvider;
import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.mock.env.MockEnvironment; import org.springframework.mock.env.MockEnvironment;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.provider.ClientDetailsService;
import org.springframework.security.oauth2.provider.client.BaseClientDetails;
import org.springframework.security.providers.ExpiringUsernameAuthenticationToken; import org.springframework.security.providers.ExpiringUsernameAuthenticationToken;
import org.springframework.security.web.savedrequest.SavedRequest;
import org.springframework.ui.ExtendedModelMap; import org.springframework.ui.ExtendedModelMap;
import org.springframework.ui.Model; import org.springframework.ui.Model;


import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map; import java.util.Map;


import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;


public class LoginInfoEndpointTest { public class LoginInfoEndpointTest {


Expand All @@ -50,7 +58,7 @@ public void testLoginReturnsSystemZone() throws Exception {
LoginInfoEndpoint endpoint = getEndpoint(); LoginInfoEndpoint endpoint = getEndpoint();
Model model = new ExtendedModelMap(); Model model = new ExtendedModelMap();
assertFalse(model.containsAttribute("zone_name")); assertFalse(model.containsAttribute("zone_name"));
endpoint.loginForHtml(model, null); endpoint.loginForHtml(model, null, new MockHttpServletRequest());
assertEquals(Origin.UAA, model.asMap().get("zone_name")); assertEquals(Origin.UAA, model.asMap().get("zone_name"));
} }


Expand All @@ -63,7 +71,7 @@ public void testLoginReturnsOtherZone() throws Exception {
LoginInfoEndpoint endpoint = getEndpoint(); LoginInfoEndpoint endpoint = getEndpoint();
Model model = new ExtendedModelMap(); Model model = new ExtendedModelMap();
assertFalse(model.containsAttribute("zone_name")); assertFalse(model.containsAttribute("zone_name"));
endpoint.loginForHtml(model, null); endpoint.loginForHtml(model, null, new MockHttpServletRequest());
assertEquals("some_other_zone", model.asMap().get("zone_name")); assertEquals("some_other_zone", model.asMap().get("zone_name"));
} }


Expand Down Expand Up @@ -92,6 +100,101 @@ public void testGeneratePasscodeForUnknownUaaPrincipal() throws Exception {
assertEquals("passcode", endpoint.generatePasscode(model, token)); assertEquals("passcode", endpoint.generatePasscode(model, token));
} }


@Test
public void testFilterIdpsForZone() throws Exception {
// mock session and saved request
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpSession session = new MockHttpSession();
SavedRequest savedRequest = mock(SavedRequest.class);
when(savedRequest.getParameterValues("client_id")).thenReturn(new String[]{"client-id"});
when(savedRequest.getRedirectUrl()).thenReturn("http://localhost:8080/uaa");
session.setAttribute("SPRING_SECURITY_SAVED_REQUEST", savedRequest);
request.setSession(session);
// mock IdentityProviderConfigurator
List<IdentityProviderDefinition> idps = getIdps();
IdentityProviderConfigurator mockIDPConfigurator = mock(IdentityProviderConfigurator.class);
when(mockIDPConfigurator.getIdentityProviderDefinitionsForZone(IdentityZoneHolder.get())).thenReturn(idps);

LoginInfoEndpoint endpoint = getEndpoint();
endpoint.setIdpDefinitions(mockIDPConfigurator);
Model model = new ExtendedModelMap();
endpoint.loginForHtml(model, null, request);

List<IdentityProviderDefinition> idpDefinitions = (List<IdentityProviderDefinition>) model.asMap().get("idpDefinitions");
assertEquals(2, idpDefinitions.size());

Iterator<IdentityProviderDefinition> iterator = idpDefinitions.iterator();
IdentityProviderDefinition clientIdp = iterator.next();
assertEquals("awesome-idp", clientIdp.getIdpEntityAlias());
assertEquals(true, clientIdp.isShowSamlLink());

clientIdp = iterator.next();
assertEquals("my-client-awesome-idp", clientIdp.getIdpEntityAlias());
assertEquals(true, clientIdp.isShowSamlLink());
}

@Test
public void testFilterIdpsWithNoSavedRequest() throws Exception {
// mock IdentityProviderConfigurator
List<IdentityProviderDefinition> idps = getIdps();
IdentityProviderConfigurator mockIDPConfigurator = mock(IdentityProviderConfigurator.class);
when(mockIDPConfigurator.getIdentityProviderDefinitionsForZone(IdentityZoneHolder.get())).thenReturn(idps);

LoginInfoEndpoint endpoint = getEndpoint();
endpoint.setIdpDefinitions(mockIDPConfigurator);
Model model = new ExtendedModelMap();
endpoint.loginForHtml(model, null, new MockHttpServletRequest());

List<IdentityProviderDefinition> idpDefinitions = (List<IdentityProviderDefinition>) model.asMap().get("idpDefinitions");
assertEquals(2, idpDefinitions.size());

Iterator<IdentityProviderDefinition> iterator = idpDefinitions.iterator();
IdentityProviderDefinition clientIdp = iterator.next();
assertEquals("awesome-idp", clientIdp.getIdpEntityAlias());
assertEquals(true, clientIdp.isShowSamlLink());

clientIdp = iterator.next();
assertEquals("my-client-awesome-idp", clientIdp.getIdpEntityAlias());
assertEquals(true, clientIdp.isShowSamlLink());
}

@Test
public void testFilterIDPsForAuthcodeClient() throws Exception {
// mock session and saved request
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpSession session = new MockHttpSession();
SavedRequest savedRequest = mock(SavedRequest.class);
when(savedRequest.getParameterValues("client_id")).thenReturn(new String[]{"client-id"});
when(savedRequest.getRedirectUrl())
.thenReturn("http://localhost:8080/uaa/oauth/authorize?client_id=identity&redirect_uri=http%3A%2F%2Flocalhost%3A8888%2Flogin&response_type=code&state=8tp0tR");
session.setAttribute("SPRING_SECURITY_SAVED_REQUEST", savedRequest);
request.setSession(session);
// mock Client service
BaseClientDetails clientDetails = new BaseClientDetails();
clientDetails.setClientId("client-id");
clientDetails.addAdditionalInformation("allowedproviders", Arrays.asList("my-client-awesome-idp"));
ClientDetailsService clientDetailsService = mock(ClientDetailsService.class);
when(clientDetailsService.loadClientByClientId("client-id")).thenReturn(clientDetails);

// mock IdentityProviderConfigurator
List<IdentityProviderDefinition> clientIDPs = new LinkedList<>();
clientIDPs.add(getIdentityProviderDefinition("my-client-awesome-idp"));
IdentityProviderConfigurator mockIDPConfigurator = mock(IdentityProviderConfigurator.class);
when(mockIDPConfigurator.getIdentityProviderDefinitionsForClient(Arrays.asList("my-client-awesome-idp"), IdentityZoneHolder.get(), false)).thenReturn(clientIDPs);

LoginInfoEndpoint endpoint = getEndpoint();
endpoint.setClientDetailsService(clientDetailsService);
endpoint.setIdpDefinitions(mockIDPConfigurator);
Model model = new ExtendedModelMap();
endpoint.loginForHtml(model, null, request);

List<IdentityProviderDefinition> idpDefinitions = (List<IdentityProviderDefinition>) model.asMap().get("idpDefinitions");
assertEquals(1, idpDefinitions.size());

IdentityProviderDefinition clientIdp = idpDefinitions.iterator().next();
assertEquals("my-client-awesome-idp", clientIdp.getIdpEntityAlias());
assertEquals(true, clientIdp.isShowSamlLink());
}


private LoginInfoEndpoint getEndpoint() { private LoginInfoEndpoint getEndpoint() {
LoginInfoEndpoint endpoint = new LoginInfoEndpoint(); LoginInfoEndpoint endpoint = new LoginInfoEndpoint();
Expand All @@ -101,4 +204,21 @@ private LoginInfoEndpoint getEndpoint() {
endpoint.setEnvironment(new MockEnvironment()); endpoint.setEnvironment(new MockEnvironment());
return endpoint; return endpoint;
} }
}
private List<IdentityProviderDefinition> getIdps() {
List<IdentityProviderDefinition> idps = new LinkedList<>();

idps.add(getIdentityProviderDefinition("awesome-idp"));
idps.add(getIdentityProviderDefinition("my-client-awesome-idp"));

return idps;
}

private IdentityProviderDefinition getIdentityProviderDefinition(String idpEntityAlias) {
IdentityProviderDefinition idp1 = new IdentityProviderDefinition();
idp1.setIdpEntityAlias(idpEntityAlias);
idp1.setShowSamlLink(true);
idp1.setZoneId("uaa");
return idp1;
}
}

0 comments on commit 0e45ff5

Please sign in to comment.