Skip to content

Commit

Permalink
KEYCLOAK-5234 (#4585)
Browse files Browse the repository at this point in the history
  • Loading branch information
stianst committed Oct 23, 2017
1 parent 20d0fa1 commit 9b75b60
Show file tree
Hide file tree
Showing 16 changed files with 318 additions and 102 deletions.
Expand Up @@ -46,7 +46,6 @@ public class AccountFederatedIdentityBean {


public AccountFederatedIdentityBean(KeycloakSession session, RealmModel realm, UserModel user, URI baseUri, String stateChecker) { public AccountFederatedIdentityBean(KeycloakSession session, RealmModel realm, UserModel user, URI baseUri, String stateChecker) {
this.session = session; this.session = session;
URI accountIdentityUpdateUri = Urls.accountFederatedIdentityUpdate(baseUri, realm.getName());


List<IdentityProviderModel> identityProviders = realm.getIdentityProviders(); List<IdentityProviderModel> identityProviders = realm.getIdentityProviders();
Set<FederatedIdentityModel> identities = session.users().getFederatedIdentities(user, realm); Set<FederatedIdentityModel> identities = session.users().getFederatedIdentities(user, realm);
Expand All @@ -63,15 +62,8 @@ public AccountFederatedIdentityBean(KeycloakSession session, RealmModel realm, U
availableIdentities++; availableIdentities++;
} }


String action = identity != null ? "remove" : "add";
String actionUrl = UriBuilder.fromUri(accountIdentityUpdateUri)
.queryParam("action", action)
.queryParam("provider_id", providerId)
.queryParam("stateChecker", stateChecker)
.build().toString();

String displayName = KeycloakModelUtils.getIdentityProviderDisplayName(session, provider); String displayName = KeycloakModelUtils.getIdentityProviderDisplayName(session, provider);
FederatedIdentityEntry entry = new FederatedIdentityEntry(identity, displayName, provider.getAlias(), provider.getAlias(), actionUrl, FederatedIdentityEntry entry = new FederatedIdentityEntry(identity, displayName, provider.getAlias(), provider.getAlias(),
provider.getConfig() != null ? provider.getConfig().get("guiOrder") : null); provider.getConfig() != null ? provider.getConfig().get("guiOrder") : null);
orderedSet.add(entry); orderedSet.add(entry);
} }
Expand Down Expand Up @@ -105,17 +97,15 @@ public class FederatedIdentityEntry {
private FederatedIdentityModel federatedIdentityModel; private FederatedIdentityModel federatedIdentityModel;
private final String providerId; private final String providerId;
private final String providerName; private final String providerName;
private final String actionUrl;
private final String guiOrder; private final String guiOrder;
private final String displayName; private final String displayName;


public FederatedIdentityEntry(FederatedIdentityModel federatedIdentityModel, String displayName, String providerId, public FederatedIdentityEntry(FederatedIdentityModel federatedIdentityModel, String displayName, String providerId,
String providerName, String actionUrl, String guiOrder) { String providerName, String guiOrder) {
this.federatedIdentityModel = federatedIdentityModel; this.federatedIdentityModel = federatedIdentityModel;
this.displayName = displayName; this.displayName = displayName;
this.providerId = providerId; this.providerId = providerId;
this.providerName = providerName; this.providerName = providerName;
this.actionUrl = actionUrl;
this.guiOrder = guiOrder; this.guiOrder = guiOrder;
} }


Expand All @@ -139,10 +129,6 @@ public boolean isConnected() {
return federatedIdentityModel != null; return federatedIdentityModel != null;
} }


public String getActionUrl() {
return actionUrl;
}

public String getGuiOrder() { public String getGuiOrder() {
return guiOrder; return guiOrder;
} }
Expand Down Expand Up @@ -186,4 +172,4 @@ private int parseOrder(FederatedIdentityEntry ip) {
return 10000; return 10000;
} }
} }
} }
Expand Up @@ -33,15 +33,13 @@ public class UrlBean {
private URI baseURI; private URI baseURI;
private URI baseQueryURI; private URI baseQueryURI;
private URI currentURI; private URI currentURI;
private String stateChecker;


public UrlBean(RealmModel realm, Theme theme, URI baseURI, URI baseQueryURI, URI currentURI, String stateChecker) { public UrlBean(RealmModel realm, Theme theme, URI baseURI, URI baseQueryURI, URI currentURI, String stateChecker) {
this.realm = realm.getName(); this.realm = realm.getName();
this.theme = theme; this.theme = theme;
this.baseURI = baseURI; this.baseURI = baseURI;
this.baseQueryURI = baseQueryURI; this.baseQueryURI = baseQueryURI;
this.currentURI = currentURI; this.currentURI = currentURI;
this.stateChecker = stateChecker;
} }


public String getApplicationsUrl() { public String getApplicationsUrl() {
Expand Down Expand Up @@ -73,15 +71,15 @@ public String getSessionsUrl() {
} }


public String getSessionsLogoutUrl() { public String getSessionsLogoutUrl() {
return Urls.accountSessionsLogoutPage(baseQueryURI, realm, stateChecker).toString(); return Urls.accountSessionsLogoutPage(baseQueryURI, realm).toString();
} }


public String getRevokeClientUrl() { public String getRevokeClientUrl() {
return Urls.accountRevokeClientPage(baseQueryURI, realm).toString(); return Urls.accountRevokeClientPage(baseQueryURI, realm).toString();
} }


public String getTotpRemoveUrl() { public String getTotpRemoveUrl() {
return Urls.accountTotpRemove(baseQueryURI, realm, stateChecker).toString(); return Urls.accountTotpRemove(baseQueryURI, realm).toString();
} }


public String getLogoutUrl() { public String getLogoutUrl() {
Expand Down
6 changes: 2 additions & 4 deletions services/src/main/java/org/keycloak/services/Urls.java
Expand Up @@ -126,9 +126,8 @@ public static URI accountTotpPage(URI baseUri, String realmName) {
return accountBase(baseUri).path(AccountFormService.class, "totpPage").build(realmName); return accountBase(baseUri).path(AccountFormService.class, "totpPage").build(realmName);
} }


public static URI accountTotpRemove(URI baseUri, String realmName, String stateChecker) { public static URI accountTotpRemove(URI baseUri, String realmName) {
return accountBase(baseUri).path(AccountFormService.class, "processTotpRemove") return accountBase(baseUri).path(AccountFormService.class, "processTotpRemove")
.queryParam("stateChecker", stateChecker)
.build(realmName); .build(realmName);
} }


Expand All @@ -140,9 +139,8 @@ public static URI accountSessionsPage(URI baseUri, String realmName) {
return accountBase(baseUri).path(AccountFormService.class, "sessionsPage").build(realmName); return accountBase(baseUri).path(AccountFormService.class, "sessionsPage").build(realmName);
} }


public static URI accountSessionsLogoutPage(URI baseUri, String realmName, String stateChecker) { public static URI accountSessionsLogoutPage(URI baseUri, String realmName) {
return accountBase(baseUri).path(AccountFormService.class, "processSessionsLogout") return accountBase(baseUri).path(AccountFormService.class, "processSessionsLogout")
.queryParam("stateChecker", stateChecker)
.build(realmName); .build(realmName);
} }


Expand Down
Expand Up @@ -572,6 +572,11 @@ public static String getRealmCookiePath(RealmModel realm, UriInfo uriInfo) {
return uri.getRawPath(); return uri.getRawPath();
} }


public static String getAccountCookiePath(RealmModel realm, UriInfo uriInfo) {
URI uri = RealmsResource.accountUrl(uriInfo.getBaseUriBuilder()).build(realm.getName());
return uri.getRawPath();
}

public static void expireCookie(RealmModel realm, String cookieName, String path, boolean httpOnly, ClientConnection connection) { public static void expireCookie(RealmModel realm, String cookieName, String path, boolean httpOnly, ClientConnection connection) {
logger.debugv("Expiring cookie: {0} path: {1}", cookieName, path); logger.debugv("Expiring cookie: {0} path: {1}", cookieName, path);
boolean secureOnly = realm.getSslRequired().isRequired(connection);; boolean secureOnly = realm.getSslRequired().isRequired(connection);;
Expand Down
Expand Up @@ -24,14 +24,12 @@
import org.keycloak.common.ClientConnection; import org.keycloak.common.ClientConnection;
import org.keycloak.common.util.Base64Url; import org.keycloak.common.util.Base64Url;
import org.keycloak.common.util.KeycloakUriBuilder; import org.keycloak.common.util.KeycloakUriBuilder;
import org.keycloak.common.util.UriUtils;
import org.keycloak.models.ClientModel; import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession; import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel; import org.keycloak.models.RealmModel;
import org.keycloak.models.utils.KeycloakModelUtils; import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.protocol.oidc.OIDCLoginProtocolService; import org.keycloak.protocol.oidc.OIDCLoginProtocolService;
import org.keycloak.services.ForbiddenException; import org.keycloak.services.ForbiddenException;
import org.keycloak.services.managers.AppAuthManager;
import org.keycloak.services.managers.Auth; import org.keycloak.services.managers.Auth;
import org.keycloak.services.managers.AuthenticationManager; import org.keycloak.services.managers.AuthenticationManager;
import org.keycloak.services.util.CookieHelper; import org.keycloak.services.util.CookieHelper;
Expand Down Expand Up @@ -130,14 +128,20 @@ public Response loginRedirect(@QueryParam("code") String code,
} }


protected void updateCsrfChecks() { protected void updateCsrfChecks() {
Cookie cookie = headers.getCookies().get(KEYCLOAK_STATE_CHECKER); stateChecker = getStateChecker();
if (cookie != null) { if (stateChecker == null) {
stateChecker = cookie.getValue();
} else {
stateChecker = Base64Url.encode(KeycloakModelUtils.generateSecret()); stateChecker = Base64Url.encode(KeycloakModelUtils.generateSecret());
String cookiePath = AuthenticationManager.getRealmCookiePath(realm, uriInfo);
StringBuilder sb = new StringBuilder();
sb.append(auth.getSession().getId());
sb.append("/");
sb.append(stateChecker);

String sessionCookieValue = sb.toString();

String cookiePath = AuthenticationManager.getAccountCookiePath(realm, uriInfo);
boolean secureOnly = realm.getSslRequired().isRequired(clientConnection); boolean secureOnly = realm.getSslRequired().isRequired(clientConnection);
CookieHelper.addCookie(KEYCLOAK_STATE_CHECKER, stateChecker, cookiePath, null, null, -1, secureOnly, true); CookieHelper.addCookie(KEYCLOAK_STATE_CHECKER, sessionCookieValue, cookiePath, null, null, -1, secureOnly, true);
} }
} }


Expand All @@ -149,25 +153,27 @@ protected void updateCsrfChecks() {
* @param formData * @param formData
*/ */
protected void csrfCheck(final MultivaluedMap<String, String> formData) { protected void csrfCheck(final MultivaluedMap<String, String> formData) {
if (!auth.isCookieAuthenticated()) return;
String stateChecker = formData.getFirst("stateChecker"); String stateChecker = formData.getFirst("stateChecker");
if (!this.stateChecker.equals(stateChecker)) { if (stateChecker == null || !stateChecker.equals(getStateChecker())) {
throw new ForbiddenException(); throw new ForbiddenException();
} }

} }


/** protected String getStateChecker() {
* Check to see if form post has sessionId hidden field and match it against the session id. Cookie cookie = headers.getCookies().get(KEYCLOAK_STATE_CHECKER);
* if (cookie != null) {
*/ stateChecker = cookie.getValue();
protected void csrfCheck(String stateChecker) { String[] s = stateChecker.split("/");
if (!auth.isCookieAuthenticated()) return; if (s.length == 2) {
if (auth.getSession() == null) return; String sessionId = s[0];
if (!this.stateChecker.equals(stateChecker)) { String stateChecker = s[1];
throw new ForbiddenException();
if (auth.getSession().getId().equals(sessionId)) {
return stateChecker;
}
}
} }

return null;
} }


protected abstract URI getBaseRedirectUri(); protected abstract URI getBaseRedirectUri();
Expand Down
Expand Up @@ -44,6 +44,7 @@
import org.keycloak.models.UserSessionModel; import org.keycloak.models.UserSessionModel;
import org.keycloak.models.utils.CredentialValidation; import org.keycloak.models.utils.CredentialValidation;
import org.keycloak.models.utils.FormMessage; import org.keycloak.models.utils.FormMessage;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.protocol.oidc.utils.RedirectUtils; import org.keycloak.protocol.oidc.utils.RedirectUtils;
import org.keycloak.services.ForbiddenException; import org.keycloak.services.ForbiddenException;
import org.keycloak.services.ServicesLogger; import org.keycloak.services.ServicesLogger;
Expand Down Expand Up @@ -343,15 +344,15 @@ public Response processAccountUpdate(final MultivaluedMap<String, String> formDa
} }


@Path("totp-remove") @Path("totp-remove")
@GET @POST
public Response processTotpRemove(@QueryParam("stateChecker") String stateChecker) { public Response processTotpRemove(final MultivaluedMap<String, String> formData) {
if (auth == null) { if (auth == null) {
return login("totp"); return login("totp");
} }


auth.require(AccountRoles.MANAGE_ACCOUNT); auth.require(AccountRoles.MANAGE_ACCOUNT);


csrfCheck(stateChecker); csrfCheck(formData);


UserModel user = auth.getUser(); UserModel user = auth.getUser();
session.userCredentialManager().disableCredentialType(realm, user, CredentialModel.OTP); session.userCredentialManager().disableCredentialType(realm, user, CredentialModel.OTP);
Expand All @@ -364,14 +365,14 @@ public Response processTotpRemove(@QueryParam("stateChecker") String stateChecke




@Path("sessions-logout") @Path("sessions-logout")
@GET @POST
public Response processSessionsLogout(@QueryParam("stateChecker") String stateChecker) { public Response processSessionsLogout(final MultivaluedMap<String, String> formData) {
if (auth == null) { if (auth == null) {
return login("sessions"); return login("sessions");
} }


auth.require(AccountRoles.MANAGE_ACCOUNT); auth.require(AccountRoles.MANAGE_ACCOUNT);
csrfCheck(stateChecker); csrfCheck(formData);


UserModel user = auth.getUser(); UserModel user = auth.getUser();


Expand Down Expand Up @@ -588,19 +589,21 @@ public Response processPasswordUpdate(final MultivaluedMap<String, String> formD
return account.setPasswordSet(true).setSuccess(Messages.ACCOUNT_PASSWORD_UPDATED).createResponse(AccountPages.PASSWORD); return account.setPasswordSet(true).setSuccess(Messages.ACCOUNT_PASSWORD_UPDATED).createResponse(AccountPages.PASSWORD);
} }


@Path("federated-identity-update") @Path("identity")
@GET @POST
public Response processFederatedIdentityUpdate(@QueryParam("action") String action, @Consumes(MediaType.APPLICATION_FORM_URLENCODED)
@QueryParam("provider_id") String providerId, public Response processFederatedIdentityUpdate(final MultivaluedMap<String, String> formData) {
@QueryParam("stateChecker") String stateChecker) {
if (auth == null) { if (auth == null) {
return login("identity"); return login("identity");
} }


auth.require(AccountRoles.MANAGE_ACCOUNT); auth.require(AccountRoles.MANAGE_ACCOUNT);
csrfCheck(stateChecker); csrfCheck(formData);
UserModel user = auth.getUser(); UserModel user = auth.getUser();


String action = formData.getFirst("action");
String providerId = formData.getFirst("providerId");

if (Validation.isEmpty(providerId)) { if (Validation.isEmpty(providerId)) {
setReferrerOnPage(); setReferrerOnPage();
return account.setError(Messages.MISSING_IDENTITY_PROVIDER).createResponse(AccountPages.FEDERATED_IDENTITY); return account.setError(Messages.MISSING_IDENTITY_PROVIDER).createResponse(AccountPages.FEDERATED_IDENTITY);
Expand Down
Expand Up @@ -22,6 +22,9 @@
import org.openqa.selenium.WebElement; import org.openqa.selenium.WebElement;
import org.openqa.selenium.support.FindBy; import org.openqa.selenium.support.FindBy;


import java.util.LinkedList;
import java.util.List;

/** /**
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a> * @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/ */
Expand Down Expand Up @@ -50,26 +53,72 @@ public String getPath() {
public boolean isCurrent() { public boolean isCurrent() {
return driver.getTitle().contains("Account Management") && driver.getPageSource().contains("Federated Identities"); return driver.getTitle().contains("Account Management") && driver.getPageSource().contains("Federated Identities");
} }


public WebElement findAddProviderButton(String alias) { public List<FederatedIdentity> getIdentities() {
return driver.findElement(By.id("add-" + alias)); List<FederatedIdentity> identities = new LinkedList<>();
WebElement identitiesElement = driver.findElement(By.id("federated-identities"));
for (WebElement i : identitiesElement.findElements(By.className("row"))) {

String providerId = i.findElement(By.tagName("label")).getText();
String subject = i.findElement(By.tagName("input")).getAttribute("value");
WebElement button = i.findElement(By.tagName("button"));

identities.add(new FederatedIdentity(providerId, subject, button));
}
return identities;
} }

public WebElement findRemoveProviderButton(String alias) { public WebElement findAddProvider(String providerId) {
return driver.findElement(By.id("remove-" + alias)); return driver.findElement(By.id("add-link-" + providerId));
} }


public void clickAddProvider(String alias) { public void clickAddProvider(String providerId) {
WebElement addButton = findAddProviderButton(alias); findAddProvider(providerId).click();
addButton.click();
} }


public void clickRemoveProvider(String alias) { public void clickRemoveProvider(String providerId) {
WebElement addButton = findRemoveProviderButton(alias); driver.findElement(By.id("remove-link-" + providerId)).click();
addButton.click();
} }


public String getError() { public String getError() {
return errorMessage.getText(); return errorMessage.getText();
} }

public static class FederatedIdentity {

private String providerId;
private String subject;
private WebElement action;

public FederatedIdentity(String providerId, String subject, WebElement action) {
this.providerId = providerId;
this.subject = subject;
this.action = action;
}

public String getProvider() {
return providerId;
}

public void setProviderId(String providerId) {
this.providerId = providerId;
}

public String getSubject() {
return subject;
}

public void setSubject(String subject) {
this.subject = subject;
}

public WebElement getAction() {
return action;
}

public void setAction(WebElement action) {
this.action = action;
}
}

} }

0 comments on commit 9b75b60

Please sign in to comment.