Skip to content

Commit

Permalink
Add components for Actor ID uniqeness (#341)
Browse files Browse the repository at this point in the history
## Overview
This PR adds components to ensure each Actor Environment is unique. The attribute list is org, project, domain, a user provided ID, and an auto-generate version hash.

## Test Plan
This will be tested locally under a variety of scenarios. This change only affects which fasttasks are executed in each environment, so there is no need for testing on Union.

## Rollout Plan (if applicable)
Once validated this update may be  deployed to all tenants immediately.

## Upstream Changes
Should this change be upstreamed to OSS (flyteorg/flyte)? If not, please uncheck this box, which is used for auditing. Note, it is the responsibility of each developer to actually upstream their changes. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F).
- [ ] To be upstreamed to OSS

## Issue
https://linear.app/unionai/issue/EXO-110/decide-on-execution-environment-id-uniqueness-for-actor-environments

## Checklist
* [ ] Added tests
* [ ] Ran a deploy dry run and shared the terraform plan
* [ ] Added logging and metrics
* [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list)
* [ ] Updated documentation
  • Loading branch information
hamersaw committed Jul 2, 2024
1 parent 375cfc6 commit 964421b
Show file tree
Hide file tree
Showing 18 changed files with 302 additions and 131 deletions.
53 changes: 28 additions & 25 deletions fasttask/plugin/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ func newBuilderMetrics(scope promutils.Scope) builderMetrics {
// environment represents a managed fast task environment, including it's definition and current
// state
type environment struct {
lastAccessedAt time.Time
extant *_struct.Struct
lastAccessedAt time.Time
name string
replicas []string
spec *pb.FastTaskEnvironmentSpec
state state
Expand All @@ -83,8 +84,8 @@ type InMemoryEnvBuilder struct {

// Get retrieves the environment with the given execution environment ID. If the environment does
// not exist or has been tombstoned, nil is returned.
func (i *InMemoryEnvBuilder) Get(ctx context.Context, executionEnvID string) *_struct.Struct {
if environment := i.environments[executionEnvID]; environment != nil {
func (i *InMemoryEnvBuilder) Get(ctx context.Context, executionEnvID core.ExecutionEnvID) *_struct.Struct {
if environment := i.environments[executionEnvID.String()]; environment != nil {
i.lock.Lock()
defer i.lock.Unlock()

Expand All @@ -98,23 +99,23 @@ func (i *InMemoryEnvBuilder) Get(ctx context.Context, executionEnvID string) *_s

// Create creates a new fast task environment with the given execution environment ID and
// specification. If the environment already exists, the existing environment is returned.
func (i *InMemoryEnvBuilder) Create(ctx context.Context, executionEnvID string, spec *_struct.Struct) (*_struct.Struct, error) {
func (i *InMemoryEnvBuilder) Create(ctx context.Context, executionEnvID core.ExecutionEnvID, spec *_struct.Struct) (*_struct.Struct, error) {
// unmarshall and validate FastTaskEnvironmentSpec
fastTaskEnvironmentSpec := &pb.FastTaskEnvironmentSpec{}
if err := utils.UnmarshalStruct(spec, fastTaskEnvironmentSpec); err != nil {
return nil, err
}

if err := isValidEnvironmentSpec(fastTaskEnvironmentSpec); err != nil {
if err := isValidEnvironmentSpec(executionEnvID, fastTaskEnvironmentSpec); err != nil {
return nil, flyteerrors.Errorf(flyteerrors.BadTaskSpecification,
"detected invalid FastTaskEnvironmentSpec [%v], Err: [%v]", fastTaskEnvironmentSpec.GetPodTemplateSpec(), err)
"detected invalid EnvironmentSpec for environment '%s', Err: [%v]", executionEnvID, err)
}

logger.Debug(ctx, "creating environment '%s'", executionEnvID)
logger.Debugf(ctx, "creating environment '%s'", executionEnvID)

// build fastTaskEnvironment extant
fastTaskEnvironment := &pb.FastTaskEnvironment{
QueueId: executionEnvID,
QueueId: executionEnvID.String(),
}
environmentStruct := &_struct.Struct{}
if err := utils.MarshalStruct(fastTaskEnvironment, environmentStruct); err != nil {
Expand All @@ -124,7 +125,7 @@ func (i *InMemoryEnvBuilder) Create(ctx context.Context, executionEnvID string,
// create environment
i.lock.Lock()

env, exists := i.environments[executionEnvID]
env, exists := i.environments[executionEnvID.String()]
if exists && env.state != ORPHANED {
i.lock.Unlock()

Expand All @@ -141,8 +142,9 @@ func (i *InMemoryEnvBuilder) Create(ctx context.Context, executionEnvID string,
}

env = &environment{
lastAccessedAt: time.Now(),
extant: environmentStruct,
lastAccessedAt: time.Now(),
name: executionEnvID.Name,
replicas: replicas,
spec: fastTaskEnvironmentSpec,
state: HEALTHY,
Expand All @@ -155,20 +157,20 @@ func (i *InMemoryEnvBuilder) Create(ctx context.Context, executionEnvID string,
return nil, err
}

podName := fmt.Sprintf("%s-%s", executionEnvID, hex.EncodeToString(nonceBytes)[:GetConfig().NonceLength])
podName := fmt.Sprintf("%s-%s", env.name, hex.EncodeToString(nonceBytes)[:GetConfig().NonceLength])
env.replicas = append(env.replicas, podName)
podNames = append(podNames, podName)
}

i.environments[executionEnvID] = env
i.environments[executionEnvID.String()] = env
i.metrics.environmentsCreated.Inc()

i.lock.Unlock()

// create replicas
for _, podName := range podNames {
logger.Debugf(ctx, "creating pod '%s' for environment '%s'", podName, executionEnvID)
if err := i.createPod(ctx, fastTaskEnvironmentSpec, executionEnvID, podName); err != nil {
if err := i.createPod(ctx, fastTaskEnvironmentSpec, executionEnvID.String(), podName); err != nil {
logger.Warnf(ctx, "failed to create pod '%s' for environment '%s' [%v]", podName, executionEnvID, err)
}
}
Expand All @@ -179,12 +181,12 @@ func (i *InMemoryEnvBuilder) Create(ctx context.Context, executionEnvID string,

// Status returns the status of the environment with the given execution environment ID. This
// includes the details of each pod in the environment replica set.
func (i *InMemoryEnvBuilder) Status(ctx context.Context, executionEnvID string) (interface{}, error) {
func (i *InMemoryEnvBuilder) Status(ctx context.Context, executionEnvID core.ExecutionEnvID) (interface{}, error) {
i.lock.Lock()
defer i.lock.Unlock()

// check if environment exists
environment, exists := i.environments[executionEnvID]
environment, exists := i.environments[executionEnvID.String()]
if !exists {
return nil, nil
}
Expand Down Expand Up @@ -425,20 +427,20 @@ func (i *InMemoryEnvBuilder) repairEnvironments(ctx context.Context) error {
// identify environments in need of repair
i.lock.Lock()
pod := &v1.Pod{}
for environmentID, environment := range i.environments {
for environmentID, env := range i.environments {
// check if environment is repairable (ie. HEALTHY or REPAIRING state)
if environment.state != HEALTHY && environment.state != REPAIRING {
if env.state != HEALTHY && env.state != REPAIRING {
continue
}

podTemplateSpec := &v1.PodTemplateSpec{}
if err := json.Unmarshal(environment.spec.GetPodTemplateSpec(), podTemplateSpec); err != nil {
if err := json.Unmarshal(env.spec.GetPodTemplateSpec(), podTemplateSpec); err != nil {
return flyteerrors.Errorf(flyteerrors.BadTaskSpecification,
"unable to unmarshal PodTemplateSpec [%v], Err: [%v]", environment.spec.GetPodTemplateSpec(), err.Error())
"unable to unmarshal PodTemplateSpec [%v], Err: [%v]", env.spec.GetPodTemplateSpec(), err.Error())
}

podNames := make([]string, 0)
for index, podName := range environment.replicas {
for index, podName := range env.replicas {
err := i.kubeClient.GetCache().Get(ctx, types.NamespacedName{
Name: podName,
Namespace: podTemplateSpec.Namespace,
Expand All @@ -450,16 +452,16 @@ func (i *InMemoryEnvBuilder) repairEnvironments(ctx context.Context) error {
return err
}

newPodName := fmt.Sprintf("%s-%s", environmentID, hex.EncodeToString(nonceBytes)[:GetConfig().NonceLength])
environment.replicas[index] = newPodName
newPodName := fmt.Sprintf("%s-%s", env.name, hex.EncodeToString(nonceBytes)[:GetConfig().NonceLength])
env.replicas[index] = newPodName
podNames = append(podNames, newPodName)
}
}

if len(podNames) > 0 {
logger.Infof(ctx, "repairing environment '%s'", environmentID)
environment.state = REPAIRING
environmentSpecs[environmentID] = *environment.spec
env.state = REPAIRING
environmentSpecs[environmentID] = *env.spec
environmentReplicas[environmentID] = podNames
}
}
Expand Down Expand Up @@ -568,8 +570,9 @@ func (i *InMemoryEnvBuilder) detectOrphanedEnvironments(ctx context.Context, k8s

// create orphaned environment
orphanedEnvironment = &environment{
lastAccessedAt: time.Now(),
extant: nil,
lastAccessedAt: time.Now(),
name: "orphaned",
replicas: make([]string, 0),
spec: &pb.FastTaskEnvironmentSpec{
PodTemplateSpec: podTemplateSpecBytes,
Expand Down
31 changes: 23 additions & 8 deletions fasttask/plugin/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
coremocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyte/flytestdlib/promutils"
Expand Down Expand Up @@ -63,8 +64,15 @@ func (k *kubeCache) List(ctx context.Context, list client.ObjectList, opts ...cl
}

func TestCreate(t *testing.T) {
executionEnvID := core.ExecutionEnvID{
Project: "project",
Domain: "domain",
Name: "foo",
Version: "0",
}

fastTaskEnvironment := &pb.FastTaskEnvironment{
QueueId: "foo",
QueueId: executionEnvID.String(),
}

fastTaskEnvStruct := &_struct.Struct{}
Expand Down Expand Up @@ -112,7 +120,7 @@ func TestCreate(t *testing.T) {
name: "Exists",
environmentSpec: fastTaskEnvSpec,
environments: map[string]*environment{
"foo": &environment{
executionEnvID.String(): &environment{
extant: fastTaskEnvStruct,
state: HEALTHY,
},
Expand All @@ -124,7 +132,7 @@ func TestCreate(t *testing.T) {
name: "Orphaned",
environmentSpec: fastTaskEnvSpec,
environments: map[string]*environment{
"foo": &environment{
executionEnvID.String(): &environment{
extant: fastTaskEnvStruct,
replicas: []string{"bar"},
state: ORPHANED,
Expand Down Expand Up @@ -155,7 +163,7 @@ func TestCreate(t *testing.T) {
builder.environments = test.environments

// call `Create`
environment, err := builder.Create(ctx, "foo", fastTaskEnvSpecStruct)
environment, err := builder.Create(ctx, executionEnvID, fastTaskEnvSpecStruct)
assert.Nil(t, err)
assert.True(t, proto.Equal(test.expectedEnvironment, environment))
assert.Equal(t, test.expectedCreateCalls, kubeClient.createCalls)
Expand Down Expand Up @@ -336,8 +344,15 @@ func TestGCEnvironments(t *testing.T) {
}

func TestGet(t *testing.T) {
executionEnvID := core.ExecutionEnvID{
Project: "project",
Domain: "domain",
Name: "name",
Version: "version",
}

fastTaskEnvironment := &pb.FastTaskEnvironment{
QueueId: "foo",
QueueId: executionEnvID.String(),
}

fastTaskEnvStruct := &_struct.Struct{}
Expand All @@ -353,7 +368,7 @@ func TestGet(t *testing.T) {
{
name: "Exists",
environments: map[string]*environment{
"foo": &environment{
executionEnvID.String(): &environment{
extant: fastTaskEnvStruct,
state: HEALTHY,
},
Expand All @@ -368,7 +383,7 @@ func TestGet(t *testing.T) {
{
name: "Tombstoned",
environments: map[string]*environment{
"foo": &environment{
executionEnvID.String(): &environment{
state: TOMBSTONED,
},
},
Expand All @@ -392,7 +407,7 @@ func TestGet(t *testing.T) {
builder.environments = test.environments

// call `Get`
environment := builder.Get(ctx, "foo")
environment := builder.Get(ctx, executionEnvID)
assert.True(t, proto.Equal(test.expectedEnvironment, environment))
})
}
Expand Down
54 changes: 39 additions & 15 deletions fasttask/plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,51 +86,73 @@ func (p *Plugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}

// buildExecutionEnvID creates an `ExecutionEnvID` from a task ID and an `ExecutionEnv`. This
// collection of attributes is used to uniquely identify an execution environment.
func buildExecutionEnvID(taskID *idlcore.Identifier, executionEnv *idlcore.ExecutionEnv) core.ExecutionEnvID {
return core.ExecutionEnvID{
Org: taskID.GetOrg(),
Project: taskID.GetProject(),
Domain: taskID.GetDomain(),
Name: executionEnv.GetName(),
Version: executionEnv.GetVersion(),
}
}

// getExecutionEnv retrieves the execution environment for the task. If the environment does not
// exist, it will create it.
// this is here because we wanted uniformity within `TaskExecutionContext` where functions simply
// return an interface rather than doing any actual work. alternatively, we could bury this within
// `NodeExecutionContext` so other `ExecutionEnvironment` plugins do not need to duplicate this.
func (p *Plugin) getExecutionEnv(ctx context.Context, tCtx core.TaskExecutionContext) (*pb.FastTaskEnvironment, error) {
func (p *Plugin) getExecutionEnv(ctx context.Context, tCtx core.TaskExecutionContext) (*idlcore.ExecutionEnv, *pb.FastTaskEnvironment, error) {
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return nil, err
return nil, nil, err
}

executionEnv := &idlcore.ExecutionEnv{}
if err := utils.UnmarshalStruct(taskTemplate.GetCustom(), executionEnv); err != nil {
return nil, flyteerrors.Wrapf(flyteerrors.BadTaskSpecification, err, "failed to unmarshal environment")
return nil, nil, flyteerrors.Wrapf(flyteerrors.BadTaskSpecification, err, "failed to unmarshal environment")
}

switch e := executionEnv.GetEnvironment().(type) {
case *idlcore.ExecutionEnv_Spec:
executionEnvClient := tCtx.GetExecutionEnvClient()
taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID()
executionEnvID := buildExecutionEnvID(taskExecutionID.GetTaskId(), executionEnv)

// if environment already exists then return it
if environment := executionEnvClient.Get(ctx, executionEnv.GetId()); environment != nil {
if environment := executionEnvClient.Get(ctx, executionEnvID); environment != nil {
fastTaskEnvironment := &pb.FastTaskEnvironment{}
if err := utils.UnmarshalStruct(environment, fastTaskEnvironment); err != nil {
return nil, flyteerrors.Wrapf(flyteerrors.BadTaskSpecification, err, "failed to unmarshal environment client")
return nil, nil, flyteerrors.Wrapf(flyteerrors.BadTaskSpecification, err, "failed to unmarshal environment client")
}

return fastTaskEnvironment, nil
return executionEnv, fastTaskEnvironment, nil
}

// otherwise create the environment
return p.createExecutionEnv(ctx, tCtx, executionEnv.GetId(), e)
fastTaskEnvironment, err := p.createExecutionEnv(ctx, tCtx, executionEnvID, e)
if err != nil {
return nil, nil, err
}

return executionEnv, fastTaskEnvironment, nil
case *idlcore.ExecutionEnv_Extant:
fastTaskEnvironment := &pb.FastTaskEnvironment{}
if err := utils.UnmarshalStruct(e.Extant, fastTaskEnvironment); err != nil {
return nil, flyteerrors.Wrapf(flyteerrors.BadTaskSpecification, err, "failed to unmarshal environment extant")
return nil, nil, flyteerrors.Wrapf(flyteerrors.BadTaskSpecification, err, "failed to unmarshal environment extant")
}

return fastTaskEnvironment, nil
return executionEnv, fastTaskEnvironment, nil
}

return nil, nil
return nil, nil, nil
}

func (p *Plugin) createExecutionEnv(ctx context.Context, tCtx core.TaskExecutionContext, envID string, envSpec *idlcore.ExecutionEnv_Spec) (*pb.FastTaskEnvironment, error) {
// createExecutionEnv creates a new execution environment based on the specified parameters.
func (p *Plugin) createExecutionEnv(ctx context.Context, tCtx core.TaskExecutionContext,
executionEnvID core.ExecutionEnvID, envSpec *idlcore.ExecutionEnv_Spec) (*pb.FastTaskEnvironment, error) {

environmentSpec := envSpec.Spec

fastTaskEnvironmentSpec := &pb.FastTaskEnvironmentSpec{}
Expand Down Expand Up @@ -171,7 +193,7 @@ func (p *Plugin) createExecutionEnv(ctx context.Context, tCtx core.TaskExecution
}

executionEnvClient := tCtx.GetExecutionEnvClient()
environment, err := executionEnvClient.Create(ctx, envID, environmentSpec)
environment, err := executionEnvClient.Create(ctx, executionEnvID, environmentSpec)
if err != nil {
return nil, flyteerrors.Wrapf(flyteerrors.BadTaskSpecification, err, "failed to create environment")
}
Expand Down Expand Up @@ -220,7 +242,7 @@ func (p *Plugin) addObjectMetadata(ctx context.Context, tCtx core.TaskExecutionC
// Handle is the main entrypoint for the plugin. It will offer the task to the worker pool and
// monitor the task until completion.
func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
fastTaskEnvironment, err := p.getExecutionEnv(ctx, tCtx)
executionEnv, fastTaskEnvironment, err := p.getExecutionEnv(ctx, tCtx)
if err != nil {
return core.UnknownTransition, err
}
Expand Down Expand Up @@ -282,7 +304,9 @@ func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (co
}

// fail if all replicas for this environment are in a failed state
statuses, err := tCtx.GetExecutionEnvClient().Status(ctx, queueID)
taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID()
executionEnvID := buildExecutionEnvID(taskExecutionID.GetTaskId(), executionEnv)
statuses, err := tCtx.GetExecutionEnvClient().Status(ctx, executionEnvID)
if err != nil {
return core.UnknownTransition, err
}
Expand Down Expand Up @@ -387,7 +411,7 @@ func (p *Plugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext) erro

// Finalize is called when the task execution is complete, performing any necessary cleanup.
func (p *Plugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
fastTaskEnvironment, err := p.getExecutionEnv(ctx, tCtx)
_, fastTaskEnvironment, err := p.getExecutionEnv(ctx, tCtx)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 964421b

Please sign in to comment.