This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 60
/
task_replayer_plugin.go
131 lines (117 loc) · 4.52 KB
/
task_replayer_plugin.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package fakeplugins
import (
"context"
"fmt"
pluginCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
)
type HandleResponse struct {
T pluginCore.Transition
Err error
}
func NewHandleTransition(transition pluginCore.Transition) HandleResponse {
return HandleResponse{
T: transition,
Err: nil,
}
}
func NewHandleError(err error) HandleResponse {
return HandleResponse{
T: pluginCore.UnknownTransition,
Err: err,
}
}
type taskReplayer struct {
nextOnHandleResponseIdx int
nextOnAbortResponseIdx int
nextOnFinalizeResponseIdx int
}
// This is a test plugin and can be used to play any scenario responses from a plugin, (exceptions: panic)
// The plugin is to be invoked within a single thread (not thread safe) and is very simple in terms of usage
// It does not use any state and does not drive the state machine using that state. It drives the state machine constantly forward for a taskID
type ReplayerPlugin struct {
id string
props pluginCore.PluginProperties
orderedOnHandleResponses []HandleResponse
orderedAbortResponses []error
orderedFinalizeResponses []error
taskReplayState map[string]*taskReplayer
}
func (r ReplayerPlugin) GetID() string {
return r.id
}
func (r ReplayerPlugin) GetProperties() pluginCore.PluginProperties {
return r.props
}
func (r ReplayerPlugin) Handle(_ context.Context, tCtx pluginCore.TaskExecutionContext) (pluginCore.Transition, error) {
n := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
s, ok := r.taskReplayState[n]
if !ok {
s = &taskReplayer{}
r.taskReplayState[n] = s
}
defer func() {
s.nextOnHandleResponseIdx++
}()
if s.nextOnHandleResponseIdx > len(r.orderedOnHandleResponses) {
return pluginCore.UnknownTransition, fmt.Errorf("plugin Handle Invoked [%d] times, expected [%d] for task [%s]", s.nextOnHandleResponseIdx, len(r.orderedOnHandleResponses), n)
}
hr := r.orderedOnHandleResponses[s.nextOnHandleResponseIdx]
return hr.T, hr.Err
}
func (r ReplayerPlugin) Abort(_ context.Context, tCtx pluginCore.TaskExecutionContext) error {
n := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
s, ok := r.taskReplayState[n]
if !ok {
s = &taskReplayer{}
r.taskReplayState[n] = s
}
defer func() {
s.nextOnAbortResponseIdx++
}()
if s.nextOnAbortResponseIdx > len(r.orderedAbortResponses) {
return fmt.Errorf("plugin Abort Invoked [%d] times, expected [%d] for task [%s]", s.nextOnAbortResponseIdx, len(r.orderedAbortResponses), n)
}
return r.orderedAbortResponses[s.nextOnAbortResponseIdx]
}
func (r ReplayerPlugin) Finalize(_ context.Context, tCtx pluginCore.TaskExecutionContext) error {
n := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
s, ok := r.taskReplayState[n]
if !ok {
s = &taskReplayer{}
r.taskReplayState[n] = s
}
defer func() {
s.nextOnAbortResponseIdx++
}()
if s.nextOnAbortResponseIdx > len(r.orderedAbortResponses) {
return fmt.Errorf("plugin Finalize Invoked [%d] times, expected [%d] for task [%s]", s.nextOnAbortResponseIdx, len(r.orderedFinalizeResponses), n)
}
return r.orderedFinalizeResponses[s.nextOnFinalizeResponseIdx]
}
func (r ReplayerPlugin) VerifyAllCallsCompleted(taskExecID string) error {
s, ok := r.taskReplayState[taskExecID]
if !ok {
s = &taskReplayer{}
r.taskReplayState[taskExecID] = s
}
if s.nextOnFinalizeResponseIdx != len(r.orderedFinalizeResponses)+1 {
return fmt.Errorf("finalize method expected invocations [%d], actual invocations [%d]", len(r.orderedFinalizeResponses), s.nextOnFinalizeResponseIdx)
}
if s.nextOnAbortResponseIdx != len(r.orderedAbortResponses)+1 {
return fmt.Errorf("abort method expected invocations [%d], actual invocations [%d]", len(r.orderedAbortResponses), s.nextOnAbortResponseIdx)
}
if s.nextOnHandleResponseIdx != len(r.orderedOnHandleResponses)+1 {
return fmt.Errorf("handle method expected invocations [%d], actual invocations [%d]", len(r.orderedOnHandleResponses), s.nextOnHandleResponseIdx)
}
return nil
}
func NewReplayer(forPluginID string, props pluginCore.PluginProperties, orderedOnHandleResponses []HandleResponse, orderedAbortResponses, orderedFinalizeResponses []error) *ReplayerPlugin {
return &ReplayerPlugin{
id: fmt.Sprintf("replayer-for-%s", forPluginID),
props: props,
orderedOnHandleResponses: orderedOnHandleResponses,
orderedAbortResponses: orderedAbortResponses,
orderedFinalizeResponses: orderedFinalizeResponses,
taskReplayState: make(map[string]*taskReplayer),
}
}