Skip to content

Commit

Permalink
Prebuilt workload suport (#1358)
Browse files Browse the repository at this point in the history
* [jobframework] Add prebuilt workload support

* [batch/job] Prebuilt workload support

* [batch/job] Prebuilt workload support - tests

* Fixup after rebase

* Review Remarks

* Review Remarks

* Split the functionality of `getPrebuiltWorkload`

* Drop queue vs prebuilt workload exclusivity.

* Review Remarks

* Review Remarks

* Review Remarks

* Review Remarks

* Review Remarks

* Review Remarks

* Review Remarks

* Review Remarks

* Review Remarks
  • Loading branch information
trasc committed Dec 18, 2023
1 parent d69244f commit bb69906
Show file tree
Hide file tree
Showing 13 changed files with 575 additions and 5 deletions.
3 changes: 3 additions & 0 deletions pkg/controller/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ const (
// DEPRECATED: Use QueueLabel as a label key.
QueueAnnotation = QueueLabel

// PrebuiltWorkloadLabel is the label key of the job holding the name of the pre-built workload to use.
PrebuiltWorkloadLabel = "kueue.x-k8s.io/prebuilt-workload-name"

// ParentWorkloadAnnotation is the annotation used to mark a kubernetes Job
// as a child of a Workload. The value is the name of the workload,
// in the same namespace. It is used when the parent workload corresponds to
Expand Down
5 changes: 5 additions & 0 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,8 @@ func workloadPriorityClassName(job GenericJob) string {
}
return ""
}

func prebuiltWorkload(job GenericJob) (string, bool) {
name, found := job.Object().GetLabels()[constants.PrebuiltWorkloadLabel]
return name, found
}
66 changes: 66 additions & 0 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"sigs.k8s.io/kueue/pkg/podset"
"sigs.k8s.io/kueue/pkg/util/equality"
"sigs.k8s.io/kueue/pkg/util/kubeversion"
"sigs.k8s.io/kueue/pkg/util/maps"
utilpriority "sigs.k8s.io/kueue/pkg/util/priority"
utilslices "sigs.k8s.io/kueue/pkg/util/slices"
"sigs.k8s.io/kueue/pkg/workload"
Expand Down Expand Up @@ -439,6 +440,23 @@ func (r *JobReconciler) getParentWorkload(ctx context.Context, job GenericJob, o
func (r *JobReconciler) ensureOneWorkload(ctx context.Context, job GenericJob, object client.Object) (*kueue.Workload, error) {
log := ctrl.LoggerFrom(ctx)

if prebuiltWorkloadName, usePrebuiltWorkload := prebuiltWorkload(job); usePrebuiltWorkload {
wl := &kueue.Workload{}
err := r.client.Get(ctx, types.NamespacedName{Name: prebuiltWorkloadName, Namespace: object.GetNamespace()}, wl)
if err != nil {
return nil, client.IgnoreNotFound(err)
}

if owns, err := r.ensurePrebuiltWorkloadOwnership(ctx, wl, object); !owns || err != nil {
return nil, err
}

if inSync, err := r.ensurePrebuiltWorkloadInSync(ctx, wl, job); !inSync || err != nil {
return nil, err
}
return wl, nil
}

// Find a matching workload first if there is one.
var toDelete []*kueue.Workload
var match *kueue.Workload
Expand Down Expand Up @@ -537,6 +555,41 @@ func FindMatchingWorkloads(ctx context.Context, c client.Client, job GenericJob)
return match, toDelete, nil
}

func (r *JobReconciler) ensurePrebuiltWorkloadOwnership(ctx context.Context, wl *kueue.Workload, object client.Object) (bool, error) {
if !metav1.IsControlledBy(wl, object) {
if err := ctrl.SetControllerReference(object, wl, r.client.Scheme()); err != nil {
// don't return an error here, since a retry cannot give a different result,
// log the error.
log := ctrl.LoggerFrom(ctx)
log.Error(err, "Cannot take ownership of the workload")
return false, nil
}

if errs := validation.IsValidLabelValue(string(object.GetUID())); len(errs) == 0 {
wl.Labels = maps.MergeKeepFirst(map[string]string{controllerconsts.JobUIDLabel: string(object.GetUID())}, wl.Labels)
}

if err := r.client.Update(ctx, wl); err != nil {
return false, err
}
}
return true, nil
}

func (r *JobReconciler) ensurePrebuiltWorkloadInSync(ctx context.Context, wl *kueue.Workload, job GenericJob) (bool, error) {
if !equivalentToWorkload(job, wl) {
// mark the workload as finished
err := workload.UpdateStatus(ctx, r.client, wl,
kueue.WorkloadFinished,
metav1.ConditionTrue,
"OutOfSync",
"The prebuilt workload is out of sync with its user job",
constants.JobControllerName)
return false, err
}
return true, nil
}

// equivalentToWorkload checks if the job corresponds to the workload
func equivalentToWorkload(job GenericJob, wl *kueue.Workload) bool {
owner := metav1.GetControllerOf(wl)
Expand Down Expand Up @@ -780,12 +833,25 @@ func (r *JobReconciler) getPodSetsInfoFromStatus(ctx context.Context, w *kueue.W
func (r *JobReconciler) handleJobWithNoWorkload(ctx context.Context, job GenericJob, object client.Object) error {
log := ctrl.LoggerFrom(ctx)

_, usePrebuiltWorkload := prebuiltWorkload(job)
if usePrebuiltWorkload {
// Stop the job if not already suspended
if stopErr := r.stopJob(ctx, job, nil, StopReasonNoMatchingWorkload, "missing workload"); stopErr != nil {
return stopErr
}
}

// Wait until there are no active pods.
if job.IsActive() {
log.V(2).Info("Job is suspended but still has active pods, waiting")
return nil
}

if usePrebuiltWorkload {
log.V(2).Info("Skip workload creation for job with prebuilt workload")
return nil
}

// Create the corresponding workload.
wl, err := r.constructWorkload(ctx, job, object)
if err != nil {
Expand Down
16 changes: 16 additions & 0 deletions pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ limitations under the License.
package jobframework

import (
"fmt"
"strings"

apivalidation "k8s.io/apimachinery/pkg/api/validation"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/validation"
"k8s.io/apimachinery/pkg/util/validation/field"

Expand All @@ -29,12 +31,22 @@ var (
parentWorkloadKeyPath = annotationsPath.Key(constants.ParentWorkloadAnnotation)
queueNameLabelPath = labelsPath.Key(constants.QueueLabel)
workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel)
supportedPrebuiltWlJobGVKs = sets.New("batch/v1, Kind=Job")
)

func ValidateCreateForQueueName(job GenericJob) field.ErrorList {
var allErrs field.ErrorList
allErrs = append(allErrs, ValidateLabelAsCRDName(job, constants.QueueLabel)...)
allErrs = append(allErrs, ValidateLabelAsCRDName(job, constants.PrebuiltWorkloadLabel)...)
allErrs = append(allErrs, ValidateAnnotationAsCRDName(job, constants.QueueAnnotation)...)

// this rule should be relaxed when its confirmed that running wit a prebuilt wl is fully supported by each integration
if _, hasPrebuilt := job.Object().GetLabels()[constants.PrebuiltWorkloadLabel]; hasPrebuilt {
gvk := job.GVK().String()
if !supportedPrebuiltWlJobGVKs.Has(gvk) {
allErrs = append(allErrs, field.Forbidden(labelsPath.Key(constants.PrebuiltWorkloadLabel), fmt.Sprintf("Is not supported for %q", gvk)))
}
}
return allErrs
}

Expand Down Expand Up @@ -73,6 +85,10 @@ func ValidateUpdateForQueueName(oldJob, newJob GenericJob) field.ErrorList {
if !newJob.IsSuspended() {
allErrs = append(allErrs, apivalidation.ValidateImmutableField(QueueName(oldJob), QueueName(newJob), queueNameLabelPath)...)
}

oldWlName, _ := prebuiltWorkload(oldJob)
newWlName, _ := prebuiltWorkload(newJob)
allErrs = append(allErrs, apivalidation.ValidateImmutableField(oldWlName, newWlName, labelsPath.Key(constants.PrebuiltWorkloadLabel))...)
return allErrs
}

Expand Down
26 changes: 24 additions & 2 deletions pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,27 @@ func (j *Job) ReclaimablePods() ([]kueue.ReclaimablePod, error) {
}}, nil
}

// The following labels are managed internally by batch/job controller, we should not
// propagate them to the workload.
var (
// the legacy names are no longer defined in the api, only in k/2/apis/batch
legacyJobNameLabel = "job-name"
legacyControllerUidLabel = "controller-uid"
managedLabels = []string{legacyJobNameLabel, legacyControllerUidLabel, batchv1.JobNameLabel, batchv1.ControllerUidLabel}
)

func cleanManagedLabels(pt *corev1.PodTemplateSpec) *corev1.PodTemplateSpec {
for _, managedLabel := range managedLabels {
delete(pt.Labels, managedLabel)
}
return pt
}

func (j *Job) PodSets() []kueue.PodSet {
return []kueue.PodSet{
{
Name: kueue.DefaultPodSetName,
Template: *j.Spec.Template.DeepCopy(),
Template: *cleanManagedLabels(j.Spec.Template.DeepCopy()),
Count: j.podsCount(),
MinCount: j.minPodsCount(),
},
Expand Down Expand Up @@ -247,7 +263,13 @@ func (j *Job) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool {
j.Spec.Completions = j.Spec.Parallelism
}
}
changed = podset.RestorePodSpec(&j.Spec.Template.ObjectMeta, &j.Spec.Template.Spec, podSetsInfo[0]) || changed
info := podSetsInfo[0]
for _, managedLabel := range managedLabels {
if v, found := j.Spec.Template.Labels[managedLabel]; found {
info.AddOrUpdateLabel(managedLabel, v)
}
}
changed = podset.RestorePodSpec(&j.Spec.Template.ObjectMeta, &j.Spec.Template.Spec, info) || changed
return changed
}

Expand Down
154 changes: 151 additions & 3 deletions pkg/controller/jobs/job/job_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,20 @@ var (
cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime"),
cmpopts.IgnoreFields(kueue.AdmissionCheckState{}, "LastTransitionTime"),
}
workloadCmpOptsWithOwner = []cmp.Option{
cmpopts.EquateEmpty(),
cmpopts.SortSlices(func(a, b kueue.Workload) bool {
return a.Name < b.Name
}),
cmpopts.SortSlices(func(a, b metav1.Condition) bool {
return a.Type < b.Type
}),
cmpopts.IgnoreFields(
kueue.Workload{}, "TypeMeta", "ObjectMeta.Name", "ObjectMeta.ResourceVersion",
),
cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime"),
cmpopts.IgnoreFields(kueue.AdmissionCheckState{}, "LastTransitionTime"),
}
)

func TestReconciler(t *testing.T) {
Expand Down Expand Up @@ -1668,6 +1682,128 @@ func TestReconciler(t *testing.T) {
Obj(),
wantWorkloads: []kueue.Workload{},
},
"when the prebuilt workload is missing, no new one is created and the job is suspended": {
job: *baseJobWrapper.
Clone().
Suspend(false).
Label(controllerconsts.PrebuiltWorkloadLabel, "missing-workload").
UID("test-uid").
Obj(),
wantJob: *baseJobWrapper.
Clone().
Label(controllerconsts.PrebuiltWorkloadLabel, "missing-workload").
UID("test-uid").
Obj(),
},
"when the prebuilt workload exists its owner info is updated": {
job: *baseJobWrapper.
Clone().
Suspend(false).
Label(controllerconsts.PrebuiltWorkloadLabel, "prebuilt-workload").
UID("test-uid").
Obj(),
wantJob: *baseJobWrapper.
Clone().
Label(controllerconsts.PrebuiltWorkloadLabel, "prebuilt-workload").
UID("test-uid").
Obj(),
workloads: []kueue.Workload{
*utiltesting.MakeWorkload("prebuilt-workload", "ns").Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").PriorityClass("test-pc").Obj()).
Queue("test-queue").
PriorityClass("test-wpc").
Priority(100).
PriorityClassSource(constants.WorkloadPriorityClassSource).
Obj(),
},
wantWorkloads: []kueue.Workload{
*utiltesting.MakeWorkload("prebuilt-workload", "ns").Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").PriorityClass("test-pc").Obj()).
Queue("test-queue").
PriorityClass("test-wpc").
Priority(100).
PriorityClassSource(constants.WorkloadPriorityClassSource).
Labels(map[string]string{
controllerconsts.JobUIDLabel: "test-uid",
}).
OwnerReference(batchv1.SchemeGroupVersion.String(), "Job", "job", "test-uid", true, true).
Obj(),
},
},
"when the prebuilt workload is owned by another object": {
job: *baseJobWrapper.
Clone().
Suspend(false).
Label(controllerconsts.PrebuiltWorkloadLabel, "prebuilt-workload").
UID("test-uid").
Obj(),
wantJob: *baseJobWrapper.
Clone().
Label(controllerconsts.PrebuiltWorkloadLabel, "prebuilt-workload").
UID("test-uid").
Obj(),
workloads: []kueue.Workload{
*utiltesting.MakeWorkload("prebuilt-workload", "ns").Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").PriorityClass("test-pc").Obj()).
Queue("test-queue").
PriorityClass("test-wpc").
Priority(100).
PriorityClassSource(constants.WorkloadPriorityClassSource).
OwnerReference(batchv1.SchemeGroupVersion.String(), "Job", "other-job", "other-uid", true, true).
Obj(),
},
wantWorkloads: []kueue.Workload{
*utiltesting.MakeWorkload("prebuilt-workload", "ns").Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 10).Request(corev1.ResourceCPU, "1").PriorityClass("test-pc").Obj()).
Queue("test-queue").
PriorityClass("test-wpc").
Priority(100).
PriorityClassSource(constants.WorkloadPriorityClassSource).
OwnerReference(batchv1.SchemeGroupVersion.String(), "Job", "other-job", "other-uid", true, true).
Obj(),
},
},
"when the prebuilt workload is not equivalent to the job": {
job: *baseJobWrapper.
Clone().
Suspend(false).
Label(controllerconsts.PrebuiltWorkloadLabel, "prebuilt-workload").
UID("test-uid").
Obj(),
wantJob: *baseJobWrapper.
Clone().
Label(controllerconsts.PrebuiltWorkloadLabel, "prebuilt-workload").
UID("test-uid").
Obj(),
workloads: []kueue.Workload{
*utiltesting.MakeWorkload("prebuilt-workload", "ns").Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 1).Request(corev1.ResourceCPU, "1").PriorityClass("test-pc").Obj()).
Queue("test-queue").
PriorityClass("test-wpc").
Priority(100).
PriorityClassSource(constants.WorkloadPriorityClassSource).
Obj(),
},
wantWorkloads: []kueue.Workload{
*utiltesting.MakeWorkload("prebuilt-workload", "ns").Finalizers(kueue.ResourceInUseFinalizerName).
PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 1).Request(corev1.ResourceCPU, "1").PriorityClass("test-pc").Obj()).
Queue("test-queue").
PriorityClass("test-wpc").
Priority(100).
PriorityClassSource(constants.WorkloadPriorityClassSource).
Labels(map[string]string{
controllerconsts.JobUIDLabel: "test-uid",
}).
OwnerReference(batchv1.SchemeGroupVersion.String(), "Job", "job", "test-uid", true, true).
Condition(metav1.Condition{
Type: kueue.WorkloadFinished,
Status: metav1.ConditionTrue,
Reason: "OutOfSync",
Message: "The prebuilt workload is out of sync with its user job",
}).
Obj(),
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand All @@ -1684,10 +1820,16 @@ func TestReconciler(t *testing.T) {
kcBuilder = kcBuilder.WithStatusSubresource(&tc.workloads[i])
}

// For prebuilt workloads we are skipping the ownership setup in the test body and
// expect the reconciler to do it.
_, useesPrebuiltWorkload := tc.job.Labels[controllerconsts.PrebuiltWorkloadLabel]

kClient := kcBuilder.Build()
for i := range tc.workloads {
if err := ctrl.SetControllerReference(&tc.job, &tc.workloads[i], kClient.Scheme()); err != nil {
t.Fatalf("Could not setup owner reference in Workloads: %v", err)
if !useesPrebuiltWorkload {
if err := ctrl.SetControllerReference(&tc.job, &tc.workloads[i], kClient.Scheme()); err != nil {
t.Fatalf("Could not setup owner reference in Workloads: %v", err)
}
}
if err := kClient.Create(ctx, &tc.workloads[i]); err != nil {
t.Fatalf("Could not create workload: %v", err)
Expand Down Expand Up @@ -1715,7 +1857,13 @@ func TestReconciler(t *testing.T) {
if err := kClient.List(ctx, &gotWorkloads); err != nil {
t.Fatalf("Could not get Workloads after reconcile: %v", err)
}
if diff := cmp.Diff(tc.wantWorkloads, gotWorkloads.Items, workloadCmpOpts...); diff != "" {

wlCheckOpts := workloadCmpOpts
if useesPrebuiltWorkload {
wlCheckOpts = workloadCmpOptsWithOwner
}

if diff := cmp.Diff(tc.wantWorkloads, gotWorkloads.Items, wlCheckOpts...); diff != "" {
t.Errorf("Workloads after reconcile (-want,+got):\n%s", diff)
}
})
Expand Down

0 comments on commit bb69906

Please sign in to comment.