Skip to content

Commit

Permalink
Limit max concurrency of test cluster nodes to a function of max work…
Browse files Browse the repository at this point in the history
…ers (#51338)
  • Loading branch information
mark-vieira committed Jan 29, 2020
1 parent 25e8732 commit 91f8d9d
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 65 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
package org.elasticsearch.gradle.testclusters;

import org.elasticsearch.gradle.tool.Boilerplate;
import org.gradle.api.provider.Provider;
import org.gradle.api.services.internal.BuildServiceRegistryInternal;
import org.gradle.api.tasks.CacheableTask;
import org.gradle.api.tasks.Internal;
import org.gradle.api.tasks.Nested;
import org.gradle.api.tasks.testing.Test;
import org.gradle.internal.resources.ResourceLock;
import org.gradle.internal.resources.SharedResource;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;

import static org.elasticsearch.gradle.testclusters.TestClustersPlugin.THROTTLE_SERVICE_NAME;

/**
* Customized version of Gradle {@link Test} task which tracks a collection of {@link ElasticsearchCluster} as a task input. We must do this
Expand Down Expand Up @@ -47,4 +57,19 @@ public Collection<ElasticsearchCluster> getClusters() {
return clusters;
}

@Override
@Internal
public List<ResourceLock> getSharedResources() {
List<ResourceLock> locks = new ArrayList<>(super.getSharedResources());
BuildServiceRegistryInternal serviceRegistry = getServices().get(BuildServiceRegistryInternal.class);
Provider<TestClustersThrottle> throttleProvider = Boilerplate.getBuildService(serviceRegistry, THROTTLE_SERVICE_NAME);
SharedResource resource = serviceRegistry.forService(throttleProvider);

int nodeCount = clusters.stream().mapToInt(cluster -> cluster.getNodes().size()).sum();
if (nodeCount > 0) {
locks.add(resource.getResourceLock(Math.min(nodeCount, resource.getMaxUsages())));
}

return Collections.unmodifiableList(locks);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.gradle.DistributionDownloadPlugin;
import org.elasticsearch.gradle.ReaperPlugin;
import org.elasticsearch.gradle.ReaperService;
import org.elasticsearch.gradle.tool.Boilerplate;
import org.gradle.api.NamedDomainObjectContainer;
import org.gradle.api.Plugin;
import org.gradle.api.Project;
Expand All @@ -30,53 +31,50 @@
import org.gradle.api.invocation.Gradle;
import org.gradle.api.logging.Logger;
import org.gradle.api.logging.Logging;
import org.gradle.api.provider.Provider;
import org.gradle.api.tasks.TaskState;

import java.io.File;

public class TestClustersPlugin implements Plugin<Project> {

private static final String LIST_TASK_NAME = "listTestClusters";
public static final String EXTENSION_NAME = "testClusters";
private static final String REGISTRY_EXTENSION_NAME = "testClustersRegistry";
public static final String THROTTLE_SERVICE_NAME = "testClustersThrottle";

private static final String LIST_TASK_NAME = "listTestClusters";
private static final String REGISTRY_SERVICE_NAME = "testClustersRegistry";
private static final Logger logger = Logging.getLogger(TestClustersPlugin.class);

private ReaperService reaper;

@Override
public void apply(Project project) {
project.getPlugins().apply(DistributionDownloadPlugin.class);

project.getRootProject().getPluginManager().apply(ReaperPlugin.class);
reaper = project.getRootProject().getExtensions().getByType(ReaperService.class);

ReaperService reaper = project.getRootProject().getExtensions().getByType(ReaperService.class);

// enable the DSL to describe clusters
NamedDomainObjectContainer<ElasticsearchCluster> container = createTestClustersContainerExtension(project);
NamedDomainObjectContainer<ElasticsearchCluster> container = createTestClustersContainerExtension(project, reaper);

// provide a task to be able to list defined clusters.
createListClustersTask(project, container);

if (project.getRootProject().getExtensions().findByName(REGISTRY_EXTENSION_NAME) == null) {
TestClustersRegistry registry = project.getRootProject()
.getExtensions()
.create(REGISTRY_EXTENSION_NAME, TestClustersRegistry.class);

// When we know what tasks will run, we claim the clusters of those task to differentiate between clusters
// that are defined in the build script and the ones that will actually be used in this invocation of gradle
// we use this information to determine when the last task that required the cluster executed so that we can
// terminate the cluster right away and free up resources.
configureClaimClustersHook(project.getGradle(), registry);
// register cluster registry as a global build service
project.getGradle().getSharedServices().registerIfAbsent(REGISTRY_SERVICE_NAME, TestClustersRegistry.class, spec -> {});

// Before each task, we determine if a cluster needs to be started for that task.
configureStartClustersHook(project.getGradle(), registry);
// register throttle so we only run at most max-workers/2 nodes concurrently
project.getGradle()
.getSharedServices()
.registerIfAbsent(
THROTTLE_SERVICE_NAME,
TestClustersThrottle.class,
spec -> spec.getMaxParallelUsages().set(project.getGradle().getStartParameter().getMaxWorkerCount() / 2)
);

// After each task we determine if there are clusters that are no longer needed.
configureStopClustersHook(project.getGradle(), registry);
}
// register cluster hooks
project.getRootProject().getPluginManager().apply(TestClustersHookPlugin.class);
}

private NamedDomainObjectContainer<ElasticsearchCluster> createTestClustersContainerExtension(Project project) {
private NamedDomainObjectContainer<ElasticsearchCluster> createTestClustersContainerExtension(Project project, ReaperService reaper) {
// Create an extensions that allows describing clusters
NamedDomainObjectContainer<ElasticsearchCluster> container = project.container(
ElasticsearchCluster.class,
Expand All @@ -95,52 +93,78 @@ private void createListClustersTask(Project project, NamedDomainObjectContainer<
);
}

private static void configureClaimClustersHook(Gradle gradle, TestClustersRegistry registry) {
// Once we know all the tasks that need to execute, we claim all the clusters that belong to those and count the
// claims so we'll know when it's safe to stop them.
gradle.getTaskGraph().whenReady(taskExecutionGraph -> {
taskExecutionGraph.getAllTasks()
.stream()
.filter(task -> task instanceof TestClustersAware)
.map(task -> (TestClustersAware) task)
.flatMap(task -> task.getClusters().stream())
.forEach(registry::claimCluster);
});
}
static class TestClustersHookPlugin implements Plugin<Project> {
@Override
public void apply(Project project) {
if (project != project.getRootProject()) {
throw new IllegalStateException(this.getClass().getName() + " can only be applied to the root project.");
}

Provider<TestClustersRegistry> registryProvider = Boilerplate.getBuildService(
project.getGradle().getSharedServices(),
REGISTRY_SERVICE_NAME
);
TestClustersRegistry registry = registryProvider.get();

// When we know what tasks will run, we claim the clusters of those task to differentiate between clusters
// that are defined in the build script and the ones that will actually be used in this invocation of gradle
// we use this information to determine when the last task that required the cluster executed so that we can
// terminate the cluster right away and free up resources.
configureClaimClustersHook(project.getGradle(), registry);

// Before each task, we determine if a cluster needs to be started for that task.
configureStartClustersHook(project.getGradle(), registry);

// After each task we determine if there are clusters that are no longer needed.
configureStopClustersHook(project.getGradle(), registry);
}

private static void configureStartClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskActionListener() {
@Override
public void beforeActions(Task task) {
if (task instanceof TestClustersAware == false) {
return;
private static void configureClaimClustersHook(Gradle gradle, TestClustersRegistry registry) {
// Once we know all the tasks that need to execute, we claim all the clusters that belong to those and count the
// claims so we'll know when it's safe to stop them.
gradle.getTaskGraph().whenReady(taskExecutionGraph -> {
taskExecutionGraph.getAllTasks()
.stream()
.filter(task -> task instanceof TestClustersAware)
.map(task -> (TestClustersAware) task)
.flatMap(task -> task.getClusters().stream())
.forEach(registry::claimCluster);
});
}

private static void configureStartClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskActionListener() {
@Override
public void beforeActions(Task task) {
if (task instanceof TestClustersAware == false) {
return;
}
// we only start the cluster before the actions, so we'll not start it if the task is up-to-date
TestClustersAware awareTask = (TestClustersAware) task;
awareTask.beforeStart();
awareTask.getClusters().forEach(registry::maybeStartCluster);
}
// we only start the cluster before the actions, so we'll not start it if the task is up-to-date
TestClustersAware awareTask = (TestClustersAware) task;
awareTask.beforeStart();
awareTask.getClusters().forEach(registry::maybeStartCluster);
}

@Override
public void afterActions(Task task) {}
});
}
@Override
public void afterActions(Task task) {}
});
}

private static void configureStopClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskExecutionListener() {
@Override
public void afterExecute(Task task, TaskState state) {
if (task instanceof TestClustersAware == false) {
return;
private static void configureStopClustersHook(Gradle gradle, TestClustersRegistry registry) {
gradle.addListener(new TaskExecutionListener() {
@Override
public void afterExecute(Task task, TaskState state) {
if (task instanceof TestClustersAware == false) {
return;
}
// always unclaim the cluster, even if _this_ task is up-to-date, as others might not have been
// and caused the cluster to start.
((TestClustersAware) task).getClusters().forEach(cluster -> registry.stopCluster(cluster, state.getFailure() != null));
}
// always unclaim the cluster, even if _this_ task is up-to-date, as others might not have been
// and caused the cluster to start.
((TestClustersAware) task).getClusters().forEach(cluster -> registry.stopCluster(cluster, state.getFailure() != null));
}

@Override
public void beforeExecute(Task task) {}
});
@Override
public void beforeExecute(Task task) {}
});
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import org.gradle.api.logging.Logger;
import org.gradle.api.logging.Logging;
import org.gradle.api.services.BuildService;
import org.gradle.api.services.BuildServiceParameters;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class TestClustersRegistry {
public abstract class TestClustersRegistry implements BuildService<BuildServiceParameters.None> {
private static final Logger logger = Logging.getLogger(TestClustersRegistry.class);
private static final String TESTCLUSTERS_INSPECT_FAILURE = "testclusters.inspect.failure";
private final Boolean allowClusterToSurvive = Boolean.valueOf(System.getProperty(TESTCLUSTERS_INSPECT_FAILURE, "false"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.elasticsearch.gradle.testclusters;

import org.gradle.api.services.BuildService;
import org.gradle.api.services.BuildServiceParameters;

public abstract class TestClustersThrottle implements BuildService<BuildServiceParameters.None> {}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
package org.elasticsearch.gradle.tool;

import org.gradle.api.Action;
import org.gradle.api.GradleException;
import org.gradle.api.NamedDomainObjectContainer;
import org.gradle.api.PolymorphicDomainObjectContainer;
import org.gradle.api.Project;
import org.gradle.api.Task;
import org.gradle.api.UnknownTaskException;
import org.gradle.api.plugins.JavaPluginConvention;
import org.gradle.api.provider.Provider;
import org.gradle.api.services.BuildService;
import org.gradle.api.services.BuildServiceRegistration;
import org.gradle.api.services.BuildServiceRegistry;
import org.gradle.api.tasks.SourceSetContainer;
import org.gradle.api.tasks.TaskContainer;
import org.gradle.api.tasks.TaskProvider;
Expand Down Expand Up @@ -102,4 +107,14 @@ public static TaskProvider<?> findByName(TaskContainer tasks, String name) {

return task;
}

@SuppressWarnings("unchecked")
public static <T extends BuildService<?>> Provider<T> getBuildService(BuildServiceRegistry registry, String name) {
BuildServiceRegistration<?, ?> registration = registry.getRegistrations().findByName(name);
if (registration == null) {
throw new GradleException("Unable to find build service with name '" + name + "'.");
}

return (Provider<T>) registration.getService();
}
}

0 comments on commit 91f8d9d

Please sign in to comment.