Skip to content

Commit

Permalink
KEYCLOAK-4035 Composite roles need to be expanded in SAML attribute m…
Browse files Browse the repository at this point in the history
…apper
  • Loading branch information
hmlnarik committed Dec 5, 2016
1 parent 32da5fe commit 3c41140
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 47 deletions.
33 changes: 33 additions & 0 deletions server-spi/src/main/java/org/keycloak/models/utils/RoleUtils.java
Expand Up @@ -20,7 +20,11 @@
import org.keycloak.models.GroupModel; import org.keycloak.models.GroupModel;
import org.keycloak.models.RoleModel; import org.keycloak.models.RoleModel;


import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.stream.Stream;
import java.util.stream.StreamSupport; import java.util.stream.StreamSupport;


/** /**
Expand Down Expand Up @@ -98,4 +102,33 @@ public static boolean hasRoleFromGroup(Iterable<GroupModel> groups, RoleModel ta
.anyMatch(group -> hasRoleFromGroup(group, targetRole, checkParentGroup)); .anyMatch(group -> hasRoleFromGroup(group, targetRole, checkParentGroup));
} }


/**
* Recursively expands composite roles into their composite.
* @param role
* @return Stream of containing all of the composite roles and their components.
*/
public static Stream<RoleModel> expandCompositeRolesStream(RoleModel role) {
Stream.Builder<RoleModel> sb = Stream.builder();
Set<RoleModel> roles = new HashSet<>();

Deque<RoleModel> stack = new ArrayDeque<>();
stack.add(role);

while (! stack.isEmpty()) {
RoleModel current = stack.pop();
sb.add(current);

if (current.isComposite()) {
current.getComposites().stream()
.filter(r -> ! roles.contains(r))
.forEach(r -> {
roles.add(r);
stack.add(r);
});
}
}

return sb.build();
}

} }
Expand Up @@ -22,10 +22,9 @@
import org.keycloak.models.RoleModel; import org.keycloak.models.RoleModel;
import org.keycloak.models.UserModel; import org.keycloak.models.UserModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.utils.RoleUtils;
import org.keycloak.representations.IDToken; import org.keycloak.representations.IDToken;


import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Set; import java.util.Set;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
Expand Down Expand Up @@ -54,7 +53,7 @@ public static Stream<RoleModel> getAllUserRolesStream(UserModel user) {
user.getGroups().stream() user.getGroups().stream()
.flatMap(g -> groupAndItsParentsStream(g)) .flatMap(g -> groupAndItsParentsStream(g))
.flatMap(g -> g.getRoleMappings().stream())) .flatMap(g -> g.getRoleMappings().stream()))
.flatMap(role -> expandCompositeRolesStream(role)); .flatMap(RoleUtils::expandCompositeRolesStream);
} }


/** /**
Expand All @@ -71,29 +70,6 @@ private static Stream<GroupModel> groupAndItsParentsStream(GroupModel group) {
return sb.build(); return sb.build();
} }


/**
* Recursively expands composite roles into their composite.
* @param role
* @return Stream of containing all of the composite roles and their components.
*/
private static Stream<RoleModel> expandCompositeRolesStream(RoleModel role) {
Stream.Builder<RoleModel> sb = Stream.builder();

Deque<RoleModel> stack = new ArrayDeque<>();
stack.add(role);

while (! stack.isEmpty()) {
RoleModel current = stack.pop();
sb.add(current);

if (current.isComposite()) {
stack.addAll(current.getComposites());
}
}

return sb.build();
}

/** /**
* Retrieves all roles of the current user based on direct roles set to the user, its groups and their parent groups. * Retrieves all roles of the current user based on direct roles set to the user, its groups and their parent groups.
* Then it recursively expands all composite roles, and restricts according to the given predicate {@code restriction}. * Then it recursively expands all composite roles, and restricts according to the given predicate {@code restriction}.
Expand Down
Expand Up @@ -35,12 +35,13 @@
public class HardcodedRole extends AbstractSAMLProtocolMapper { public class HardcodedRole extends AbstractSAMLProtocolMapper {
public static final String PROVIDER_ID = "saml-hardcode-role-mapper"; public static final String PROVIDER_ID = "saml-hardcode-role-mapper";
public static final String ATTRIBUTE_VALUE = "attribute.value"; public static final String ATTRIBUTE_VALUE = "attribute.value";
private static final List<ProviderConfigProperty> configProperties = new ArrayList<ProviderConfigProperty>(); private static final List<ProviderConfigProperty> configProperties = new ArrayList<>();
public static final String ROLE_ATTRIBUTE = "role";


static { static {
ProviderConfigProperty property; ProviderConfigProperty property;
property = new ProviderConfigProperty(); property = new ProviderConfigProperty();
property.setName("role"); property.setName(ROLE_ATTRIBUTE);
property.setLabel("Role"); property.setLabel("Role");
property.setHelpText("Arbitrary role name you want to hardcode. This role does not have to exist in current realm and can be just any string you need"); property.setHelpText("Arbitrary role name you want to hardcode. This role does not have to exist in current realm and can be just any string you need");
property.setType(ProviderConfigProperty.ROLE_TYPE); property.setType(ProviderConfigProperty.ROLE_TYPE);
Expand Down Expand Up @@ -79,8 +80,8 @@ public static ProtocolMapperModel create(String name,
mapper.setName(name); mapper.setName(name);
mapper.setProtocolMapper(mapperId); mapper.setProtocolMapper(mapperId);
mapper.setProtocol(SamlProtocol.LOGIN_PROTOCOL); mapper.setProtocol(SamlProtocol.LOGIN_PROTOCOL);
Map<String, String> config = new HashMap<String, String>(); Map<String, String> config = new HashMap<>();
config.put("role", role); config.put(ROLE_ATTRIBUTE, role);
mapper.setConfig(config); mapper.setConfig(config);
return mapper; return mapper;


Expand Down
Expand Up @@ -23,8 +23,9 @@
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory; import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.ProtocolMapperModel; import org.keycloak.models.ProtocolMapperModel;
import org.keycloak.models.RoleModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.utils.RoleUtils;
import org.keycloak.protocol.ProtocolMapper; import org.keycloak.protocol.ProtocolMapper;
import org.keycloak.protocol.saml.SamlProtocol; import org.keycloak.protocol.saml.SamlProtocol;
import org.keycloak.provider.ProviderConfigProperty; import org.keycloak.provider.ProviderConfigProperty;
Expand All @@ -35,7 +36,9 @@
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;


/** /**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a> * @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
Expand All @@ -45,7 +48,7 @@ public class RoleListMapper extends AbstractSAMLProtocolMapper implements SAMLRo
public static final String PROVIDER_ID = "saml-role-list-mapper"; public static final String PROVIDER_ID = "saml-role-list-mapper";
public static final String SINGLE_ROLE_ATTRIBUTE = "single"; public static final String SINGLE_ROLE_ATTRIBUTE = "single";


private static final List<ProviderConfigProperty> configProperties = new ArrayList<ProviderConfigProperty>(); private static final List<ProviderConfigProperty> configProperties = new ArrayList<>();


static { static {
ProviderConfigProperty property; ProviderConfigProperty property;
Expand Down Expand Up @@ -120,11 +123,13 @@ public void mapRoles(AttributeStatementType roleAttributeStatement, ProtocolMapp


ProtocolMapper mapper = (ProtocolMapper)sessionFactory.getProviderFactory(ProtocolMapper.class, mapping.getProtocolMapper()); ProtocolMapper mapper = (ProtocolMapper)sessionFactory.getProviderFactory(ProtocolMapper.class, mapping.getProtocolMapper());
if (mapper == null) continue; if (mapper == null) continue;

if (mapper instanceof SAMLRoleNameMapper) { if (mapper instanceof SAMLRoleNameMapper) {
roleNameMappers.add(new SamlProtocol.ProtocolMapperProcessor<>((SAMLRoleNameMapper) mapper,mapping)); roleNameMappers.add(new SamlProtocol.ProtocolMapperProcessor<>((SAMLRoleNameMapper) mapper,mapping));
} }

if (mapper instanceof HardcodedRole) { if (mapper instanceof HardcodedRole) {
AttributeType attributeType = null; AttributeType attributeType;
if (singleAttribute) { if (singleAttribute) {
if (singleAttributeType == null) { if (singleAttributeType == null) {
singleAttributeType = AttributeStatementHelper.createAttributeType(mappingModel); singleAttributeType = AttributeStatementHelper.createAttributeType(mappingModel);
Expand All @@ -135,14 +140,26 @@ public void mapRoles(AttributeStatementType roleAttributeStatement, ProtocolMapp
attributeType = AttributeStatementHelper.createAttributeType(mappingModel); attributeType = AttributeStatementHelper.createAttributeType(mappingModel);
roleAttributeStatement.addAttribute(new AttributeStatementType.ASTChoiceType(attributeType)); roleAttributeStatement.addAttribute(new AttributeStatementType.ASTChoiceType(attributeType));
} }
attributeType.addAttributeValue(mapping.getConfig().get("role"));
attributeType.addAttributeValue(mapping.getConfig().get(HardcodedRole.ROLE_ATTRIBUTE));
} }
} }


for (String roleId : clientSession.getRoles()) { RealmModel realm = clientSession.getRealm();
// todo need a role mapping List<String> allRoleNames = clientSession.getRoles().stream()
RoleModel roleModel = clientSession.getRealm().getRoleById(roleId); // todo need a role mapping
AttributeType attributeType = null; .map(realm::getRoleById)
.filter(Objects::nonNull)
.flatMap(RoleUtils::expandCompositeRolesStream)
.map(roleModel -> roleNameMappers.stream()
.map(entry -> entry.mapper.mapName(entry.model, roleModel))
.filter(Objects::nonNull)
.findFirst()
.orElse(roleModel.getName())
).collect(Collectors.toList());

for (String roleName : allRoleNames) {
AttributeType attributeType;
if (singleAttribute) { if (singleAttribute) {
if (singleAttributeType == null) { if (singleAttributeType == null) {
singleAttributeType = AttributeStatementHelper.createAttributeType(mappingModel); singleAttributeType = AttributeStatementHelper.createAttributeType(mappingModel);
Expand All @@ -153,14 +170,7 @@ public void mapRoles(AttributeStatementType roleAttributeStatement, ProtocolMapp
attributeType = AttributeStatementHelper.createAttributeType(mappingModel); attributeType = AttributeStatementHelper.createAttributeType(mappingModel);
roleAttributeStatement.addAttribute(new AttributeStatementType.ASTChoiceType(attributeType)); roleAttributeStatement.addAttribute(new AttributeStatementType.ASTChoiceType(attributeType));
} }
String roleName = roleModel.getName();
for (SamlProtocol.ProtocolMapperProcessor<SAMLRoleNameMapper> entry : roleNameMappers) {
String newName = entry.mapper.mapName(entry.model, roleModel);
if (newName != null) {
roleName = newName;
break;
}
}
attributeType.addAttributeValue(roleName); attributeType.addAttributeValue(roleName);
} }


Expand All @@ -172,7 +182,7 @@ public static ProtocolMapperModel create(String name, String samlAttributeName,
mapper.setProtocolMapper(PROVIDER_ID); mapper.setProtocolMapper(PROVIDER_ID);
mapper.setProtocol(SamlProtocol.LOGIN_PROTOCOL); mapper.setProtocol(SamlProtocol.LOGIN_PROTOCOL);
mapper.setConsentRequired(false); mapper.setConsentRequired(false);
Map<String, String> config = new HashMap<String, String>(); Map<String, String> config = new HashMap<>();
config.put(AttributeStatementHelper.SAML_ATTRIBUTE_NAME, samlAttributeName); config.put(AttributeStatementHelper.SAML_ATTRIBUTE_NAME, samlAttributeName);
if (friendlyName != null) { if (friendlyName != null) {
config.put(AttributeStatementHelper.FRIENDLY_NAME, friendlyName); config.put(AttributeStatementHelper.FRIENDLY_NAME, friendlyName);
Expand Down
Expand Up @@ -22,6 +22,7 @@
import org.jboss.shrinkwrap.api.spec.WebArchive; import org.jboss.shrinkwrap.api.spec.WebArchive;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;

import org.keycloak.admin.client.resource.ClientResource; import org.keycloak.admin.client.resource.ClientResource;
import org.keycloak.admin.client.resource.ProtocolMappersResource; import org.keycloak.admin.client.resource.ProtocolMappersResource;
import org.keycloak.admin.client.resource.RoleScopeResource; import org.keycloak.admin.client.resource.RoleScopeResource;
Expand Down Expand Up @@ -71,6 +72,7 @@
import org.keycloak.testsuite.page.AbstractPage; import org.keycloak.testsuite.page.AbstractPage;
import org.keycloak.testsuite.util.IOUtil; import org.keycloak.testsuite.util.IOUtil;
import org.keycloak.testsuite.util.UserBuilder; import org.keycloak.testsuite.util.UserBuilder;

import org.openqa.selenium.By; import org.openqa.selenium.By;
import org.w3c.dom.Document; import org.w3c.dom.Document;
import org.xml.sax.SAXException; import org.xml.sax.SAXException;
Expand Down Expand Up @@ -104,6 +106,7 @@
import static org.keycloak.representations.idm.CredentialRepresentation.PASSWORD; import static org.keycloak.representations.idm.CredentialRepresentation.PASSWORD;
import static org.keycloak.testsuite.AbstractAuthTest.createUserRepresentation; import static org.keycloak.testsuite.AbstractAuthTest.createUserRepresentation;
import static org.keycloak.testsuite.admin.ApiUtil.createUserAndResetPasswordWithAdminClient; import static org.keycloak.testsuite.admin.ApiUtil.createUserAndResetPasswordWithAdminClient;
import static org.keycloak.testsuite.admin.Users.setPasswordFor;
import static org.keycloak.testsuite.auth.page.AuthRealm.SAMLSERVLETDEMO; import static org.keycloak.testsuite.auth.page.AuthRealm.SAMLSERVLETDEMO;
import static org.keycloak.testsuite.util.IOUtil.loadRealm; import static org.keycloak.testsuite.util.IOUtil.loadRealm;
import static org.keycloak.testsuite.util.IOUtil.loadXML; import static org.keycloak.testsuite.util.IOUtil.loadXML;
Expand Down Expand Up @@ -529,6 +532,14 @@ public void salesMetadataTest() throws Exception {
testSuccessfulAndUnauthorizedLogin(salesMetadataServletPage, testRealmSAMLPostLoginPage); testSuccessfulAndUnauthorizedLogin(salesMetadataServletPage, testRealmSAMLPostLoginPage);
} }


@Test
public void salesPostTestCompositeRoleForUser() {
UserRepresentation topGroupUser = createUserRepresentation("topGroupUser", "top@redhat.com", "", "", true);
setPasswordFor(topGroupUser, PASSWORD);

assertSuccessfulLogin(salesPostServletPage, topGroupUser, testRealmSAMLPostLoginPage, "principal=topgroupuser");
}

@Test @Test
public void salesPostTest() { public void salesPostTest() {
testSuccessfulAndUnauthorizedLogin(salesPostServletPage, testRealmSAMLPostLoginPage); testSuccessfulAndUnauthorizedLogin(salesPostServletPage, testRealmSAMLPostLoginPage);
Expand Down
Expand Up @@ -49,6 +49,7 @@
{ "type" : "password", { "type" : "password",
"value" : "password" } "value" : "password" }
], ],
"realmRoles": [ "realm-composite-role" ],
"groups": [ "groups": [
"/top" "/top"
] ]
Expand All @@ -75,6 +76,14 @@
{ {
"name": "admin", "name": "admin",
"description": "Administrator privileges" "description": "Administrator privileges"
},
{
"name": "realm-composite-role",
"description": "Realm composite role containing user role",
"composite": true,
"composites": {
"realm": ["user"]
}
} }
] ]
}, },
Expand Down

0 comments on commit 3c41140

Please sign in to comment.