Skip to content

Commit

Permalink
fix(core): handle properly replay of a task from taskRunId
Browse files Browse the repository at this point in the history
  • Loading branch information
tchiotludo committed Sep 3, 2021
1 parent e6d716e commit 9dce89d
Show file tree
Hide file tree
Showing 13 changed files with 312 additions and 135 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
package io.kestra.core.models.hierarchies;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Value;
import org.apache.commons.lang3.tuple.Pair;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.flows.Flow;
import io.kestra.core.services.GraphService;
import lombok.*;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Value
@Builder
Expand Down Expand Up @@ -48,6 +43,8 @@ public static FlowGraph of(Flow flow, Execution execution) throws IllegalVariabl

@Getter
@AllArgsConstructor
@ToString
@EqualsAndHashCode
public static class Edge {
private final String source;
private final String target;
Expand All @@ -56,6 +53,8 @@ public static class Edge {

@Getter
@AllArgsConstructor
@ToString
@EqualsAndHashCode
public static class Cluster {
private final GraphCluster cluster;
private final List<String> nodes;
Expand Down
106 changes: 54 additions & 52 deletions core/src/main/java/io/kestra/core/services/ExecutionService.java
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
package io.kestra.core.services;

import io.kestra.core.exceptions.InternalException;
import io.micronaut.context.ApplicationContext;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.util.StringUtils;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.flows.Flow;
import io.kestra.core.models.flows.State;
import io.kestra.core.queues.QueueFactoryInterface;
import io.kestra.core.queues.QueueInterface;
import io.kestra.core.models.hierarchies.GraphCluster;
import io.kestra.core.repositories.FlowRepositoryInterface;
import io.kestra.core.utils.IdUtils;
import io.micronaut.context.ApplicationContext;
import io.micronaut.core.annotation.Nullable;

import java.util.*;
import java.util.function.Predicate;
import java.util.AbstractMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.inject.Inject;
import javax.inject.Named;
import javax.inject.Singleton;

import static io.kestra.core.utils.Rethrow.throwFunction;
Expand All @@ -32,25 +30,24 @@ public class ExecutionService {
@Inject
private FlowRepositoryInterface flowRepositoryInterface;

@Inject
@Named(QueueFactoryInterface.EXECUTION_NAMED)
private QueueInterface<Execution> executionQueue;

public Execution restart(final Execution execution, @Nullable String taskId, @Nullable Integer revision) throws Exception {
public Execution restart(final Execution execution, @Nullable Integer revision) throws Exception {
if (!execution.getState().isTerninated()) {
throw new IllegalStateException("Execution must be terminated to be restarted, " +
"current state is '" + execution.getState().getCurrent() + "' !"
);
}

Execution newExecution;
if (StringUtils.hasText(taskId)) {
newExecution = newExecutionFromTaskRunId(execution, taskId, State.Type.RESTARTED, revision);
} else {
newExecution = newExecutionFromFailed(execution, State.Type.RESTARTED, revision);
return restartExecutionFromFailed(execution, State.Type.RESTARTED, revision);
}

public Execution replay(final Execution execution, String taskRunId, @Nullable Integer revision) throws Exception {
if (!execution.getState().isTerninated()) {
throw new IllegalStateException("Execution must be terminated to be restarted, " +
"current state is '" + execution.getState().getCurrent() + "' !"
);
}

return newExecution;
return replayExecutionFromTaskRunId(execution, taskRunId, State.Type.RESTARTED, revision);
}

private Set<String> getAncestors(Execution execution, TaskRun taskRun) {
Expand All @@ -65,46 +62,51 @@ private Set<String> getAncestors(Execution execution, TaskRun taskRun) {
.collect(Collectors.toSet());
}

private Execution newExecutionFromTaskRunId(final Execution execution, String referenceTaskId, State.Type newStateType, Integer revision) throws IllegalArgumentException, InternalException {
private Execution replayExecutionFromTaskRunId(final Execution execution, String taskRunId, State.Type newStateType, Integer revision) throws IllegalArgumentException, InternalException {
final Flow flow = flowRepositoryInterface.findByExecution(execution);
GraphCluster graphCluster = GraphService.of(flow, execution);

final Predicate<TaskRun> isNotReferenceTask = taskRun -> !(referenceTaskId.equals(taskRun.getTaskId()));
final Predicate<TaskRun> isNotFailed = taskRun -> !taskRun.getState().getCurrent().equals(State.Type.FAILED);

// Extract the reference task run index
final long refTaskRunIndex = execution
.getTaskRunList()
.stream()
.takeWhile(isNotFailed.and(isNotReferenceTask))
.count();
Set<String> taskRunToRestart = this.taskRunWithAncestors(
execution,
execution
.getTaskRunList()
.stream()
.filter(taskRun -> taskRun.getId().equals(taskRunId))
.collect(Collectors.toList())
);

if (refTaskRunIndex == execution.getTaskRunList().size()) {
throw new IllegalArgumentException("Task [" + referenceTaskId + "] does not exist !");
if (taskRunToRestart.size() == 0) {
throw new IllegalArgumentException("No task found to restart execution from !");
}

Map<String, String> mappingTaskRunId = this.mapTaskRunId(execution, false);
final String newExecutionId = IdUtils.create();

// Create new task run list
List<TaskRun> newTaskRuns = IntStream
.range(0, (int) refTaskRunIndex + 1)
.boxed()
.map(throwFunction(currentIndex -> {
final TaskRun originalTaskRun = execution.getTaskRunList().get(currentIndex);

return this.mapTaskRun(
flow,
originalTaskRun,
mappingTaskRunId,
newExecutionId,
newStateType,
currentIndex == refTaskRunIndex
);
}))
List<TaskRun> newTaskRuns = execution
.getTaskRunList()
.stream()
.map(throwFunction(originalTaskRun -> this.mapTaskRun(
flow,
originalTaskRun,
mappingTaskRunId,
newExecutionId,
newStateType,
taskRunToRestart.contains(originalTaskRun.getId()))
))
.collect(Collectors.toList());

// Build and launch new execution
Set<String> taskRunToRemove = GraphService.successors(graphCluster, List.of(taskRunId))
.stream()
.filter(task -> task.getTaskRun() != null)
.filter(task -> !task.getTaskRun().getId().equals(taskRunId))
.filter(task -> !taskRunToRestart.contains(task.getTaskRun().getId()))
.map(s -> mappingTaskRunId.get(s.getTaskRun().getId()))
.collect(Collectors.toSet());

taskRunToRemove
.forEach(r -> newTaskRuns.removeIf(taskRun -> taskRun.getId().equals(r)));

// Build and launch new execution
Execution newExecution = execution.childExecution(
newExecutionId,
newTaskRuns,
Expand Down Expand Up @@ -153,17 +155,17 @@ private TaskRun mapTaskRun(
);
}

private Set<String> taskRunToRestart(Execution execution, List<TaskRun> taskRuns) {
private Set<String> taskRunWithAncestors(Execution execution, List<TaskRun> taskRuns) {
return taskRuns
.stream()
.flatMap(throwFunction(taskRun -> this.getAncestors(execution, taskRun).stream()))
.collect(Collectors.toSet());
}

private Execution newExecutionFromFailed(final Execution execution, State.Type newStateType, Integer revision) throws InternalException {
private Execution restartExecutionFromFailed(final Execution execution, State.Type newStateType, Integer revision) throws InternalException {
final Flow flow = flowRepositoryInterface.findByExecution(execution);

Set<String> taskRunToRestart = this.taskRunToRestart(
Set<String> taskRunToRestart = this.taskRunWithAncestors(
execution,
execution
.getTaskRunList()
Expand Down
32 changes: 32 additions & 0 deletions core/src/main/java/io/kestra/core/services/GraphService.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,38 @@ public static List<Pair<GraphCluster, List<String>>> clusters(GraphCluster graph
.collect(Collectors.toList());
}

public static Set<AbstractGraphTask> successors(GraphCluster graphCluster, List<String> taskRunIds) {
List<FlowGraph.Edge> edges = GraphService.edges(graphCluster);
List<AbstractGraphTask> nodes = GraphService.nodes(graphCluster);

List<AbstractGraphTask> selectedTaskRuns = nodes
.stream()
.filter(task -> task.getTaskRun() != null && taskRunIds.contains(task.getTaskRun().getId()))
.collect(Collectors.toList());

Set<String> edgeUuid = selectedTaskRuns
.stream()
.flatMap(task -> recursiveEdge(edges, task.getUid()).stream())
.map(FlowGraph.Edge::getSource)
.collect(Collectors.toSet());

return nodes
.stream()
.filter(task -> edgeUuid.contains(task.getUid()))
.collect(Collectors.toSet());
}

private static List<FlowGraph.Edge> recursiveEdge(List<FlowGraph.Edge> edges, String selectedUuid) {
return edges
.stream()
.filter(edge -> edge.getSource().equals(selectedUuid))
.flatMap(edge -> Stream.concat(
Stream.of(edge),
recursiveEdge(edges, edge.getTarget()).stream()
))
.collect(Collectors.toList());
}

public static void sequential(
GraphCluster graph,
List<Task> tasks,
Expand Down

0 comments on commit 9dce89d

Please sign in to comment.