Skip to content
This repository was archived by the owner on Sep 12, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/apis/common/v1/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type ControllerInterface interface {
UpdateJobStatusInApiServer(job interface{}, jobStatus *JobStatus) error

// SetClusterSpec sets the cluster spec for the pod
SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype, index string) error
SetClusterSpec(job interface{}, podTemplate *v1.PodTemplateSpec, rtype ReplicaType, index string) error

// Returns the default container name in pod
GetDefaultContainerName() string
Expand Down
4 changes: 1 addition & 3 deletions pkg/controller.v1/common/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"reflect"
"sort"
"strings"
"time"

apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
Expand Down Expand Up @@ -368,8 +367,7 @@ func (jc *JobController) PastBackoffLimit(jobName string, runPolicy *apiv1.RunPo
continue
}
// Convert ReplicaType to lower string.
rt := strings.ToLower(string(rtype))
pods, err := jc.FilterPodsForReplicaType(pods, rt)
pods, err := jc.FilterPodsForReplicaType(pods, rtype)
if err != nil {
return false, err
}
Expand Down
29 changes: 13 additions & 16 deletions pkg/controller.v1/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ package common

import (
"fmt"
"reflect"
"strconv"
"strings"

"github.com/kubeflow/common/pkg/controller.v1/control"
"github.com/kubeflow/common/pkg/controller.v1/expectation"
"github.com/prometheus/client_golang/prometheus"
Expand All @@ -32,6 +28,8 @@ import (
"k8s.io/apimachinery/pkg/runtime"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/tools/cache"
"reflect"
"strconv"

apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
commonutil "github.com/kubeflow/common/pkg/util"
Expand Down Expand Up @@ -104,7 +102,7 @@ func (jc *JobController) AddPod(obj interface{}) {
}

rtype := pod.Labels[apiv1.ReplicaTypeLabel]
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype))

jc.Expectations.CreationObserved(expectationPodsKey)
// TODO: we may need add backoff here
Expand Down Expand Up @@ -205,7 +203,7 @@ func (jc *JobController) DeletePod(obj interface{}) {
}

rtype := pod.Labels[apiv1.ReplicaTypeLabel]
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, rtype)
expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, apiv1.ReplicaType(rtype))

jc.Expectations.DeletionObserved(expectationPodsKey)
deletedPodsCount.Inc()
Expand Down Expand Up @@ -254,14 +252,14 @@ func (jc *JobController) GetPodsForJob(jobObject interface{}) ([]*v1.Pod, error)
}

// FilterPodsForReplicaType returns pods belong to a replicaType.
func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) {
func (jc *JobController) FilterPodsForReplicaType(pods []*v1.Pod, replicaType apiv1.ReplicaType) ([]*v1.Pod, error) {
var result []*v1.Pod

replicaSelector := &metav1.LabelSelector{
MatchLabels: make(map[string]string),
}

replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType
replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType)

for _, pod := range pods {
selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
Expand Down Expand Up @@ -337,10 +335,9 @@ func (jc *JobController) ReconcilePods(
}

// Convert ReplicaType to lower string.
rt := strings.ToLower(string(rtype))
logger := commonutil.LoggerForReplica(metaObject, rt)
logger := commonutil.LoggerForReplica(metaObject, rtype)
// Get all pods for the type rt.
pods, err := jc.FilterPodsForReplicaType(pods, rt)
pods, err := jc.FilterPodsForReplicaType(pods, rtype)
if err != nil {
return err
}
Expand All @@ -358,13 +355,13 @@ func (jc *JobController) ReconcilePods(
podSlices := jc.GetPodSlices(pods, numReplicas, logger)
for index, podSlice := range podSlices {
if len(podSlice) > 1 {
logger.Warningf("We have too many pods for %s %d", rt, index)
logger.Warningf("We have too many pods for %s %d", rtype, index)
} else if len(podSlice) == 0 {
logger.Infof("Need to create new pod: %s-%d", rt, index)
logger.Infof("Need to create new pod: %s-%d", rtype, index)

// check if this replica is the master role
masterRole = jc.Controller.IsMasterRole(replicas, rtype, index)
err = jc.createNewPod(job, rt, strconv.Itoa(index), spec, masterRole, replicas)
err = jc.createNewPod(job, rtype, strconv.Itoa(index), spec, masterRole, replicas)
if err != nil {
return err
}
Expand Down Expand Up @@ -408,7 +405,7 @@ func (jc *JobController) ReconcilePods(
}

// createNewPod creates a new pod for the given index and type.
func (jc *JobController) createNewPod(job interface{}, rt, index string, spec *apiv1.ReplicaSpec, masterRole bool,
func (jc *JobController) createNewPod(job interface{}, rt apiv1.ReplicaType, index string, spec *apiv1.ReplicaSpec, masterRole bool,
replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error {

metaObject, ok := job.(metav1.Object)
Expand All @@ -433,7 +430,7 @@ func (jc *JobController) createNewPod(job interface{}, rt, index string, spec *a

// Set type and index for the worker.
labels := jc.GenLabels(metaObject.GetName())
labels[apiv1.ReplicaTypeLabel] = rt
labels[apiv1.ReplicaTypeLabel] = string(rt)
labels[apiv1.ReplicaIndexLabel] = index

if masterRole {
Expand Down
30 changes: 12 additions & 18 deletions pkg/controller.v1/common/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ package common

import (
"fmt"
"strconv"
"strings"

apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/kubeflow/common/pkg/controller.v1/control"
"github.com/kubeflow/common/pkg/controller.v1/expectation"
Expand All @@ -31,6 +28,7 @@ import (
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"strconv"
)

var (
Expand Down Expand Up @@ -71,8 +69,8 @@ func (jc *JobController) AddService(obj interface{}) {
return
}

rtype := service.Labels[apiv1.ReplicaTypeLabel]
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype)
rtypeValue := service.Labels[apiv1.ReplicaTypeLabel]
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, apiv1.ReplicaType(rtypeValue))

jc.Expectations.CreationObserved(expectationServicesKey)
// TODO: we may need add backoff here
Expand Down Expand Up @@ -137,14 +135,14 @@ func (jc *JobController) GetServicesForJob(jobObject interface{}) ([]*v1.Service
}

// FilterServicesForReplicaType returns service belong to a replicaType.
func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType string) ([]*v1.Service, error) {
func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, replicaType apiv1.ReplicaType) ([]*v1.Service, error) {
var result []*v1.Service

replicaSelector := &metav1.LabelSelector{
MatchLabels: make(map[string]string),
}

replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = replicaType
replicaSelector.MatchLabels[apiv1.ReplicaTypeLabel] = string(replicaType)

for _, service := range services {
selector, err := metav1.LabelSelectorAsSelector(replicaSelector)
Expand Down Expand Up @@ -209,12 +207,9 @@ func (jc *JobController) ReconcileServices(
rtype apiv1.ReplicaType,
spec *apiv1.ReplicaSpec) error {

// Convert ReplicaType to lower string.
rt := strings.ToLower(string(rtype))

replicas := int(*spec.Replicas)
// Get all services for the type rt.
services, err := jc.FilterServicesForReplicaType(services, rt)
services, err := jc.FilterServicesForReplicaType(services, rtype)
if err != nil {
return err
}
Expand All @@ -225,13 +220,13 @@ func (jc *JobController) ReconcileServices(
// If replica is 4, return a slice with size 4. [[0],[1],[2],[]], a svc with replica-index 3 will be created.
//
// If replica is 1, return a slice with size 3. [[0],[1],[2]], svc with replica-index 1 and 2 are out of range and will be deleted.
serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rt))
serviceSlices := jc.GetServiceSlices(services, replicas, commonutil.LoggerForReplica(job, rtype))

for index, serviceSlice := range serviceSlices {
if len(serviceSlice) > 1 {
commonutil.LoggerForReplica(job, rt).Warningf("We have too many services for %s %d", rt, index)
commonutil.LoggerForReplica(job, rtype).Warningf("We have too many services for %s %d", rtype, index)
} else if len(serviceSlice) == 0 {
commonutil.LoggerForReplica(job, rt).Infof("need to create new service: %s-%d", rt, index)
commonutil.LoggerForReplica(job, rtype).Infof("need to create new service: %s-%d", rtype, index)
err = jc.CreateNewService(job, rtype, spec, strconv.Itoa(index))
if err != nil {
return err
Expand Down Expand Up @@ -283,16 +278,15 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica
}

// Convert ReplicaType to lower string.
rt := strings.ToLower(string(rtype))
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rt)
expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, rtype)
err = jc.Expectations.ExpectCreations(expectationServicesKey, 1)
if err != nil {
return err
}

// Append ReplicaTypeLabel and ReplicaIndexLabel labels.
labels := jc.GenLabels(job.GetName())
labels[apiv1.ReplicaTypeLabel] = rt
labels[apiv1.ReplicaTypeLabel] = string(rtype)
labels[apiv1.ReplicaIndexLabel] = index

ports, err := jc.GetPortsFromJob(spec)
Expand All @@ -314,7 +308,7 @@ func (jc *JobController) CreateNewService(job metav1.Object, rtype apiv1.Replica
service.Spec.Ports = append(service.Spec.Ports, svcPort)
}

service.Name = GenGeneralName(job.GetName(), rt, index)
service.Name = GenGeneralName(job.GetName(), rtype, index)
service.Labels = labels
// Create OwnerReference.
controllerRef := jc.GenOwnerReference(job)
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller.v1/common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func (p ReplicasPriority) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}

func GenGeneralName(jobName, rtype, index string) string {
n := jobName + "-" + rtype + "-" + index
func GenGeneralName(jobName string, rtype apiv1.ReplicaType, index string) string {
n := jobName + "-" + string(rtype) + "-" + index
return strings.Replace(n, "/", "-", -1)
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/controller.v1/common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ package common

import (
"fmt"
apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/stretchr/testify/assert"
"testing"
)

func TestGenGeneralName(t *testing.T) {
testRType := "worker"
var testRType apiv1.ReplicaType = "worker"
testIndex := "1"
testKey := "1/2/3/4/5"
expectedName := fmt.Sprintf("1-2-3-4-5-%s-%s", testRType, testIndex)
Expand Down
13 changes: 8 additions & 5 deletions pkg/controller.v1/expectation/util.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package expectation

import "strings"
import (
apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"strings"
)

// GenExpectationPodsKey generates an expectation key for pods of a job
func GenExpectationPodsKey(jobKey, replicaType string) string {
return jobKey + "/" + strings.ToLower(replicaType) + "/pods"
func GenExpectationPodsKey(jobKey string, replicaType apiv1.ReplicaType) string {
return jobKey + "/" + strings.ToLower(string(replicaType)) + "/pods"
}

// GenExpectationPodsKey generates an expectation key for services of a job
func GenExpectationServicesKey(jobKey, replicaType string) string {
return jobKey + "/" + strings.ToLower(replicaType) + "/services"
func GenExpectationServicesKey(jobKey string, replicaType apiv1.ReplicaType) string {
return jobKey + "/" + strings.ToLower(string(replicaType)) + "/services"
}
3 changes: 2 additions & 1 deletion pkg/util/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package util

import (
apiv1 "github.com/kubeflow/common/pkg/apis/common/v1"
"strings"

log "github.com/sirupsen/logrus"
Expand All @@ -23,7 +24,7 @@ import (
metav1unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
)

func LoggerForReplica(job metav1.Object, rtype string) *log.Entry {
func LoggerForReplica(job metav1.Object, rtype apiv1.ReplicaType) *log.Entry {
return log.WithFields(log.Fields{
// We use job to match the key used in controller.go
// Its more common in K8s to use a period to indicate namespace.name. So that's what we use.
Expand Down
2 changes: 1 addition & 1 deletion test_job/controller.v1/test_job/test_job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (t *TestJobController) UpdateJobStatusInApiServer(job interface{}, jobStatu
return nil
}

func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
func (t *TestJobController) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype commonv1.ReplicaType, index string) error {
return nil
}

Expand Down