Skip to content

Commit

Permalink
Add static attributes to the SAML assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
fhanik committed Oct 11, 2017
1 parent e34676d commit d770a48
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
Expand Up @@ -65,8 +65,12 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static java.util.Optional.ofNullable;


public class IdpWebSsoProfileImpl extends WebSSOProfileImpl implements IdpWebSsoProfile {
Expand Down Expand Up @@ -311,7 +315,28 @@ protected void buildAttributeStatement(Assertion assertion, Authentication authe
Attribute zoneAttribute = buildStringAttribute("zoneId", Collections.singletonList(principal.getZoneId()));
attributeStatement.getAttributes().add(zoneAttribute);

Map<String, Object> attributeMappings = samlServiceProviderProvisioning.retrieveByEntityId(providerEntityId, IdentityZoneHolder.get().getId()).getConfig().getAttributeMappings();
SamlServiceProviderDefinition config = samlServiceProviderProvisioning.retrieveByEntityId(providerEntityId, IdentityZoneHolder.get().getId()).getConfig();

//static attributes
for (Map.Entry<String,Object> staticAttribute : (ofNullable(config.getStaticCustomAttributes()).orElse(Collections.emptyMap())).entrySet()) {
String name = staticAttribute.getKey();
Object value = staticAttribute.getValue();
if (value==null) {
continue;
}

List values = new LinkedList<>();
if (value instanceof List) {
values = (List) value;
} else {
values.add(value);
}

List<String> stringValues = (List) values.stream().map(s -> s==null ? "null" : s.toString()).collect(Collectors.toList());
attributeStatement.getAttributes().add(buildStringAttribute(name, stringValues));
}

Map<String, Object> attributeMappings = config.getAttributeMappings();

if (attributeMappings.size() > 0) {
ScimUser user = scimUserProvisioning.retrieve(principal.getId(), IdentityZoneHolder.get().getId());
Expand Down
Expand Up @@ -18,22 +18,26 @@
import org.opensaml.saml2.metadata.provider.MetadataProviderException;
import org.opensaml.ws.message.encoder.MessageEncodingException;
import org.opensaml.xml.ConfigurationException;
import org.opensaml.xml.XMLObject;
import org.opensaml.xml.io.MarshallingException;
import org.opensaml.xml.schema.XSString;
import org.opensaml.xml.security.SecurityException;
import org.opensaml.xml.signature.SignatureException;
import org.springframework.security.core.Authentication;
import org.springframework.security.saml.context.SAMLMessageContext;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -175,6 +179,10 @@ public void verifyAttributeMappings() throws Exception {
String phone = "123";
user.setPhoneNumbers(Collections.singletonList(new ScimUser.PhoneNumber(phone)));
when(scimUserProvisioning.extractPhoneNumber(any(ScimUser.class))).thenReturn(phone);
Map<String, Object> staticAttributes = new HashMap<>();
staticAttributes.put("organization-id","12345");
staticAttributes.put("organization-dba", Arrays.asList("The Org", "Acme Inc"));
samlServiceProvider.getConfig().setStaticCustomAttributes(staticAttributes);

Map<String, Object> attributeMappings = new HashMap<>();
attributeMappings.put("given_name", "first_name");
Expand All @@ -198,6 +206,8 @@ public void verifyAttributeMappings() throws Exception {
assertAttributeValue(attributes, "first_name", user.getGivenName());
assertAttributeValue(attributes, "last_name", user.getFamilyName());
assertAttributeValue(attributes, "cell_phone", user.getPhoneNumbers().get(0).getValue());
assertAttributeValue(attributes, "organization-dba", "The Org", "Acme Inc");
assertAttributeValue(attributes, "organization-id", "12345");
}

@Test
Expand Down Expand Up @@ -246,14 +256,15 @@ private void assertAttributeDoesNotExist(List<Attribute> attributeList, String n
}

private void assertAttributeValue(List<Attribute> attributeList, String name, String expectedValue) {
assertAttributeValue(attributeList, name, new String[] {expectedValue});
}

private void assertAttributeValue(List<Attribute> attributeList, String name, String... expectedValue) {
for (Attribute attribute : attributeList) {
if (attribute.getName().equals(name)) {
if (1 != attribute.getAttributeValues().size()) {
Assert.fail(String.format("More than one attribute value with name of '%s'.", name));
}
XSString xsString = (XSString) attribute.getAttributeValues().get(0);
Assert.assertEquals(String.format("Attribute mismatch for '%s'.", name), expectedValue,
xsString.getValue());
List<XMLObject> xsString = attribute.getAttributeValues();
List<String> attributeValues = xsString.stream().map(xs -> ((XSString)xs).getValue()).collect(Collectors.toList());
assertThat(String.format("Attribute mismatch for '%s'.", name), attributeValues, containsInAnyOrder(expectedValue));
return;
}
}
Expand Down

0 comments on commit d770a48

Please sign in to comment.