Skip to content

Commit

Permalink
allow using WORKER:0 as chief (#221)
Browse files Browse the repository at this point in the history
Allow training jobs to work without a master by treating one of the workers as the chiefs.

* Fixes #192

* This will allow TfJobs to be used with a lot of existing TensorFlow programs without modification since using worker 0 as the chief is a common pattern. Currently to run these programs using TfJob's  you need to spin up a dummy TensorFlow gRPC server just to serve as the master.

* This is also necessary to support changes in estimator API with TF 1.4 (#61)
  • Loading branch information
lluunn authored and jlewi committed Dec 20, 2017
1 parent 6a2fc9c commit cb1e053
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 62 deletions.
19 changes: 12 additions & 7 deletions pkg/spec/tf_job.go
Expand Up @@ -127,7 +127,11 @@ type ChiefSpec struct {

// Validate checks that the TfJobSpec is valid.
func (c *TfJobSpec) Validate() error {
if c.TerminationPolicy == nil || c.TerminationPolicy.Chief == nil {
return fmt.Errorf("invalid termination policy: %v", c.TerminationPolicy)
}
// Check that each replica has a TensorFlow container.
chiefExists := false
for _, r := range c.ReplicaSpecs {
found := false
if r.Template == nil && r.TfReplicaType != PS {
Expand All @@ -138,6 +142,10 @@ func (c *TfJobSpec) Validate() error {
return errors.New("The MASTER must have Replicas = 1")
}

if r.TfReplicaType == TfReplicaType(c.TerminationPolicy.Chief.ReplicaName) {
chiefExists = true
}

if r.TfPort == nil {
return errors.New("tfReplicaSpec.TfPort can't be nil.")
}
Expand Down Expand Up @@ -167,14 +175,11 @@ func (c *TfJobSpec) Validate() error {
return fmt.Errorf("Replica type %v is missing a container named %v", r.TfReplicaType, TENSORFLOW)
}
}
if c.TerminationPolicy != nil {
if c.TerminationPolicy.Chief == nil {
return errors.New("invalid termination policy, Chief cannot be nil")
}
if c.TerminationPolicy.Chief.ReplicaName != "MASTER" || c.TerminationPolicy.Chief.ReplicaIndex != 0 {
return errors.New("invalid termination policy, Chief should have replicaName=MASTER and index=0")
}

if !chiefExists {
return fmt.Errorf("Missing ReplicaSpec for chief: %v", c.TerminationPolicy.Chief.ReplicaName)
}

return nil
}

Expand Down
86 changes: 86 additions & 0 deletions pkg/spec/tf_job_test.go
Expand Up @@ -383,3 +383,89 @@ func TestSetDefaults(t *testing.T) {
})
}
}

func TestValidate(t *testing.T) {
type testCase struct {
in *TfJobSpec
expectingError bool
}

testCases := []testCase{
{
in: &TfJobSpec{
ReplicaSpecs: []*TfReplicaSpec{
{
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: MASTER,
Replicas: proto.Int32(1),
},
},
TfImage: "tensorflow/tensorflow:1.3.0",
},
expectingError: false,
},
{
in: &TfJobSpec{
ReplicaSpecs: []*TfReplicaSpec{
{
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: WORKER,
Replicas: proto.Int32(1),
},
},
TfImage: "tensorflow/tensorflow:1.3.0",
},
expectingError: true,
},
{
in: &TfJobSpec{
ReplicaSpecs: []*TfReplicaSpec{
{
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: WORKER,
Replicas: proto.Int32(1),
},
},
TfImage: "tensorflow/tensorflow:1.3.0",
TerminationPolicy: &TerminationPolicySpec{
Chief: &ChiefSpec{
ReplicaName: "WORKER",
ReplicaIndex: 0,
},
},
},
expectingError: false,
},
}

for _, c := range testCases {
c.in.SetDefaults("")
if err := c.in.Validate(); (err != nil) != c.expectingError {
t.Errorf("unexpected validation result: %v", err)
}
}
}
71 changes: 35 additions & 36 deletions pkg/trainer/replicas.go
Expand Up @@ -406,6 +406,40 @@ func replicaStatusFromPodList(l v1.PodList, name spec.ContainerName) spec.Replic
return spec.ReplicaStateUnknown
}

func (s *TFReplicaSet) GetSingleReplicaStatus(index int32) (spec.ReplicaState) {
j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.Metadata.Namespace).Get(s.jobName(index), meta_v1.GetOptions{})

if err != nil {
return spec.ReplicaStateUnknown
}

if j.Status.Succeeded >= 1 {
return spec.ReplicaStateSucceeded
}

labels := s.Labels()
labels["task_index"] = fmt.Sprintf("%v", index)
selector, err := labels.ToSelector()
if err != nil {
log.Errorf("labels.ToSelector() error; %v", err)
return spec.ReplicaStateFailed
}

// TODO(jlewi): Handle errors. We need to get the pod and looking at recent container exits.
l, err := s.ClientSet.CoreV1().Pods(s.Job.job.Metadata.Namespace).List(meta_v1.ListOptions{
// TODO(jlewi): Why isn't the label selector working?
LabelSelector: selector,
})

if err != nil {
// TODO(jlewi): Are there errors that should be treated as retryable errors?
return spec.ReplicaStateFailed
}

status := replicaStatusFromPodList(*l, spec.TENSORFLOW)
return status
}

// Status returns the status of the replica set.
func (s *TFReplicaSet) GetStatus() (spec.TfReplicaStatus, error) {

Expand All @@ -425,42 +459,7 @@ func (s *TFReplicaSet) GetStatus() (spec.TfReplicaStatus, error) {
}

for index := int32(0); index < *s.Spec.Replicas; index++ {

j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.Metadata.Namespace).Get(s.jobName(index), meta_v1.GetOptions{})

if err != nil {
increment(spec.ReplicaStateUnknown)
continue
}

if j.Status.Succeeded >= 1 {
increment(spec.ReplicaStateSucceeded)
continue
}

labels := s.Labels()
labels["task_index"] = fmt.Sprintf("%v", index)
selector, err := labels.ToSelector()
if err != nil {
log.Errorf("labels.ToSelector() error; %v", err)
increment(spec.ReplicaStateFailed)
continue
}

// TODO(jlewi): Handle errors. We need to get the pod and looking at recent container exits.
l, err := s.ClientSet.CoreV1().Pods(s.Job.job.Metadata.Namespace).List(meta_v1.ListOptions{
// TODO(jlewi): Why isn't the label selector working?
LabelSelector: selector,
})

if err != nil {
// TODO(jlewi): Are there errors that should be treated as retryable errors?
increment(spec.ReplicaStateFailed)
continue
}

status := replicaStatusFromPodList(*l, spec.TENSORFLOW)
increment(status)
increment(s.GetSingleReplicaStatus(index))
}

// Determine the overall status for the replica set based on the status of the individual
Expand Down
27 changes: 8 additions & 19 deletions pkg/trainer/training.go
Expand Up @@ -160,8 +160,9 @@ func (j *TrainingJob) deleteResources() error {
return nil
}

func (j *TrainingJob) GetStatus() (spec.State, []*spec.TfReplicaStatus, error) {
state := spec.StateUnknown
func (j *TrainingJob) GetStatus() (spec.ReplicaState, []*spec.TfReplicaStatus, error) {
chief := j.job.Spec.TerminationPolicy.Chief
chiefState := spec.ReplicaStateUnknown
replicaStatuses := make([]*spec.TfReplicaStatus, 0)

// The state for each replica.
Expand All @@ -178,24 +179,12 @@ func (j *TrainingJob) GetStatus() (spec.State, []*spec.TfReplicaStatus, error) {

replicaStatuses = append(replicaStatuses, &rStatus)

// If any replicas are failed mark job as failed.
if rStatus.State == spec.ReplicaStateFailed {
state = spec.StateFailed
if string(r.Spec.TfReplicaType) == string(chief.ReplicaName) {
chiefState = r.GetSingleReplicaStatus(int32(chief.ReplicaIndex))
}
}

if v, ok := replicaSetStates[spec.MASTER]; ok && v == spec.ReplicaStateSucceeded {
state = spec.StateSucceeded
return state, replicaStatuses, nil
}

if v, ok := replicaSetStates[spec.MASTER]; ok && v == spec.ReplicaStateFailed {
state = spec.StateFailed
return state, replicaStatuses, nil
}

state = spec.StateRunning
return state, replicaStatuses, nil
return chiefState, replicaStatuses, nil
}

// isRetryableTerminationState returns true if a container terminated in a state
Expand Down Expand Up @@ -373,11 +362,11 @@ func (j *TrainingJob) reconcile(config *spec.ControllerConfig) {
log.Errorf("GetStatus() for job %v returned error: %v", j.job.Metadata.Name, err)
}
// TODO(jlewi): We should update the Phase if we detect the job is done.
if state == spec.StateFailed {
if state == spec.ReplicaStateFailed {
log.Errorf("Master failed Job: %v.", j.job.Metadata.Name)
j.status.SetPhase(spec.TfJobPhaseDone)
j.status.SetState(spec.StateFailed)
} else if state == spec.StateSucceeded {
} else if state == spec.ReplicaStateSucceeded {
log.Infof("Master succeeded Job: %v.", j.job.Metadata.Name)
j.status.SetPhase(spec.TfJobPhaseDone)
j.status.SetState(spec.StateSucceeded)
Expand Down
52 changes: 52 additions & 0 deletions pkg/trainer/training_test.go
Expand Up @@ -202,6 +202,20 @@ func TestJobSetup(t *testing.T) {
},
TfReplicaType: spec.PS,
},
{
Replicas: proto.Int32(1),
TfPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
},
},
},
},
TfReplicaType: spec.MASTER,
},
},
},
},
Expand Down Expand Up @@ -232,6 +246,25 @@ func TestJobSetup(t *testing.T) {
},
TfReplicaType: spec.PS,
},
{
Replicas: proto.Int32(1),
TfPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
Resources: v1.ResourceRequirements{
Requests: map[v1.ResourceName]resource.Quantity{
"nvidia-gpu": resource.MustParse("1"),
},
},
},
},
},
},
TfReplicaType: spec.MASTER,
},
},
},
},
Expand Down Expand Up @@ -263,6 +296,25 @@ func TestJobSetup(t *testing.T) {
},
TfReplicaType: spec.PS,
},
{
Replicas: proto.Int32(1),
TfPort: proto.Int32(10),
Template: &v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "tensorflow",
Resources: v1.ResourceRequirements{
Requests: map[v1.ResourceName]resource.Quantity{
"nvidia-gpu": resource.MustParse("1"),
},
},
},
},
},
},
TfReplicaType: spec.MASTER,
},
},
TensorBoard: &spec.TensorBoardSpec{},
},
Expand Down

0 comments on commit cb1e053

Please sign in to comment.