Skip to content

Commit

Permalink
Allow running an extra spawn for local branch of dynamic execution.
Browse files Browse the repository at this point in the history
Allow `DynamicExecutionModule` to specify an extra spawn to be ran in the local
branch. Add support in `DynamicSpawnStrategy` for running the extra spawn when
it is provided.

PiperOrigin-RevId: 344826682
  • Loading branch information
alexjski authored and Copybara-Service committed Nov 30, 2020
1 parent d155376 commit f395157
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 3 deletions.
Expand Up @@ -38,6 +38,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

Expand Down Expand Up @@ -140,7 +141,12 @@ final void registerSpawnStrategies(
if (options.legacySpawnScheduler) {
strategy = new LegacyDynamicSpawnStrategy(executorService, options, this::getExecutionPolicy);
} else {
strategy = new DynamicSpawnStrategy(executorService, options, this::getExecutionPolicy);
strategy =
new DynamicSpawnStrategy(
executorService,
options,
this::getExecutionPolicy,
this::getPostProcessingSpawnForLocalExecution);
}
registryBuilder.registerStrategy(strategy, "dynamic", "dynamic_worker");

Expand Down Expand Up @@ -183,6 +189,18 @@ protected ExecutionPolicy getExecutionPolicy(Spawn spawn) {
return ExecutionPolicy.ANYWHERE;
}

/**
* Returns a post processing {@link Spawn} if one needs to be executed after given {@link Spawn}
* when running locally.
*
* <p>The intention of this is to allow post-processing of the original {@linkplain Spawn spawn}
* when executing it locally. In particular, such spawn should never create outputs which are not
* included in the generating action of the original one.
*/
protected Optional<Spawn> getPostProcessingSpawnForLocalExecution(Spawn spawn) {
return Optional.empty();
}

@Override
public void afterCommand() {
ExecutorUtil.interruptibleShutdown(executorService);
Expand Down
Expand Up @@ -33,6 +33,7 @@
import com.google.devtools.build.lib.actions.SandboxedSpawnStrategy;
import com.google.devtools.build.lib.actions.Spawn;
import com.google.devtools.build.lib.actions.SpawnResult;
import com.google.devtools.build.lib.actions.SpawnResult.Status;
import com.google.devtools.build.lib.actions.SpawnStrategy;
import com.google.devtools.build.lib.events.Event;
import com.google.devtools.build.lib.exec.ExecutionPolicy;
Expand All @@ -42,6 +43,7 @@
import com.google.devtools.build.lib.util.io.FileOutErr;
import com.google.devtools.build.lib.vfs.Path;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -88,6 +90,8 @@ public class DynamicSpawnStrategy implements SpawnStrategy {
*/
private final AtomicBoolean delayLocalExecution = new AtomicBoolean(false);

private final Function<Spawn, Optional<Spawn>> getExtraSpawnForLocalExecution;

/**
* Constructs a {@code DynamicSpawnStrategy}.
*
Expand All @@ -96,10 +100,12 @@ public class DynamicSpawnStrategy implements SpawnStrategy {
public DynamicSpawnStrategy(
ExecutorService executorService,
DynamicExecutionOptions options,
Function<Spawn, ExecutionPolicy> getExecutionPolicy) {
Function<Spawn, ExecutionPolicy> getExecutionPolicy,
Function<Spawn, Optional<Spawn>> getPostProcessingSpawnForLocalExecution) {
this.executorService = MoreExecutors.listeningDecorator(executorService);
this.options = options;
this.getExecutionPolicy = getExecutionPolicy;
this.getExtraSpawnForLocalExecution = getPostProcessingSpawnForLocalExecution;
}

/**
Expand Down Expand Up @@ -463,7 +469,34 @@ private static FileOutErr getSuffixedFileOutErr(FileOutErr fileOutErr, String su
outDir.getChild(outBaseName + suffix), errDir.getChild(errBaseName + suffix));
}

private static ImmutableList<SpawnResult> runLocally(
private ImmutableList<SpawnResult> runLocally(
Spawn spawn,
ActionExecutionContext actionExecutionContext,
@Nullable SandboxedSpawnStrategy.StopConcurrentSpawns stopConcurrentSpawns)
throws ExecException, InterruptedException {
ImmutableList<SpawnResult> spawnResult =
runSpawnLocally(spawn, actionExecutionContext, stopConcurrentSpawns);
if (spawnResult.stream().anyMatch(result -> result.status() != Status.SUCCESS)) {
return spawnResult;
}

Optional<Spawn> extraSpawn = getExtraSpawnForLocalExecution.apply(spawn);
if (!extraSpawn.isPresent()) {
return spawnResult;
}

// The remote branch was already cancelled -- we are holding the output lock during the
// execution of the extra spawn.
ImmutableList<SpawnResult> extraSpawnResult =
runSpawnLocally(extraSpawn.get(), actionExecutionContext, null);
return ImmutableList.<SpawnResult>builderWithExpectedSize(
spawnResult.size() + extraSpawnResult.size())
.addAll(spawnResult)
.addAll(extraSpawnResult)
.build();
}

private static ImmutableList<SpawnResult> runSpawnLocally(
Spawn spawn,
ActionExecutionContext actionExecutionContext,
@Nullable SandboxedSpawnStrategy.StopConcurrentSpawns stopConcurrentSpawns)
Expand Down
19 changes: 19 additions & 0 deletions src/test/java/com/google/devtools/build/lib/dynamic/BUILD
Expand Up @@ -12,6 +12,25 @@ filegroup(
visibility = ["//src:__subpackages__"],
)

java_test(
name = "DynamicSpawnStrategyUnitTest",
size = "small",
srcs = ["DynamicSpawnStrategyUnitTest.java"],
deps = [
"//src/main/java/com/google/devtools/build/lib/actions",
"//src/main/java/com/google/devtools/build/lib/dynamic",
"//src/main/java/com/google/devtools/build/lib/exec:execution_policy",
"//src/main/protobuf:failure_details_java_proto",
"//src/test/java/com/google/devtools/build/lib/exec/util",
"//src/test/java/com/google/devtools/build/lib/testutil",
"//src/test/java/com/google/devtools/build/lib/testutil:TestUtils",
"//third_party:guava",
"//third_party:junit4",
"//third_party:mockito",
"//third_party:truth",
],
)

java_test(
name = "DynamicSpawnStrategyTest",
size = "small",
Expand Down
@@ -0,0 +1,254 @@
// Copyright 2020 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.dynamic;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.truth.Truth.assertThat;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNotNull;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.devtools.build.lib.actions.ActionExecutionContext;
import com.google.devtools.build.lib.actions.DynamicStrategyRegistry;
import com.google.devtools.build.lib.actions.SandboxedSpawnStrategy;
import com.google.devtools.build.lib.actions.SandboxedSpawnStrategy.StopConcurrentSpawns;
import com.google.devtools.build.lib.actions.Spawn;
import com.google.devtools.build.lib.actions.SpawnResult;
import com.google.devtools.build.lib.actions.SpawnResult.Status;
import com.google.devtools.build.lib.exec.ExecutionPolicy;
import com.google.devtools.build.lib.exec.util.SpawnBuilder;
import com.google.devtools.build.lib.server.FailureDetails.Execution;
import com.google.devtools.build.lib.server.FailureDetails.FailureDetail;
import com.google.devtools.build.lib.testutil.TestFileOutErr;
import com.google.devtools.build.lib.testutil.TestUtils;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.function.Function;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

/** Unit tests for {@link DynamicSpawnStrategy}. */
@RunWith(JUnit4.class)
public class DynamicSpawnStrategyUnitTest {

private static final SpawnResult SUCCESSFUL_SPAWN_RESULT =
new SpawnResult.Builder().setRunnerName("test").setStatus(Status.SUCCESS).build();
private static final FailureDetail FAILURE_DETAIL =
FailureDetail.newBuilder().setExecution(Execution.getDefaultInstance()).build();

private ExecutorService executorServiceForCleanup;

@Mock private Function<Spawn, Optional<Spawn>> mockGetPostProcessingSpawn;

@Before
public void initMocks() {
MockitoAnnotations.initMocks(this);
}

@After
public void stopExecutorService() throws InterruptedException {
executorServiceForCleanup.shutdown();
assertThat(
executorServiceForCleanup.awaitTermination(
TestUtils.WAIT_TIMEOUT_MILLISECONDS, MILLISECONDS))
.isTrue();
}

@Test
public void exec_remoteOnlySpawn_doesNotGetLocalPostProcessingSpawn() throws Exception {
DynamicSpawnStrategy dynamicSpawnStrategy =
createDynamicSpawnStrategy(
ExecutionPolicy.REMOTE_EXECUTION_ONLY, mockGetPostProcessingSpawn);
SandboxedSpawnStrategy local = createMockSpawnStrategy();
SandboxedSpawnStrategy remote = createMockSpawnStrategy();
ArgumentCaptor<Spawn> remoteSpawnCaptor = ArgumentCaptor.forClass(Spawn.class);
when(remote.exec(remoteSpawnCaptor.capture(), any(), any()))
.thenReturn(ImmutableList.of(SUCCESSFUL_SPAWN_RESULT));
ActionExecutionContext actionExecutionContext = createMockActionExecutionContext(local, remote);
Spawn spawn = new SpawnBuilder().build();

ImmutableList<SpawnResult> results = dynamicSpawnStrategy.exec(spawn, actionExecutionContext);

assertThat(results).containsExactly(SUCCESSFUL_SPAWN_RESULT);
verify(mockGetPostProcessingSpawn, never()).apply(any());
verify(local, never()).exec(any(), any(), any());
assertThat(remoteSpawnCaptor.getAllValues()).containsExactly(spawn);
}

@Test
public void exec_localOnlySpawn_runsLocalPostProcessingSpawn() throws Exception {
Spawn spawn = new SpawnBuilder("command").build();
Spawn postProcessingSpawn = new SpawnBuilder("extra_command").build();
DynamicSpawnStrategy dynamicSpawnStrategy =
createDynamicSpawnStrategy(
ExecutionPolicy.LOCAL_EXECUTION_ONLY, ignored -> Optional.of(postProcessingSpawn));
SandboxedSpawnStrategy local = createMockSpawnStrategy();
ArgumentCaptor<Spawn> localSpawnCaptor = ArgumentCaptor.forClass(Spawn.class);
when(local.exec(localSpawnCaptor.capture(), any(), any()))
.thenReturn(ImmutableList.of(SUCCESSFUL_SPAWN_RESULT));
SandboxedSpawnStrategy remote = createMockSpawnStrategy();
ActionExecutionContext actionExecutionContext = createMockActionExecutionContext(local, remote);

ImmutableList<SpawnResult> results = dynamicSpawnStrategy.exec(spawn, actionExecutionContext);

assertThat(results).containsExactly(SUCCESSFUL_SPAWN_RESULT, SUCCESSFUL_SPAWN_RESULT);
verifyZeroInteractions(remote);
assertThat(localSpawnCaptor.getAllValues())
.containsExactly(spawn, postProcessingSpawn)
.inOrder();
}

@Test
public void exec_failedLocalSpawn_doesNotGetLocalPostProcessingSpawn() throws Exception {
testExecFailedLocalSpawnDoesNotGetLocalPostProcessingSpawn(
new SpawnResult.Builder()
.setRunnerName("test")
.setStatus(Status.TIMEOUT)
.setExitCode(SpawnResult.POSIX_TIMEOUT_EXIT_CODE)
.setFailureDetail(FAILURE_DETAIL)
.build());
}

@Test
public void exec_nonZeroExitCodeLocalSpawn_doesNotGetLocalPostProcessingSpawn() throws Exception {
testExecFailedLocalSpawnDoesNotGetLocalPostProcessingSpawn(
new SpawnResult.Builder()
.setRunnerName("test")
.setStatus(Status.EXECUTION_FAILED)
.setExitCode(123)
.setFailureDetail(FAILURE_DETAIL)
.build());
}

private void testExecFailedLocalSpawnDoesNotGetLocalPostProcessingSpawn(SpawnResult failedResult)
throws Exception {
DynamicSpawnStrategy dynamicSpawnStrategy =
createDynamicSpawnStrategy(
ExecutionPolicy.LOCAL_EXECUTION_ONLY, mockGetPostProcessingSpawn);
SandboxedSpawnStrategy local = createMockSpawnStrategy();
ArgumentCaptor<Spawn> localSpawnCaptor = ArgumentCaptor.forClass(Spawn.class);
when(local.exec(localSpawnCaptor.capture(), any(), any()))
.thenReturn(ImmutableList.of(failedResult));
SandboxedSpawnStrategy remote = createMockSpawnStrategy();
ActionExecutionContext actionExecutionContext = createMockActionExecutionContext(local, remote);
Spawn spawn = new SpawnBuilder().build();

ImmutableList<SpawnResult> results = dynamicSpawnStrategy.exec(spawn, actionExecutionContext);

assertThat(results).containsExactly(failedResult);
assertThat(localSpawnCaptor.getAllValues()).containsExactly(spawn);
verify(remote, never()).exec(any(), any(), any());
verify(mockGetPostProcessingSpawn, never()).apply(any());
}

@Test
public void exec_runAnywhereSpawn_runsLocalPostProcessingSpawn() throws Exception {
Spawn spawn = new SpawnBuilder().build();
Spawn postProcessingSpawn = new SpawnBuilder("extra_command").build();
DynamicSpawnStrategy dynamicSpawnStrategy =
createDynamicSpawnStrategy(
ExecutionPolicy.ANYWHERE, ignored -> Optional.of(postProcessingSpawn));
SandboxedSpawnStrategy local = createMockSpawnStrategy();
// Make sure that local execution does not win the race before remote starts.
Semaphore remoteStarted = new Semaphore(0);
// Only the first spawn should be able to stop the concurrent remote execution (get the output
// lock).
when(local.exec(eq(spawn), any(), /*stopConcurrentSpawns=*/ isNotNull()))
.thenAnswer(
invocation -> {
remoteStarted.acquire();
StopConcurrentSpawns stopConcurrentSpawns = invocation.getArgument(2);
stopConcurrentSpawns.stop();
return ImmutableList.of(SUCCESSFUL_SPAWN_RESULT);
});
when(local.exec(eq(postProcessingSpawn), any(), /*stopConcurrentSpawns=*/ isNull()))
.thenReturn(ImmutableList.of(SUCCESSFUL_SPAWN_RESULT));
SandboxedSpawnStrategy remote = createMockSpawnStrategy();
when(remote.exec(eq(spawn), any(), any()))
.thenAnswer(
invocation -> {
remoteStarted.release();
Thread.sleep(TestUtils.WAIT_TIMEOUT_MILLISECONDS);
throw new AssertionError("Timed out waiting for interruption");
});
ActionExecutionContext actionExecutionContext = createMockActionExecutionContext(local, remote);

ImmutableList<SpawnResult> results = dynamicSpawnStrategy.exec(spawn, actionExecutionContext);

assertThat(results).containsExactly(SUCCESSFUL_SPAWN_RESULT, SUCCESSFUL_SPAWN_RESULT);
}

private DynamicSpawnStrategy createDynamicSpawnStrategy(
ExecutionPolicy executionPolicy,
Function<Spawn, Optional<Spawn>> getPostProcessingSpawnForLocalExecution) {
checkState(
executorServiceForCleanup == null,
"Creating the DynamicSpawnStrategy twice in the same test is not supported.");
executorServiceForCleanup = Executors.newCachedThreadPool();
return new DynamicSpawnStrategy(
executorServiceForCleanup,
new DynamicExecutionOptions(),
ignored -> executionPolicy,
getPostProcessingSpawnForLocalExecution);
}

private static ActionExecutionContext createMockActionExecutionContext(
SandboxedSpawnStrategy localStrategy, SandboxedSpawnStrategy remoteStrategy) {
ActionExecutionContext actionExecutionContext = mock(ActionExecutionContext.class);
when(actionExecutionContext.getFileOutErr()).thenReturn(new TestFileOutErr());
when(actionExecutionContext.getContext(DynamicStrategyRegistry.class))
.thenReturn(
new DynamicStrategyRegistry() {
@Override
public ImmutableList<SandboxedSpawnStrategy> getDynamicSpawnActionContexts(
Spawn spawn, DynamicMode dynamicMode) {
switch (dynamicMode) {
case LOCAL:
return ImmutableList.of(localStrategy);
case REMOTE:
return ImmutableList.of(remoteStrategy);
}
throw new AssertionError("Unexpected mode: " + dynamicMode);
}

@Override
public void notifyUsedDynamic(ActionContextRegistry actionContextRegistry) {}
});
when(actionExecutionContext.withFileOutErr(any())).thenReturn(actionExecutionContext);
return actionExecutionContext;
}

private static SandboxedSpawnStrategy createMockSpawnStrategy() throws InterruptedException {
SandboxedSpawnStrategy strategy = mock(SandboxedSpawnStrategy.class);
when(strategy.canExec(any(), any())).thenReturn(true);
when(strategy.beginExecution(any(), any())).thenThrow(UnsupportedOperationException.class);
return strategy;
}
}

0 comments on commit f395157

Please sign in to comment.