Skip to content

Commit

Permalink
User Profile - Request cancellation for SuggestProfiles on HTTP disco…
Browse files Browse the repository at this point in the history
…nnect (#86332)

This PR adds support for automatic request cancellation on HTTP
connection drop for the SuggestProfiles API. Both the Suggest
request itself and its child Search requests are cancelled once the 
incominng HTTP connection is closed.
  • Loading branch information
ywangd committed May 6, 2022
1 parent 4075bd4 commit 97222b7
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 10 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/86332.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 86332
summary: User Profile - Support request cancellation on HTTP disconnect
area: Security
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -90,6 +93,16 @@ public ActionRequestValidationException validate() {
return validationException;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers);
}

@Override
public String getDescription() {
return "SuggestProfiles{" + "name='" + name + "', hint=" + hint + '}';
}

public static class Hint implements Writeable {
@Nullable
private final List<String> uids;
Expand Down Expand Up @@ -180,5 +193,10 @@ private ActionRequestValidationException validate(ActionRequestValidationExcepti
}
return validationException;
}

@Override
public String toString() {
return "Hint{" + "uids=" + uids + ", labels=" + labels + '}';
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.security.profile;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.client.Cancellable;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexModule;
import org.elasticsearch.index.shard.SearchOperationListener;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.search.internal.ReaderContext;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.profile.SuggestProfilesAction;
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;

import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.test.SecuritySettingsSource.TEST_USER_NAME;
import static org.elasticsearch.test.SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItems;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.startsWith;

public class ProfileCancellationIntegTests extends AbstractProfileIntegTestCase {

@Override
protected boolean addMockHttpTransport() {
return false;
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
final List<Class<? extends Plugin>> plugins = new ArrayList<>(super.nodePlugins());
plugins.add(SearchBlockPlugin.class);
return List.copyOf(plugins);
}

public void testSuggestProfilesCancellation() throws Exception {
doActivateProfile(RAC_USER_NAME, TEST_PASSWORD_SECURE_STRING);

final String xOpaqueId = randomAlphaOfLength(10);
final Request request = new Request("GET", "/_security/profile/_suggest");
RequestOptions.Builder options = request.getOptions()
.toBuilder()
.addHeader("Authorization", UsernamePasswordToken.basicAuthHeaderValue(TEST_USER_NAME, TEST_PASSWORD_SECURE_STRING))
.addHeader(Task.X_OPAQUE_ID_HTTP_HEADER, xOpaqueId);
request.setOptions(options);

// Stall the search
enableSearchBlock();

final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<Exception> error = new AtomicReference<>();
final Cancellable cancellable = getRestClient().performRequestAsync(request, new ResponseListener() {
@Override
public void onSuccess(Response response) {
latch.countDown();
}

@Override
public void onFailure(Exception exception) {
error.set(exception);
latch.countDown();
}
});

// Assert that suggest task and search sub-tasks are initiated
final Set<Long> taskIds = ConcurrentHashMap.newKeySet();
assertBusy(() -> {
final List<Task> tasks = getTasksForXOpaqueId(xOpaqueId);
final List<String> taskActions = tasks.stream().map(Task::getAction).toList();
assertThat(taskActions, hasItems(equalTo(SuggestProfilesAction.NAME), startsWith(SearchAction.NAME)));
tasks.forEach(t -> taskIds.add(t.getId()));
});

// Cancel the suggest request and all tasks should be cancelled
cancellable.cancel();
assertBusy(() -> {
final List<CancellableTask> cancellableTasks = getCancellableTasksForXOpaqueId(xOpaqueId);
cancellableTasks.forEach(cancellableTask -> {
assertThat(
"task " + cancellableTask.getId() + "/" + cancellableTask.getAction() + " not cancelled",
cancellableTask.isCancelled(),
is(true)
);
taskIds.remove(cancellableTask.getId());
});
assertThat(taskIds, empty());
});

disableSearchBlock();
latch.await();
assertThat(error.get(), instanceOf(CancellationException.class));
}

private List<Task> getTasksForXOpaqueId(String xOpaqueId) {
final ArrayList<Task> tasks = new ArrayList<>();
for (TransportService transportService : internalCluster().getInstances(TransportService.class)) {
tasks.addAll(
transportService.getTaskManager()
.getTasks()
.values()
.stream()
.filter(t -> xOpaqueId.equals(t.headers().get(Task.X_OPAQUE_ID_HTTP_HEADER)))
.toList()
);
}
return tasks;
}

private List<CancellableTask> getCancellableTasksForXOpaqueId(String xOpaqueId) {
final ArrayList<CancellableTask> cancellableTasks = new ArrayList<>();
for (TransportService transportService : internalCluster().getInstances(TransportService.class)) {
cancellableTasks.addAll(
transportService.getTaskManager()
.getCancellableTasks()
.values()
.stream()
.filter(t -> xOpaqueId.equals(t.headers().get(Task.X_OPAQUE_ID_HTTP_HEADER)))
.toList()
);
}
return cancellableTasks;
}

private void enableSearchBlock() {
for (PluginsService pluginsService : internalCluster().getInstances(PluginsService.class)) {
pluginsService.filterPlugins(SearchBlockPlugin.class).forEach(SearchBlockPlugin::enableSearchBlock);
}
}

private void disableSearchBlock() {
for (PluginsService pluginsService : internalCluster().getInstances(PluginsService.class)) {
pluginsService.filterPlugins(SearchBlockPlugin.class).forEach(SearchBlockPlugin::disableSearchBlock);
}
}

public static class SearchBlockPlugin extends Plugin implements ActionPlugin {
protected static final Logger logger = LogManager.getLogger(SearchBlockPlugin.class);

private final String nodeId;
private final AtomicBoolean shouldBlockOnSearch = new AtomicBoolean(false);

public SearchBlockPlugin(Settings settings, Path configPath) throws Exception {
nodeId = settings.get("node.name");
}

@Override
public void onIndexModule(IndexModule indexModule) {
super.onIndexModule(indexModule);
indexModule.addSearchOperationListener(new SearchOperationListener() {
@Override
public void onNewReaderContext(ReaderContext readerContext) {
try {
logger.info("blocking search on " + nodeId);
assertBusy(() -> assertFalse(shouldBlockOnSearch.get()));
logger.info("unblocking search on " + nodeId);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
}

void enableSearchBlock() {
shouldBlockOnSearch.set(true);
}

void disableSearchBlock() {
shouldBlockOnSearch.set(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.profile.SuggestProfilesAction;
import org.elasticsearch.xpack.core.security.action.profile.SuggestProfilesRequest;
Expand All @@ -21,15 +23,22 @@
public class TransportSuggestProfilesAction extends HandledTransportAction<SuggestProfilesRequest, SuggestProfilesResponse> {

private final ProfileService profileService;
private final ClusterService clusterService;

@Inject
public TransportSuggestProfilesAction(TransportService transportService, ActionFilters actionFilters, ProfileService profileService) {
public TransportSuggestProfilesAction(
TransportService transportService,
ActionFilters actionFilters,
ProfileService profileService,
ClusterService clusterService
) {
super(SuggestProfilesAction.NAME, transportService, actionFilters, SuggestProfilesRequest::new);
this.profileService = profileService;
this.clusterService = clusterService;
}

@Override
protected void doExecute(Task task, SuggestProfilesRequest request, ActionListener<SuggestProfilesResponse> listener) {
profileService.suggestProfile(request, listener);
profileService.suggestProfile(request, new TaskId(clusterService.localNode().getId(), task.getId()), listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -210,7 +211,7 @@ public void updateProfileData(UpdateProfileDataRequest request, ActionListener<A
);
}

public void suggestProfile(SuggestProfilesRequest request, ActionListener<SuggestProfilesResponse> listener) {
public void suggestProfile(SuggestProfilesRequest request, TaskId parentTaskId, ActionListener<SuggestProfilesResponse> listener) {
tryFreezeAndCheckIndex(listener.map(response -> {
assert response == null : "only null response can reach here";
return new SuggestProfilesResponse(
Expand All @@ -219,7 +220,7 @@ public void suggestProfile(SuggestProfilesRequest request, ActionListener<Sugges
new TotalHits(0, TotalHits.Relation.EQUAL_TO)
);
})).ifPresent(frozenProfileIndex -> {
final SearchRequest searchRequest = buildSearchRequest(request);
final SearchRequest searchRequest = buildSearchRequest(request, parentTaskId);

frozenProfileIndex.checkIndexVersionThenExecute(
listener::onFailure,
Expand Down Expand Up @@ -271,7 +272,7 @@ public void setEnabled(String uid, boolean enabled, RefreshPolicy refreshPolicy,
}

// package private for testing
SearchRequest buildSearchRequest(SuggestProfilesRequest request) {
SearchRequest buildSearchRequest(SuggestProfilesRequest request, TaskId parentTaskId) {
final BoolQueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("user_profile.enabled", true));
if (Strings.hasText(request.getName())) {
query.must(
Expand Down Expand Up @@ -302,12 +303,14 @@ SearchRequest buildSearchRequest(SuggestProfilesRequest request) {
query.minimumShouldMatch(0);
}

return client.prepareSearch(SECURITY_PROFILE_ALIAS)
final SearchRequest searchRequest = client.prepareSearch(SECURITY_PROFILE_ALIAS)
.setQuery(query)
.setSize(request.getSize())
.addSort("_score", SortOrder.DESC)
.addSort("user_profile.last_synchronized", SortOrder.DESC)
.request();
searchRequest.setParentTask(parentTaskId);
return searchRequest;
}

private void getVersionedDocument(String uid, ActionListener<VersionedDocument> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
Expand Down Expand Up @@ -84,7 +86,12 @@ protected RestChannelConsumer innerPrepareRequest(RestRequest request, NodeClien
payload.size(),
payload.hint() == null ? null : new SuggestProfilesRequest.Hint(payload.hint().uids(), payload.hint().labels())
);
return channel -> client.execute(SuggestProfilesAction.INSTANCE, suggestProfilesRequest, new RestToXContentListener<>(channel));
final HttpChannel httpChannel = request.getHttpChannel();
return channel -> new RestCancellableNodeClient(client, httpChannel).execute(
SuggestProfilesAction.INSTANCE,
suggestProfilesRequest,
new RestToXContentListener<>(channel)
);
}

record Payload(String name, Integer size, PayloadHint hint, String data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLogAppender;
import org.elasticsearch.test.VersionUtils;
Expand Down Expand Up @@ -459,8 +460,11 @@ public void testBuildSearchRequest() {
final int size = randomIntBetween(0, Integer.MAX_VALUE);
final SuggestProfilesRequest.Hint hint = SuggestProfilesRequestTests.randomHint();
final SuggestProfilesRequest suggestProfilesRequest = new SuggestProfilesRequest(Set.of(), name, size, hint);
final TaskId parentTaskId = new TaskId(randomAlphaOfLength(20), randomNonNegativeLong());

final SearchRequest searchRequest = profileService.buildSearchRequest(suggestProfilesRequest, parentTaskId);
assertThat(searchRequest.getParentTask(), is(parentTaskId));

final SearchRequest searchRequest = profileService.buildSearchRequest(suggestProfilesRequest);
final SearchSourceBuilder searchSourceBuilder = searchRequest.source();

assertThat(
Expand Down Expand Up @@ -572,7 +576,11 @@ public void testSecurityProfileOrigin() {
return null;
}).when(client).execute(eq(SearchAction.INSTANCE), any(SearchRequest.class), anyActionListener());
final PlainActionFuture<SuggestProfilesResponse> future3 = new PlainActionFuture<>();
profileService.suggestProfile(new SuggestProfilesRequest(Set.of(), "", 1, null), future3);
profileService.suggestProfile(
new SuggestProfilesRequest(Set.of(), "", 1, null),
new TaskId(randomAlphaOfLength(20), randomNonNegativeLong()),
future3
);
final RuntimeException e3 = expectThrows(RuntimeException.class, future3::actionGet);
assertThat(e3, is(expectedException));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void init() {
requestHolder = new AtomicReference<>();
restSuggestProfilesAction = new RestSuggestProfilesAction(settings, licenseState);
controller().registerHandler(restSuggestProfilesAction);
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
verifyingClient.setExecuteLocallyVerifier(((actionType, actionRequest) -> {
assertThat(actionRequest, instanceOf(SuggestProfilesRequest.class));
requestHolder.set((SuggestProfilesRequest) actionRequest);
return mock(SuggestProfilesResponse.class);
Expand Down

0 comments on commit 97222b7

Please sign in to comment.