From 445f01bfb0a193dceeb42225a181840793062f27 Mon Sep 17 00:00:00 2001 From: Hitoshi Mitake Date: Mon, 26 Feb 2018 14:22:52 +0900 Subject: [PATCH] Add a field SchedulerName to TFJob for specifying a scheduler This commit adds a new field SchedulerName to the definition of TFJob. The purpose of the field is specifying the scheduler name of the pods created by tf-operator and let the scheduler (which wouldn't be the default scheduler) handle them. It would be convenient for letting kube-batchd (a component of kube-arbitrator) handle the pods. --- pkg/apis/tensorflow/v1alpha1/types.go | 3 +++ pkg/trainer/replicas.go | 2 ++ pkg/trainer/replicas_test.go | 8 ++++++++ pkg/trainer/training.go | 4 ++++ 4 files changed, 17 insertions(+) diff --git a/pkg/apis/tensorflow/v1alpha1/types.go b/pkg/apis/tensorflow/v1alpha1/types.go index 7dc6ba6bc7..26e26d3c45 100644 --- a/pkg/apis/tensorflow/v1alpha1/types.go +++ b/pkg/apis/tensorflow/v1alpha1/types.go @@ -56,6 +56,9 @@ type TFJobSpec struct { // TerminationPolicy specifies the condition that the tfjob should be considered finished. TerminationPolicy *TerminationPolicySpec `json:"terminationPolicy,omitempty"` + + // SchedulerName specifies the name of scheduler which should handle the TFJob + SchedulerName string `json:"schedulerName,omitempty"` } type TerminationPolicySpec struct { diff --git a/pkg/trainer/replicas.go b/pkg/trainer/replicas.go index 047842905b..4e44d1f8a0 100644 --- a/pkg/trainer/replicas.go +++ b/pkg/trainer/replicas.go @@ -171,6 +171,8 @@ func (s *TFReplicaSet) CreatePodWithIndex(index int32) (*v1.Pod, error) { Spec: *s.Spec.Template.Spec.DeepCopy(), } + pod.Spec.SchedulerName = s.Job.SchedulerName() + // Configure the TFCONFIG environment variable. tfConfig := TFConfig{ Cluster: s.Job.ClusterSpec(), diff --git a/pkg/trainer/replicas_test.go b/pkg/trainer/replicas_test.go index 43284abf0e..28642e8321 100644 --- a/pkg/trainer/replicas_test.go +++ b/pkg/trainer/replicas_test.go @@ -18,6 +18,7 @@ import ( "encoding/json" "fmt" "reflect" + "strings" "testing" "time" @@ -44,6 +45,8 @@ var ( func TestTFReplicaSet(t *testing.T) { clientSet := fake.NewSimpleClientset() + testSchedulerName := "test-scheduler" + jobSpec := &tfv1alpha1.TFJob{ ObjectMeta: meta_v1.ObjectMeta{ Name: "some-job", @@ -67,6 +70,7 @@ func TestTFReplicaSet(t *testing.T) { TFReplicaType: tfv1alpha1.PS, }, }, + SchedulerName: testSchedulerName, }, } @@ -169,6 +173,10 @@ func TestTFReplicaSet(t *testing.T) { t.Fatalf("Expected 1 environment variable got %v", len(c.Env)) } + if strings.Compare(p.Spec.SchedulerName, testSchedulerName) != 0 { + t.Fatalf("p.Spec.Template.Spec.SchedulerName; Got %v; want %v", p.Spec.SchedulerName, testSchedulerName) + } + actualTFConfig := &TFConfig{} if err := json.Unmarshal([]byte(c.Env[0].Value), actualTFConfig); err != nil { t.Fatalf("Could not unmarshal TFConfig %v", err) diff --git a/pkg/trainer/training.go b/pkg/trainer/training.go index 316bf572ae..fe2939bf16 100644 --- a/pkg/trainer/training.go +++ b/pkg/trainer/training.go @@ -412,3 +412,7 @@ func (j *TrainingJob) name() string { func (j *TrainingJob) fullname() string { return j.job.ObjectMeta.GetNamespace() + ":" + j.job.ObjectMeta.GetName() } + +func (j *TrainingJob) SchedulerName() string { + return j.job.Spec.SchedulerName +}