This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
/
launcher.go
100 lines (81 loc) · 3.46 KB
/
launcher.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package awsbatch
import (
"context"
"fmt"
"github.com/lyft/flyteplugins/go/tasks/errors"
"github.com/lyft/flytestdlib/logger"
arrayCore "github.com/lyft/flyteplugins/go/tasks/plugins/array/core"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/lyft/flyteplugins/go/tasks/plugins/array/arraystatus"
"github.com/lyft/flyteplugins/go/tasks/plugins/array/awsbatch/config"
)
func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchClient Client, pluginConfig *config.Config,
currentState *State, metrics ExecutorMetrics) (nextState *State, err error) {
size := currentState.GetExecutionArraySize()
if int64(currentState.GetExecutionArraySize()) > pluginConfig.MaxArrayJobSize {
ee := fmt.Errorf("array size > max allowed. Requested [%v]. Allowed [%v]", currentState.GetExecutionArraySize(), pluginConfig.MaxArrayJobSize)
logger.Info(ctx, ee)
currentState.State = currentState.SetPhase(arrayCore.PhasePermanentFailure, 0).SetReason(ee.Error())
return currentState, nil
}
jobDefinition := currentState.GetJobDefinitionArn()
if len(jobDefinition) == 0 {
return nil, fmt.Errorf("system error; no job definition created")
}
batchInput, err := FlyteTaskToBatchInput(ctx, tCtx, jobDefinition, pluginConfig)
if err != nil {
return nil, err
}
t, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return nil, err
}
// If the original job was marked as an array (not a single job), then make sure to set it up correctly.
if t.Type == arrayTaskType {
logger.Debugf(ctx, "Task is of type [%v]. Will setup task index env vars.", t.Type)
batchInput = UpdateBatchInputForArray(ctx, batchInput, int64(size))
}
j, err := batchClient.SubmitJob(ctx, batchInput)
if err != nil {
logger.Errorf(ctx, "Failed to submit job [%+v]. Error: %v", batchInput, err)
return nil, err
}
metrics.SubTasksSubmitted.Add(ctx, float64(size))
parentState := currentState.
SetPhase(arrayCore.PhaseCheckingSubTaskExecutions, 0).
SetArrayStatus(arraystatus.ArrayStatus{
Summary: arraystatus.ArraySummary{
core.PhaseQueued: int64(size),
},
Detailed: arrayCore.NewPhasesCompactArray(uint(size)),
}).
SetReason("Successfully launched subtasks.")
nextState = currentState.SetExternalJobID(j)
nextState.State = parentState
return nextState, nil
}
// Attempts to terminate the AWS Job if one is recorded in the pluginState. This API is idempotent and should be safe
// to call multiple times on the same job. It'll result in multiple calls to AWS Batch in that case, however.
func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchClient Client, reason string, metrics ExecutorMetrics) error {
pluginState := &State{}
if _, err := tCtx.PluginStateReader().Get(pluginState); err != nil {
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to unmarshal custom state")
}
// This only makes sense if the task has "just" been kicked off. Assigning state here is meant to make subsequent
// code simpler.
if pluginState.State == nil {
pluginState.State = &arrayCore.State{}
}
p, _ := pluginState.GetPhase()
logger.Infof(ctx, "TerminateSubTasks is called with phase [%v] and reason [%v]", p, reason)
if pluginState.GetExternalJobID() != nil {
jobID := *pluginState.GetExternalJobID()
logger.Infof(ctx, "Cancelling AWS Job [%v] because [%v].", jobID, reason)
err := batchClient.TerminateJob(ctx, jobID, reason)
if err != nil {
return err
}
metrics.BatchJobTerminated.Inc(ctx)
}
return nil
}