Skip to content

Commit

Permalink
ClusterStateTaskListener usage refactoring in MasterServiceTests (#82869
Browse files Browse the repository at this point in the history
)

Today node removal tasks executed by the master have a separate
ClusterStateTaskListener to feed back the result to the requester.
It'd be preferable to use the task itself as the listener.
  • Loading branch information
idegtiarenko committed Jan 24, 2022
1 parent 0592c4c commit 805cd39
Showing 1 changed file with 108 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
Expand All @@ -20,6 +21,7 @@
import org.elasticsearch.cluster.ClusterStatePublicationEvent;
import org.elasticsearch.cluster.ClusterStateTaskConfig;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ClusterStateTaskExecutor.ClusterTasksResult;
import org.elasticsearch.cluster.ClusterStateTaskListener;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.LocalMasterServiceTask;
Expand Down Expand Up @@ -51,14 +53,10 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Semaphore;
Expand All @@ -67,13 +65,14 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static java.util.stream.Collectors.toMap;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;

public class MasterServiceTests extends ESTestCase {

Expand Down Expand Up @@ -263,9 +262,18 @@ public void testClusterStateTaskListenerThrowingExceptionIsOkay() throws Interru
AtomicBoolean published = new AtomicBoolean();

try (MasterService masterService = createMasterService(true)) {
ClusterStateTaskListener update = new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
throw new RuntimeException("testing exception handling");
}

@Override
public void onFailure(Exception e) {}
};
masterService.submitStateUpdateTask(
"testClusterStateTaskListenerThrowingExceptionIsOkay",
new Object(),
update,
ClusterStateTaskConfig.build(Priority.NORMAL),
new ClusterStateTaskExecutor<Object>() {
@Override
Expand All @@ -280,15 +288,7 @@ public void clusterStatePublished(ClusterStatePublicationEvent clusterStatePubli
latch.countDown();
}
},
new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
throw new IllegalStateException();
}

@Override
public void onFailure(Exception e) {}
}
update
);

latch.await();
Expand Down Expand Up @@ -464,23 +464,39 @@ public void onFailure(Exception e) {
}

public void testClusterStateBatchedUpdates() throws BrokenBarrierException, InterruptedException {
AtomicInteger counter = new AtomicInteger();
class Task {
private AtomicBoolean state = new AtomicBoolean();

AtomicInteger executedTasks = new AtomicInteger();
AtomicInteger submittedTasks = new AtomicInteger();
AtomicInteger processedStates = new AtomicInteger();
SetOnce<CountDownLatch> processedStatesLatch = new SetOnce<>();

class Task implements ClusterStateTaskListener {
private final AtomicBoolean executed = new AtomicBoolean();
private final int id;

Task(int id) {
this.id = id;
}

public void execute() {
if (state.compareAndSet(false, true) == false) {
throw new IllegalStateException();
if (executed.compareAndSet(false, true) == false) {
throw new AssertionError("Task [" + id + "] should only be executed once");
} else {
counter.incrementAndGet();
executedTasks.incrementAndGet();
}
}

@Override
public void onFailure(Exception e) {
throw new AssertionError(e);
}

@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
processedStates.incrementAndGet();
processedStatesLatch.get().countDown();
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -491,7 +507,6 @@ public boolean equals(Object o) {
}
Task task = (Task) o;
return id == task.id;

}

@Override
Expand All @@ -505,38 +520,43 @@ public String toString() {
}
}

int numberOfThreads = randomIntBetween(2, 8);
int taskSubmissionsPerThread = randomIntBetween(1, 64);
int numberOfExecutors = Math.max(1, numberOfThreads / 4);
final Semaphore semaphore = new Semaphore(numberOfExecutors);
final int numberOfThreads = randomIntBetween(2, 8);
final int taskSubmissionsPerThread = randomIntBetween(1, 64);
final int numberOfExecutors = Math.max(1, numberOfThreads / 4);
final Semaphore semaphore = new Semaphore(1);

class TaskExecutor implements ClusterStateTaskExecutor<Task> {
private final List<Set<Task>> taskGroups;
private AtomicInteger counter = new AtomicInteger();
private AtomicInteger batches = new AtomicInteger();
private AtomicInteger published = new AtomicInteger();

TaskExecutor(List<Set<Task>> taskGroups) {
this.taskGroups = taskGroups;
}
private final AtomicInteger executed = new AtomicInteger();
private final AtomicInteger assigned = new AtomicInteger();
private final AtomicInteger batches = new AtomicInteger();
private final AtomicInteger published = new AtomicInteger();
private final List<Set<Task>> assignments = new ArrayList<>();

@Override
public ClusterTasksResult<Task> execute(ClusterState currentState, List<Task> tasks) throws Exception {
for (Set<Task> expectedSet : taskGroups) {
long count = tasks.stream().filter(expectedSet::contains).count();
int totalCount = 0;
for (Set<Task> group : assignments) {
long count = tasks.stream().filter(group::contains).count();
assertThat(
"batched set should be executed together or not at all. Expected " + expectedSet + "s. Executing " + tasks,
"batched set should be executed together or not at all. Expected " + group + "s. Executing " + tasks,
count,
anyOf(equalTo(0L), equalTo((long) expectedSet.size()))
anyOf(equalTo(0L), equalTo((long) group.size()))
);
totalCount += count;
}
assertThat("All tasks should belong to this executor", totalCount, equalTo(tasks.size()));
tasks.forEach(Task::execute);
counter.addAndGet(tasks.size());
executed.addAndGet(tasks.size());
ClusterState maybeUpdatedClusterState = currentState;
if (randomBoolean()) {
maybeUpdatedClusterState = ClusterState.builder(currentState).build();
batches.incrementAndGet();
semaphore.acquire();
assertThat(
"All cluster state modifications should be executed on a single thread",
semaphore.tryAcquire(),
equalTo(true)
);
}
return ClusterTasksResult.<Task>builder().successes(tasks).build(maybeUpdatedClusterState);
}
Expand All @@ -548,40 +568,27 @@ public void clusterStatePublished(ClusterStatePublicationEvent clusterPublicatio
}
}

ConcurrentMap<String, AtomicInteger> processedStates = new ConcurrentHashMap<>();

List<Set<Task>> taskGroups = new ArrayList<>();
List<TaskExecutor> executors = new ArrayList<>();
for (int i = 0; i < numberOfExecutors; i++) {
executors.add(new TaskExecutor(taskGroups));
executors.add(new TaskExecutor());
}

// randomly assign tasks to executors
List<Tuple<TaskExecutor, Set<Task>>> assignments = new ArrayList<>();
int taskId = 0;
AtomicInteger totalTasks = new AtomicInteger();
for (int i = 0; i < numberOfThreads; i++) {
for (int j = 0; j < taskSubmissionsPerThread; j++) {
TaskExecutor executor = randomFrom(executors);
Set<Task> tasks = new HashSet<>();
for (int t = randomInt(3); t >= 0; t--) {
tasks.add(new Task(taskId++));
}
taskGroups.add(tasks);
var executor = randomFrom(executors);
var tasks = Set.copyOf(randomList(1, 3, () -> new Task(totalTasks.getAndIncrement())));

assignments.add(Tuple.tuple(executor, tasks));
executor.assigned.addAndGet(tasks.size());
executor.assignments.add(tasks);
}
}

Map<TaskExecutor, Integer> counts = new HashMap<>();
int totalTaskCount = 0;
for (Tuple<TaskExecutor, Set<Task>> assignment : assignments) {
final int taskCount = assignment.v2().size();
counts.merge(assignment.v1(), taskCount, (previous, count) -> previous + count);
totalTaskCount += taskCount;
}
final CountDownLatch updateLatch = new CountDownLatch(totalTaskCount);
processedStatesLatch.set(new CountDownLatch(totalTasks.get()));

try (MasterService masterService = createMasterService(true)) {
final ConcurrentMap<String, AtomicInteger> submittedTasksPerThread = new ConcurrentHashMap<>();
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
final int index = i;
Expand All @@ -590,36 +597,23 @@ public void clusterStatePublished(ClusterStatePublicationEvent clusterPublicatio
try {
barrier.await();
for (int j = 0; j < taskSubmissionsPerThread; j++) {
Tuple<TaskExecutor, Set<Task>> assignment = assignments.get(index * taskSubmissionsPerThread + j);
final Set<Task> tasks = assignment.v2();
submittedTasksPerThread.computeIfAbsent(threadName, key -> new AtomicInteger()).addAndGet(tasks.size());
final TaskExecutor executor = assignment.v1();
final ClusterStateTaskListener listener = new ClusterStateTaskListener() {
@Override
public void onFailure(Exception e) {
throw new AssertionError(e);
}

@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
processedStates.computeIfAbsent(threadName, key -> new AtomicInteger()).incrementAndGet();
updateLatch.countDown();
}
};
var assignment = assignments.get(index * taskSubmissionsPerThread + j);
var tasks = assignment.v2();
var executor = assignment.v1();
submittedTasks.addAndGet(tasks.size());
if (tasks.size() == 1) {
var update = tasks.iterator().next();
masterService.submitStateUpdateTask(
threadName,
tasks.stream().findFirst().get(),
update,
ClusterStateTaskConfig.build(randomFrom(Priority.values())),
executor,
listener
update
);
} else {
Map<Task, ClusterStateTaskListener> taskListeners = new HashMap<>();
tasks.forEach(t -> taskListeners.put(t, listener));
masterService.submitStateUpdateTasks(
threadName,
taskListeners,
tasks.stream().collect(toMap(Function.<Task>identity(), Function.<ClusterStateTaskListener>identity())),
ClusterStateTaskConfig.build(randomFrom(Priority.values())),
executor
);
Expand All @@ -639,29 +633,19 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState)
barrier.await();

// wait until all the cluster state updates have been processed
updateLatch.await();
// and until all of the publication callbacks have completed
semaphore.acquire(numberOfExecutors);
processedStatesLatch.get().await();
// and until all the publication callbacks have completed
semaphore.acquire();

// assert the number of executed tasks is correct
assertEquals(totalTaskCount, counter.get());
assertThat(submittedTasks.get(), equalTo(totalTasks.get()));
assertThat(executedTasks.get(), equalTo(totalTasks.get()));
assertThat(processedStates.get(), equalTo(totalTasks.get()));

// assert each executor executed the correct number of tasks
for (TaskExecutor executor : executors) {
if (counts.containsKey(executor)) {
assertEquals((int) counts.get(executor), executor.counter.get());
assertEquals(executor.batches.get(), executor.published.get());
}
}

// assert the correct number of clusterStateProcessed events were triggered
for (Map.Entry<String, AtomicInteger> entry : processedStates.entrySet()) {
assertThat(submittedTasksPerThread, hasKey(entry.getKey()));
assertEquals(
"not all tasks submitted by " + entry.getKey() + " received a processed event",
entry.getValue().get(),
submittedTasksPerThread.get(entry.getKey()).get()
);
assertEquals(executor.assigned.get(), executor.executed.get());
assertEquals(executor.batches.get(), executor.published.get());
}
}
}
Expand All @@ -672,36 +656,37 @@ public void testBlockingCallInClusterStateTaskListenerFails() throws Interrupted
final AtomicReference<AssertionError> assertionRef = new AtomicReference<>();

try (MasterService masterService = createMasterService(true)) {
ClusterStateTaskListener update = new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
BaseFuture<Void> future = new BaseFuture<Void>() {
};
try {
if (randomBoolean()) {
future.get(1L, TimeUnit.SECONDS);
} else {
future.get();
}
} catch (Exception e) {
throw new RuntimeException(e);
} catch (AssertionError e) {
assertionRef.set(e);
latch.countDown();
}
}

@Override
public void onFailure(Exception e) {}
};
masterService.submitStateUpdateTask(
"testBlockingCallInClusterStateTaskListenerFails",
new Object(),
update,
ClusterStateTaskConfig.build(Priority.NORMAL),
(currentState, tasks) -> {
ClusterState newClusterState = ClusterState.builder(currentState).build();
return ClusterStateTaskExecutor.ClusterTasksResult.builder().successes(tasks).build(newClusterState);
return ClusterTasksResult.<ClusterStateTaskListener>builder().successes(tasks).build(newClusterState);
},
new ClusterStateTaskListener() {
@Override
public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
BaseFuture<Void> future = new BaseFuture<Void>() {
};
try {
if (randomBoolean()) {
future.get(1L, TimeUnit.SECONDS);
} else {
future.get();
}
} catch (Exception e) {
throw new RuntimeException(e);
} catch (AssertionError e) {
assertionRef.set(e);
latch.countDown();
}
}

@Override
public void onFailure(Exception e) {}
}
update
);

latch.await();
Expand Down

0 comments on commit 805cd39

Please sign in to comment.