Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Store and use only internal security headers (#66365)
For async searches (EQL included) the client's request headers were
erroneously stored in the .tasks index. This might expose the requesting
client's HTTP Authorization header. This PR fixes that by employing the
usual approach to store only the security-internal headers, which carry
the authentication result, instead of the original Authorization header,
which is commonly utilized to redo authentication for scheduled tasks.
  • Loading branch information
albertzaharovits committed Dec 17, 2020
1 parent 50aca62 commit 480561d
Show file tree
Hide file tree
Showing 24 changed files with 84 additions and 65 deletions.
Expand Up @@ -324,7 +324,13 @@ public String getHeader(String key) {
}

/**
* Returns all of the request contexts headers
* Returns all of the request headers from the thread's context.<br>
* <b>Be advised, headers might contain credentials.</b>
* In order to avoid storing, and erroneously exposing, such headers,
* it is recommended to instead store security headers that prove
* the credentials have been verified successfully, and which are
* internal to the system, in the sense that they cannot be sent
* by the clients.
*/
public Map<String, String> getHeaders() {
HashMap<String, String> map = new HashMap<>(defaultHeader);
Expand Down
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.core.async.AsyncExecutionId;
import org.elasticsearch.xpack.core.async.AsyncTaskIndexService;
Expand Down Expand Up @@ -136,7 +137,7 @@ public void onFailure(Exception exc) {

private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, Task submitTask, TimeValue keepAlive) {
String docID = UUIDs.randomBase64UUID();
Map<String, String> originHeaders = nodeClient.threadPool().getThreadContext().getHeaders();
Map<String, String> originHeaders = ClientHelper.filterSecurityHeaders(nodeClient.threadPool().getThreadContext().getHeaders());
SearchRequest searchRequest = new SearchRequest(request.getSearchRequest()) {
@Override
public AsyncSearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> taskHeaders) {
Expand Down
Expand Up @@ -37,7 +37,7 @@
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.ccr.action.ShardChangesAction;
import org.elasticsearch.xpack.ccr.action.ShardFollowTask;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction;
Expand All @@ -58,7 +58,6 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* Encapsulates licensing checking for CCR.
Expand Down Expand Up @@ -363,18 +362,15 @@ public static Client wrapClient(Client client, Map<String, String> headers) {
if (headers.isEmpty()) {
return client;
} else {
final ThreadContext threadContext = client.threadPool().getThreadContext();
Map<String, String> filteredHeaders = headers.entrySet().stream()
.filter(e -> ShardFollowTask.HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Map<String, String> filteredHeaders = ClientHelper.filterSecurityHeaders(headers);
if (filteredHeaders.isEmpty()) {
return client;
}
return new FilterClient(client) {
@Override
protected <Request extends ActionRequest, Response extends ActionResponse>
void doExecute(ActionType<Response> action, Request request, ActionListener<Response> listener) {
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ignore = stashWithHeaders(threadContext, filteredHeaders)) {
super.doExecute(action, request, new ContextPreservingActionListener<>(supplier, listener));
}
ClientHelper.executeWithHeadersAsync(filteredHeaders, null, client, action, request, listener);
}
};
}
Expand Down
Expand Up @@ -21,21 +21,14 @@
import org.elasticsearch.xpack.core.ccr.action.ImmutableFollowParameters;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class ShardFollowTask extends ImmutableFollowParameters implements XPackPlugin.XPackPersistentTaskParams {

public static final String NAME = "xpack/ccr/shard_follow_task";

// list of headers that will be stored when a job is created
public static final Set<String> HEADER_FILTERS =
Collections.unmodifiableSet(new HashSet<>(Arrays.asList("es-security-runas-user", "_xpack_security_authentication")));

private static final ParseField REMOTE_CLUSTER_FIELD = new ParseField("remote_cluster");
private static final ParseField FOLLOW_SHARD_INDEX_FIELD = new ParseField("follow_shard_index");
private static final ParseField FOLLOW_SHARD_INDEX_UUID_FIELD = new ParseField("follow_shard_index_uuid");
Expand Down
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.ccr.CcrLicenseChecker;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ccr.AutoFollowMetadata;
import org.elasticsearch.xpack.core.ccr.AutoFollowMetadata.AutoFollowPattern;
import org.elasticsearch.xpack.core.ccr.action.PutAutoFollowPatternAction;
Expand Down Expand Up @@ -92,9 +93,7 @@ protected void masterOperation(PutAutoFollowPatternAction.Request request,
return;
}
final Client remoteClient = client.getRemoteClusterClient(request.getRemoteCluster());
final Map<String, String> filteredHeaders = threadPool.getThreadContext().getHeaders().entrySet().stream()
.filter(e -> ShardFollowTask.HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
final Map<String, String> filteredHeaders = ClientHelper.filterSecurityHeaders(threadPool.getThreadContext().getHeaders());

Consumer<ClusterStateResponse> consumer = remoteClusterState -> {
String[] indices = request.getLeaderIndexPatterns().toArray(new String[0]);
Expand Down
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.xpack.ccr.Ccr;
import org.elasticsearch.xpack.ccr.CcrLicenseChecker;
import org.elasticsearch.xpack.ccr.CcrSettings;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ccr.action.FollowParameters;
import org.elasticsearch.xpack.core.ccr.action.ResumeFollowAction;

Expand All @@ -58,7 +59,6 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class TransportResumeFollowAction extends TransportMasterNodeAction<ResumeFollowAction.Request, AcknowledgedResponse> {

Expand Down Expand Up @@ -173,9 +173,7 @@ void start(
validate(request, leaderIndexMetadata, followIndexMetadata, leaderIndexHistoryUUIDs, mapperService);
final int numShards = followIndexMetadata.getNumberOfShards();
final ResponseHandler handler = new ResponseHandler(numShards, listener);
Map<String, String> filteredHeaders = threadPool.getThreadContext().getHeaders().entrySet().stream()
.filter(e -> ShardFollowTask.HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Map<String, String> filteredHeaders = ClientHelper.filterSecurityHeaders(threadPool.getThreadContext().getHeaders());

for (int shardId = 0; shardId < numShards; shardId++) {
String taskId = followIndexMetadata.getIndexUUID() + "-" + shardId;
Expand Down
Expand Up @@ -18,12 +18,14 @@
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.xpack.core.security.authc.AuthenticationField;
import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField;
import org.elasticsearch.xpack.core.security.authc.support.SecondaryAuthentication;

import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
Expand All @@ -32,13 +34,27 @@
*/
public final class ClientHelper {

private static Pattern authorizationHeaderPattern = Pattern.compile("\\s*" + Pattern.quote("Authorization") + "\\s*",
Pattern.CASE_INSENSITIVE);

public static void assertNoAuthorizationHeader(Map<String, String> headers) {
if (org.elasticsearch.Assertions.ENABLED) {
for (String header : headers.keySet()) {
if (authorizationHeaderPattern.matcher(header).find()) {
assert false : "headers contain \"Authorization\"";
}
}
}
}

/**
* List of headers that are related to security
*/
public static final Set<String> SECURITY_HEADER_FILTERS =
Sets.newHashSet(
AuthenticationServiceField.RUN_AS_USER_HEADER,
AuthenticationField.AUTHENTICATION_KEY);
AuthenticationField.AUTHENTICATION_KEY,
SecondaryAuthentication.THREAD_CTX_KEY);

/**
* Leaves only headers that are related to security and filters out the rest.
Expand All @@ -47,9 +63,14 @@ public final class ClientHelper {
* @return A portion of entries that are related to security
*/
public static Map<String, String> filterSecurityHeaders(Map<String, String> headers) {
return Objects.requireNonNull(headers).entrySet().stream()
.filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
if (SECURITY_HEADER_FILTERS.containsAll(headers.keySet())) {
// fast-track to skip the artifice below
return headers;
} else {
return Objects.requireNonNull(headers).entrySet().stream()
.filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
}

/**
Expand Down Expand Up @@ -162,11 +183,8 @@ public static <T extends ActionResponse> T executeWithHeaders(Map<String, String
public static <Request extends ActionRequest, Response extends ActionResponse>
void executeWithHeadersAsync(Map<String, String> headers, String origin, Client client, ActionType<Response> action, Request request,
ActionListener<Response> listener) {

Map<String, String> filteredHeaders = filterSecurityHeaders(headers);

final Map<String, String> filteredHeaders = filterSecurityHeaders(headers);
final ThreadContext threadContext = client.threadPool().getThreadContext();

// No headers (e.g. security not installed/in use) so execute as origin
if (filteredHeaders.isEmpty()) {
ClientHelper.executeAsyncWithOrigin(client, origin, action, request, listener);
Expand All @@ -181,6 +199,7 @@ void executeWithHeadersAsync(Map<String, String> headers, String origin, Client

private static ThreadContext.StoredContext stashWithHeaders(ThreadContext threadContext, Map<String, String> headers) {
final ThreadContext.StoredContext storedContext = threadContext.stashContext();
assertNoAuthorizationHeader(headers);
threadContext.copyHeaders(headers.entrySet());
return storedContext;
}
Expand Down
Expand Up @@ -50,6 +50,8 @@
import java.util.Random;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.xpack.core.ClientHelper.assertNoAuthorizationHeader;

/**
* Datafeed configuration options. Describes where to proactively pull input
* data from.
Expand Down Expand Up @@ -506,6 +508,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(CHUNKING_CONFIG.getPreferredName(), chunkingConfig);
}
if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
assertNoAuthorizationHeader(headers);
builder.field(HEADERS.getPreferredName(), headers);
}
if (delayedDataCheckConfig != null) {
Expand Down
Expand Up @@ -35,6 +35,7 @@

import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.OBJECT_ARRAY_BOOLEAN_OR_STRING;
import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.VALUE;
import static org.elasticsearch.xpack.core.ClientHelper.assertNoAuthorizationHeader;

public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable {

Expand Down Expand Up @@ -251,6 +252,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.field(MODEL_MEMORY_LIMIT.getPreferredName(), getModelMemoryLimit().getStringRep());
if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
assertNoAuthorizationHeader(headers);
builder.field(HEADERS.getPreferredName(), headers);
}
if (createTime != null) {
Expand Down Expand Up @@ -414,6 +416,7 @@ public Builder setAnalyzedFields(FetchSourceContext fields) {

public Builder setHeaders(Map<String, String> headers) {
this.headers = headers;
assertNoAuthorizationHeader(this.headers);
return this;
}

Expand Down
Expand Up @@ -21,6 +21,8 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.core.ClientHelper.assertNoAuthorizationHeader;

/**
* This class is the main wrapper object that is serialized into the PersistentTask's cluster state.
* It holds the config (RollupJobConfig) and a map of authentication headers. Only RollupJobConfig
Expand Down Expand Up @@ -67,6 +69,7 @@ public Map<String, String> getHeaders() {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CONFIG.getPreferredName(), config);
assertNoAuthorizationHeader(headers);
builder.field(HEADERS.getPreferredName(), headers);
builder.endObject();
return builder;
Expand Down
Expand Up @@ -25,7 +25,7 @@
*/
public class SecondaryAuthentication {

private static final String THREAD_CTX_KEY = "_xpack_security_secondary_authc";
public static final String THREAD_CTX_KEY = "_xpack_security_secondary_authc";

private final SecurityContext securityContext;
private final Authentication authentication;
Expand Down
Expand Up @@ -24,6 +24,8 @@
import java.util.Objects;
import java.util.Optional;

import static org.elasticsearch.xpack.core.ClientHelper.assertNoAuthorizationHeader;

/**
* {@code SnapshotLifecyclePolicyMetadata} encapsulates a {@link SnapshotLifecyclePolicy} as well as
* the additional meta information link headers used for execution, version (a monotonically
Expand Down Expand Up @@ -86,6 +88,7 @@ public static SnapshotLifecyclePolicyMetadata parse(XContentParser parser, Strin
SnapshotInvocationRecord lastSuccess, SnapshotInvocationRecord lastFailure) {
this.policy = policy;
this.headers = headers;
assertNoAuthorizationHeader(this.headers);
this.version = version;
this.modifiedDate = modifiedDate;
this.lastSuccess = lastSuccess;
Expand All @@ -96,6 +99,7 @@ public static SnapshotLifecyclePolicyMetadata parse(XContentParser parser, Strin
SnapshotLifecyclePolicyMetadata(StreamInput in) throws IOException {
this.policy = new SnapshotLifecyclePolicy(in);
this.headers = (Map<String, String>) in.readGenericValue();
assertNoAuthorizationHeader(this.headers);
this.version = in.readVLong();
this.modifiedDate = in.readVLong();
this.lastSuccess = in.readOptionalWriteable(SnapshotInvocationRecord::new);
Expand Down
Expand Up @@ -28,6 +28,8 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.xpack.core.ClientHelper.assertNoAuthorizationHeader;

public abstract class WatchExecutionContext {

private final Wid id;
Expand Down Expand Up @@ -261,6 +263,7 @@ public WatchExecutionSnapshot createSnapshot(Thread executionThread) {
*/
public static String getUsernameFromWatch(Watch watch) throws IOException {
if (watch != null && watch.status() != null && watch.status().getHeaders() != null) {
assertNoAuthorizationHeader(watch.status().getHeaders());
String header = watch.status().getHeaders().get(AuthenticationField.AUTHENTICATION_KEY);
if (header != null) {
Authentication auth = AuthenticationContextSerializer.decode(header);
Expand Down
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.async.AsyncExecutionId;
import org.elasticsearch.xpack.core.async.AsyncTask;
import org.elasticsearch.xpack.core.async.AsyncTaskIndexService;
Expand Down Expand Up @@ -91,8 +92,9 @@ public TaskId getParentTask() {

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return operation.createTask(request, id, type, action, parentTaskId, headers, threadPool.getThreadContext().getHeaders(),
new AsyncExecutionId(doc, new TaskId(node, id)));
Map<String, String> originHeaders = ClientHelper.filterSecurityHeaders(threadPool.getThreadContext().getHeaders());
return operation.createTask(request, id, type, action, parentTaskId, headers, originHeaders, new AsyncExecutionId(doc,
new TaskId(node, id)));
}

@Override
Expand Down Expand Up @@ -193,7 +195,7 @@ private void storeResults(T searchTask, StoredAsyncResponse<Response> storedResp
private void storeResults(T searchTask, StoredAsyncResponse<Response> storedResponse, ActionListener<Void> finalListener) {
try {
asyncTaskIndexService.createResponse(searchTask.getExecutionId().getDocId(),
threadPool.getThreadContext().getHeaders(), storedResponse, ActionListener.wrap(
searchTask.getOriginHeaders(), storedResponse, ActionListener.wrap(
// We should only unregister after the result is saved
resp -> {
logger.trace(() -> new ParameterizedMessage("stored eql search results for [{}]",
Expand Down
Expand Up @@ -93,9 +93,7 @@ protected void masterOperation(Request request, ClusterState state, ActionListen
// REST layer and the Transport layer here must be accessed within this thread and not in the
// cluster state thread in the ClusterStateUpdateTask below since that thread does not share the
// same context, and therefore does not have access to the appropriate security headers.
Map<String, String> filteredHeaders = threadPool.getThreadContext().getHeaders().entrySet().stream()
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Map<String, String> filteredHeaders = ClientHelper.filterSecurityHeaders(threadPool.getThreadContext().getHeaders());
LifecyclePolicy.validatePolicyName(request.getPolicy().getName());
clusterService.submitStateUpdateTask("put-lifecycle-" + request.getPolicy().getName(),
new AckedClusterStateUpdateTask<Response>(request, listener) {
Expand Down

0 comments on commit 480561d

Please sign in to comment.