Skip to content

Commit

Permalink
Pod tasks can be used for map tasks (flyteorg#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrina Rogan committed Jun 11, 2021
1 parent 632c508 commit 7c143b2
Show file tree
Hide file tree
Showing 9 changed files with 512 additions and 58 deletions.
22 changes: 22 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,28 @@ func DemystifySuccess(status v1.PodStatus, info pluginsCore.TaskInfo) (pluginsCo
return pluginsCore.PhaseInfoSuccess(&info), nil
}

func DeterminePrimaryContainerPhase(primaryContainerName string, statuses []v1.ContainerStatus, info *pluginsCore.TaskInfo) pluginsCore.PhaseInfo {
for _, s := range statuses {
if s.Name == primaryContainerName {
if s.State.Waiting != nil || s.State.Running != nil {
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info)
}

if s.State.Terminated != nil {
if s.State.Terminated.ExitCode != 0 {
return pluginsCore.PhaseInfoRetryableFailure(
s.State.Terminated.Reason, s.State.Terminated.Message, info)
}
return pluginsCore.PhaseInfoSuccess(info)
}
}
}

// If for some reason we can't find the primary container, always just return a permanent failure
return pluginsCore.PhaseInfoFailure("PrimaryContainerMissing",
fmt.Sprintf("Primary container [%s] not found in pod's container statuses", primaryContainerName), info)
}

func ConvertPodFailureToError(status v1.PodStatus) (code, message string) {
code = "UnknownError"
message = "Pod failed. No message received from kubernetes."
Expand Down
76 changes: 76 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,79 @@ func TestDemystifyPending_testcases(t *testing.T) {
})
}
}

func TestDeterminePrimaryContainerPhase(t *testing.T) {
primaryContainerName := "primary"
secondaryContainer := v1.ContainerStatus{
Name: "secondary",
State: v1.ContainerState{
Terminated: &v1.ContainerStateTerminated{
ExitCode: 0,
},
},
}
var info = &pluginsCore.TaskInfo{}
t.Run("primary container waiting", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer, {
Name: primaryContainerName,
State: v1.ContainerState{
Waiting: &v1.ContainerStateWaiting{
Reason: "just dawdling",
},
},
},
}, info)
assert.Equal(t, pluginsCore.PhaseRunning, phaseInfo.Phase())
})
t.Run("primary container running", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer, {
Name: primaryContainerName,
State: v1.ContainerState{
Running: &v1.ContainerStateRunning{
StartedAt: metaV1.Now(),
},
},
},
}, info)
assert.Equal(t, pluginsCore.PhaseRunning, phaseInfo.Phase())
})
t.Run("primary container failed", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer, {
Name: primaryContainerName,
State: v1.ContainerState{
Terminated: &v1.ContainerStateTerminated{
ExitCode: 1,
Reason: "foo",
Message: "foo failed",
},
},
},
}, info)
assert.Equal(t, pluginsCore.PhaseRetryableFailure, phaseInfo.Phase())
assert.Equal(t, "foo", phaseInfo.Err().Code)
assert.Equal(t, "foo failed", phaseInfo.Err().Message)
})
t.Run("primary container succeeded", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer, {
Name: primaryContainerName,
State: v1.ContainerState{
Terminated: &v1.ContainerStateTerminated{
ExitCode: 0,
},
},
},
}, info)
assert.Equal(t, pluginsCore.PhaseSuccess, phaseInfo.Phase())
})
t.Run("missing primary container", func(t *testing.T) {
phaseInfo := DeterminePrimaryContainerPhase(primaryContainerName, []v1.ContainerStatus{
secondaryContainer,
}, info)
assert.Equal(t, pluginsCore.PhasePermanentFailure, phaseInfo.Phase())
assert.Equal(t, "Primary container [primary] not found in pod's container statuses", phaseInfo.Err().Message)
})
}
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus
// No chance to reach the required success numbers.
if totalRunning+totalSuccesses+totalWaitingForResources < minSuccesses {
logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v]",
totalRunning, totalSuccesses, totalWaitingForResources, minSuccesses)
minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources)
return PhaseWriteToDiscoveryThenFail
}

Expand Down
10 changes: 10 additions & 0 deletions flyteplugins/go/tasks/plugins/array/k8s/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,16 @@ func FetchPodStatusAndLogs(ctx context.Context, client core.KubeClient, name k8s
case v1.PodUnknown:
phaseInfo = core.PhaseInfoUndefined
default:
primaryContainerName, ok := pod.GetAnnotations()[primaryContainerKey]
if ok {
// Special handling for determining the phase of an array job for a Pod task.
phaseInfo = flytek8s.DeterminePrimaryContainerPhase(primaryContainerName, pod.Status.ContainerStatuses, &taskInfo)
if phaseInfo.Phase() == core.PhaseRunning && len(taskInfo.Logs) > 0 {
return core.PhaseInfoRunning(core.DefaultPhaseVersion+1, phaseInfo.Info()), nil
}
return phaseInfo, nil
}

if len(taskInfo.Logs) > 0 {
phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion+1, &taskInfo)
} else {
Expand Down
49 changes: 28 additions & 21 deletions flyteplugins/go/tasks/plugins/array/k8s/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus"
Expand All @@ -22,6 +20,7 @@ import (
"github.com/flyteorg/flytestdlib/logger"
"github.com/flyteorg/flytestdlib/storage"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
k8sTypes "k8s.io/apimachinery/pkg/types"
Expand Down Expand Up @@ -51,6 +50,26 @@ const (
MonitorError
)

func getTaskContainerIndex(pod *v1.Pod) (int, error) {
primaryContainerName, ok := pod.Annotations[primaryContainerKey]
// For tasks with a Container target, we only ever build one container as part of the pod
if !ok {
if len(pod.Spec.Containers) == 1 {
return 0, nil
}
// For tasks with a K8sPod task target, they may produce multiple containers but at least one must be the designated primary.
return -1, errors2.Errorf(ErrBuildPodTemplate, "Expected a specified primary container key when building an array job with a K8sPod spec target")

}

for idx, container := range pod.Spec.Containers {
if container.Name == primaryContainerName {
return idx, nil
}
}
return -1, errors2.Errorf(ErrBuildPodTemplate, "Couldn't find any container matching the primary container key when building an array job with a K8sPod spec target")
}

func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) (LaunchResult, error) {
podTemplate, _, err := FlyteArrayJobToK8sPodTemplate(ctx, tCtx, t.Config.NamespaceTemplate)
if err != nil {
Expand All @@ -60,43 +79,31 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl
if t.Config.RemoteClusterConfig.Enabled {
podTemplate.OwnerReferences = nil
}
var args []string
if len(podTemplate.Spec.Containers) > 0 {
args = append(podTemplate.Spec.Containers[0].Command, podTemplate.Spec.Containers[0].Args...)
podTemplate.Spec.Containers[0].Command = []string{}
} else {
if len(podTemplate.Spec.Containers) == 0 {
return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "No containers found in podSpec.")
}
containerIndex, err := getTaskContainerIndex(&podTemplate)
if err != nil {
return LaunchError, err
}

indexStr := strconv.Itoa(t.ChildIdx)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr)

pod := podTemplate.DeepCopy()
pod.Name = podName
pod.Spec.Containers[0].Env = append(pod.Spec.Containers[0].Env, corev1.EnvVar{
pod.Spec.Containers[containerIndex].Env = append(pod.Spec.Containers[containerIndex].Env, corev1.EnvVar{
Name: FlyteK8sArrayIndexVarName,
Value: indexStr,
})

pod.Spec.Containers[0].Env = append(pod.Spec.Containers[0].Env, arrayJobEnvVars...)
pod.Spec.Containers[containerIndex].Env = append(pod.Spec.Containers[containerIndex].Env, arrayJobEnvVars...)
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return LaunchError, errors2.Wrapf(ErrGetTaskTypeVersion, err, "Unable to read task template")
} else if taskTemplate == nil {
return LaunchError, errors2.Wrapf(ErrGetTaskTypeVersion, err, "Missing task template")
}
inputReader := array.GetInputReader(tCtx, taskTemplate)
pod.Spec.Containers[0].Args, err = template.Render(ctx, args,
template.Parameters{
TaskExecMetadata: tCtx.TaskExecutionMetadata(),
Inputs: inputReader,
OutputPath: tCtx.OutputWriter(),
Task: tCtx.TaskReader(),
})
if err != nil {
return LaunchError, errors2.Wrapf(ErrReplaceCmdTemplate, err, "Failed to replace cmd args")
}

pod = ApplyPodPolicies(ctx, t.Config, pod)
pod = applyNodeSelectorLabels(ctx, t.Config, pod)
pod = applyPodTolerations(ctx, t.Config, pod)
Expand Down
85 changes: 85 additions & 0 deletions flyteplugins/go/tasks/plugins/array/k8s/task_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package k8s

import (
"testing"

"github.com/stretchr/testify/assert"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestGetTaskContainerIndex(t *testing.T) {
t.Run("test container target", func(t *testing.T) {
pod := &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "container",
},
},
},
}
index, err := getTaskContainerIndex(pod)
assert.NoError(t, err)
assert.Equal(t, 0, index)
})
t.Run("test missing primary container annotation", func(t *testing.T) {
pod := &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "container",
},
{
Name: "container b",
},
},
},
}
_, err := getTaskContainerIndex(pod)
assert.EqualError(t, err, "[POD_TEMPLATE_FAILED] Expected a specified primary container key when building an array job with a K8sPod spec target")
})
t.Run("test get primary container index", func(t *testing.T) {
pod := &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "container a",
},
{
Name: "container b",
},
{
Name: "container c",
},
},
},
ObjectMeta: metav1.ObjectMeta{
Annotations: map[string]string{
primaryContainerKey: "container c",
},
},
}
index, err := getTaskContainerIndex(pod)
assert.NoError(t, err)
assert.Equal(t, 2, index)
})
t.Run("specified primary container doesn't exist", func(t *testing.T) {
pod := &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "container a",
},
},
},
ObjectMeta: metav1.ObjectMeta{
Annotations: map[string]string{
primaryContainerKey: "container c",
},
},
}
_, err := getTaskContainerIndex(pod)
assert.EqualError(t, err, "[POD_TEMPLATE_FAILED] Couldn't find any container matching the primary container key when building an array job with a K8sPod spec target")
})
}
Loading

0 comments on commit 7c143b2

Please sign in to comment.