Skip to content

Commit

Permalink
Fix node execution for tasks in new events (flyteorg#206)
Browse files Browse the repository at this point in the history
Fix uniqueness for task execution events.
  • Loading branch information
anandswaminathan authored Nov 30, 2020
1 parent 9585314 commit 318ee7e
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Con
nodeID := node.Id
var subNodeStatus v1alpha1.ExecutableNodeStatus
if nCtx.ExecutionContext().GetEventVersion() == v1alpha1.EventVersion0 {
newID, err := hierarchicalNodeID(parentNodeID, currentAttemptStr, node.Id)
newID, err := hierarchicalNodeID(parentNodeID, currentAttemptStr, nodeID)
if err != nil {
return nil, err
}
Expand All @@ -69,7 +69,6 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Con
if err != nil {
return nil, err
}

subNodeStatus.SetDataDir(originalNodePath)
subNodeStatus.SetOutputDir(outputDir)
}
Expand Down
8 changes: 4 additions & 4 deletions flytepropeller/pkg/controller/nodes/task/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ func (p *pluginRequestedTransition) TransitionPreviouslyRecorded() {
}

func (p *pluginRequestedTransition) FinalTaskEvent(id *core.TaskExecutionIdentifier, in io.InputFilePaths, out io.OutputFilePaths,
nodeExecutionMetadata handler.NodeExecutionMetadata) (*event.TaskExecutionEvent, error) {
nodeExecutionMetadata handler.NodeExecutionMetadata, execContext executors.ExecutionContext) (*event.TaskExecutionEvent, error) {
if p.previouslyObserved {
return nil, nil
}

return ToTaskExecutionEvent(id, in, out, p.pInfo, nodeExecutionMetadata)
return ToTaskExecutionEvent(id, in, out, p.pInfo, nodeExecutionMetadata, execContext)
}

func (p *pluginRequestedTransition) ObserveSuccess(outputPath storage.DataReference, taskMetadata *event.TaskNodeMetadata) {
Expand Down Expand Up @@ -579,7 +579,7 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext)
// STEP 4: Send buffered events!
logger.Debugf(ctx, "Sending buffered Task events.")
for _, ev := range tCtx.ber.GetAll(ctx) {
evInfo, err := ToTaskExecutionEvent(&execID, nCtx.InputReader(), tCtx.ow, ev, nCtx.NodeExecutionMetadata())
evInfo, err := ToTaskExecutionEvent(&execID, nCtx.InputReader(), tCtx.ow, ev, nCtx.NodeExecutionMetadata(), nCtx.ExecutionContext())
if err != nil {
return handler.UnknownTransition, err
}
Expand All @@ -593,7 +593,7 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext)

// STEP 5: Send Transition events
logger.Debugf(ctx, "Sending transition event for plugin phase [%s]", pluginTrns.pInfo.Phase().String())
evInfo, err := pluginTrns.FinalTaskEvent(&execID, nCtx.InputReader(), tCtx.ow, nCtx.NodeExecutionMetadata())
evInfo, err := pluginTrns.FinalTaskEvent(&execID, nCtx.InputReader(), tCtx.ow, nCtx.NodeExecutionMetadata(), nCtx.ExecutionContext())
if err != nil {
logger.Errorf(ctx, "failed to convert plugin transition to TaskExecutionEvent. Error: %s", err.Error())
return handler.UnknownTransition, err
Expand Down
8 changes: 8 additions & 0 deletions flytepropeller/pkg/controller/nodes/task/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ func Test_task_Handle_NoCatalog(t *testing.T) {

executionContext := &mocks.ExecutionContext{}
executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{})
executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0)
executionContext.OnGetParentInfo().Return(nil)
nCtx.OnExecutionContext().Return(executionContext)

st := bytes.NewBuffer([]byte{})
Expand Down Expand Up @@ -768,6 +770,8 @@ func Test_task_Handle_Catalog(t *testing.T) {

executionContext := &mocks.ExecutionContext{}
executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{})
executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0)
executionContext.OnGetParentInfo().Return(nil)
nCtx.OnExecutionContext().Return(executionContext)

nCtx.OnRawOutputPrefix().Return("s3://sandbox/")
Expand Down Expand Up @@ -992,6 +996,8 @@ func Test_task_Handle_Barrier(t *testing.T) {

executionContext := &mocks.ExecutionContext{}
executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{})
executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0)
executionContext.OnGetParentInfo().Return(nil)
nCtx.OnExecutionContext().Return(executionContext)

nCtx.OnRawOutputPrefix().Return("s3://sandbox/")
Expand Down Expand Up @@ -1264,6 +1270,7 @@ func Test_task_Abort(t *testing.T) {

executionContext := &mocks.ExecutionContext{}
executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{})
executionContext.OnGetParentInfo().Return(nil)
nCtx.OnExecutionContext().Return(executionContext)

nCtx.OnRawOutputPrefix().Return("s3://sandbox/")
Expand Down Expand Up @@ -1404,6 +1411,7 @@ func Test_task_Finalize(t *testing.T) {

executionContext := &mocks.ExecutionContext{}
executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{})
executionContext.OnGetParentInfo().Return(nil)
nCtx.OnExecutionContext().Return(executionContext)

nCtx.OnRawOutputPrefix().Return("s3://sandbox/")
Expand Down
8 changes: 7 additions & 1 deletion flytepropeller/pkg/controller/nodes/task/taskexec_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"context"
"strconv"

"github.com/lyft/flytepropeller/pkg/controller/nodes/common"

"github.com/lyft/flytepropeller/pkg/controller/nodes/task/resourcemanager"

"github.com/lyft/flytestdlib/logger"
Expand Down Expand Up @@ -124,7 +126,11 @@ func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx handler.Node

id := GetTaskExecutionIdentifier(nCtx)

uniqueID, err := utils.FixedLengthUniqueIDForParts(IDMaxLength, nCtx.NodeExecutionMetadata().GetOwnerID().Name, nCtx.NodeID(), strconv.Itoa(int(id.RetryAttempt)))
currentNodeUniqueID, err := common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeID())
if err != nil {
return nil, err
}
uniqueID, err := utils.FixedLengthUniqueIDForParts(IDMaxLength, nCtx.NodeExecutionMetadata().GetOwnerID().Name, currentNodeUniqueID, strconv.Itoa(int(id.RetryAttempt)))
if err != nil {
// SHOULD never really happen
return nil, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"context"
"testing"

mocks2 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/catalog/mocks"
ioMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks"
Expand Down Expand Up @@ -75,6 +77,11 @@ func TestHandler_newTaskExecutionContext(t *testing.T) {
nCtx.OnEventsRecorder().Return(nil)
nCtx.OnEnqueueOwnerFunc().Return(nil)

executionContext := &mocks2.ExecutionContext{}
executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{})
executionContext.OnGetParentInfo().Return(nil)
nCtx.OnExecutionContext().Return(executionContext)

ds, err := storage.NewDataStore(
&storage.Config{
Type: storage.TypeMemory,
Expand Down
19 changes: 17 additions & 2 deletions flytepropeller/pkg/controller/nodes/task/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/event"
pluginCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
"github.com/lyft/flytepropeller/pkg/controller/executors"
"github.com/lyft/flytepropeller/pkg/controller/nodes/common"

"github.com/lyft/flytepropeller/pkg/controller/nodes/handler"
)
Expand Down Expand Up @@ -51,7 +54,7 @@ func trimErrorMessage(original string, maxLength int) string {
}

func ToTaskExecutionEvent(taskExecID *core.TaskExecutionIdentifier, in io.InputFilePaths, out io.OutputFilePaths, info pluginCore.PhaseInfo,
nodeExecutionMetadata handler.NodeExecutionMetadata) (*event.TaskExecutionEvent, error) {
nodeExecutionMetadata handler.NodeExecutionMetadata, execContext executors.ExecutionContext) (*event.TaskExecutionEvent, error) {
// Transitions to a new phase

tm := ptypes.TimestampNow()
Expand All @@ -63,9 +66,21 @@ func ToTaskExecutionEvent(taskExecID *core.TaskExecutionIdentifier, in io.InputF
}
}

nodeExecutionID := &core.NodeExecutionIdentifier{
ExecutionId: taskExecID.NodeExecutionId.ExecutionId,
}
if execContext.GetEventVersion() != v1alpha1.EventVersion0 {
currentNodeUniqueID, err := common.GenerateUniqueID(execContext.GetParentInfo(), taskExecID.NodeExecutionId.NodeId)
if err != nil {
return nil, err
}
nodeExecutionID.NodeId = currentNodeUniqueID
} else {
nodeExecutionID.NodeId = taskExecID.NodeExecutionId.NodeId
}
tev := &event.TaskExecutionEvent{
TaskId: taskExecID.TaskId,
ParentNodeExecutionId: taskExecID.NodeExecutionId,
ParentNodeExecutionId: nodeExecutionID,
RetryAttempt: taskExecID.RetryAttempt,
Phase: ToTaskEventPhase(info.Phase()),
PhaseVersion: info.Version(),
Expand Down
82 changes: 79 additions & 3 deletions flytepropeller/pkg/controller/nodes/task/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"testing"
"time"

"github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"
mocks2 "github.com/lyft/flytepropeller/pkg/controller/executors/mocks"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/event"

"github.com/golang/protobuf/ptypes"
Expand Down Expand Up @@ -73,8 +76,13 @@ func TestToTaskExecutionEvent(t *testing.T) {

nodeExecutionMetadata := handlerMocks.NodeExecutionMetadata{}
nodeExecutionMetadata.OnIsInterruptible().Return(true)

mockExecContext := &mocks2.ExecutionContext{}
mockExecContext.OnGetEventVersion().Return(v1alpha1.EventVersion0)
mockExecContext.OnGetParentInfo().Return(nil)

tev, err := ToTaskExecutionEvent(id, in, out, pluginCore.PhaseInfoWaitingForResources(n, 0, "reason"),
&nodeExecutionMetadata)
&nodeExecutionMetadata, mockExecContext)
assert.NoError(t, err)
assert.Nil(t, tev.Logs)
assert.Equal(t, core.TaskExecution_WAITING_FOR_RESOURCES, tev.Phase)
Expand All @@ -94,7 +102,7 @@ func TestToTaskExecutionEvent(t *testing.T) {
OccurredAt: &n,
Logs: l,
CustomInfo: c,
}), &nodeExecutionMetadata)
}), &nodeExecutionMetadata, mockExecContext)
assert.NoError(t, err)
assert.Equal(t, core.TaskExecution_RUNNING, tev.Phase)
assert.Equal(t, uint32(1), tev.PhaseVersion)
Expand All @@ -113,7 +121,7 @@ func TestToTaskExecutionEvent(t *testing.T) {
OccurredAt: &n,
Logs: l,
CustomInfo: c,
}), &defaultNodeExecutionMetadata)
}), &defaultNodeExecutionMetadata, mockExecContext)
assert.NoError(t, err)
assert.Equal(t, core.TaskExecution_SUCCEEDED, tev.Phase)
assert.Equal(t, uint32(0), tev.PhaseVersion)
Expand All @@ -133,3 +141,71 @@ func TestToTransitionType(t *testing.T) {
assert.Equal(t, handler.TransitionTypeEphemeral, ToTransitionType(pluginCore.TransitionTypeEphemeral))
assert.Equal(t, handler.TransitionTypeBarrier, ToTransitionType(pluginCore.TransitionTypeBarrier))
}

func TestToTaskExecutionEventWithParent(t *testing.T) {
tkID := &core.Identifier{}
nodeID := &core.NodeExecutionIdentifier{
NodeId: "n1234567812345678123344568",
}
id := &core.TaskExecutionIdentifier{
TaskId: tkID,
NodeExecutionId: nodeID,
}
n := time.Now()
np, _ := ptypes.TimestampProto(n)

in := &mocks.InputFilePaths{}
const inputPath = "in"
in.On("GetInputPath").Return(storage.DataReference(inputPath))

out := &mocks.OutputFilePaths{}
const outputPath = "out"
out.On("GetOutputPath").Return(storage.DataReference(outputPath))

nodeExecutionMetadata := handlerMocks.NodeExecutionMetadata{}
nodeExecutionMetadata.OnIsInterruptible().Return(true)

mockExecContext := &mocks2.ExecutionContext{}
mockExecContext.OnGetEventVersion().Return(v1alpha1.EventVersion1)
mockParentInfo := &mocks2.ImmutableParentInfo{}
mockParentInfo.OnGetUniqueID().Return("np1")
mockParentInfo.OnCurrentAttempt().Return(uint32(2))
mockExecContext.OnGetParentInfo().Return(mockParentInfo)

tev, err := ToTaskExecutionEvent(id, in, out, pluginCore.PhaseInfoWaitingForResources(n, 0, "reason"),
&nodeExecutionMetadata, mockExecContext)
assert.NoError(t, err)
expectedNodeID := &core.NodeExecutionIdentifier{
NodeId: "fmxzd5ta",
}
assert.Nil(t, tev.Logs)
assert.Equal(t, core.TaskExecution_WAITING_FOR_RESOURCES, tev.Phase)
assert.Equal(t, uint32(0), tev.PhaseVersion)
assert.Equal(t, np, tev.OccurredAt)
assert.Equal(t, tkID, tev.TaskId)
assert.Equal(t, expectedNodeID, tev.ParentNodeExecutionId)
assert.Equal(t, inputPath, tev.InputUri)
assert.Nil(t, tev.OutputResult)
assert.Equal(t, event.TaskExecutionMetadata_INTERRUPTIBLE, tev.Metadata.InstanceClass)

l := []*core.TaskLog{
{Uri: "x", Name: "y", MessageFormat: core.TaskLog_JSON},
}
c := &structpb.Struct{}
tev, err = ToTaskExecutionEvent(id, in, out, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{
OccurredAt: &n,
Logs: l,
CustomInfo: c,
}), &nodeExecutionMetadata, mockExecContext)
assert.NoError(t, err)
assert.Equal(t, core.TaskExecution_RUNNING, tev.Phase)
assert.Equal(t, uint32(1), tev.PhaseVersion)
assert.Equal(t, l, tev.Logs)
assert.Equal(t, c, tev.CustomInfo)
assert.Equal(t, np, tev.OccurredAt)
assert.Equal(t, tkID, tev.TaskId)
assert.Equal(t, expectedNodeID, tev.ParentNodeExecutionId)
assert.Equal(t, inputPath, tev.InputUri)
assert.Nil(t, tev.OutputResult)
assert.Equal(t, event.TaskExecutionMetadata_INTERRUPTIBLE, tev.Metadata.InstanceClass)
}

0 comments on commit 318ee7e

Please sign in to comment.