Skip to content

Commit

Permalink
Create Pod instead of Job
Browse files Browse the repository at this point in the history
  • Loading branch information
ScorpioCPH committed Feb 10, 2018
1 parent 57567df commit be20406
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 68 deletions.
56 changes: 19 additions & 37 deletions pkg/trainer/replicas.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import (
"strings"

log "github.com/golang/glog"
"github.com/golang/protobuf/proto"
batch "k8s.io/api/batch/v1"
"k8s.io/api/core/v1"
k8s_errors "k8s.io/apimachinery/pkg/api/errors"
meta_v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -125,7 +123,7 @@ func (s *TFReplicaSet) Create(config *tfv1alpha1.ControllerConfig) error {
// Create the service.
service := &v1.Service{
ObjectMeta: meta_v1.ObjectMeta{
Name: s.jobName(index),
Name: s.genName(index),
Labels: taskLabels,
OwnerReferences: []meta_v1.OwnerReference{
helper.AsOwner(s.Job.job),
Expand All @@ -148,7 +146,7 @@ func (s *TFReplicaSet) Create(config *tfv1alpha1.ControllerConfig) error {
// If the job already exists do nothing.
if err != nil {
if k8s_errors.IsAlreadyExists(err) {
log.Infof("Service %v already exists.", s.jobName(index))
log.Infof("Service %v already exists.", s.genName(index))
} else {
s.recorder.Eventf(s.Job.job, v1.EventTypeWarning, FailedCreateReason, "Error creating: %v", err)
return k8sErrors.NewAggregate([]error{fmt.Errorf("Creating service %v returned error.", createdService.ObjectMeta.Name), err})
Expand All @@ -174,38 +172,22 @@ func (s *TFReplicaSet) Create(config *tfv1alpha1.ControllerConfig) error {
return err
}

// Make a copy of the template because we will modify it below. .
newPodSpecTemplate := s.Spec.Template.DeepCopy()

newJ := &batch.Job{
newP := &v1.Pod{
ObjectMeta: meta_v1.ObjectMeta{
Name: s.jobName(index),
Name: s.genName(index),
Labels: taskLabels,
OwnerReferences: []meta_v1.OwnerReference{
helper.AsOwner(s.Job.job),
},
},
Spec: batch.JobSpec{
Completions: proto.Int32(1),
Parallelism: proto.Int32(1),
Template: *newPodSpecTemplate,
},
}

if newJ.Spec.Template.ObjectMeta.Labels == nil {
newJ.Spec.Template.ObjectMeta.Labels = make(map[string]string)
}

// Pods need to be tagged with the labels.
for k, v := range taskLabels {
newJ.Spec.Template.ObjectMeta.Labels[k] = v
Spec: *s.Spec.Template.Spec.DeepCopy(),
}

// Add TF_CONFIG environment variable.
for i, _ := range newJ.Spec.Template.Spec.Containers {
for i, _ := range newP.Spec.Containers {
// We can't get c in the loop variable because that would be by value so our modifications
// wouldn't have any effect.
c := &newJ.Spec.Template.Spec.Containers[i]
c := &newP.Spec.Containers[i]
if tfv1alpha1.ContainerName(c.Name) != tfv1alpha1.TENSORFLOW {
continue
}
Expand All @@ -218,20 +200,20 @@ func (s *TFReplicaSet) Create(config *tfv1alpha1.ControllerConfig) error {
})
}

log.Infof("Creating Job: %v", newJ.ObjectMeta.Name)
createdJob, err := s.ClientSet.BatchV1().Jobs(s.Job.job.ObjectMeta.Namespace).Create(newJ)
log.Infof("Creating Pod: %v", newP.ObjectMeta.Name)
createdPod, err := s.ClientSet.CoreV1().Pods(s.Job.job.ObjectMeta.Namespace).Create(newP)

// If the job already exists do nothing.
// If the Pod already exists do nothing.
if err != nil {
if k8s_errors.IsAlreadyExists(err) {
log.Infof("%v already exists.", s.jobName(index))
log.Infof("%v already exists.", s.genName(index))

} else {
s.recorder.Eventf(s.Job.job, v1.EventTypeWarning, FailedCreateReason, "Error creating: %v", err)
return k8sErrors.NewAggregate([]error{fmt.Errorf("Creating Job %v returned error.", createdJob.ObjectMeta.Name), err})
return k8sErrors.NewAggregate([]error{fmt.Errorf("Creating Pod %v returned error.", createdPod.ObjectMeta.Name), err})
}
} else {
s.recorder.Eventf(s.Job.job, v1.EventTypeNormal, SuccessfulCreateReason, "Created job: %v", createdJob.Name)
s.recorder.Eventf(s.Job.job, v1.EventTypeNormal, SuccessfulCreateReason, "Created Pod: %v", createdPod.Name)
}
}
return nil
Expand Down Expand Up @@ -270,11 +252,11 @@ func (s *TFReplicaSet) Delete() error {
// Services doesn't support DeleteCollection so we delete them individually.
// TODO(jlewi): We should check if this has changed with K8s 1.8 or other releases.
for index := int32(0); index < *s.Spec.Replicas; index++ {
log.V(1).Infof("Deleting Service %v:%v", s.Job.job.ObjectMeta.Namespace, s.jobName((index)))
err = s.ClientSet.CoreV1().Services(s.Job.job.ObjectMeta.Namespace).Delete(s.jobName(index), &meta_v1.DeleteOptions{})
log.V(1).Infof("Deleting Service %v:%v", s.Job.job.ObjectMeta.Namespace, s.genName((index)))
err = s.ClientSet.CoreV1().Services(s.Job.job.ObjectMeta.Namespace).Delete(s.genName(index), &meta_v1.DeleteOptions{})

if err != nil {
log.Errorf("Error deleting service %v; %v", s.jobName(index), err)
log.Errorf("Error deleting service %v; %v", s.genName(index), err)
failures = true
}
}
Expand Down Expand Up @@ -359,7 +341,7 @@ func replicaStatusFromPodList(l v1.PodList, name tfv1alpha1.ContainerName) tfv1a
}

func (s *TFReplicaSet) GetSingleReplicaStatus(index int32) tfv1alpha1.ReplicaState {
j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.ObjectMeta.Namespace).Get(s.jobName(index), meta_v1.GetOptions{})
j, err := s.ClientSet.BatchV1().Jobs(s.Job.job.ObjectMeta.Namespace).Get(s.genName(index), meta_v1.GetOptions{})

if err != nil {
return tfv1alpha1.ReplicaStateUnknown
Expand Down Expand Up @@ -436,10 +418,10 @@ func (s *TFReplicaSet) GetStatus() (tfv1alpha1.TFReplicaStatus, error) {
return status, nil
}

func (s *TFReplicaSet) jobName(index int32) string {
func (s *TFReplicaSet) genName(index int32) string {
// Truncate tfjob name to 40 characters
// The whole job name should be compliant with the DNS_LABEL spec, up to a max length of 63 characters
// Thus jobname(40 chars)-replicaType(6 chars)-runtimeId(4 chars)-index(4 chars), also leaving some spaces
// Thus genName(40 chars)-replicaType(6 chars)-runtimeId(4 chars)-index(4 chars), also leaving some spaces
// See https://github.com/kubernetes/community/blob/master/contributors/design-proposals/architecture/identifiers.md
return fmt.Sprintf("%v-%v-%v-%v", fmt.Sprintf("%.40s", s.Job.job.ObjectMeta.Name), strings.ToLower(string(s.Spec.TFReplicaType)), s.Job.job.Spec.RuntimeId, index)
}
Expand Down
32 changes: 16 additions & 16 deletions pkg/trainer/replicas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,39 +136,39 @@ func TestTFReplicaSet(t *testing.T) {
t.Fatalf("Service.Metadata.OwnerReferences; Got %v; want %v", util.Pformat(s.ObjectMeta.OwnerReferences[0]), util.Pformat(expectedOwnerReference))
}

// Check that a job was created.
l, err := clientSet.BatchV1().Jobs(replica.Job.job.ObjectMeta.Namespace).List(meta_v1.ListOptions{})
// Check that a pod was created.
l, err := clientSet.CoreV1().Pods(replica.Job.job.ObjectMeta.Namespace).List(meta_v1.ListOptions{})
if err != nil {
t.Fatalf("List jobs error; %v", err)
t.Fatalf("List pods error; %v", err)
}

if len(l.Items) != 2 {
t.Fatalf("Expected 1 job got %v", len(l.Items))
t.Fatalf("Expected 1 pod got %v", len(l.Items))
}

j := l.Items[index]
p := l.Items[index]

if !reflect.DeepEqual(expectedLabels, j.ObjectMeta.Labels) {
t.Fatalf("Job Labels; Got %v Want: %v", expectedLabels, j.ObjectMeta.Labels)
if !reflect.DeepEqual(expectedLabels, p.ObjectMeta.Labels) {
t.Fatalf("Pod Labels; Got %v Want: %v", expectedLabels, p.ObjectMeta.Labels)
}

if j.ObjectMeta.Name != name {
t.Fatalf("Job.ObjectMeta.Name = %v; want %v", j.ObjectMeta.Name, name)
if p.ObjectMeta.Name != name {
t.Fatalf("Pod.ObjectMeta.Name = %v; want %v", p.ObjectMeta.Name, name)
}

if len(j.Spec.Template.Spec.Containers) != 1 {
t.Fatalf("Expected 1 container got %v", len(j.Spec.Template.Spec.Containers))
if len(p.Spec.Containers) != 1 {
t.Fatalf("Expected 1 container got %v", len(p.Spec.Containers))
}

if len(j.ObjectMeta.OwnerReferences) != 1 {
t.Fatalf("Expected 1 owner reference got %v", len(j.ObjectMeta.OwnerReferences))
if len(p.ObjectMeta.OwnerReferences) != 1 {
t.Fatalf("Expected 1 owner reference got %v", len(p.ObjectMeta.OwnerReferences))
}

if !reflect.DeepEqual(j.ObjectMeta.OwnerReferences[0], expectedOwnerReference) {
t.Fatalf("Job.Metadata.OwnerReferences; Got %v; want %v", util.Pformat(j.ObjectMeta.OwnerReferences[0]), util.Pformat(expectedOwnerReference))
if !reflect.DeepEqual(p.ObjectMeta.OwnerReferences[0], expectedOwnerReference) {
t.Fatalf("Pod.Metadata.OwnerReferences; Got %v; want %v", util.Pformat(p.ObjectMeta.OwnerReferences[0]), util.Pformat(expectedOwnerReference))
}

c := j.Spec.Template.Spec.Containers[0]
c := p.Spec.Containers[0]
if len(c.Env) != 1 {
t.Fatalf("Expected 1 environment variable got %v", len(c.Env))
}
Expand Down
28 changes: 14 additions & 14 deletions pkg/trainer/tensorboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (s *TBReplicaSet) Create() error {
// create the service exposing TensorBoard
service := &v1.Service{
ObjectMeta: meta_v1.ObjectMeta{
Name: s.jobName(),
Name: s.genName(),
Labels: s.Labels(),
OwnerReferences: []meta_v1.OwnerReference{
helper.AsOwner(s.Job.job),
Expand All @@ -91,15 +91,15 @@ func (s *TBReplicaSet) Create() error {
// If the job already exists do nothing.
if err != nil {
if k8s_errors.IsAlreadyExists(err) {
log.Infof("Service %v already exists.", s.jobName())
log.Infof("Service %v already exists.", s.genName())
} else {
return err
}
}

newD := &v1beta1.Deployment{
ObjectMeta: meta_v1.ObjectMeta{
Name: s.jobName(),
Name: s.genName(),
Labels: s.Labels(),
OwnerReferences: []meta_v1.OwnerReference{
helper.AsOwner(s.Job.job),
Expand All @@ -119,7 +119,7 @@ func (s *TBReplicaSet) Create() error {

if err != nil {
if k8s_errors.IsAlreadyExists(err) {
log.Infof("%v already exists.", s.jobName())
log.Infof("%v already exists.", s.genName())
} else {
return err
}
Expand All @@ -131,19 +131,19 @@ func (s *TBReplicaSet) Delete() error {
failures := false

delProp := meta_v1.DeletePropagationForeground
log.V(1).Infof("Deleting deployment %v:%v", s.Job.job.ObjectMeta.Namespace, s.jobName())
err := s.ClientSet.ExtensionsV1beta1().Deployments(s.Job.job.ObjectMeta.Namespace).Delete(s.jobName(), &meta_v1.DeleteOptions{
log.V(1).Infof("Deleting deployment %v:%v", s.Job.job.ObjectMeta.Namespace, s.genName())
err := s.ClientSet.ExtensionsV1beta1().Deployments(s.Job.job.ObjectMeta.Namespace).Delete(s.genName(), &meta_v1.DeleteOptions{
PropagationPolicy: &delProp,
})
if err != nil {
log.Errorf("There was a problem deleting TensorBoard's deployment %v; %v", s.jobName(), err)
log.Errorf("There was a problem deleting TensorBoard's deployment %v; %v", s.genName(), err)
failures = true
}

log.V(1).Infof("Deleting service %v:%v", s.Job.job.ObjectMeta.Namespace, s.jobName())
err = s.ClientSet.CoreV1().Services(s.Job.job.ObjectMeta.Namespace).Delete(s.jobName(), &meta_v1.DeleteOptions{})
log.V(1).Infof("Deleting service %v:%v", s.Job.job.ObjectMeta.Namespace, s.genName())
err = s.ClientSet.CoreV1().Services(s.Job.job.ObjectMeta.Namespace).Delete(s.genName(), &meta_v1.DeleteOptions{})
if err != nil {
log.Errorf("Error deleting service: %v; %v", s.jobName(), err)
log.Errorf("Error deleting service: %v; %v", s.genName(), err)
failures = true
}

Expand All @@ -156,7 +156,7 @@ func (s *TBReplicaSet) Delete() error {
func (s *TBReplicaSet) getDeploymentSpecTemplate(image string) v1.PodTemplateSpec {
// TODO: make the TensorFlow image a parameter of the job operator.
c := &v1.Container{
Name: s.jobName(),
Name: s.genName(),
Image: image,
Command: []string{
"tensorboard", "--logdir", s.Spec.LogDir, "--host", "0.0.0.0",
Expand All @@ -180,7 +180,7 @@ func (s *TBReplicaSet) getDeploymentSpecTemplate(image string) v1.PodTemplateSpe

return v1.PodTemplateSpec{
ObjectMeta: meta_v1.ObjectMeta{
Name: s.jobName(),
Name: s.genName(),
Labels: s.Labels(),
},
Spec: *ps,
Expand All @@ -197,10 +197,10 @@ func (s *TBReplicaSet) Labels() KubernetesLabels {
})
}

func (s *TBReplicaSet) jobName() string {
func (s *TBReplicaSet) genName() string {
// Truncate tfjob name to 40 characters
// The whole job name should be compliant with the DNS_LABEL spec, up to a max length of 63 characters
// Thus jobname(40 chars)-tensorboard(11 chars)-runtimeId(4 chars), also leaving some spaces
// Thus genName(40 chars)-tensorboard(11 chars)-runtimeId(4 chars), also leaving some spaces
// See https://github.com/kubernetes/community/blob/master/contributors/design-proposals/architecture/identifiers.md
return fmt.Sprintf("%v-tensorboard-%v", fmt.Sprintf("%.40s", s.Job.job.ObjectMeta.Name), strings.ToLower(s.Job.job.Spec.RuntimeId))
}
2 changes: 1 addition & 1 deletion pkg/trainer/training.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (j *TrainingJob) ClusterSpec() ClusterSpec {
replicaNames := make([]string, 0, *p.Spec.Replicas)

for i := int32(0); i < *p.Spec.Replicas; i++ {
replicaNames = append(replicaNames, fmt.Sprintf("%v:%v", p.jobName(i), *p.Spec.TFPort))
replicaNames = append(replicaNames, fmt.Sprintf("%v:%v", p.genName(i), *p.Spec.TFPort))
}

clusterSpec[strings.ToLower(string(p.Spec.TFReplicaType))] = replicaNames
Expand Down

0 comments on commit be20406

Please sign in to comment.