Skip to content

Commit

Permalink
rest interface for claim mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
patriot1burke committed Feb 25, 2015
1 parent 1704a6c commit 9f759ed
Show file tree
Hide file tree
Showing 20 changed files with 435 additions and 59 deletions.
Expand Up @@ -30,7 +30,7 @@ public class ApplicationRepresentation {
protected Integer nodeReRegistrationTimeout; protected Integer nodeReRegistrationTimeout;
protected Map<String, Integer> registeredNodes; protected Map<String, Integer> registeredNodes;
protected List<String> allowedIdentityProviders; protected List<String> allowedIdentityProviders;
protected Set<String> protocolMappers; protected List<ClientProtocolMappingRepresentation> protocolMappers;


public String getId() { public String getId() {
return id; return id;
Expand Down Expand Up @@ -200,11 +200,11 @@ public void setAllowedIdentityProviders(List<String> allowedIdentityProviders) {
this.allowedIdentityProviders = allowedIdentityProviders; this.allowedIdentityProviders = allowedIdentityProviders;
} }


public Set<String> getProtocolMappers() { public List<ClientProtocolMappingRepresentation> getProtocolMappers() {
return protocolMappers; return protocolMappers;
} }


public void setProtocolMappers(Set<String> protocolMappers) { public void setProtocolMappers(List<ClientProtocolMappingRepresentation> protocolMappers) {
this.protocolMappers = protocolMappers; this.protocolMappers = protocolMappers;
} }
} }
@@ -0,0 +1,26 @@
package org.keycloak.representations.idm;

/**
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class ClientProtocolMappingRepresentation {
protected String protocol;
protected String name;

public String getProtocol() {
return protocol;
}

public void setProtocol(String protocol) {
this.protocol = protocol;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}
}
Expand Up @@ -24,7 +24,7 @@ public class OAuthClientRepresentation {
protected Boolean fullScopeAllowed; protected Boolean fullScopeAllowed;
protected Boolean frontchannelLogout; protected Boolean frontchannelLogout;
protected List<String> allowedIdentityProviders; protected List<String> allowedIdentityProviders;
protected Set<String> protocolClaimMappings; protected List<ClientProtocolMappingRepresentation> protocolMappers;




public String getId() { public String getId() {
Expand Down Expand Up @@ -147,11 +147,11 @@ public void setAllowedIdentityProviders(List<String> allowedIdentityProviders) {
this.allowedIdentityProviders = allowedIdentityProviders; this.allowedIdentityProviders = allowedIdentityProviders;
} }


public Set<String> getProtocolClaimMappings() { public List<ClientProtocolMappingRepresentation> getProtocolMappers() {
return protocolClaimMappings; return protocolMappers;
} }


public void setProtocolClaimMappings(Set<String> protocolClaimMappings) { public void setProtocolMappers(List<ClientProtocolMappingRepresentation> protocolMappers) {
this.protocolClaimMappings = protocolClaimMappings; this.protocolMappers = protocolMappers;
} }
} }
6 changes: 3 additions & 3 deletions model/api/src/main/java/org/keycloak/models/ClientModel.java
Expand Up @@ -105,7 +105,7 @@ public interface ClientModel {
boolean hasIdentityProvider(String providerId); boolean hasIdentityProvider(String providerId);


Set<ProtocolMapperModel> getProtocolMappers(); Set<ProtocolMapperModel> getProtocolMappers();
void addProtocolMappers(Set<String> mapperNames); void addProtocolMappers(Set<String> mapperIds);
void removeProtocolMappers(Set<String> mapperNames); void removeProtocolMappers(Set<String> mapperIds);
void setProtocolMappers(Set<String> mapperNames); void setProtocolMappers(Set<String> mapperIds);
} }
Expand Up @@ -249,7 +249,7 @@ interface OAuthClientCreationEvent extends ClientCreationEvent {
void removeProtocolMapper(ProtocolMapperModel mapping); void removeProtocolMapper(ProtocolMapperModel mapping);
void updateProtocolMapper(ProtocolMapperModel mapping); void updateProtocolMapper(ProtocolMapperModel mapping);
public ProtocolMapperModel getProtocolMapperById(String id); public ProtocolMapperModel getProtocolMapperById(String id);
public ProtocolMapperModel getProtocolMapperByName(String name); public ProtocolMapperModel getProtocolMapperByName(String protocol, String name);




} }
Expand Up @@ -19,6 +19,7 @@
import org.keycloak.representations.idm.ApplicationRepresentation; import org.keycloak.representations.idm.ApplicationRepresentation;
import org.keycloak.representations.idm.ClaimRepresentation; import org.keycloak.representations.idm.ClaimRepresentation;
import org.keycloak.representations.idm.ClaimTypeRepresentation; import org.keycloak.representations.idm.ClaimTypeRepresentation;
import org.keycloak.representations.idm.ClientProtocolMappingRepresentation;
import org.keycloak.representations.idm.CredentialRepresentation; import org.keycloak.representations.idm.CredentialRepresentation;
import org.keycloak.representations.idm.FederatedIdentityRepresentation; import org.keycloak.representations.idm.FederatedIdentityRepresentation;
import org.keycloak.representations.idm.IdentityProviderRepresentation; import org.keycloak.representations.idm.IdentityProviderRepresentation;
Expand Down Expand Up @@ -266,8 +267,12 @@ public static ApplicationRepresentation toRepresentation(ApplicationModel applic
} }


if (!applicationModel.getProtocolMappers().isEmpty()) { if (!applicationModel.getProtocolMappers().isEmpty()) {
Set<String> mappings = new HashSet<String>(); List<ClientProtocolMappingRepresentation> mappings = new LinkedList<ClientProtocolMappingRepresentation>();
for (ProtocolMapperModel model : applicationModel.getProtocolMappers()) mappings.add(model.getName()); for (ProtocolMapperModel model : applicationModel.getProtocolMappers()) {
ClientProtocolMappingRepresentation map = new ClientProtocolMappingRepresentation();
map.setProtocol(model.getProtocol());
map.setName(model.getName());
}
rep.setProtocolMappers(mappings); rep.setProtocolMappers(mappings);
} }


Expand Down Expand Up @@ -301,10 +306,15 @@ public static OAuthClientRepresentation toRepresentation(OAuthClientModel model)
} }


if (!model.getProtocolMappers().isEmpty()) { if (!model.getProtocolMappers().isEmpty()) {
Set<String> mappings = new HashSet<String>(); List<ClientProtocolMappingRepresentation> mappings = new LinkedList<ClientProtocolMappingRepresentation>();
for (ProtocolMapperModel mappingModel : model.getProtocolMappers()) mappings.add(mappingModel.getName()); for (ProtocolMapperModel mapping : model.getProtocolMappers()) {
rep.setProtocolClaimMappings(mappings); ClientProtocolMappingRepresentation map = new ClientProtocolMappingRepresentation();
map.setProtocol(mapping.getProtocol());
map.setName(mapping.getName());
}
rep.setProtocolMappers(mappings);
} }

return rep; return rep;
} }


Expand Down
Expand Up @@ -23,6 +23,7 @@
import org.keycloak.representations.idm.ApplicationRepresentation; import org.keycloak.representations.idm.ApplicationRepresentation;
import org.keycloak.representations.idm.ClaimRepresentation; import org.keycloak.representations.idm.ClaimRepresentation;
import org.keycloak.representations.idm.ClaimTypeRepresentation; import org.keycloak.representations.idm.ClaimTypeRepresentation;
import org.keycloak.representations.idm.ClientProtocolMappingRepresentation;
import org.keycloak.representations.idm.CredentialRepresentation; import org.keycloak.representations.idm.CredentialRepresentation;
import org.keycloak.representations.idm.FederatedIdentityRepresentation; import org.keycloak.representations.idm.FederatedIdentityRepresentation;
import org.keycloak.representations.idm.IdentityProviderRepresentation; import org.keycloak.representations.idm.IdentityProviderRepresentation;
Expand Down Expand Up @@ -461,9 +462,18 @@ public static ApplicationModel createApplication(RealmModel realm, ApplicationRe
} }


if (resourceRep.getProtocolMappers() != null) { if (resourceRep.getProtocolMappers() != null) {
applicationModel.setProtocolMappers(resourceRep.getProtocolMappers()); Set<String> ids = new HashSet<String>();
for (ClientProtocolMappingRepresentation map : resourceRep.getProtocolMappers()) {
ProtocolMapperModel mapperModel = applicationModel.getRealm().getProtocolMapperByName(map.getProtocol(), map.getName());
if (mapperModel != null) {
ids.add(mapperModel.getId());
}

}
applicationModel.setProtocolMappers(ids);
} }



return applicationModel; return applicationModel;
} }


Expand Down Expand Up @@ -637,8 +647,16 @@ public static void updateOAuthClient(OAuthClientRepresentation rep, OAuthClientM
model.updateAllowedIdentityProviders(rep.getAllowedIdentityProviders()); model.updateAllowedIdentityProviders(rep.getAllowedIdentityProviders());
} }


if (rep.getProtocolClaimMappings() != null) { if (rep.getProtocolMappers() != null) {
model.addProtocolMappers(rep.getProtocolClaimMappings()); Set<String> ids = new HashSet<String>();
for (ClientProtocolMappingRepresentation map : rep.getProtocolMappers()) {
ProtocolMapperModel mapperModel = model.getRealm().getProtocolMapperByName(map.getProtocol(), map.getName());
if (mapperModel != null) {
ids.add(mapperModel.getId());
}

}
model.setProtocolMappers(ids);
} }


} }
Expand Down Expand Up @@ -777,7 +795,7 @@ private static void importProtocolMappers(RealmRepresentation rep, RealmModel ne
// we make sure we don't recreate mappers that are automatically created by the protocol providers. // we make sure we don't recreate mappers that are automatically created by the protocol providers.
Set<ProtocolMapperModel> mappers = newRealm.getProtocolMappers(); Set<ProtocolMapperModel> mappers = newRealm.getProtocolMappers();
for (ProtocolMapperRepresentation representation : rep.getProtocolMappers()) { for (ProtocolMapperRepresentation representation : rep.getProtocolMappers()) {
ProtocolMapperModel existing = newRealm.getProtocolMapperByName(representation.getName()); ProtocolMapperModel existing = newRealm.getProtocolMapperByName(representation.getProtocol(), representation.getName());
if (existing == null) { if (existing == null) {
newRealm.addProtocolMapper(toModel(representation)); newRealm.addProtocolMapper(toModel(representation));
} else { } else {
Expand Down
Expand Up @@ -920,9 +920,9 @@ public ProtocolMapperModel getProtocolMapperById(String id) {
} }


@Override @Override
public ProtocolMapperModel getProtocolMapperByName(String name) { public ProtocolMapperModel getProtocolMapperByName(String protocol, String name) {
for (ProtocolMapperModel mapping : cached.getClaimMappings()) { for (ProtocolMapperModel mapping : cached.getClaimMappings()) {
if (mapping.getName().equals(name)) return mapping; if (mapping.getProtocol().equals(protocol) && mapping.getName().equals(name)) return mapping;
} }
return null; return null;
} }
Expand Down
25 changes: 13 additions & 12 deletions model/jpa/src/main/java/org/keycloak/models/jpa/ClientAdapter.java
Expand Up @@ -380,9 +380,10 @@ public Set<ProtocolMapperModel> getProtocolMappers() {
return mappings; return mappings;
} }


protected ProtocolMapperEntity findProtocolMapperByName(String name) { protected ProtocolMapperEntity findProtocolMapperByName(String protocol, String name) {
TypedQuery<ProtocolMapperEntity> query = em.createNamedQuery("getProtocolMapperByName", ProtocolMapperEntity.class); TypedQuery<ProtocolMapperEntity> query = em.createNamedQuery("getProtocolMapperByNameProtocol", ProtocolMapperEntity.class);
query.setParameter("name", name); query.setParameter("name", name);
query.setParameter("protocol", protocol);
query.setParameter("realm", entity.getRealm()); query.setParameter("realm", entity.getRealm());
List<ProtocolMapperEntity> entities = query.getResultList(); List<ProtocolMapperEntity> entities = query.getResultList();
if (entities.size() == 0) return null; if (entities.size() == 0) return null;
Expand All @@ -396,11 +397,11 @@ public void addProtocolMappers(Set<String> mappings) {
Collection<ProtocolMapperEntity> entities = entity.getProtocolMappers(); Collection<ProtocolMapperEntity> entities = entity.getProtocolMappers();
Set<String> already = new HashSet<String>(); Set<String> already = new HashSet<String>();
for (ProtocolMapperEntity rel : entities) { for (ProtocolMapperEntity rel : entities) {
already.add(rel.getName()); already.add(rel.getId());
} }
for (String name : mappings) { for (String id : mappings) {
if (!already.contains(name)) { if (!already.contains(id)) {
ProtocolMapperEntity mapping = findProtocolMapperByName(name); ProtocolMapperEntity mapping = em.find(ProtocolMapperEntity.class, id);
if (mapping != null) { if (mapping != null) {
entities.add(mapping); entities.add(mapping);
} }
Expand All @@ -414,7 +415,7 @@ public void removeProtocolMappers(Set<String> mappings) {
Collection<ProtocolMapperEntity> entities = entity.getProtocolMappers(); Collection<ProtocolMapperEntity> entities = entity.getProtocolMappers();
List<ProtocolMapperEntity> remove = new LinkedList<ProtocolMapperEntity>(); List<ProtocolMapperEntity> remove = new LinkedList<ProtocolMapperEntity>();
for (ProtocolMapperEntity rel : entities) { for (ProtocolMapperEntity rel : entities) {
if (mappings.contains(rel.getName())) remove.add(rel); if (mappings.contains(rel.getId())) remove.add(rel);
} }
for (ProtocolMapperEntity entity : remove) { for (ProtocolMapperEntity entity : remove) {
entities.remove(entity); entities.remove(entity);
Expand All @@ -428,15 +429,15 @@ public void setProtocolMappers(Set<String> mappings) {
Set<String> already = new HashSet<String>(); Set<String> already = new HashSet<String>();
while (it.hasNext()) { while (it.hasNext()) {
ProtocolMapperEntity mapper = it.next(); ProtocolMapperEntity mapper = it.next();
if (mappings.contains(mapper.getName())) { if (mappings.contains(mapper.getId())) {
already.add(mapper.getName()); already.add(mapper.getId());
continue; continue;
} }
it.remove(); it.remove();
} }
for (String name : mappings) { for (String id : mappings) {
if (!already.contains(name)) { if (!already.contains(id)) {
ProtocolMapperEntity mapping = findProtocolMapperByName(name); ProtocolMapperEntity mapping = em.find(ProtocolMapperEntity.class, id);
if (mapping != null) { if (mapping != null) {
entities.add(mapping); entities.add(mapping);
} }
Expand Down
Expand Up @@ -632,7 +632,7 @@ public void addDefaultClientProtocolMappers(ClientModel client) {
Set<String> adding = new HashSet<String>(); Set<String> adding = new HashSet<String>();
for (ProtocolMapperEntity mapper : realm.getProtocolMappers()) { for (ProtocolMapperEntity mapper : realm.getProtocolMappers()) {
if (mapper.isAppliedByDefault()) { if (mapper.isAppliedByDefault()) {
adding.add(mapper.getName()); adding.add(mapper.getId());
} }
} }
client.setProtocolMappers(adding); client.setProtocolMappers(adding);
Expand Down Expand Up @@ -1295,8 +1295,8 @@ public Set<ProtocolMapperModel> getProtocolMappers() {


@Override @Override
public ProtocolMapperModel addProtocolMapper(ProtocolMapperModel model) { public ProtocolMapperModel addProtocolMapper(ProtocolMapperModel model) {
if (getProtocolMapperByName(model.getName()) != null) { if (getProtocolMapperByName(model.getProtocol(), model.getName()) != null) {
throw new RuntimeException("Duplicate protocol mapper with name: " + model.getName()); throw new RuntimeException("protocol mapper name must be unique per protocol");
} }
String id = KeycloakModelUtils.generateId(); String id = KeycloakModelUtils.generateId();
ProtocolMapperEntity entity = new ProtocolMapperEntity(); ProtocolMapperEntity entity = new ProtocolMapperEntity();
Expand Down Expand Up @@ -1325,9 +1325,9 @@ protected ProtocolMapperEntity getProtocolMapperEntity(String id) {


} }


protected ProtocolMapperEntity getProtocolMapperEntityByName(String name) { protected ProtocolMapperEntity getProtocolMapperEntityByName(String protocol, String name) {
for (ProtocolMapperEntity entity : realm.getProtocolMappers()) { for (ProtocolMapperEntity entity : realm.getProtocolMappers()) {
if (entity.getName().equals(name)) { if (entity.getProtocol().equals(protocol) && entity.getName().equals(name)) {
return entity; return entity;
} }
} }
Expand Down Expand Up @@ -1370,8 +1370,8 @@ public ProtocolMapperModel getProtocolMapperById(String id) {
} }


@Override @Override
public ProtocolMapperModel getProtocolMapperByName(String name) { public ProtocolMapperModel getProtocolMapperByName(String protocol, String name) {
ProtocolMapperEntity entity = getProtocolMapperEntityByName(name); ProtocolMapperEntity entity = getProtocolMapperEntityByName(protocol, name);
if (entity == null) return null; if (entity == null) return null;
return entityToModel(entity); return entityToModel(entity);
} }
Expand Down
Expand Up @@ -20,7 +20,7 @@
*/ */
@Entity @Entity
@NamedQueries({ @NamedQueries({
@NamedQuery(name="getProtocolMapperByName", query="select mapper from ProtocolMapperEntity mapper where mapper.name = :name and mapper.realm = :realm") @NamedQuery(name="getProtocolMapperByNameProtocol", query="select mapper from ProtocolMapperEntity mapper where mapper.protocol = :protocol and mapper.name = :name and mapper.realm = :realm")
}) })
@Table(name="PROTOCOL_MAPPER") @Table(name="PROTOCOL_MAPPER")
public class ProtocolMapperEntity { public class ProtocolMapperEntity {
Expand Down
Expand Up @@ -303,22 +303,22 @@ public Set<ProtocolMapperModel> getProtocolMappers() {
} }


@Override @Override
public void addProtocolMappers(Set<String> mapperNames) { public void addProtocolMappers(Set<String> mapperIds) {
getMongoEntityAsClient().getProtocolMappers().addAll(mapperNames); getMongoEntityAsClient().getProtocolMappers().addAll(mapperIds);
updateMongoEntity(); updateMongoEntity();


} }


@Override @Override
public void removeProtocolMappers(Set<String> mapperNames) { public void removeProtocolMappers(Set<String> mapperIds) {
getMongoEntityAsClient().getProtocolMappers().removeAll(mapperNames); getMongoEntityAsClient().getProtocolMappers().removeAll(mapperIds);
updateMongoEntity(); updateMongoEntity();
} }


@Override @Override
public void setProtocolMappers(Set<String> mapperNames) { public void setProtocolMappers(Set<String> mapperIds) {
getMongoEntityAsClient().getProtocolMappers().clear(); getMongoEntityAsClient().getProtocolMappers().clear();
getMongoEntityAsClient().getProtocolMappers().addAll(mapperNames); getMongoEntityAsClient().getProtocolMappers().addAll(mapperIds);
updateMongoEntity(); updateMongoEntity();
} }


Expand Down
Expand Up @@ -619,7 +619,7 @@ public List<ApplicationModel> getApplications() {
public void addDefaultClientProtocolMappers(ClientModel client) { public void addDefaultClientProtocolMappers(ClientModel client) {
Set<String> adding = new HashSet<String>(); Set<String> adding = new HashSet<String>();
for (ProtocolMapperEntity mapper : realm.getProtocolMappers()) { for (ProtocolMapperEntity mapper : realm.getProtocolMappers()) {
if (mapper.isAppliedByDefault()) adding.add(mapper.getName()); if (mapper.isAppliedByDefault()) adding.add(mapper.getId());
} }
client.setProtocolMappers(adding); client.setProtocolMappers(adding);


Expand Down Expand Up @@ -820,6 +820,9 @@ public Set<ProtocolMapperModel> getProtocolMappers() {


@Override @Override
public ProtocolMapperModel addProtocolMapper(ProtocolMapperModel model) { public ProtocolMapperModel addProtocolMapper(ProtocolMapperModel model) {
if (getProtocolMapperByName(model.getProtocol(), model.getName()) != null) {
throw new RuntimeException("protocol mapper name must be unique per protocol");
}
ProtocolMapperEntity entity = new ProtocolMapperEntity(); ProtocolMapperEntity entity = new ProtocolMapperEntity();
entity.setId(KeycloakModelUtils.generateId()); entity.setId(KeycloakModelUtils.generateId());
entity.setProtocol(model.getProtocol()); entity.setProtocol(model.getProtocol());
Expand Down Expand Up @@ -855,9 +858,9 @@ protected ProtocolMapperEntity getProtocolMapperyEntityById(String id) {
return null; return null;


} }
protected ProtocolMapperEntity getProtocolMapperyEntityByName(String name) { protected ProtocolMapperEntity getProtocolMapperEntityByName(String protocol, String name) {
for (ProtocolMapperEntity entity : realm.getProtocolMappers()) { for (ProtocolMapperEntity entity : realm.getProtocolMappers()) {
if (entity.getName().equals(name)) { if (entity.getProtocol().equals(protocol) && entity.getName().equals(name)) {
return entity; return entity;
} }
} }
Expand Down Expand Up @@ -891,8 +894,8 @@ public ProtocolMapperModel getProtocolMapperById(String id) {
} }


@Override @Override
public ProtocolMapperModel getProtocolMapperByName(String name) { public ProtocolMapperModel getProtocolMapperByName(String protocol, String name) {
ProtocolMapperEntity entity = getProtocolMapperyEntityById(name); ProtocolMapperEntity entity = getProtocolMapperEntityByName(protocol, name);
if (entity == null) return null; if (entity == null) return null;
return entityToModel(entity); return entityToModel(entity);
} }
Expand Down

0 comments on commit 9f759ed

Please sign in to comment.