Skip to content

Commit

Permalink
plugin helper function adn error handling
Browse files Browse the repository at this point in the history
Signed-off-by: Future Outlier <eric901201@gmai.com>
  • Loading branch information
Future Outlier committed Nov 9, 2023
1 parent 4b32845 commit a125bd2
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
2 changes: 1 addition & 1 deletion flyteidl/protos/flyteidl/core/tasks.proto
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ message RuntimeMetadata {
// checks to ensure tighter validation or setting expectations.
string version = 2;

//+optional It can be used to provide extra information about the runtime (e.g. python, golang... etc.).
//+optional It can be used to provide extra information about the plugin type (e.g. async plugin, sync plugin... etc.).
string flavor = 3;
}

Expand Down
62 changes: 36 additions & 26 deletions flyteplugins/go/tasks/pluginmachinery/internal/webapi/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (
maxBurst = 10000
minQPS = 1
maxQPS = 100000
asyncPlugin = "async_plugin"
syncPlugin = "sync_plugin"
)

Expand Down Expand Up @@ -66,24 +67,41 @@ func (c CorePlugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}

func (c CorePlugin) syncHandle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
plugin, ok := c.p.(webapi.SyncPlugin)
if !ok {
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return core.UnknownTransition, err
func (c CorePlugin) getPlugin(pluginType string) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
if pluginType == syncPlugin {
plugin, ok := c.p.(webapi.SyncPlugin)
if !ok {
return nil, nil, fmt.Errorf("Core plugin does not implement the sync plugin interface")
}
return core.UnknownTransition, fmt.Errorf("%s does not implement required sync plugin interface methods", taskTemplate.GetType())
return nil, plugin, nil
}
// Assume the plugin is an async plugin if not explicitly specified as sync.
// This helps maintain backward compatibility with existing implementations that
// expect an async plugin by default, thereby avoiding breaking changes.
plugin, ok := c.p.(webapi.AsyncPlugin)
if !ok {
return nil, nil, fmt.Errorf("Core plugin does not implement the async plugin interface")
}
return plugin, nil, nil
}

func (c CorePlugin) syncHandle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
_, plugin, err := c.getPlugin(syncPlugin)
if err != nil {
return core.UnknownTransition, err
}

phaseInfo, err := plugin.Do(ctx, tCtx)
if err != nil {
logger.Errorf(ctx, "please check if the sync plugin interface is implemented or not")
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return core.UnknownTransition, err
}
logger.Errorf(ctx, "please check if [%v] task type has implemented sync plugin method or not", taskTemplate.GetType())
return core.UnknownTransition, err
}

return core.DoTransition(phaseInfo), nil

}

func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
Expand All @@ -104,9 +122,9 @@ func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext)

var nextState *State
var phaseInfo core.PhaseInfo
plugin, ok := c.p.(webapi.AsyncPlugin)
if !ok {
return core.UnknownTransition, fmt.Errorf("%s does not implement required async plugin interface methods", taskTemplate.GetType())
plugin, _, err := c.getPlugin(asyncPlugin)
if err != nil {
return core.UnknownTransition, err
}

switch incomingState.Phase {
Expand Down Expand Up @@ -141,13 +159,9 @@ func (c CorePlugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext) e

logger.Infof(ctx, "Attempting to abort resource [%v].", tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID())

plugin, ok := c.p.(webapi.AsyncPlugin)
if !ok {
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return err
}
return fmt.Errorf("%s does not implement required plugin interface methods", taskTemplate.GetType())
plugin, _, err := c.getPlugin(asyncPlugin)
if err != nil {
return err
}

err = plugin.Delete(ctx, newPluginContext(incomingState.ResourceMeta, nil, "Aborted", tCtx))
Expand All @@ -169,13 +183,9 @@ func (c CorePlugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext
logger.Infof(ctx, "Attempting to finalize resource [%v].",
tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName())

plugin, ok := c.p.(webapi.AsyncPlugin)
if !ok {
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return err
}
return fmt.Errorf("%s does not implement required plugin interface methods", taskTemplate.GetType())
plugin, _, err := c.getPlugin(asyncPlugin)
if err != nil {
return err
}

return c.tokenAllocator.releaseToken(ctx, plugin, tCtx, c.metrics)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func TestHandle(t *testing.T) {
ctx := context.TODO()
tCtx := getTaskContext(t)
taskReader := new(mocks.TaskReader)
taskInfo := &core.TaskInfo{}

template := flyteIdlCore.TaskTemplate{
Type: "api_task",
Expand All @@ -40,7 +41,7 @@ func TestHandle(t *testing.T) {
tCtx.On("TaskReader").Return(taskReader)

p := new(webapiMocks.SyncPlugin)
p.On("Do", ctx, tCtx).Return(core.PhaseInfo{}, nil)
p.On("Do", ctx, tCtx).Return(core.PhaseInfoSuccess(taskInfo), nil)

c := CorePlugin{
id: "test",
Expand All @@ -50,6 +51,8 @@ func TestHandle(t *testing.T) {
transition, err := c.Handle(ctx, tCtx)
assert.NoError(t, err)
assert.NotNil(t, transition)
assert.Equal(t, core.PhaseInfoSuccess(taskInfo), transition.Info())
assert.Equal(t, core.TransitionTypeEphemeral, transition.Type())
}

func Test_validateConfig(t *testing.T) {
Expand Down

0 comments on commit a125bd2

Please sign in to comment.