Skip to content

Commit

Permalink
Improve BWC for persisted authentication headers (#83913)
Browse files Browse the repository at this point in the history
Authentication headers are persisted as part of a task definition including ML
jobs, CCR following etc. The persistence process store them into either an
index or the cluster state. In both cases, the headers are retrieved from
ThreadContext as a string which is the serialised form of the Authentication
object. This string is always serialised with the node's version.

The problem is: In a mixed cluster, the task can be created in a newer node and
persisted into an index but then needs to be loaded by a older node. The older
node does not understand the newer format of the serialised Authentication
object and hence error out on reading it.

This PR adds additional logic in places where the headers are persisted. It
compares the Authentication version with minNodeVersion and rewrites it if the
minNodeVersion is older. Since we already filter security headers in places
where headers are persisted, the new logic is hooked into the same places and
essentially another enhancement on how to handle security headers for persisted
tasks.

Resolves: #83567
  • Loading branch information
ywangd committed Feb 17, 2022
1 parent 7d094c3 commit fb65f95
Show file tree
Hide file tree
Showing 32 changed files with 386 additions and 67 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/83913.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 83913
summary: Improve BWC for persisted authentication headers
area: Authentication
type: enhancement
issues:
- 83567
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN;

public class TransportSubmitAsyncSearchAction extends HandledTransportAction<SubmitAsyncSearchRequest, AsyncSearchResponse> {
private final ClusterService clusterService;
private final NodeClient nodeClient;
private final BiFunction<Supplier<Boolean>, SearchRequest, AggregationReduceContext> requestToAggReduceContextBuilder;
private final TransportSearchAction searchAction;
Expand All @@ -62,6 +63,7 @@ public TransportSubmitAsyncSearchAction(
BigArrays bigArrays
) {
super(SubmitAsyncSearchAction.NAME, transportService, actionFilters, SubmitAsyncSearchRequest::new);
this.clusterService = clusterService;
this.nodeClient = nodeClient;
this.requestToAggReduceContextBuilder = (task, request) -> searchService.aggReduceContextBuilder(task, request).forFinalReduction();
this.searchAction = searchAction;
Expand Down Expand Up @@ -144,7 +146,10 @@ public void onFailure(Exception exc) {

private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, Task submitTask, TimeValue keepAlive) {
String docID = UUIDs.randomBase64UUID();
Map<String, String> originHeaders = ClientHelper.filterSecurityHeaders(nodeClient.threadPool().getThreadContext().getHeaders());
Map<String, String> originHeaders = ClientHelper.getPersistableSafeSecurityHeaders(
nodeClient.threadPool().getThreadContext(),
clusterService.state()
);
SearchRequest searchRequest = new SearchRequest(request.getSearchRequest()) {
@Override
public AsyncSearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> taskHeaders) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,11 @@ User getUser(final Client remoteClient) {
return securityContext.getUser();
}

public static Client wrapClient(Client client, Map<String, String> headers) {
public static Client wrapClient(Client client, Map<String, String> headers, ClusterState clusterState) {
if (headers.isEmpty()) {
return client;
} else {
Map<String, String> filteredHeaders = ClientHelper.filterSecurityHeaders(headers);
Map<String, String> filteredHeaders = ClientHelper.getPersistableSafeSecurityHeaders(headers, clusterState);
if (filteredHeaders.isEmpty()) {
return client;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ void createAndFollow(
Runnable successHandler,
Consumer<Exception> failureHandler
) {
Client followerClient = CcrLicenseChecker.wrapClient(client, headers);
Client followerClient = CcrLicenseChecker.wrapClient(client, headers, clusterService.state());
followerClient.execute(
PutFollowAction.INSTANCE,
request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ protected AllocatedPersistentTask createTask(
Map<String, String> headers
) {
ShardFollowTask params = taskInProgress.getParams();
Client followerClient = wrapClient(client, params.getHeaders());
Client followerClient = wrapClient(client, params.getHeaders(), clusterService.state());
BiConsumer<TimeValue, Runnable> scheduler = (delay, command) -> threadPool.scheduleUnlessShuttingDown(
delay,
Ccr.CCR_THREAD_POOL_NAME,
Expand Down Expand Up @@ -562,7 +562,8 @@ private String getLeaderShardHistoryUUID(ShardFollowTask params) {
}

private Client remoteClient(ShardFollowTask params) {
return wrapClient(client.getRemoteClusterClient(params.getRemoteCluster()), params.getHeaders());
// TODO: do we need minNodeVersion here since it is for remote cluster
return wrapClient(client.getRemoteClusterClient(params.getRemoteCluster()), params.getHeaders(), clusterService.state());
}

interface FollowerStatsInfoHandler {
Expand All @@ -571,7 +572,7 @@ interface FollowerStatsInfoHandler {

@Override
protected void nodeOperation(final AllocatedPersistentTask task, final ShardFollowTask params, final PersistentTaskState state) {
Client followerClient = wrapClient(client, params.getHeaders());
Client followerClient = wrapClient(client, params.getHeaders(), clusterService.state());
ShardFollowNodeTask shardFollowNodeTask = (ShardFollowNodeTask) task;
logger.info("{} Starting to track leader shard {}", params.getFollowShardId(), params.getLeaderShardId());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ protected void masterOperation(
return;
}
final Client remoteClient = client.getRemoteClusterClient(request.getRemoteCluster());
final Map<String, String> filteredHeaders = ClientHelper.filterSecurityHeaders(threadPool.getThreadContext().getHeaders());
final Map<String, String> filteredHeaders = ClientHelper.getPersistableSafeSecurityHeaders(
threadPool.getThreadContext(),
clusterService.state()
);

Consumer<ClusterStateResponse> consumer = remoteClusterState -> {
String[] indices = request.getLeaderIndexPatterns().toArray(new String[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,11 @@ private void createFollowerIndex(
.masterNodeTimeout(request.masterNodeTimeout())
.indexSettings(overrideSettings);

final Client clientWithHeaders = CcrLicenseChecker.wrapClient(this.client, threadPool.getThreadContext().getHeaders());
final Client clientWithHeaders = CcrLicenseChecker.wrapClient(
this.client,
threadPool.getThreadContext().getHeaders(),
clusterService.state()
);
threadPool.executor(ThreadPool.Names.SNAPSHOT).execute(new AbstractRunnable() {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ void start(
validate(request, leaderIndexMetadata, followIndexMetadata, leaderIndexHistoryUUIDs, mapperService);
final int numShards = followIndexMetadata.getNumberOfShards();
final ResponseHandler handler = new ResponseHandler(numShards, listener);
Map<String, String> filteredHeaders = ClientHelper.filterSecurityHeaders(threadPool.getThreadContext().getHeaders());
Map<String, String> filteredHeaders = ClientHelper.getPersistableSafeSecurityHeaders(
threadPool.getThreadContext(),
clusterService.state()
);

for (int shardId = 0; shardId < numShards; shardId++) {
String taskId = followIndexMetadata.getIndexUUID() + "-" + shardId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
package org.elasticsearch.xpack.core;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestBuilder;
Expand All @@ -14,12 +15,19 @@
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.AuthenticationField;
import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField;
import org.elasticsearch.xpack.core.security.authc.support.AuthenticationContextSerializer;
import org.elasticsearch.xpack.core.security.authc.support.SecondaryAuthentication;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -77,6 +85,89 @@ public static Map<String, String> filterSecurityHeaders(Map<String, String> head
}
}

/**
* In addition to {@link #filterSecurityHeaders}, also check the version of Authentication objects
* and rewrite them using minNodeVersion so that they are safe to be persisted as index data
* and loaded by all nodes in the cluster.
*/
public static Map<String, String> getPersistableSafeSecurityHeaders(ThreadContext threadContext, ClusterState clusterState) {
return maybeRewriteAuthenticationHeadersForVersion(
filterSecurityHeaders(threadContext.getHeaders()),
key -> new AuthenticationContextSerializer(key).readFromContext(threadContext),
clusterState.nodes().getMinNodeVersion()
);
}

/**
* Similar to {@link #getPersistableSafeSecurityHeaders(ThreadContext, ClusterState)},
* but works on a Map of headers instead of ThreadContext.
*/
public static Map<String, String> getPersistableSafeSecurityHeaders(Map<String, String> headers, ClusterState clusterState) {
final CheckedFunction<String, Authentication, IOException> authenticationReader = key -> {
final String authHeader = headers.get(key);
return authHeader == null ? null : AuthenticationContextSerializer.decode(authHeader);
};
return maybeRewriteAuthenticationHeadersForVersion(
filterSecurityHeaders(headers),
authenticationReader,
clusterState.nodes().getMinNodeVersion()
);
}

private static Map<String, String> maybeRewriteAuthenticationHeadersForVersion(
Map<String, String> filteredHeaders,
CheckedFunction<String, Authentication, IOException> authenticationReader,
Version minNodeVersion
) {
Map<String, String> newHeaders = null;

final String authHeader = maybeRewriteSingleAuthenticationHeaderForVersion(
authenticationReader,
AuthenticationField.AUTHENTICATION_KEY,
minNodeVersion
);
if (authHeader != null) {
newHeaders = new HashMap<>();
newHeaders.put(AuthenticationField.AUTHENTICATION_KEY, authHeader);
}

final String secondaryHeader = maybeRewriteSingleAuthenticationHeaderForVersion(
authenticationReader,
SecondaryAuthentication.THREAD_CTX_KEY,
minNodeVersion
);
if (secondaryHeader != null) {
if (newHeaders == null) {
newHeaders = new HashMap<>();
}
newHeaders.put(SecondaryAuthentication.THREAD_CTX_KEY, secondaryHeader);
}

if (newHeaders != null) {
final HashMap<String, String> mutableHeaders = new HashMap<>(filteredHeaders);
mutableHeaders.putAll(newHeaders);
return Map.copyOf(mutableHeaders);
} else {
return filteredHeaders;
}
}

private static String maybeRewriteSingleAuthenticationHeaderForVersion(
CheckedFunction<String, Authentication, IOException> authenticationReader,
String authenticationHeaderKey,
Version minNodeVersion
) {
try {
final Authentication authentication = authenticationReader.apply(authenticationHeaderKey);
if (authentication != null && authentication.getVersion().after(minNodeVersion)) {
return authentication.maybeRewriteForOlderVersion(minNodeVersion).encode();
}
} catch (IOException e) {
throw new UncheckedIOException("failed to read authentication with key [" + authenticationHeaderKey + "]", e);
}
return null;
}

/**
* .
* @deprecated use ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME
Expand Down Expand Up @@ -167,6 +258,7 @@ public static <T extends ActionResponse> T executeWithHeaders(
Client client,
Supplier<T> supplier
) {
// No need to rewrite authentication header because it will be handled by Security Interceptor
Map<String, String> filteredHeaders = filterSecurityHeaders(headers);

// no security headers, we will have to use the xpack internal user for
Expand Down Expand Up @@ -206,6 +298,7 @@ public static <Request extends ActionRequest, Response extends ActionResponse> v
Request request,
ActionListener<Response> listener
) {
// No need to rewrite authentication header because it will be handled by Security Interceptor
final Map<String, String> filteredHeaders = filterSecurityHeaders(headers);
final ThreadContext threadContext = client.threadPool().getThreadContext();
// No headers (e.g. security not installed/in use) so execute as origin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -22,6 +23,7 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand All @@ -38,8 +40,6 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;

/**
* A datafeed update contains partial properties to update a {@link DatafeedConfig}.
* The main difference between this class and {@link DatafeedConfig} is that here all
Expand Down Expand Up @@ -334,7 +334,7 @@ public IndicesOptions getIndicesOptions() {
* Applies the update to the given {@link DatafeedConfig}
* @return a new {@link DatafeedConfig} that contains the update
*/
public DatafeedConfig apply(DatafeedConfig datafeedConfig, Map<String, String> headers) {
public DatafeedConfig apply(DatafeedConfig datafeedConfig, Map<String, String> headers, ClusterState clusterState) {
if (id.equals(datafeedConfig.getId()) == false) {
throw new IllegalArgumentException("Cannot apply update to datafeedConfig with different id");
}
Expand Down Expand Up @@ -384,7 +384,7 @@ public DatafeedConfig apply(DatafeedConfig datafeedConfig, Map<String, String> h
builder.setRuntimeMappings(runtimeMappings);
}
if (headers.isEmpty() == false) {
builder.setHeaders(filterSecurityHeaders(headers));
builder.setHeaders(ClientHelper.getPersistableSafeSecurityHeaders(headers, clusterState));
}
return builder.build();
}
Expand Down

0 comments on commit fb65f95

Please sign in to comment.