Skip to content

Commit

Permalink
feat: propagate kills to the TaskRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
fhussonnois committed May 14, 2024
1 parent 4bb05f1 commit fb8c757
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 13 deletions.
45 changes: 41 additions & 4 deletions src/main/java/io/kestra/plugin/gcp/runner/GcpBatchTaskRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.time.*;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -181,7 +183,7 @@ public class GcpBatchTaskRunner extends TaskRunner implements GcpInterface, Remo
public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<String> filesToUpload, List<String> filesToDownload) throws Exception {
String renderedBucket = runContext.render(this.bucket);

GoogleCredentials credentials = CredentialService.credentials(runContext, this);
final GoogleCredentials credentials = CredentialService.credentials(runContext, this);

boolean hasFilesToUpload = !ListUtils.isEmpty(filesToUpload);
if (hasFilesToUpload && bucket == null) {
Expand All @@ -198,7 +200,7 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
String workingDirectoryToBlobPath = batchWorkingDirectory.toString().substring(1);
boolean hasBucket = this.bucket != null;

try (BatchServiceClient batchServiceClient = BatchServiceClient.create(BatchServiceSettings.newBuilder().setCredentialsProvider(() -> credentials).build());
try (BatchServiceClient batchServiceClient = newBatchServiceClient(credentials);
Logging logging = LoggingOptions.getDefaultInstance().toBuilder().setCredentials(credentials).build().getService()) {
Duration waitDuration = Optional.ofNullable(taskCommands.getTimeout()).orElse(this.waitUntilCompletion);
Map<String, String> labels = ScriptService.labels(runContext, "kestra-", true, true);
Expand Down Expand Up @@ -313,7 +315,10 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
.build();

result = batchServiceClient.createJob(createJobRequest);
runContext.logger().info("Job created: {}", result.getName());

final String jobName = result.getName();
onKill(() -> safelyDeleteJob(runContext, credentials, jobName));
runContext.logger().info("Job created: {}", jobName);
}


Expand Down Expand Up @@ -387,7 +392,11 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
}
}
}


private static BatchServiceClient newBatchServiceClient(GoogleCredentials credentials) throws IOException {
return BatchServiceClient.create(BatchServiceSettings.newBuilder().setCredentialsProvider(() -> credentials).build());
}

private String labelsFilter(Map<String, String> labels) {
return labels.entrySet().stream()
.map(entry -> "labels." + entry.getKey() + "=\"" + entry.getValue().toLowerCase() + "\"")
Expand Down Expand Up @@ -462,6 +471,34 @@ protected Map<String, Object> runnerAdditionalVars(RunContext runContext, TaskCo

return additionalVars;
}

protected void safelyDeleteJob(final RunContext runContext,
final GoogleCredentials credentials,
final String jobName) {
// Use a dedicated BatchServiceClient, as the one used in the run method may be closed in the meantime.
try (BatchServiceClient batchServiceClient = newBatchServiceClient(credentials)){
final Job job = batchServiceClient.getJob(jobName);
if (isTerminated(job.getStatus().getState())) {
// Job execution is already terminated so we can skip deletion.
return;
}

final DeleteJobRequest request = DeleteJobRequest.newBuilder()
.setName(jobName)
.setReason("Kestra task was killed.")
.build();

batchServiceClient.deleteJobAsync(request).get();
runContext.logger().info("Job deleted: {}", jobName);
// we don't need to clean up the storage here as this will be
// properly handle by the Task Thread in the run method once the job is terminated (i.e., deleted).
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException | IOException e) {
Throwable t = e.getCause() != null ? e.getCause() : e;
runContext.logger().info("Failed to delete Job: {}", jobName, t);
}
}

@Getter
@Builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@
import lombok.ToString;
import lombok.experimental.SuperBuilder;

import java.io.IOException;
import java.nio.file.Path;
import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;

@SuperBuilder
Expand Down Expand Up @@ -212,8 +214,8 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
}
}

try (JobsClient jobsClient = JobsClient.create(JobsSettings.newBuilder().setCredentialsProvider(() -> credentials).build());
ExecutionsClient executionsClient = ExecutionsClient.create(ExecutionsSettings.newBuilder().setCredentialsProvider(() -> credentials).build());
try (JobsClient jobsClient = newJobsClient(credentials);
ExecutionsClient executionsClient = newExecutionsClient(credentials);
Logging logging = LoggingOptions.getDefaultInstance().toBuilder().setCredentials(credentials).build().getService()) {

// Create new Job TaskTemplate
Expand Down Expand Up @@ -299,18 +301,22 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S

OperationFuture<Execution, Execution> future = jobsClient.runJobAsync(runJobRequest);
Execution execution = future.getMetadata().get();


final String executionName = execution.getName();

onKill(() -> safelyCancelJobExecution(runContext, credentials, executionName));

String logFilter = String.format(
"logName=\"projects/%s/logs/run.googleapis.com\" labels.\"run.googleapis.com/execution_name\"=\"%s\"",
renderedProjectId,
execution.getName()
executionName
);

LogEntryServerStream stream = logging.tailLogEntries(Logging.TailOption.filter(logFilter));
try (LogTail ignored = new LogTail(stream, taskCommands.getLogConsumer())) {
if (!isTerminated(execution)) {
runContext.logger().info("Waiting for execution completion: {}.", execution.getName());
execution = awaitJobExecutionTermination(executionsClient, execution.getName(), timeout);
runContext.logger().info("Waiting for execution completion: {}.", executionName);
execution = awaitJobExecutionTermination(executionsClient, executionName, timeout);
}
// Check for the job successful creation
if (isFailed(execution)) {
Expand All @@ -323,8 +329,8 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S

if (delete) {
// not waiting for Job Execution deletion
executionsClient.deleteExecutionAsync(execution.getName());
runContext.logger().info("Job Execution deleted: {}", execution.getName());
executionsClient.deleteExecutionAsync(executionName);
runContext.logger().info("Job Execution deleted: {}", executionName);
// not waiting for Job deletion
jobsClient.deleteJobAsync(runJobRequest.getName());
runContext.logger().info("Job deleted: {}", runJobRequest.getName());
Expand All @@ -349,7 +355,15 @@ public RunnerResult run(RunContext runContext, TaskCommands taskCommands, List<S
}
}
}


private static ExecutionsClient newExecutionsClient(GoogleCredentials credentials) throws IOException {
return ExecutionsClient.create(ExecutionsSettings.newBuilder().setCredentialsProvider(() -> credentials).build());
}

private static JobsClient newJobsClient(GoogleCredentials credentials) throws IOException {
return JobsClient.create(JobsSettings.newBuilder().setCredentialsProvider(() -> credentials).build());
}

private Duration getTaskTimeout(final TaskCommands taskCommands) {
return Optional
.ofNullable(taskCommands.getTimeout())
Expand Down Expand Up @@ -423,4 +437,26 @@ protected Map<String, Object> runnerAdditionalVars(RunContext runContext, TaskCo

return additionalVars;
}

protected void safelyCancelJobExecution(final RunContext runContext,
final GoogleCredentials credentials,
final String executionName) {
// Use a dedicated JobsClient, as the one used in the run method may be closed in the meantime.
try (ExecutionsClient executionsClient = newExecutionsClient(credentials);){
Execution execution = executionsClient.getExecution(executionName);
if (isTerminated(execution)) {
// Execution is already terminated so we can skip deletion.
return;
}
executionsClient.cancelExecutionAsync(executionName).get();
runContext.logger().info("Job execution canceled: {}", executionName);
// we don't need to clean up the storage and execution here as this will be
// properly handle by the Task Thread in the run method once the job is terminated (i.e., deleted).
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException | IOException e) {
Throwable t = e.getCause() != null ? e.getCause() : e;
runContext.logger().info("Failed to cancel Job execution: {}", executionName, t);
}
}
}

0 comments on commit fb8c757

Please sign in to comment.