Skip to content

Commit

Permalink
Add error message in JSON response (#54389) (#54656)
Browse files Browse the repository at this point in the history
When the SAML authentication is not successful, we return a SAML
Response with a status that indicates a failure. This commit adds
an error message in the REST API response along with the SAML
Response XML string so that the caller of the API can identify
that this is an unsuccessful response without needing to parse the
XML.
  • Loading branch information
jkakavas committed Apr 2, 2020
1 parent afbd8e6 commit ddab3f7
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ private String initSso(String entityId, String acs, UsernamePasswordToken second

final Map<String, Object> map = entityAsMap(response);
assertThat(map, notNullValue());
assertThat(map.keySet(), containsInAnyOrder("post_url", "saml_response", "service_provider"));
assertThat(map.keySet(), containsInAnyOrder("post_url", "saml_response", "saml_status", "service_provider", "error"));
assertThat(map.get("post_url"), equalTo(acs));
assertThat(map.get("saml_response"), instanceOf(String.class));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.idp.saml.support.SamlAuthenticationState;
Expand Down Expand Up @@ -42,15 +41,6 @@ public ActionRequestValidationException validate() {
if (Strings.isNullOrEmpty(assertionConsumerService)) {
validationException = addValidationError("acs is missing", validationException);
}
if (samlAuthenticationState != null) {
final ValidationException authnStateException = samlAuthenticationState.validate();
if (authnStateException != null && authnStateException.validationErrors().isEmpty() == false) {
if (validationException == null) {
validationException = new ActionRequestValidationException();
}
validationException.addValidationErrors(authnStateException.validationErrors());
}
}
return validationException;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.idp.action;

import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;

Expand All @@ -16,18 +17,25 @@ public class SamlInitiateSingleSignOnResponse extends ActionResponse {
private String postUrl;
private String samlResponse;
private String entityId;
private String samlStatus;
private String error;

public SamlInitiateSingleSignOnResponse(StreamInput in) throws IOException {
super(in);
this.entityId = in.readString();
this.postUrl = in.readString();
this.samlResponse = in.readString();
this.entityId = in.readString();
this.samlStatus = in.readString();
this.error = in.readOptionalString();
}

public SamlInitiateSingleSignOnResponse(String postUrl, String samlResponse, String entityId) {
public SamlInitiateSingleSignOnResponse(String entityId, String postUrl, String samlResponse, String samlStatus,
@Nullable String error) {
this.entityId = entityId;
this.postUrl = postUrl;
this.samlResponse = samlResponse;
this.entityId = entityId;
this.samlStatus = samlStatus;
this.error = error;
}

public String getPostUrl() {
Expand All @@ -42,10 +50,20 @@ public String getEntityId() {
return entityId;
}

public String getError() {
return error;
}

public String getSamlStatus() {
return samlStatus;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(entityId);
out.writeString(postUrl);
out.writeString(samlResponse);
out.writeString(entityId);
out.writeString(samlStatus);
out.writeOptionalString(error);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,50 +62,53 @@ protected void doExecute(Task task, SamlInitiateSingleSignOnRequest request,
request.getAssertionConsumerService(),
false,
ActionListener.wrap(
sp -> {
if (null == sp) {
final String message = "Service Provider with Entity ID [" + request.getSpEntityId() + "] and ACS ["
+ request.getAssertionConsumerService() + "] is not known to this Identity Provider";
logger.debug(message);
possiblyReplyWithSamlFailure(authenticationState, StatusCode.RESPONDER, new IllegalArgumentException(message),
listener);
return;
}
final SecondaryAuthentication secondaryAuthentication = SecondaryAuthentication.readFromContext(securityContext);
if (secondaryAuthentication == null) {
possiblyReplyWithSamlFailure(authenticationState,
StatusCode.REQUESTER,
new ElasticsearchSecurityException("Request is missing secondary authentication", RestStatus.FORBIDDEN),
listener);
return;
}
buildUserFromAuthentication(secondaryAuthentication, sp, ActionListener.wrap(
user -> {
if (user == null) {
possiblyReplyWithSamlFailure(authenticationState,
StatusCode.REQUESTER,
new ElasticsearchSecurityException("User [{}] is not permitted to access service [{}]",
RestStatus.FORBIDDEN, secondaryAuthentication.getUser(), sp),
listener);
return;
}
final SuccessfulAuthenticationResponseMessageBuilder builder =
new SuccessfulAuthenticationResponseMessageBuilder(samlFactory, Clock.systemUTC(), identityProvider);
try {
final Response response = builder.build(user, authenticationState);
listener.onResponse(new SamlInitiateSingleSignOnResponse(
user.getServiceProvider().getAssertionConsumerService().toString(),
samlFactory.getXmlContent(response),
user.getServiceProvider().getEntityId()));
} catch (ElasticsearchException e) {
listener.onFailure(e);
}
},
e -> possiblyReplyWithSamlFailure(authenticationState, StatusCode.RESPONDER, e, listener)
));
},
e -> possiblyReplyWithSamlFailure(authenticationState, StatusCode.RESPONDER, e, listener)
));
sp -> {
if (null == sp) {
final String message = "Service Provider with Entity ID [" + request.getSpEntityId() + "] and ACS ["
+ request.getAssertionConsumerService() + "] is not known to this Identity Provider";
possiblyReplyWithSamlFailure(authenticationState, request.getSpEntityId(), request.getAssertionConsumerService(),
StatusCode.RESPONDER, new IllegalArgumentException(message), listener);
return;
}
final SecondaryAuthentication secondaryAuthentication = SecondaryAuthentication.readFromContext(securityContext);
if (secondaryAuthentication == null) {
possiblyReplyWithSamlFailure(authenticationState, request.getSpEntityId(), request.getAssertionConsumerService(),
StatusCode.REQUESTER,
new ElasticsearchSecurityException("Request is missing secondary authentication", RestStatus.FORBIDDEN),
listener);
return;
}
buildUserFromAuthentication(secondaryAuthentication, sp, ActionListener.wrap(
user -> {
if (user == null) {
possiblyReplyWithSamlFailure(authenticationState, request.getSpEntityId(),
request.getAssertionConsumerService(), StatusCode.REQUESTER,
new ElasticsearchSecurityException("User [{}] is not permitted to access service [{}]",
RestStatus.FORBIDDEN, secondaryAuthentication.getUser().principal(), sp.getEntityId()),
listener);
return;
}
final SuccessfulAuthenticationResponseMessageBuilder builder =
new SuccessfulAuthenticationResponseMessageBuilder(samlFactory, Clock.systemUTC(), identityProvider);
try {
final Response response = builder.build(user, authenticationState);
listener.onResponse(new SamlInitiateSingleSignOnResponse(
user.getServiceProvider().getEntityId(),
user.getServiceProvider().getAssertionConsumerService().toString(),
samlFactory.getXmlContent(response),
StatusCode.SUCCESS,
null));
} catch (ElasticsearchException e) {
listener.onFailure(e);
}
},
e -> possiblyReplyWithSamlFailure(authenticationState, request.getSpEntityId(),
request.getAssertionConsumerService(), StatusCode.RESPONDER, e, listener)
));
},
e -> possiblyReplyWithSamlFailure(authenticationState, request.getSpEntityId(), request.getAssertionConsumerService(),
StatusCode.RESPONDER, e, listener)
));
}

private void buildUserFromAuthentication(SecondaryAuthentication secondaryAuthentication, SamlServiceProvider serviceProvider,
Expand All @@ -129,20 +132,23 @@ private void buildUserFromAuthentication(SecondaryAuthentication secondaryAuthen
);
}

private void possiblyReplyWithSamlFailure(SamlAuthenticationState authenticationState, String statusCode, Exception e,
private void possiblyReplyWithSamlFailure(SamlAuthenticationState authenticationState, String spEntityId,
String acsUrl, String statusCode, Exception e,
ActionListener<SamlInitiateSingleSignOnResponse> listener) {
logger.debug("Failed to generate a successful SAML response: ", e);
if (authenticationState != null) {
final FailedAuthenticationResponseMessageBuilder builder =
new FailedAuthenticationResponseMessageBuilder(samlFactory, Clock.systemUTC(), identityProvider)
.setInResponseTo(authenticationState.getAuthnRequestId())
.setAcsUrl(authenticationState.getRequestedAcsUrl())
.setAcsUrl(acsUrl)
.setPrimaryStatusCode(statusCode);
final Response response = builder.build();
//TODO: Log and indicate SAML Response status is failure in the response
listener.onResponse(new SamlInitiateSingleSignOnResponse(
authenticationState.getRequestedAcsUrl(),
spEntityId,
acsUrl,
samlFactory.getXmlContent(response),
authenticationState.getEntityId()));
statusCode,
e.getMessage()));
} else {
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ private void validateAuthnRequest(AuthnRequest authnRequest, SamlServiceProvider
checkDestination(authnRequest);
final String acs = checkAcs(authnRequest, sp, authnState);
validateNameIdPolicy(authnRequest, sp, authnState);
authnState.put(SamlAuthenticationState.Fields.ENTITY_ID.getPreferredName(), sp.getEntityId());
authnState.put(SamlAuthenticationState.Fields.AUTHN_REQUEST_ID.getPreferredName(), authnRequest.getID());
final SamlValidateAuthnRequestResponse response = new SamlValidateAuthnRequestResponse(sp.getEntityId(), acs,
authnRequest.isForceAuthn(), authnState);
Expand Down Expand Up @@ -268,7 +267,6 @@ private String checkAcs(AuthnRequest request, SamlServiceProvider sp, Map<String
throw new ElasticsearchSecurityException("The registered ACS URL for this Service Provider is [{}] but the authentication " +
"request contained [{}]", RestStatus.BAD_REQUEST, sp.getAssertionConsumerService(), acs);
}
authnState.put(SamlAuthenticationState.Fields.ACS_URL.getPreferredName(), acs);
return acs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public RestResponse buildResponse(SamlInitiateSingleSignOnResponse response, XCo
builder.startObject();
builder.field("post_url", response.getPostUrl());
builder.field("saml_response", response.getSamlResponse());
builder.field("saml_status", response.getSamlStatus());
builder.field("error", response.getError());
builder.startObject("service_provider");
builder.field("entity_id", response.getEntityId());
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -27,8 +26,6 @@
* these.
*/
public class SamlAuthenticationState implements Writeable, ToXContentObject {
private String entityId;
private String requestedAcsUrl;
@Nullable
private String requestedNameidFormat;
@Nullable
Expand All @@ -39,8 +36,6 @@ public SamlAuthenticationState() {
}

public SamlAuthenticationState(StreamInput in) throws IOException {
entityId = in.readString();
requestedAcsUrl = in.readString();
requestedNameidFormat = in.readOptionalString();
authnRequestId = in.readOptionalString();
}
Expand All @@ -61,39 +56,10 @@ public void setAuthnRequestId(String authnRequestId) {
this.authnRequestId = authnRequestId;
}

public String getEntityId() {
return entityId;
}

public void setEntityId(String entityId) {
this.entityId = entityId;
}

public String getRequestedAcsUrl() {
return requestedAcsUrl;
}

public void setRequestedAcsUrl(String requestedAcsUrl) {
this.requestedAcsUrl = requestedAcsUrl;
}

public ValidationException validate() {
final ValidationException validation = new ValidationException();
if (Strings.isNullOrEmpty(entityId)) {
validation.addValidationError("field [" + Fields.ENTITY_ID + "] is required, but was [" + entityId + "]");
}
if (Strings.isNullOrEmpty(requestedAcsUrl)) {
validation.addValidationError("field [" + Fields.ACS_URL + "] is required, but was [" + requestedAcsUrl + "]");
}
return validation;
}

public static final ObjectParser<SamlAuthenticationState, SamlAuthenticationState> PARSER
= new ObjectParser<>("saml_authn_state", true, SamlAuthenticationState::new);

static {
PARSER.declareString(SamlAuthenticationState::setEntityId, Fields.ENTITY_ID);
PARSER.declareString(SamlAuthenticationState::setRequestedAcsUrl, Fields.ACS_URL);
PARSER.declareStringOrNull(SamlAuthenticationState::setRequestedNameidFormat, Fields.NAMEID_FORMAT);
PARSER.declareStringOrNull(SamlAuthenticationState::setAuthnRequestId, Fields.AUTHN_REQUEST_ID);
}
Expand All @@ -103,8 +69,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.field(Fields.NAMEID_FORMAT.getPreferredName(), requestedNameidFormat);
builder.field(Fields.AUTHN_REQUEST_ID.getPreferredName(), authnRequestId);
builder.field(Fields.ENTITY_ID.getPreferredName(), entityId);
builder.field(Fields.ACS_URL.getPreferredName(), requestedAcsUrl);
return builder.endObject();
}

Expand All @@ -116,14 +80,10 @@ public static SamlAuthenticationState fromXContent(XContentParser parser) throws
public interface Fields {
ParseField NAMEID_FORMAT = new ParseField("nameid_format");
ParseField AUTHN_REQUEST_ID = new ParseField("authn_request_id");
ParseField ENTITY_ID = new ParseField("entity_id");
ParseField ACS_URL = new ParseField("acs_url");
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(entityId);
out.writeString(requestedAcsUrl);
out.writeOptionalString(requestedNameidFormat);
out.writeOptionalString(authnRequestId);
}
Expand All @@ -138,14 +98,12 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SamlAuthenticationState that = (SamlAuthenticationState) o;
return entityId.equals(that.entityId) &&
requestedAcsUrl.equals(that.requestedAcsUrl) &&
Objects.equals(requestedNameidFormat, that.requestedNameidFormat) &&
return Objects.equals(requestedNameidFormat, that.requestedNameidFormat) &&
Objects.equals(authnRequestId, that.authnRequestId);
}

@Override
public int hashCode() {
return Objects.hash(entityId, requestedAcsUrl, requestedNameidFormat, authnRequestId);
return Objects.hash(requestedNameidFormat, authnRequestId);
}
}

0 comments on commit ddab3f7

Please sign in to comment.