Skip to content

Commit

Permalink
Create PDB of TFReplicaSet for gang scheduling by kube-arbitrator (#452)
Browse files Browse the repository at this point in the history
* Create PDB of TFReplicaSet for gang scheduling by kube-arbitrator

kube-arbitrator (technically its component kube-batchd) requires PDB
for its gang scheduling feature. The feature is useful for creating
every pod of TFJob at the same time. This commit lets tf-operator
create the PDB for the purpose.

* Add a new UT TestPDBForGangScheduling for checking PDB creation
  • Loading branch information
mitake authored and k8s-ci-robot committed Mar 20, 2018
1 parent a36b350 commit ebb0053
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 6 deletions.
2 changes: 2 additions & 0 deletions cmd/tf-operator/app/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type ServerOption struct {
PrintVersion bool
GCInterval time.Duration
JsonLogFormat bool
EnableGangScheduling bool
}

// NewServerOption creates a new CMServer with a default config.
Expand All @@ -42,4 +43,5 @@ func (s *ServerOption) AddFlags(fs *flag.FlagSet) {
fs.DurationVar(&s.GCInterval, "gc-interval", 10*time.Minute, "GC interval")
fs.StringVar(&s.ControllerConfigFile, "controller-config-file", "", "Path to file containing the controller config.")
fs.BoolVar(&s.JsonLogFormat, "json-log-format", true, "Set true to use json style log format. Set false to use plaintext style log format")
fs.BoolVar(&s.EnableGangScheduling, "enable-gang-scheduling", false, "Set true to enable gang scheduling by kube-arbitrator.")
}
2 changes: 1 addition & 1 deletion cmd/tf-operator/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func Run(opt *options.ServerOption) error {
defer close(neverStop)

tfJobInformerFactory := informers.NewSharedInformerFactory(tfJobClient, time.Second*30)
controller, err := controller.New(kubeClient, tfJobClient, *controllerConfig, tfJobInformerFactory)
controller, err := controller.New(kubeClient, tfJobClient, *controllerConfig, tfJobInformerFactory, opt.EnableGangScheduling)
if err != nil {
return err
}
Expand Down
12 changes: 8 additions & 4 deletions pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ type Controller struct {
recorder record.EventRecorder

syncHandler func(jobKey string) (bool, error)

enableGangScheduling bool
}

func New(kubeClient kubernetes.Interface, tfJobClient tfjobclient.Interface,
config tfv1alpha1.ControllerConfig, tfJobInformerFactory informers.SharedInformerFactory) (*Controller, error) {
config tfv1alpha1.ControllerConfig, tfJobInformerFactory informers.SharedInformerFactory,
enableGangScheduling bool) (*Controller, error) {
tfJobInformer := tfJobInformerFactory.Kubeflow().V1alpha1().TFJobs()

kubeflowscheme.AddToScheme(scheme.Scheme)
Expand All @@ -100,8 +103,9 @@ func New(kubeClient kubernetes.Interface, tfJobClient tfjobclient.Interface,
WorkQueue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "TFjobs"),
recorder: recorder,
// TODO(jlewi)): What to do about cluster.Cluster?
jobs: make(map[string]*trainer.TrainingJob),
config: config,
jobs: make(map[string]*trainer.TrainingJob),
config: config,
enableGangScheduling: enableGangScheduling,
}

log.Info("Setting up event handlers")
Expand Down Expand Up @@ -236,7 +240,7 @@ func (c *Controller) syncTFJob(key string) (bool, error) {

nc := c.jobs[key]

if err := nc.Reconcile(&c.config); err != nil {
if err := nc.Reconcile(&c.config, c.enableGangScheduling); err != nil {
return false, err
}

Expand Down
70 changes: 69 additions & 1 deletion pkg/trainer/training.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ import (

log "github.com/sirupsen/logrus"
"k8s.io/api/core/v1"
"k8s.io/api/policy/v1beta1"
k8s_errors "k8s.io/apimachinery/pkg/api/errors"
meta_v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/record"

Expand Down Expand Up @@ -53,6 +57,8 @@ type TrainingJob struct {
status tfv1alpha1.TFJobStatus

memberCounter int

pdb *v1beta1.PodDisruptionBudget
}

// ClusterSpec represents a cluster TensorFlow specification.
Expand Down Expand Up @@ -268,6 +274,14 @@ func (j *TrainingJob) Delete() {
if cErr := j.deleteResources(); cErr != nil {
log.Errorf("trainingJob.deleteResources() error; %v", cErr)
}

if j.pdb != nil {
// if the job has PDB for gang scheduling, delete it
err := j.KubeCli.PolicyV1beta1().PodDisruptionBudgets(j.job.ObjectMeta.Namespace).Delete(j.pdb.ObjectMeta.Name, &meta_v1.DeleteOptions{})
if err != nil {
log.Errorf("Error deleting PDB %v; %v", j.pdb.ObjectMeta.Name, err)
}
}
}

// updateCRDStatus updates the job status based on TraingingJob.status.
Expand All @@ -290,7 +304,7 @@ func (j *TrainingJob) updateCRDStatus() error {
}

// reconcile tries to get the job into the desired state.
func (j *TrainingJob) Reconcile(config *tfv1alpha1.ControllerConfig) error {
func (j *TrainingJob) Reconcile(config *tfv1alpha1.ControllerConfig, enableGangScheduling bool) error {
if j.job.Status.Phase == tfv1alpha1.TFJobPhaseNone {
// The job hasn't been setup.
j.setup(config)
Expand All @@ -313,6 +327,15 @@ func (j *TrainingJob) Reconcile(config *tfv1alpha1.ControllerConfig) error {
return err
}

// sync PDB for gang scheduling
// TODO(mitake): replace PDB with a newer mechanism if it is replaced
if enableGangScheduling {
err := j.syncPdb()
if err != nil {
log.Errorf("SyncPdb error: %v", err)
}
}

// sync pods
for _, rc := range j.Replicas {
err := rc.SyncPods()
Expand Down Expand Up @@ -398,3 +421,48 @@ func (j *TrainingJob) fullname() string {
func (j *TrainingJob) SchedulerName() string {
return j.job.Spec.SchedulerName
}

// SyncPdb will create a PDB for gang scheduling by kube-arbitrator.
func (j *TrainingJob) syncPdb() error {
nrReplicas := int32(0)
for _, r := range j.Replicas {
nrReplicas += *r.Spec.Replicas
}

if nrReplicas == 1 {
// gang scheduling isn't required by a non distributed training process
return nil
}

minAvailable := intstr.FromInt(int(nrReplicas))
pdb := &v1beta1.PodDisruptionBudget{
ObjectMeta: meta_v1.ObjectMeta{
GenerateName: "tf-job-pdb-",
},
Spec: v1beta1.PodDisruptionBudgetSpec{
MinAvailable: &minAvailable,
Selector: &meta_v1.LabelSelector{
MatchLabels: map[string]string{
"runtime_id": j.job.Spec.RuntimeId,
"tf_job_name": j.job.ObjectMeta.Name,
},
},
},
}

createdPdb, err := j.KubeCli.PolicyV1beta1().PodDisruptionBudgets(j.job.ObjectMeta.Namespace).Create(pdb)
if err != nil {
if k8s_errors.IsAlreadyExists(err) {
log.Infof("PDB: %v already exists.", j.job.ObjectMeta.Name)
return nil
}

j.recorder.Eventf(j.job, v1.EventTypeWarning, FailedCreateReason, "Error creating: %v", err)
return err
}

j.pdb = createdPdb

j.recorder.Eventf(j.job, v1.EventTypeNormal, SuccessfulCreateReason, "Created PDB: %v", createdPdb.Name)
return nil
}
145 changes: 145 additions & 0 deletions pkg/trainer/training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import (
tfv1alpha1 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1alpha1"
tfJobFake "github.com/kubeflow/tf-operator/pkg/client/clientset/versioned/fake"
"k8s.io/api/core/v1"
"k8s.io/api/policy/v1beta1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/client-go/kubernetes/fake"
"k8s.io/client-go/tools/record"
)
Expand Down Expand Up @@ -342,3 +344,146 @@ func TestJobSetup(t *testing.T) {
}
}
}

func TestPDBForGangScheduling(t *testing.T) {
clientSet := fake.NewSimpleClientset()

type testCase struct {
jobSpec *tfv1alpha1.TFJob
expectPdb *v1beta1.PodDisruptionBudget
}

minAvailable3 := intstr.FromInt(3)

testCases := []testCase{
{
jobSpec: &tfv1alpha1.TFJob{
ObjectMeta: metav1.ObjectMeta{
Name: "some-meta-name",
},
Spec: tfv1alpha1.TFJobSpec{
RuntimeId: "some-runtime-id",
ReplicaSpecs: []*tfv1alpha1.TFReplicaSpec{
{
Replicas: proto.Int32(1),
TFPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TFReplicaType: tfv1alpha1.WORKER,
},
},
},
},
expectPdb: nil,
},

{
jobSpec: &tfv1alpha1.TFJob{
ObjectMeta: metav1.ObjectMeta{
Name: "some-meta-name",
},
Spec: tfv1alpha1.TFJobSpec{
RuntimeId: "some-runtime-id",
ReplicaSpecs: []*tfv1alpha1.TFReplicaSpec{
{
Replicas: proto.Int32(1),
TFPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TFReplicaType: tfv1alpha1.MASTER,
},
{
Replicas: proto.Int32(1),
TFPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TFReplicaType: tfv1alpha1.PS,
},
{
Replicas: proto.Int32(1),
TFPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TFReplicaType: tfv1alpha1.WORKER,
},
},
},
},
expectPdb: &v1beta1.PodDisruptionBudget{
Spec: v1beta1.PodDisruptionBudgetSpec{
MinAvailable: &minAvailable3,
Selector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"runtime_id": "some-runtime-id",
"tf_job_name": "some-meta-name",
},
},
},
},
},
}

for _, c := range testCases {
recorder := record.NewFakeRecorder(100)
job, err := initJob(clientSet, &tfJobFake.Clientset{}, recorder, c.jobSpec)
if err != nil {
t.Errorf("j.initJob() error: %v", err)
}

err = job.setupReplicas()
if err != nil {
t.Errorf("j.setupReplicas() error: %v", err)
}

err = job.syncPdb()
if err != nil {
t.Errorf("j.Reconcile() error: %v", err)
}

actualPdbList, err := clientSet.PolicyV1beta1().PodDisruptionBudgets(job.job.ObjectMeta.Namespace).List(metav1.ListOptions{})
if err != nil {
t.Fatalf("Could not get PDB List: %v", err)
}
if len(actualPdbList.Items) != 1 && c.expectPdb != nil {
t.Fatalf("k8s should have one PDB but the length of actually created PDB isn't 1, Got %d", len(actualPdbList.Items))
}

if c.expectPdb == nil {
// non distributed training job, shouldn't have PDB
continue
}

actualPdb := actualPdbList.Items[0]
if !reflect.DeepEqual(c.expectPdb.Spec, actualPdb.Spec) {
t.Fatalf("Got %v, Want %v", actualPdb.Spec, c.expectPdb.Spec)
}
}
}

0 comments on commit ebb0053

Please sign in to comment.