From 5b9852b02d32b3d3e06b7754e7593f93db52845b Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 27 Sep 2023 21:42:23 +0800 Subject: [PATCH 1/2] sync plugin to be merged Signed-off-by: Future Outlier --- go.mod | 2 + .../pluginmachinery/internal/webapi/core.go | 20 ++++- go/tasks/pluginmachinery/webapi/plugin.go | 4 +- go/tasks/plugins/webapi/agent/plugin.go | 85 +++++++++++++++++-- go/tasks/plugins/webapi/athena/plugin.go | 8 +- go/tasks/plugins/webapi/bigquery/plugin.go | 8 +- go/tasks/plugins/webapi/databricks/plugin.go | 4 +- go/tasks/plugins/webapi/snowflake/plugin.go | 4 +- 8 files changed, 114 insertions(+), 21 deletions(-) diff --git a/go.mod b/go.mod index 9d1af74f5..484692ba0 100644 --- a/go.mod +++ b/go.mod @@ -136,3 +136,5 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d + +replace github.com/flyteorg/flyteidl => /mnt/c/code/dev/flyteidl diff --git a/go/tasks/pluginmachinery/internal/webapi/core.go b/go/tasks/pluginmachinery/internal/webapi/core.go index 049fc431a..1c855f3b1 100644 --- a/go/tasks/pluginmachinery/internal/webapi/core.go +++ b/go/tasks/pluginmachinery/internal/webapi/core.go @@ -36,6 +36,7 @@ const ( type CorePlugin struct { id string p webapi.AsyncPlugin + sp webapi.SyncPlugin cache cache.AutoRefresh tokenAllocator tokenAllocator metrics Metrics @@ -68,12 +69,28 @@ func (c CorePlugin) GetProperties() core.PluginProperties { return core.PluginProperties{} } +// syncHandle +// TODO: ADD Sync Handle func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { incomingState, err := c.unmarshalState(ctx, tCtx.PluginStateReader()) if err != nil { return core.UnknownTransition, err } + taskTemplate, err := tCtx.TaskReader().Read(ctx) + + if taskTemplate.Type == "requester" { + res, err := c.sp.Do(ctx, tCtx) + if err != nil { + return core.UnknownTransition, err + } + logger.Infof(ctx, "@@@ SyncPlugin [%v] returned result: %v", c.GetID(), res) + // if err := tCtx.PluginStateWriter().Put(pluginStateVersion, nextState); err != nil { + // return core.UnknownTransition, err + // } + return core.DoTransition(core.PhaseInfoSuccess(nil)), nil + } + var nextState *State var phaseInfo core.PhaseInfo switch incomingState.Phase { @@ -165,7 +182,7 @@ func createRemotePlugin(pluginEntry webapi.PluginEntry, c clock.Clock) core.Plug RegisteredTaskTypes: pluginEntry.SupportedTaskTypes, LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) ( core.Plugin, error) { - p, err := pluginEntry.PluginLoader(ctx, iCtx) + p, sp, err := pluginEntry.PluginLoader(ctx, iCtx) if err != nil { return nil, err } @@ -205,6 +222,7 @@ func createRemotePlugin(pluginEntry webapi.PluginEntry, c clock.Clock) core.Plug return CorePlugin{ id: pluginEntry.ID, p: p, + sp: sp, cache: resourceCache, metrics: newMetrics(iCtx.MetricsScope()), tokenAllocator: newTokenAllocator(c), diff --git a/go/tasks/pluginmachinery/webapi/plugin.go b/go/tasks/pluginmachinery/webapi/plugin.go index 63b6b5e2b..622cc7d2e 100644 --- a/go/tasks/pluginmachinery/webapi/plugin.go +++ b/go/tasks/pluginmachinery/webapi/plugin.go @@ -22,7 +22,7 @@ import ( // A Lazy loading function, that will load the plugin. Plugins should be initialized in this method. It is guaranteed // that the plugin loader will be called before any Handle/Abort/Finalize functions are invoked -type PluginLoader func(ctx context.Context, iCtx PluginSetupContext) (AsyncPlugin, error) +type PluginLoader func(ctx context.Context, iCtx PluginSetupContext) (AsyncPlugin, SyncPlugin, error) // PluginEntry is a structure that is used to indicate to the system a K8s plugin type PluginEntry struct { @@ -150,5 +150,5 @@ type SyncPlugin interface { GetConfig() PluginConfig // Do performs the action associated with this plugin. - Do(ctx context.Context, tCtx TaskExecutionContext) (phase pluginsCore.PhaseInfo, err error) + Do(ctx context.Context, tCtx TaskExecutionContextReader) (latest Resource, err error) } diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index 9e663319e..2b68db72b 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -61,6 +61,73 @@ func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionCo return "default", p.cfg.ResourceConstraints, nil } +func (p Plugin) Do(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (latest webapi.Resource, err error) { + // write the resource here + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return nil, err + } + + inputs, err := taskCtx.InputReader().Get(ctx) + if err != nil { + return nil, err + } + + if taskTemplate.GetContainer() != nil { + templateParameters := template.Parameters{ + TaskExecMetadata: taskCtx.TaskExecutionMetadata(), + Inputs: taskCtx.InputReader(), + OutputPath: taskCtx.OutputWriter(), + Task: taskCtx.TaskReader(), + } + modifiedArgs, err := template.Render(ctx, taskTemplate.GetContainer().Args, templateParameters) + if err != nil { + return nil, err + } + taskTemplate.GetContainer().Args = modifiedArgs + } + + agent, err := getFinalAgent(taskTemplate.Type, p.cfg) + if err != nil { + return nil, fmt.Errorf("failed to find agent agent with error: %v", err) + } + + client, err := p.getClient(ctx, agent, p.connectionCache) + if err != nil { + return nil, fmt.Errorf("failed to connect to agent with error: %v", err) + } + + finalCtx, cancel := getFinalContext(ctx, "DoTask", agent) + + defer cancel() + + // taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + // write it in agent? + + logger.Infof(ctx, "@@@ inputs: [%v]", inputs) + logger.Infof(ctx, "@@@ taskTemplate: [%v]", taskTemplate) + + res, err := client.DoTask(finalCtx, &admin.DoTaskRequest{Inputs: inputs, Template: taskTemplate}) + if err != nil { + return nil, err + } + + logger.Infof(ctx, "@@@ res.Resource.State: [%v]", res.Resource.State) + logger.Infof(ctx, "@@@ res.Resource.Outputs: [%v]", res.Resource.Outputs) + + return &ResourceWrapper{ + State: res.Resource.State, + Outputs: res.Resource.Outputs, + }, nil +} + +// todo: write the output + +// we can get the task type in core.go +/* +taskTemplate, err := taskCtx.TaskReader().Read(ctx) +taskTemplate.type = spark, requester ... +*/ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, webapi.Resource, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) @@ -298,6 +365,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Agent) (conte return context.WithTimeout(ctx, timeout) } +// TODO: Add sync agent plugin func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry { if len(supportedTaskTypes) == 0 { supportedTaskTypes = SupportedTaskTypes{"default_supported_task_type"} @@ -306,13 +374,18 @@ func newAgentPlugin(supportedTaskTypes SupportedTaskTypes) webapi.PluginEntry { return webapi.PluginEntry{ ID: "agent-service", SupportedTaskTypes: supportedTaskTypes, - PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) { return &Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: GetConfig(), - getClient: getClientFunc, - connectionCache: make(map[*Agent]*grpc.ClientConn), - }, nil + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + getClient: getClientFunc, + connectionCache: make(map[*Agent]*grpc.ClientConn), + }, &Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + getClient: getClientFunc, + connectionCache: make(map[*Agent]*grpc.ClientConn), + }, nil }, } } diff --git a/go/tasks/plugins/webapi/athena/plugin.go b/go/tasks/plugins/webapi/athena/plugin.go index 5bbf99bdd..ef6c8579a 100644 --- a/go/tasks/plugins/webapi/athena/plugin.go +++ b/go/tasks/plugins/webapi/athena/plugin.go @@ -200,10 +200,10 @@ func createTaskInfo(queryID string, cfg awsSdk.Config) *core.TaskInfo { } } -func NewPlugin(_ context.Context, cfg *Config, awsConfig *aws.Config, metricScope promutils.Scope) (Plugin, error) { +func NewPlugin(_ context.Context, cfg *Config, awsConfig *aws.Config, metricScope promutils.Scope) (Plugin, webapi.SyncPlugin, error) { sdkCfg, err := awsConfig.GetSdkConfig() if err != nil { - return Plugin{}, err + return Plugin{}, nil, err } return Plugin{ @@ -211,14 +211,14 @@ func NewPlugin(_ context.Context, cfg *Config, awsConfig *aws.Config, metricScop client: athena.NewFromConfig(sdkCfg), cfg: cfg, awsConfig: sdkCfg, - }, nil + }, nil, nil } func init() { pluginmachinery.PluginRegistry().RegisterRemotePlugin(webapi.PluginEntry{ ID: "athena", SupportedTaskTypes: []core.TaskType{"hive", "presto"}, - PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) { return NewPlugin(ctx, GetConfig(), aws.GetConfig(), iCtx.MetricsScope()) }, }) diff --git a/go/tasks/plugins/webapi/bigquery/plugin.go b/go/tasks/plugins/webapi/bigquery/plugin.go index 264957643..eaec69de2 100644 --- a/go/tasks/plugins/webapi/bigquery/plugin.go +++ b/go/tasks/plugins/webapi/bigquery/plugin.go @@ -547,25 +547,25 @@ func (p Plugin) newBigQueryClient(ctx context.Context, identity google.Identity) return bigquery.NewService(ctx, options...) } -func NewPlugin(cfg *Config, metricScope promutils.Scope) (*Plugin, error) { +func NewPlugin(cfg *Config, metricScope promutils.Scope) (*Plugin, webapi.SyncPlugin, error) { googleTokenSource, err := google.NewTokenSourceFactory(cfg.GoogleTokenSource) if err != nil { - return nil, pluginErrors.Wrapf(pluginErrors.PluginInitializationFailed, err, "failed to get google token source") + return nil, nil, pluginErrors.Wrapf(pluginErrors.PluginInitializationFailed, err, "failed to get google token source") } return &Plugin{ metricScope: metricScope, cfg: cfg, googleTokenSource: googleTokenSource, - }, nil + }, nil, nil } func newBigQueryJobTaskPlugin() webapi.PluginEntry { return webapi.PluginEntry{ ID: "bigquery", SupportedTaskTypes: []core.TaskType{bigqueryQueryJobTask}, - PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) { cfg := GetConfig() return NewPlugin(cfg, iCtx.MetricsScope()) diff --git a/go/tasks/plugins/webapi/databricks/plugin.go b/go/tasks/plugins/webapi/databricks/plugin.go index dca3581d5..c69a209a1 100644 --- a/go/tasks/plugins/webapi/databricks/plugin.go +++ b/go/tasks/plugins/webapi/databricks/plugin.go @@ -339,12 +339,12 @@ func newDatabricksJobTaskPlugin() webapi.PluginEntry { return webapi.PluginEntry{ ID: "databricks", SupportedTaskTypes: []core.TaskType{"spark"}, - PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) { return &Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), client: &http.Client{}, - }, nil + }, nil, nil }, } } diff --git a/go/tasks/plugins/webapi/snowflake/plugin.go b/go/tasks/plugins/webapi/snowflake/plugin.go index 09ff0156d..78079e4b1 100644 --- a/go/tasks/plugins/webapi/snowflake/plugin.go +++ b/go/tasks/plugins/webapi/snowflake/plugin.go @@ -275,12 +275,12 @@ func newSnowflakeJobTaskPlugin() webapi.PluginEntry { return webapi.PluginEntry{ ID: "snowflake", SupportedTaskTypes: []core.TaskType{"snowflake"}, - PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, webapi.SyncPlugin, error) { return &Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), client: &http.Client{}, - }, nil + }, nil, nil }, } } From f06c3953116171a0e0a18111846550d206cf4b0b Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 27 Sep 2023 22:17:35 +0800 Subject: [PATCH 2/2] dispatcher Signed-off-by: Future Outlier --- go/tasks/pluginmachinery/internal/webapi/core.go | 2 +- go/tasks/plugins/webapi/agent/plugin.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/go/tasks/pluginmachinery/internal/webapi/core.go b/go/tasks/pluginmachinery/internal/webapi/core.go index 1c855f3b1..ecf2b11f8 100644 --- a/go/tasks/pluginmachinery/internal/webapi/core.go +++ b/go/tasks/pluginmachinery/internal/webapi/core.go @@ -79,7 +79,7 @@ func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) taskTemplate, err := tCtx.TaskReader().Read(ctx) - if taskTemplate.Type == "requester" { + if taskTemplate.Type == "dispatcher" { res, err := c.sp.Do(ctx, tCtx) if err != nil { return core.UnknownTransition, err diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index 2b68db72b..45ec56ed1 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -126,7 +126,7 @@ func (p Plugin) Do(ctx context.Context, taskCtx webapi.TaskExecutionContextReade // we can get the task type in core.go /* taskTemplate, err := taskCtx.TaskReader().Read(ctx) -taskTemplate.type = spark, requester ... +taskTemplate.type = spark, dispatcher ... */ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, webapi.Resource, error) {