Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Sync Plugin Interface #409

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,5 @@ require (
)

replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d

replace github.com/flyteorg/flyteidl => /mnt/c/code/dev/flyteidl
20 changes: 19 additions & 1 deletion go/tasks/pluginmachinery/internal/webapi/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const (
type CorePlugin struct {
id string
p webapi.AsyncPlugin
sp webapi.SyncPlugin
cache cache.AutoRefresh
tokenAllocator tokenAllocator
metrics Metrics
Expand Down Expand Up @@ -68,12 +69,28 @@ func (c CorePlugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}

// syncHandle
// TODO: ADD Sync Handle
func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
incomingState, err := c.unmarshalState(ctx, tCtx.PluginStateReader())
if err != nil {
return core.UnknownTransition, err
}

taskTemplate, err := tCtx.TaskReader().Read(ctx)

if taskTemplate.Type == "dispatcher" {
res, err := c.sp.Do(ctx, tCtx)
if err != nil {
return core.UnknownTransition, err
}
logger.Infof(ctx, "@@@ SyncPlugin [%v] returned result: %v", c.GetID(), res)
// if err := tCtx.PluginStateWriter().Put(pluginStateVersion, nextState); err != nil {
// return core.UnknownTransition, err
// }
return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
}

var nextState *State
var phaseInfo core.PhaseInfo
switch incomingState.Phase {
Expand Down Expand Up @@ -165,7 +182,7 @@ func createRemotePlugin(pluginEntry webapi.PluginEntry, c clock.Clock) core.Plug
RegisteredTaskTypes: pluginEntry.SupportedTaskTypes,
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (
core.Plugin, error) {
p, err := pluginEntry.PluginLoader(ctx, iCtx)
p, sp, err := pluginEntry.PluginLoader(ctx, iCtx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -205,6 +222,7 @@ func createRemotePlugin(pluginEntry webapi.PluginEntry, c clock.Clock) core.Plug
return CorePlugin{
id: pluginEntry.ID,
p: p,
sp: sp,
cache: resourceCache,
metrics: newMetrics(iCtx.MetricsScope()),
tokenAllocator: newTokenAllocator(c),
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/pluginmachinery/webapi/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

// A Lazy loading function, that will load the plugin. Plugins should be initialized in this method. It is guaranteed
// that the plugin loader will be called before any Handle/Abort/Finalize functions are invoked
type PluginLoader func(ctx context.Context, iCtx PluginSetupContext) (AsyncPlugin, error)
type PluginLoader func(ctx context.Context, iCtx PluginSetupContext) (AsyncPlugin, SyncPlugin, error)

// PluginEntry is a structure that is used to indicate to the system a K8s plugin
type PluginEntry struct {
Expand Down Expand Up @@ -150,5 +150,5 @@ type SyncPlugin interface {
GetConfig() PluginConfig

// Do performs the action associated with this plugin.
Do(ctx context.Context, tCtx TaskExecutionContext) (phase pluginsCore.PhaseInfo, err error)
Do(ctx context.Context, tCtx TaskExecutionContextReader) (latest Resource, err error)
}
85 changes: 79 additions & 6 deletions go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,73 @@ func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionCo
return "default", p.cfg.ResourceConstraints, nil
}

func (p Plugin) Do(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (latest webapi.Resource, err error) {
// write the resource here
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
return nil, err
}

inputs, err := taskCtx.InputReader().Get(ctx)
if err != nil {
return nil, err
}

if taskTemplate.GetContainer() != nil {
templateParameters := template.Parameters{
TaskExecMetadata: taskCtx.TaskExecutionMetadata(),
Inputs: taskCtx.InputReader(),
OutputPath: taskCtx.OutputWriter(),
Task: taskCtx.TaskReader(),
}
modifiedArgs, err := template.Render(ctx, taskTemplate.GetContainer().Args, templateParameters)
if err != nil {
return nil, err
}
taskTemplate.GetContainer().Args = modifiedArgs
}

agent, err := getFinalAgent(taskTemplate.Type, p.cfg)
if err != nil {
return nil, fmt.Errorf("failed to find agent agent with error: %v", err)
}

client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "DoTask", agent)

defer cancel()

// taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
// write it in agent?

logger.Infof(ctx, "@@@ inputs: [%v]", inputs)
logger.Infof(ctx, "@@@ taskTemplate: [%v]", taskTemplate)

res, err := client.DoTask(finalCtx, &admin.DoTaskRequest{Inputs: inputs, Template: taskTemplate})
if err != nil {
return nil, err
}

logger.Infof(ctx, "@@@ res.Resource.State: [%v]", res.Resource.State)
logger.Infof(ctx, "@@@ res.Resource.Outputs: [%v]", res.Resource.Outputs)

return &ResourceWrapper{
State: res.Resource.State,
Outputs: res.Resource.Outputs,
}, nil
}

// todo: write the output

// we can get the task type in core.go
/*
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
taskTemplate.type = spark, dispatcher ...
*/
func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta,
webapi.Resource, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
Expand Down Expand Up @@ -298,6 +365,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Agent) (conte
return context.WithTimeout(ctx, timeout)
}

// TODO: Add sync agent plugin
func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry {
if len(supportedTaskTypes) == 0 {
supportedTaskTypes = SupportedTaskTypes{"default_supported_task_type"}
Expand All @@ -306,13 +374,18 @@ func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry {
return webapi.PluginEntry{
ID: "agent-service",
SupportedTaskTypes: supportedTaskTypes,
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, nil
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[*Agent]*grpc.ClientConn),
}, nil
},
}
}
Expand Down
8 changes: 4 additions & 4 deletions go/tasks/plugins/webapi/athena/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,25 +200,25 @@ func createTaskInfo(queryID string, cfg awsSdk.Config) *core.TaskInfo {
}
}

func NewPlugin(_ context.Context, cfg *Config, awsConfig *aws.Config, metricScope promutils.Scope) (Plugin, error) {
func NewPlugin(_ context.Context, cfg *Config, awsConfig *aws.Config, metricScope promutils.Scope) (Plugin, webapi.SyncPlugin, error) {
sdkCfg, err := awsConfig.GetSdkConfig()
if err != nil {
return Plugin{}, err
return Plugin{}, nil, err
}

return Plugin{
metricScope: metricScope,
client: athena.NewFromConfig(sdkCfg),
cfg: cfg,
awsConfig: sdkCfg,
}, nil
}, nil, nil
}

func init() {
pluginmachinery.PluginRegistry().RegisterRemotePlugin(webapi.PluginEntry{
ID: "athena",
SupportedTaskTypes: []core.TaskType{"hive", "presto"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return NewPlugin(ctx, GetConfig(), aws.GetConfig(), iCtx.MetricsScope())
},
})
Expand Down
8 changes: 4 additions & 4 deletions go/tasks/plugins/webapi/bigquery/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,25 +547,25 @@ func (p Plugin) newBigQueryClient(ctx context.Context, identity google.Identity)
return bigquery.NewService(ctx, options...)
}

func NewPlugin(cfg *Config, metricScope promutils.Scope) (*Plugin, error) {
func NewPlugin(cfg *Config, metricScope promutils.Scope) (*Plugin, webapi.SyncPlugin, error) {
googleTokenSource, err := google.NewTokenSourceFactory(cfg.GoogleTokenSource)

if err != nil {
return nil, pluginErrors.Wrapf(pluginErrors.PluginInitializationFailed, err, "failed to get google token source")
return nil, nil, pluginErrors.Wrapf(pluginErrors.PluginInitializationFailed, err, "failed to get google token source")
}

return &Plugin{
metricScope: metricScope,
cfg: cfg,
googleTokenSource: googleTokenSource,
}, nil
}, nil, nil
}

func newBigQueryJobTaskPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "bigquery",
SupportedTaskTypes: []core.TaskType{bigqueryQueryJobTask},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
cfg := GetConfig()

return NewPlugin(cfg, iCtx.MetricsScope())
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/webapi/databricks/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,12 @@ func newDatabricksJobTaskPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "databricks",
SupportedTaskTypes: []core.TaskType{"spark"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
client: &http.Client{},
}, nil
}, nil, nil
},
}
}
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/webapi/snowflake/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ func newSnowflakeJobTaskPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "snowflake",
SupportedTaskTypes: []core.TaskType{"snowflake"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
client: &http.Client{},
}, nil
}, nil, nil
},
}
}
Expand Down
Loading