Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Implement Spark pod template tolerations
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Dye <andrewwdye@gmail.com>
  • Loading branch information
andrewwdye committed Oct 2, 2023
1 parent 445fc58 commit 24a2e43
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 16 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,5 @@ require (
)

replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d

replace github.com/flyteorg/flyteidl => /Users/andrew/dev/forks/flyteidl
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,6 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/flyteorg/flyteidl v1.5.13 h1:IQ2Cw+u36ew3BPyRDAcHdzc/GyNEOXOxhKy9jbS4hbo=
github.com/flyteorg/flyteidl v1.5.13/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og=
github.com/flyteorg/flytestdlib v1.0.24 h1:jDvymcjlsTRCwOtxPapro0WZBe3isTz+T3Tiq+mZUuk=
github.com/flyteorg/flytestdlib v1.0.24/go.mod h1:6nXa5g00qFIsgdvQ7jKQMJmDniqO0hG6Z5X5olfduqQ=
github.com/flyteorg/stow v0.3.7 h1:Cx7j8/Ux6+toD5hp5fy++927V+yAcAttDeQAlUD/864=
Expand Down
37 changes: 33 additions & 4 deletions go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"google.golang.org/protobuf/types/known/structpb"
"sigs.k8s.io/controller-runtime/pkg/client"

"strconv"
Expand All @@ -20,14 +21,17 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"

v1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes/scheme"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"k8s.io/client-go/kubernetes/scheme"

sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"regexp"
"strings"
Expand Down Expand Up @@ -61,6 +65,21 @@ func (sparkResourceHandler) GetProperties() k8s.PluginProperties {
return k8s.PluginProperties{}
}

func getTolerations(podSpecPb *structpb.Struct) ([]v1.Toleration, error) {
tolerations := make([]v1.Toleration, 0)
tolerations = append(tolerations, config.GetK8sPluginConfig().DefaultTolerations...)
if podSpecPb != nil {
var podSpec v1.PodSpec
err := utils.UnmarshalStruct(podSpecPb, &podSpec)
if err != nil {
return nil, errors.Wrapf(errors.BadTaskSpecification, err,
"invalid pod spec [%v], failed to unmarshal", podSpec)
}
tolerations = append(tolerations, podSpec.Tolerations...)
}
return tolerations, nil
}

// Creates a new Job that will execute the main container as well as any generated types the result from the execution.
func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
Expand Down Expand Up @@ -99,6 +118,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
if len(serviceAccountName) == 0 {
serviceAccountName = sparkTaskType
}

tolerations, err := getTolerations(sparkJob.GetDriverPod().GetPodSpec())
if err != nil {
return nil, err
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
Expand All @@ -108,14 +132,19 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
Image: &container.Image,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
Tolerations: tolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
ServiceAccount: &serviceAccountName,
}

tolerations, err = getTolerations(sparkJob.GetExecutorPod().GetPodSpec())
if err != nil {
return nil, err
}

executorSpec := sparkOp.ExecutorSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity.DeepCopy(),
Expand All @@ -125,7 +154,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
EnvVars: sparkEnvVars,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
Tolerations: tolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
Expand Down
92 changes: 82 additions & 10 deletions go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ import (
pluginIOMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"

sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
)

const sparkMainClass = "MainClass"
Expand Down Expand Up @@ -87,7 +88,8 @@ func TestGetEventInfo(t *testing.T) {
},
},
}))
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("blah-1", dummySparkConf), false)
sparkJob := dummySparkCustomObj(dummySparkConf)
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("blah-1", sparkJob), false)
info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState))
assert.NoError(t, err)
assert.Len(t, info.Logs, 6)
Expand Down Expand Up @@ -157,7 +159,8 @@ func TestGetTaskPhase(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

ctx := context.TODO()
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("", dummySparkConf), false)
sparkJob := dummySparkCustomObj(dummySparkConf)
taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("", sparkJob), false)
taskPhase, err := sparkResourceHandler.GetTaskPhase(ctx, taskCtx, dummySparkApplication(sj.NewState))
assert.NoError(t, err)
assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseQueued)
Expand Down Expand Up @@ -242,17 +245,14 @@ func dummySparkApplication(state sj.ApplicationStateType) *sj.SparkApplication {

func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob {
sparkJob := plugins.SparkJob{}

sparkJob.MainClass = sparkMainClass
sparkJob.MainApplicationFile = sparkApplicationFile
sparkJob.SparkConf = sparkConf
sparkJob.ApplicationType = plugins.SparkApplication_PYTHON
return &sparkJob
}

func dummySparkTaskTemplate(id string, sparkConf map[string]string) *core.TaskTemplate {

sparkJob := dummySparkCustomObj(sparkConf)
func dummySparkTaskTemplate(id string, sparkJob *plugins.SparkJob) *core.TaskTemplate {
sparkJobJSON, err := utils.MarshalToString(sparkJob)
if err != nil {
panic(err)
Expand Down Expand Up @@ -335,7 +335,8 @@ func TestBuildResourceSpark(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

// Case1: Valid Spark Task-Template
taskTemplate := dummySparkTaskTemplate("blah-1", dummySparkConf)
sparkJob := dummySparkCustomObj(dummySparkConf)
taskTemplate := dummySparkTaskTemplate("blah-1", sparkJob)

// Set spark custom feature config.
assert.NoError(t, setSparkConfig(&Config{
Expand Down Expand Up @@ -619,7 +620,8 @@ func TestBuildResourceSpark(t *testing.T) {
dummyConfWithRequest["spark.kubernetes.driver.request.cores"] = "3"
dummyConfWithRequest["spark.kubernetes.executor.request.cores"] = "4"

taskTemplate = dummySparkTaskTemplate("blah-1", dummyConfWithRequest)
sparkJob = dummySparkCustomObj(dummyConfWithRequest)
taskTemplate = dummySparkTaskTemplate("blah-1", sparkJob)
resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.Nil(t, err)
assert.NotNil(t, resource)
Expand Down Expand Up @@ -678,6 +680,76 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Nil(t, resource)
}

func TestBuildResourcePodTemplate(t *testing.T) {
defaultToleration := corev1.Toleration{

Key: "x/flyte",
Value: "default",
Operator: "Equal",
}
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultTolerations: []corev1.Toleration{
defaultToleration,
},
})
assert.NoError(t, err)
sparkJob := dummySparkCustomObj(dummySparkConf)
extraDriverToleration := corev1.Toleration{
Key: "x/flyte",
Value: "extra-driver",
Operator: "Equal",
}
podSpec := corev1.PodSpec{
Tolerations: []corev1.Toleration{
extraDriverToleration,
},
}
driverPodSpecPb := structpb.Struct{}
err = utils.MarshalStruct(&podSpec, &driverPodSpecPb)
assert.NoError(t, err)
sparkJob.DriverPodValue = &plugins.SparkJob_DriverPod{
DriverPod: &core.K8SPod{
PodSpec: &driverPodSpecPb,
},
}
extraExecutorToleration := corev1.Toleration{
Key: "x/flyte",
Value: "extra-executor",
Operator: "Equal",
}
podSpec = corev1.PodSpec{
Tolerations: []corev1.Toleration{
extraExecutorToleration,
},
}
execPodSpecPb := structpb.Struct{}
err = utils.MarshalStruct(&podSpec, &execPodSpecPb)
assert.NoError(t, err)
sparkJob.ExecutorPodValue = &plugins.SparkJob_ExecutorPod{
ExecutorPod: &core.K8SPod{
PodSpec: &execPodSpecPb,
},
}
taskTemplate := dummySparkTaskTemplate("blah-1", sparkJob)
sparkResourceHandler := sparkResourceHandler{}
resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.Nil(t, err)

assert.NotNil(t, resource)
sparkApp, ok := resource.(*sj.SparkApplication)
assert.True(t, ok)
assert.Equal(t, 2, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, sparkApp.Spec.Driver.Tolerations, []corev1.Toleration{
defaultToleration,
extraDriverToleration,
})
assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, sparkApp.Spec.Executor.Tolerations, []corev1.Toleration{
defaultToleration,
extraExecutorToleration,
})
}

func TestGetPropertiesSpark(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}
expected := k8s.PluginProperties{}
Expand Down

0 comments on commit 24a2e43

Please sign in to comment.