Skip to content

Commit

Permalink
restrict generic parameter type in ClusterStateTaskExecutor (#83024)
Browse files Browse the repository at this point in the history
Currently, submitStateUpdateTask only accept tasks that implement
ClusterStateTaskListener. This pr adjusts the ClusterStateTaskExecutor
to have similar restriction.
  • Loading branch information
idegtiarenko committed Jan 31, 2022
1 parent 1418823 commit 0431d85
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import java.util.List;
import java.util.Map;

public interface ClusterStateTaskExecutor<T> {
public interface ClusterStateTaskExecutor<T extends ClusterStateTaskListener> {
/**
* Update the cluster state based on the current state and the given tasks. Return the *same instance* if no state
* should be changed.
Expand Down Expand Up @@ -63,16 +63,16 @@ default String describeTasks(List<T> tasks) {
*
* @param <T> the type of the cluster state update task
*/
record ClusterTasksResult<T> (
record ClusterTasksResult<T extends ClusterStateTaskListener> (
@Nullable ClusterState resultingState, // the resulting cluster state
Map<T, TaskResult> executionResults // the correspondence between tasks and their outcome
) {

public static <T> Builder<T> builder() {
public static <T extends ClusterStateTaskListener> Builder<T> builder() {
return new Builder<>();
}

public static class Builder<T> {
public static class Builder<T extends ClusterStateTaskListener> {
private final Map<T, TaskResult> executionResults = new IdentityHashMap<>();

public Builder<T> success(T task) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ protected void onTimeout(List<? extends BatchedTask> tasks, TimeValue timeout) {

@Override
protected void run(Object batchingKey, List<? extends BatchedTask> tasks, String tasksSummary) {
ClusterStateTaskExecutor<Object> taskExecutor = (ClusterStateTaskExecutor<Object>) batchingKey;
List<UpdateTask> updateTasks = (List<UpdateTask>) tasks;
runTasks(new TaskInputs(taskExecutor, updateTasks, tasksSummary));
runTasks(
new TaskInputs((ClusterStateTaskExecutor<ClusterStateTaskListener>) batchingKey, (List<UpdateTask>) tasks, tasksSummary)
);
}

class UpdateTask extends BatchedTask {
Expand All @@ -180,10 +180,15 @@ class UpdateTask extends BatchedTask {

@Override
public String describeTasks(List<? extends BatchedTask> tasks) {
return ((ClusterStateTaskExecutor<Object>) batchingKey).describeTasks(
tasks.stream().map(BatchedTask::getTask).collect(Collectors.toList())
return ((ClusterStateTaskExecutor<ClusterStateTaskListener>) batchingKey).describeTasks(
tasks.stream().map(task -> (ClusterStateTaskListener) task.task).toList()
);
}

@Override
public ClusterStateTaskListener getTask() {
return (ClusterStateTaskListener) task;
}
}
}

Expand Down Expand Up @@ -389,7 +394,7 @@ private void handleException(String summary, long startTimeMillis, ClusterState
}

private TaskOutputs calculateTaskOutputs(TaskInputs taskInputs, ClusterState previousClusterState) {
ClusterTasksResult<Object> clusterTasksResult = executeTasks(taskInputs, previousClusterState);
ClusterTasksResult<ClusterStateTaskListener> clusterTasksResult = executeTasks(taskInputs, previousClusterState);
ClusterState newClusterState = patchVersions(previousClusterState, clusterTasksResult);
return new TaskOutputs(
taskInputs,
Expand Down Expand Up @@ -474,14 +479,14 @@ class TaskOutputs {
final ClusterState previousClusterState;
final ClusterState newClusterState;
final List<Batcher.UpdateTask> nonFailedTasks;
final Map<Object, ClusterStateTaskExecutor.TaskResult> executionResults;
final Map<ClusterStateTaskListener, ClusterStateTaskExecutor.TaskResult> executionResults;

TaskOutputs(
TaskInputs taskInputs,
ClusterState previousClusterState,
ClusterState newClusterState,
List<Batcher.UpdateTask> nonFailedTasks,
Map<Object, ClusterStateTaskExecutor.TaskResult> executionResults
Map<ClusterStateTaskListener, ClusterStateTaskExecutor.TaskResult> executionResults
) {
this.taskInputs = taskInputs;
this.previousClusterState = previousClusterState;
Expand Down Expand Up @@ -806,10 +811,10 @@ public void onTimeout() {
}
}

private ClusterTasksResult<Object> executeTasks(TaskInputs taskInputs, ClusterState previousClusterState) {
ClusterTasksResult<Object> clusterTasksResult;
private ClusterTasksResult<ClusterStateTaskListener> executeTasks(TaskInputs taskInputs, ClusterState previousClusterState) {
ClusterTasksResult<ClusterStateTaskListener> clusterTasksResult;
try {
List<Object> inputs = taskInputs.updateTasks.stream().map(tUpdateTask -> tUpdateTask.task).collect(Collectors.toList());
List<ClusterStateTaskListener> inputs = taskInputs.updateTasks.stream().map(Batcher.UpdateTask::getTask).toList();
clusterTasksResult = taskInputs.executor.execute(previousClusterState, inputs);
if (previousClusterState != clusterTasksResult.resultingState()
&& previousClusterState.nodes().isLocalNodeElectedMaster()
Expand All @@ -829,8 +834,8 @@ private ClusterTasksResult<Object> executeTasks(TaskInputs taskInputs, ClusterSt
), // may be expensive => construct message lazily
e
);
clusterTasksResult = ClusterTasksResult.builder()
.failures(taskInputs.updateTasks.stream().map(updateTask -> updateTask.task)::iterator, e)
clusterTasksResult = ClusterTasksResult.<ClusterStateTaskListener>builder()
.failures(taskInputs.updateTasks.stream().map(Batcher.UpdateTask::getTask)::iterator, e)
.build(previousClusterState);
}

Expand All @@ -844,7 +849,7 @@ private ClusterTasksResult<Object> executeTasks(TaskInputs taskInputs, ClusterSt
clusterTasksResult.executionResults().size()
);
if (Assertions.ENABLED) {
ClusterTasksResult<Object> finalClusterTasksResult = clusterTasksResult;
ClusterTasksResult<ClusterStateTaskListener> finalClusterTasksResult = clusterTasksResult;
taskInputs.updateTasks.forEach(
updateTask -> {
assert finalClusterTasksResult.executionResults().containsKey(updateTask.task)
Expand All @@ -856,11 +861,13 @@ private ClusterTasksResult<Object> executeTasks(TaskInputs taskInputs, ClusterSt
return clusterTasksResult;
}

private List<Batcher.UpdateTask> getNonFailedTasks(TaskInputs taskInputs, ClusterTasksResult<Object> clusterTasksResult) {
private List<Batcher.UpdateTask> getNonFailedTasks(
TaskInputs taskInputs,
ClusterTasksResult<ClusterStateTaskListener> clusterTasksResult
) {
return taskInputs.updateTasks.stream().filter(updateTask -> {
assert clusterTasksResult.executionResults().containsKey(updateTask.task) : "missing " + updateTask;
final ClusterStateTaskExecutor.TaskResult taskResult = clusterTasksResult.executionResults().get(updateTask.task);
return taskResult.isSuccess();
assert clusterTasksResult.executionResults().containsKey(updateTask.getTask()) : "missing " + updateTask;
return clusterTasksResult.executionResults().get(updateTask.getTask()).isSuccess();
}).collect(Collectors.toList());
}

Expand All @@ -870,9 +877,9 @@ private List<Batcher.UpdateTask> getNonFailedTasks(TaskInputs taskInputs, Cluste
private class TaskInputs {
final String summary;
final List<Batcher.UpdateTask> updateTasks;
final ClusterStateTaskExecutor<Object> executor;
final ClusterStateTaskExecutor<ClusterStateTaskListener> executor;

TaskInputs(ClusterStateTaskExecutor<Object> executor, List<Batcher.UpdateTask> updateTasks, String summary) {
TaskInputs(ClusterStateTaskExecutor<ClusterStateTaskListener> executor, List<Batcher.UpdateTask> updateTasks, String summary) {
this.summary = summary;
this.executor = executor;
this.updateTasks = updateTasks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,24 @@

import org.elasticsearch.test.ESTestCase;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.hamcrest.Matchers.equalTo;

public class ClusterStateTaskExecutorTests extends ESTestCase {

private class TestTask {
private class TestTask implements ClusterStateTaskListener {
private final String description;

TestTask(String description) {
this.description = description;
}

@Override
public void onFailure(Exception e) {
throw new AssertionError("Should not fail in test", e);
}

@Override
public String toString() {
return description == null ? "" : "Task{" + description + "}";
Expand All @@ -32,28 +36,18 @@ public String toString() {
public void testDescribeTasks() {
final ClusterStateTaskExecutor<TestTask> executor = (currentState, tasks) -> { throw new AssertionError("should not be called"); };

assertThat("describes an empty list", executor.describeTasks(Collections.emptyList()), equalTo(""));
assertThat(
"describes a singleton list",
executor.describeTasks(Collections.singletonList(new TestTask("a task"))),
equalTo("Task{a task}")
);
assertThat("describes an empty list", executor.describeTasks(List.of()), equalTo(""));
assertThat("describes a singleton list", executor.describeTasks(List.of(new TestTask("a task"))), equalTo("Task{a task}"));
assertThat(
"describes a list of two tasks",
executor.describeTasks(Arrays.asList(new TestTask("a task"), new TestTask("another task"))),
executor.describeTasks(List.of(new TestTask("a task"), new TestTask("another task"))),
equalTo("Task{a task}, Task{another task}")
);

assertThat(
"skips the only item if it has no description",
executor.describeTasks(Collections.singletonList(new TestTask(null))),
equalTo("")
);
assertThat("skips the only item if it has no description", executor.describeTasks(List.of(new TestTask(null))), equalTo(""));
assertThat(
"skips an item if it has no description",
executor.describeTasks(
Arrays.asList(new TestTask("a task"), new TestTask(null), new TestTask("another task"), new TestTask(null))
),
executor.describeTasks(List.of(new TestTask("a task"), new TestTask(null), new TestTask("another task"), new TestTask(null))),
equalTo("Task{a task}, Task{another task}")
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ClusterStateTaskExecutor.ClusterTasksResult;
import org.elasticsearch.cluster.ClusterStateTaskExecutor.TaskResult;
import org.elasticsearch.cluster.ClusterStateTaskListener;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.EmptyClusterInfoService;
import org.elasticsearch.cluster.action.shard.ShardStateAction;
Expand Down Expand Up @@ -101,9 +102,9 @@
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.carrotsearch.randomizedtesting.RandomizedTest.getRandom;
import static java.util.stream.Collectors.toMap;
import static org.elasticsearch.env.Environment.PATH_HOME_SETTING;
import static org.elasticsearch.test.CheckedFunctionUtils.anyCheckedFunction;
import static org.hamcrest.Matchers.notNullValue;
Expand Down Expand Up @@ -325,7 +326,7 @@ public ClusterState closeIndices(ClusterState state, CloseIndexRequest request)
newState = MetadataIndexStateServiceUtils.closeRoutingTable(
newState,
blockedIndices,
blockedIndices.keySet().stream().collect(Collectors.toMap(Function.identity(), CloseIndexResponse.IndexResult::new))
blockedIndices.keySet().stream().collect(toMap(Function.identity(), CloseIndexResponse.IndexResult::new))
);
return allocationService.reroute(newState, "indices closed");
}
Expand Down Expand Up @@ -358,7 +359,7 @@ public ClusterState addNodes(ClusterState clusterState, List<DiscoveryNode> node
ActionListener.wrap(() -> { throw new AssertionError("should not complete publication"); })
)
)
.collect(Collectors.toList())
.toList()
);
}

Expand All @@ -375,7 +376,7 @@ public ClusterState joinNodesAndBecomeMaster(ClusterState clusterState, List<Dis
ActionListener.wrap(() -> { throw new AssertionError("should not complete publication"); })
)
)
.collect(Collectors.toList())
.toList()
);

return runTasks(joinTaskExecutor, clusterState, joinNodes);
Expand All @@ -385,7 +386,7 @@ public ClusterState removeNodes(ClusterState clusterState, List<DiscoveryNode> n
return runTasks(
nodeRemovalExecutor,
clusterState,
nodes.stream().map(n -> new NodeRemovalClusterStateTaskExecutor.Task(n, "dummy reason", () -> {})).collect(Collectors.toList())
nodes.stream().map(n -> new NodeRemovalClusterStateTaskExecutor.Task(n, "dummy reason", () -> {})).toList()
);
}

Expand All @@ -404,12 +405,12 @@ public ClusterState applyFailedShards(ClusterState clusterState, List<FailedShar
createTestListener()
)
)
.collect(Collectors.toList());
.toList();
return runTasks(shardFailedClusterStateTaskExecutor, clusterState, entries);
}

public ClusterState applyStartedShards(ClusterState clusterState, List<ShardRouting> startedShards) {
final Map<ShardRouting, Long> entries = startedShards.stream().collect(Collectors.toMap(Function.identity(), startedShard -> {
final Map<ShardRouting, Long> entries = startedShards.stream().collect(toMap(Function.identity(), startedShard -> {
final IndexMetadata indexMetadata = clusterState.metadata().index(startedShard.shardId().getIndex());
return indexMetadata != null ? indexMetadata.primaryTerm(startedShard.shardId().id()) : 0L;
}));
Expand All @@ -434,11 +435,15 @@ public ClusterState applyStartedShards(ClusterState clusterState, Map<ShardRouti
createTestListener()
)
)
.collect(Collectors.toList())
.toList()
);
}

private <T> ClusterState runTasks(ClusterStateTaskExecutor<T> executor, ClusterState clusterState, List<T> entries) {
private <T extends ClusterStateTaskListener> ClusterState runTasks(
ClusterStateTaskExecutor<T> executor,
ClusterState clusterState,
List<T> entries
) {
try {
ClusterTasksResult<T> result = executor.execute(clusterState, entries);
for (TaskResult taskResult : result.executionResults().values()) {
Expand Down

0 comments on commit 0431d85

Please sign in to comment.