Skip to content

Commit

Permalink
Audit authn success in the SecurityRestFilter (#94120)
Browse files Browse the repository at this point in the history
Authentication for the HTTP channel is changing to not have access
to the contents of the HTTP request body.
But auditing of authentication success still requires access to the
request body. Consequently, auditing of authentication success
is now separated from the authentication logic: auditing is invoked
in the SecurityRestFilter#handleRequest, after the
AuthenticationService has done its job earlier in the handling
flow for the HTTP request.
  • Loading branch information
albertzaharovits committed Apr 3, 2023
1 parent a375393 commit f353be2
Show file tree
Hide file tree
Showing 14 changed files with 169 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,11 @@ Collection<Object> createComponents(
components.add(allRolesStore); // for SecurityInfoTransportAction and clear roles cache
components.add(authzService);

final SecondaryAuthenticator secondaryAuthenticator = new SecondaryAuthenticator(securityContext.get(), authcService.get());
final SecondaryAuthenticator secondaryAuthenticator = new SecondaryAuthenticator(
securityContext.get(),
authcService.get(),
auditTrailService
);
this.secondayAuthc.set(secondaryAuthenticator);
components.add(secondaryAuthenticator);

Expand Down Expand Up @@ -1655,6 +1659,7 @@ public UnaryOperator<RestHandler> getRestHandlerInterceptor(ThreadContext thread
threadContext,
authcService.get(),
secondayAuthc.get(),
auditTrailService.get(),
handler,
extractClientCertificate
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public interface AuditTrail {

String name();

void authenticationSuccess(String requestId, Authentication authentication, RestRequest request);
void authenticationSuccess(RestRequest request);

void authenticationSuccess(String requestId, Authentication authentication, String action, TransportRequest transportRequest);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public String name() {
}

@Override
public void authenticationSuccess(String requestId, Authentication authentication, RestRequest request) {}
public void authenticationSuccess(RestRequest request) {}

@Override
public void authenticationSuccess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonStringEncoder;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.action.Grant;
import org.elasticsearch.xpack.core.security.action.apikey.BaseUpdateApiKeyRequest;
import org.elasticsearch.xpack.core.security.action.apikey.BulkUpdateApiKeyAction;
Expand Down Expand Up @@ -95,6 +96,7 @@
import org.elasticsearch.xpack.security.Security;
import org.elasticsearch.xpack.security.audit.AuditLevel;
import org.elasticsearch.xpack.security.audit.AuditTrail;
import org.elasticsearch.xpack.security.audit.AuditUtil;
import org.elasticsearch.xpack.security.authc.ApiKeyService;
import org.elasticsearch.xpack.security.authc.service.ServiceAccountToken;
import org.elasticsearch.xpack.security.rest.RemoteHostHeader;
Expand All @@ -121,6 +123,7 @@
import java.util.stream.Stream;

import static java.util.Map.entry;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.security.SecurityField.setting;
import static org.elasticsearch.xpack.core.security.authc.service.ServiceAccountSettings.TOKEN_NAME_FIELD;
import static org.elasticsearch.xpack.core.security.authc.service.ServiceAccountSettings.TOKEN_SOURCE_FIELD;
Expand Down Expand Up @@ -366,6 +369,7 @@ public class LoggingAuditTrail implements AuditTrail, ClusterStateListener {

private final Logger logger;
private final ThreadContext threadContext;
private final SecurityContext securityContext;
final EventFilterPolicyRegistry eventFilterPolicyRegistry;
// package for testing
volatile EnumSet<AuditLevel> events;
Expand All @@ -387,6 +391,7 @@ public LoggingAuditTrail(Settings settings, ClusterService clusterService, Threa
this.events = parse(INCLUDE_EVENT_SETTINGS.get(settings), EXCLUDE_EVENT_SETTINGS.get(settings));
this.includeRequestBody = INCLUDE_REQUEST_BODY.get(settings);
this.threadContext = threadContext;
this.securityContext = new SecurityContext(settings, threadContext);
this.entryCommonFields = new EntryCommonFields(settings, null, clusterService);
this.eventFilterPolicyRegistry = new EventFilterPolicyRegistry(settings);
clusterService.addListener(this);
Expand Down Expand Up @@ -444,7 +449,24 @@ public LoggingAuditTrail(Settings settings, ClusterService clusterService, Threa
}

@Override
public void authenticationSuccess(String requestId, Authentication authentication, RestRequest request) {
public void authenticationSuccess(RestRequest request) {
final String requestId = AuditUtil.extractRequestId(securityContext.getThreadContext());
if (requestId == null) {
// should never happen
throw new ElasticsearchSecurityException("Authenticated context must include request id");
}
final Authentication authentication;
try {
authentication = securityContext.getAuthentication();
} catch (Exception e) {
logger.error(() -> format("caught exception while trying to read authentication from request [%s]", request), e);
tamperedRequest(requestId, request);
throw new ElasticsearchSecurityException("rest request attempted to inject a user", e);
}
if (authentication == null) {
// should never happen
throw new ElasticsearchSecurityException("Context is not authenticated");
}
if (events.contains(AUTHENTICATION_SUCCESS)
&& eventFilterPolicyRegistry.ignorePredicate()
.test(
Expand All @@ -468,7 +490,7 @@ public void authenticationSuccess(String requestId, Authentication authenticatio
.withAuthentication(authentication)
.withRestOrigin(threadContext)
.withRequestBody(request)
.withThreadContext(threadContext)
.withThreadContext(securityContext.getThreadContext())
.build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
*/
package org.elasticsearch.xpack.security.authc;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.cache.Cache;
Expand Down Expand Up @@ -67,7 +65,6 @@ public class AuthenticationService {
TimeValue.timeValueHours(1L),
Property.NodeScope
);
private static final Logger logger = LogManager.getLogger(AuthenticationService.class);

private final Realms realms;
private final AuditTrailService auditTrailService;
Expand Down Expand Up @@ -383,7 +380,14 @@ static class AuditableRestRequest extends AuditableRequest {

@Override
void authenticationSuccess(Authentication authentication) {
auditTrail.authenticationSuccess(requestId, authentication, request);
// REST requests are audited in the {@code SecurityRestFilter} because they need access to the request body
// see {@code AuditTrail#authenticationSuccess(RestRequest)}
// It's still valuable to keep the parent interface {@code AuditableRequest#AuthenticationSuccess(Authentication)} around
// in order to audit authN success for transport requests for CCS. We may be able to find another way to audit that, which
// doesn't rely on an `AuditableRequest` instance, but it's not trivial because we'd have to make sure to not audit
// existing authentications. Separately, it's not easy to reconstruct another `AuditableRequest` outside the
// `AuthenticationService` because that's tied to the audit `request.id` generation.
// For more context see: https://github.com/elastic/elasticsearch/pull/94120#discussion_r1152804133
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.support.SecondaryAuthentication;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
import org.elasticsearch.xpack.security.audit.AuditTrailService;
import org.elasticsearch.xpack.security.authc.AuthenticationService;

import java.util.function.Consumer;
Expand All @@ -39,14 +40,25 @@ public class SecondaryAuthenticator {
private final Logger logger = LogManager.getLogger(SecondaryAuthenticator.class);
private final SecurityContext securityContext;
private final AuthenticationService authenticationService;
private final AuditTrailService auditTrailService;

public SecondaryAuthenticator(Settings settings, ThreadContext threadContext, AuthenticationService authenticationService) {
this(new SecurityContext(settings, threadContext), authenticationService);
public SecondaryAuthenticator(
Settings settings,
ThreadContext threadContext,
AuthenticationService authenticationService,
AuditTrailService auditTrailService
) {
this(new SecurityContext(settings, threadContext), authenticationService, auditTrailService);
}

public SecondaryAuthenticator(SecurityContext securityContext, AuthenticationService authenticationService) {
public SecondaryAuthenticator(
SecurityContext securityContext,
AuthenticationService authenticationService,
AuditTrailService auditTrailService
) {
this.securityContext = securityContext;
this.authenticationService = authenticationService;
this.auditTrailService = auditTrailService;
}

/**
Expand Down Expand Up @@ -76,7 +88,10 @@ public void authenticateAndAttachToContext(RestRequest request, ActionListener<S
// Use cases for secondary authentication are far more likely to want to fall back to the primary authentication if no secondary
// auth is provided, so in that case we do no want to set anything in the context
authenticate(
authListener -> authenticationService.authenticate(request, false, authListener),
authListener -> authenticationService.authenticate(request, false, authListener.delegateFailure((l, authentication) -> {
auditTrailService.get().authenticationSuccess(request);
l.onResponse(authentication);
})),
ActionListener.wrap(secondaryAuthentication -> {
if (secondaryAuthentication != null) {
secondaryAuthentication.writeToContext(threadContext);
Expand Down Expand Up @@ -121,16 +136,4 @@ private void authenticate(Consumer<ActionListener<Authentication>> authenticate,
authenticate.accept(authenticationListener);
}
}

/**
* Checks whether this thread context provides secondary authentication credentials.
* This does not check whether the header contains valid credentials
* - you must call {@link #authenticateAndAttachToContext} to validate the header.
*
* @return {@code true} if a secondary authentication header exists in the thread context.
*/
public boolean hasSecondaryAuthenticationHeader() {
final String header = securityContext.getThreadContext().getHeader(SECONDARY_AUTH_HEADER_NAME);
return Strings.isNullOrEmpty(header) == false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.rest.RestRequestFilter;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.security.audit.AuditTrailService;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authc.support.SecondaryAuthenticator;
import org.elasticsearch.xpack.security.transport.SSLEngineUtils;
Expand All @@ -38,6 +39,7 @@ public class SecurityRestFilter implements RestHandler {
private final RestHandler restHandler;
private final AuthenticationService authenticationService;
private final SecondaryAuthenticator secondaryAuthenticator;
private final AuditTrailService auditTrailService;
private final boolean enabled;
private final ThreadContext threadContext;
private final boolean extractClientCertificate;
Expand All @@ -64,13 +66,15 @@ public SecurityRestFilter(
ThreadContext threadContext,
AuthenticationService authenticationService,
SecondaryAuthenticator secondaryAuthenticator,
AuditTrailService auditTrailService,
RestHandler restHandler,
boolean extractClientCertificate
) {
this.enabled = enabled;
this.threadContext = threadContext;
this.authenticationService = authenticationService;
this.secondaryAuthenticator = secondaryAuthenticator;
this.auditTrailService = auditTrailService;
this.restHandler = restHandler;
this.extractClientCertificate = extractClientCertificate;
}
Expand Down Expand Up @@ -102,15 +106,16 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
SSLEngineUtils.extractClientCertificates(logger, threadContext, httpChannel);
}

final RestRequest wrappedRequest = maybeWrapRestRequest(request);
RemoteHostHeader.process(request, threadContext);

authenticationService.authenticate(maybeWrapRestRequest(request), ActionListener.wrap(authentication -> {
authenticationService.authenticate(wrappedRequest, ActionListener.wrap(authentication -> {
if (authentication == null) {
logger.trace("No authentication available for REST request [{}]", request.uri());
} else {
logger.trace("Authenticated REST request [{}] as {}", request.uri(), authentication);
}
secondaryAuthenticator.authenticateAndAttachToContext(request, ActionListener.wrap(secondaryAuthentication -> {
auditTrailService.get().authenticationSuccess(wrappedRequest);
secondaryAuthenticator.authenticateAndAttachToContext(wrappedRequest, ActionListener.wrap(secondaryAuthentication -> {
if (secondaryAuthentication != null) {
logger.trace("Found secondary authentication {} in REST request [{}]", secondaryAuthentication, request.uri());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,10 @@ public void testAuthenticationSuccessRest() throws Exception {
.realmRef(new RealmRef("_look", "_type", "node", null))
.build();
final String requestId = randomAlphaOfLengthBetween(6, 12);
service.get().authenticationSuccess(requestId, authentication, restRequest);
service.get().authenticationSuccess(restRequest);
verify(licenseState).isAllowed(Security.AUDITING_FEATURE);
if (isAuditingAllowed) {
verify(auditTrail).authenticationSuccess(requestId, authentication, restRequest);
verify(auditTrail).authenticationSuccess(restRequest);
} else {
verifyNoMoreInteractions(auditTrail);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine.AuthorizationInfo;
import org.elasticsearch.xpack.core.security.user.SystemUser;
import org.elasticsearch.xpack.core.security.user.User;
import org.elasticsearch.xpack.security.audit.AuditUtil;
import org.elasticsearch.xpack.security.audit.logfile.LoggingAuditTrail.AuditEventMetaInfo;
import org.elasticsearch.xpack.security.audit.logfile.LoggingAuditTrailTests.MockRequest;
import org.elasticsearch.xpack.security.audit.logfile.LoggingAuditTrailTests.RestContent;
Expand Down Expand Up @@ -1139,12 +1140,16 @@ public void testUsersFilter() throws Exception {
threadContext.stashContext();

// authentication Success
auditTrail.authenticationSuccess(randomAlphaOfLength(8), unfilteredAuthentication, getRestRequest());
AuditUtil.generateRequestId(threadContext);
unfilteredAuthentication.writeToContext(threadContext);
auditTrail.authenticationSuccess(getRestRequest());
assertThat("AuthenticationSuccess rest request: unfiltered user is filtered out", logOutput.size(), is(1));
logOutput.clear();
threadContext.stashContext();

auditTrail.authenticationSuccess(randomAlphaOfLength(8), filteredAuthentication, getRestRequest());
AuditUtil.generateRequestId(threadContext);
filteredAuthentication.writeToContext(threadContext);
auditTrail.authenticationSuccess(getRestRequest());
assertThat("AuthenticationSuccess rest request: filtered user is not filtered out", logOutput.size(), is(0));
logOutput.clear();
threadContext.stashContext();
Expand Down Expand Up @@ -1634,12 +1639,16 @@ public void testRealmsFilter() throws Exception {
threadContext.stashContext();

// authentication Success
auditTrail.authenticationSuccess(randomAlphaOfLength(8), createAuthentication(user, authUser, unfilteredRealm), getRestRequest());
AuditUtil.generateRequestId(threadContext);
createAuthentication(user, authUser, unfilteredRealm).writeToContext(threadContext);
auditTrail.authenticationSuccess(getRestRequest());
assertThat("AuthenticationSuccess rest request: unfiltered realm is filtered out", logOutput.size(), is(1));
logOutput.clear();
threadContext.stashContext();

auditTrail.authenticationSuccess(randomAlphaOfLength(8), createAuthentication(user, authUser, filteredRealm), getRestRequest());
AuditUtil.generateRequestId(threadContext);
createAuthentication(user, authUser, filteredRealm).writeToContext(threadContext);
auditTrail.authenticationSuccess(getRestRequest());
assertThat("AuthenticationSuccess rest request: filtered realm is not filtered out", logOutput.size(), is(0));
logOutput.clear();
threadContext.stashContext();
Expand Down Expand Up @@ -1950,7 +1959,9 @@ public void testRolesFilter() throws Exception {
threadContext.stashContext();

// authentication Success
auditTrail.authenticationSuccess(randomAlphaOfLength(8), authentication, getRestRequest());
AuditUtil.generateRequestId(threadContext);
authentication.writeToContext(threadContext);
auditTrail.authenticationSuccess(getRestRequest());
if (filterMissingRoles) {
assertThat("AuthenticationSuccess rest request: is not filtered out by the missing roles filter", logOutput.size(), is(0));
} else {
Expand Down Expand Up @@ -2426,7 +2437,9 @@ public void testIndicesFilter() throws Exception {
threadContext.stashContext();

// authentication Success
auditTrail.authenticationSuccess(randomAlphaOfLength(8), authentication, getRestRequest());
AuditUtil.generateRequestId(threadContext);
authentication.writeToContext(threadContext);
auditTrail.authenticationSuccess(getRestRequest());
if (filterMissingIndices) {
assertThat("AuthenticationSuccess rest request: is not filtered out by the missing indices filter", logOutput.size(), is(0));
} else {
Expand Down Expand Up @@ -2697,7 +2710,9 @@ public void testActionsFilter() throws Exception {
threadContext.stashContext();

// authentication Success
auditTrail.authenticationSuccess(randomAlphaOfLength(8), createAuthentication(user, authUser, "realm"), getRestRequest());
AuditUtil.generateRequestId(threadContext);
createAuthentication(user, authUser, "realm").writeToContext(threadContext);
auditTrail.authenticationSuccess(getRestRequest());
if (filterMissingAction) {
assertThat("AuthenticationSuccess rest request: not filtered out by the missing action filter", logOutput.size(), is(0));
} else {
Expand Down

0 comments on commit f353be2

Please sign in to comment.