diff --git a/pkg/controller/core/workload_controller.go b/pkg/controller/core/workload_controller.go index 1c00783f34..b899386ef7 100644 --- a/pkg/controller/core/workload_controller.go +++ b/pkg/controller/core/workload_controller.go @@ -48,6 +48,7 @@ import ( "sigs.k8s.io/kueue/pkg/cache" "sigs.k8s.io/kueue/pkg/constants" "sigs.k8s.io/kueue/pkg/controller/core/indexer" + "sigs.k8s.io/kueue/pkg/features" "sigs.k8s.io/kueue/pkg/queue" "sigs.k8s.io/kueue/pkg/util/slices" "sigs.k8s.io/kueue/pkg/workload" @@ -192,6 +193,33 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return ctrl.Result{}, err } + if features.Enabled(features.DynamicallySizedJobs) && ptr.Deref(wl.Status.Admission.PodSetAssignments[1].Count, 0) != wl.Spec.PodSets[1].Count { + // Get Memory and CPU values + originalMemoryPerPod := wl.Spec.PodSets[1].Template.Spec.Containers[0].Resources.Requests.Memory() + originalCpuPerPod := wl.Spec.PodSets[1].Template.Spec.Containers[0].Resources.Requests.Cpu() + currentAssignedMemoryPerPod := wl.Status.Admission.PodSetAssignments[1].ResourceUsage.Memory() + currentAssignedCpuPerPod := wl.Status.Admission.PodSetAssignments[1].ResourceUsage.Cpu() + + diff := ptr.Deref(wl.Status.Admission.PodSetAssignments[1].Count, 0) - wl.Spec.PodSets[1].Count + originalMemoryPerPod.Mul(int64(diff)) + originalCpuPerPod.Mul(int64(diff)) + + currentAssignedMemoryPerPod.Sub(*originalMemoryPerPod) + currentAssignedCpuPerPod.Sub(*originalCpuPerPod) + + wl.Status.Admission.PodSetAssignments[1].Count = ptr.To(wl.Spec.PodSets[1].Count) + wl.Status.Admission.PodSetAssignments[1].ResourceUsage = corev1.ResourceList{ + corev1.ResourceMemory: *currentAssignedMemoryPerPod, + corev1.ResourceCPU: *currentAssignedCpuPerPod, + } + + // Update Status + workload.SyncAdmittedCondition(&wl) + if err := workload.ApplyAdmissionStatus(ctx, r.client, &wl, true); err != nil { + return ctrl.Result{}, err + } + } + return r.reconcileNotReadyTimeout(ctx, req, &wl) } @@ -237,7 +265,6 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c return ctrl.Result{}, client.IgnoreNotFound(err) } } - return ctrl.Result{}, nil } diff --git a/pkg/controller/jobframework/reconciler.go b/pkg/controller/jobframework/reconciler.go index 36159f21cc..d677cf50b2 100644 --- a/pkg/controller/jobframework/reconciler.go +++ b/pkg/controller/jobframework/reconciler.go @@ -346,6 +346,21 @@ func (r *JobReconciler) ReconcileGenericJob(ctx context.Context, req ctrl.Reques } } + // 4.1 update podSetCount for RayCluster resize + if features.Enabled(features.DynamicallySizedJobs) && workload.IsAdmitted(wl) && job.GVK().Kind == "RayCluster" { + podSets := job.PodSets() + jobPodSetCount := podSets[1].Count + workloadPodSetCount := wl.Spec.PodSets[1].Count + if workloadPodSetCount > jobPodSetCount { + toUpdate := wl + _, err := r.updateWorkloadToMatchJob(ctx, job, object, toUpdate, "Updated Workload due to resize: %v") + if err != nil { + return ctrl.Result{}, err + } + return ctrl.Result{}, nil + } + } + // 5. handle WaitForPodsReady only for a standalone job. // handle a job when waitForPodsReady is enabled, and it is the main job if r.waitForPodsReady { @@ -572,7 +587,7 @@ func (r *JobReconciler) ensureOneWorkload(ctx context.Context, job GenericJob, o } if toUpdate != nil { - return r.updateWorkloadToMatchJob(ctx, job, object, toUpdate) + return r.updateWorkloadToMatchJob(ctx, job, object, toUpdate, "Updated not matching Workload for suspended job: %v") } return match, nil @@ -677,7 +692,8 @@ func equivalentToWorkload(ctx context.Context, c client.Client, job GenericJob, jobPodSets := clearMinCountsIfFeatureDisabled(job.PodSets()) if runningPodSets := expectedRunningPodSets(ctx, c, wl); runningPodSets != nil { - if equality.ComparePodSetSlices(jobPodSets, runningPodSets) { + jobPodSetCount := job.PodSets() + if equality.ComparePodSetSlices(jobPodSets, runningPodSets) || (features.Enabled(features.DynamicallySizedJobs) && job.GVK().Kind == "RayCluster" && jobPodSetCount[1].Count < wl.Spec.PodSets[1].Count) { return true } // If the workload is admitted but the job is suspended, do the check @@ -690,7 +706,7 @@ func equivalentToWorkload(ctx context.Context, c client.Client, job GenericJob, return equality.ComparePodSetSlices(jobPodSets, wl.Spec.PodSets) } -func (r *JobReconciler) updateWorkloadToMatchJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload) (*kueue.Workload, error) { +func (r *JobReconciler) updateWorkloadToMatchJob(ctx context.Context, job GenericJob, object client.Object, wl *kueue.Workload, message string) (*kueue.Workload, error) { newWl, err := r.constructWorkload(ctx, job, object) if err != nil { return nil, fmt.Errorf("can't construct workload for update: %w", err) @@ -705,7 +721,7 @@ func (r *JobReconciler) updateWorkloadToMatchJob(ctx context.Context, job Generi } r.record.Eventf(object, corev1.EventTypeNormal, ReasonUpdatedWorkload, - "Updated not matching Workload for suspended job: %v", klog.KObj(wl)) + message, klog.KObj(wl)) return newWl, nil } diff --git a/pkg/features/kube_features.go b/pkg/features/kube_features.go index b9d75db47d..a653ecb25c 100644 --- a/pkg/features/kube_features.go +++ b/pkg/features/kube_features.go @@ -83,6 +83,10 @@ const ( // // Enables lending limit. LendingLimit featuregate.Feature = "LendingLimit" + // owner: @vicenteferrara + // kep: + // alpha: v0.8 + DynamicallySizedJobs featuregate.Feature = "DynamicallySizedJobs" ) func init() { @@ -104,6 +108,7 @@ var defaultFeatureGates = map[featuregate.Feature]featuregate.FeatureSpec{ PrioritySortingWithinCohort: {Default: true, PreRelease: featuregate.Beta}, MultiKueue: {Default: false, PreRelease: featuregate.Alpha}, LendingLimit: {Default: false, PreRelease: featuregate.Alpha}, + DynamicallySizedJobs: {Default: false, PreRelease: featuregate.Alpha}, } func SetFeatureGateDuringTest(tb testing.TB, f featuregate.Feature, value bool) func() { diff --git a/pkg/util/testingjobs/raycluster/wrappers.go b/pkg/util/testingjobs/raycluster/wrappers.go index 7d55e93bdf..e573d4974b 100644 --- a/pkg/util/testingjobs/raycluster/wrappers.go +++ b/pkg/util/testingjobs/raycluster/wrappers.go @@ -80,6 +80,12 @@ func (j *ClusterWrapper) NodeSelectorHeadGroup(k, v string) *ClusterWrapper { return j } +// Set replica count +func (j *ClusterWrapper) SetReplicaCount(c int32) *ClusterWrapper { + j.Spec.WorkerGroupSpecs[0].Replicas = ptr.To(c) + return j +} + // Obj returns the inner Job. func (j *ClusterWrapper) Obj() *rayv1.RayCluster { return &j.RayCluster diff --git a/pkg/webhooks/workload_webhook.go b/pkg/webhooks/workload_webhook.go index a4527f693c..1214a59e35 100644 --- a/pkg/webhooks/workload_webhook.go +++ b/pkg/webhooks/workload_webhook.go @@ -347,7 +347,9 @@ func ValidateWorkloadUpdate(newObj, oldObj *kueue.Workload) field.ErrorList { allErrs = append(allErrs, ValidateWorkload(newObj)...) if workload.HasQuotaReservation(oldObj) { - allErrs = append(allErrs, apivalidation.ValidateImmutableField(newObj.Spec.PodSets, oldObj.Spec.PodSets, specPath.Child("podSets"))...) + if !features.Enabled(features.DynamicallySizedJobs) { + allErrs = append(allErrs, apivalidation.ValidateImmutableField(newObj.Spec.PodSets, oldObj.Spec.PodSets, specPath.Child("podSets"))...) + } allErrs = append(allErrs, apivalidation.ValidateImmutableField(newObj.Spec.PriorityClassSource, oldObj.Spec.PriorityClassSource, specPath.Child("priorityClassSource"))...) allErrs = append(allErrs, apivalidation.ValidateImmutableField(newObj.Spec.PriorityClassName, oldObj.Spec.PriorityClassName, specPath.Child("priorityClassName"))...) } @@ -355,7 +357,9 @@ func ValidateWorkloadUpdate(newObj, oldObj *kueue.Workload) field.ErrorList { allErrs = append(allErrs, apivalidation.ValidateImmutableField(newObj.Spec.QueueName, oldObj.Spec.QueueName, specPath.Child("queueName"))...) allErrs = append(allErrs, validateReclaimablePodsUpdate(newObj, oldObj, field.NewPath("status", "reclaimablePods"))...) } - allErrs = append(allErrs, validateAdmissionUpdate(newObj.Status.Admission, oldObj.Status.Admission, field.NewPath("status", "admission"))...) + if !features.Enabled(features.DynamicallySizedJobs) { + allErrs = append(allErrs, validateAdmissionUpdate(newObj.Status.Admission, oldObj.Status.Admission, field.NewPath("status", "admission"))...) + } allErrs = append(allErrs, validateImmutablePodSetUpdates(newObj, oldObj, statusPath.Child("admissionChecks"))...) return allErrs diff --git a/test/integration/controller/jobs/raycluster/raycluster_controller_test.go b/test/integration/controller/jobs/raycluster/raycluster_controller_test.go index 37bdd78c89..f338730aed 100644 --- a/test/integration/controller/jobs/raycluster/raycluster_controller_test.go +++ b/test/integration/controller/jobs/raycluster/raycluster_controller_test.go @@ -39,6 +39,7 @@ import ( "sigs.k8s.io/kueue/pkg/controller/jobframework" workloadraycluster "sigs.k8s.io/kueue/pkg/controller/jobs/raycluster" _ "sigs.k8s.io/kueue/pkg/controller/jobs/rayjob" // to enable the framework + "sigs.k8s.io/kueue/pkg/features" "sigs.k8s.io/kueue/pkg/util/testing" testingraycluster "sigs.k8s.io/kueue/pkg/util/testingjobs/raycluster" testingrayjob "sigs.k8s.io/kueue/pkg/util/testingjobs/rayjob" @@ -205,7 +206,7 @@ var _ = ginkgo.Describe("RayCluster controller", ginkgo.Ordered, ginkgo.Continue return apimeta.IsStatusConditionTrue(createdWorkload.Status.Conditions, kueue.WorkloadQuotaReserved) }, util.Timeout, util.Interval).Should(gomega.BeTrue()) - ginkgo.By("checking the job gets suspended when parallelism changes and the added node selectors are removed") + ginkgo.By("checking the job is suspended when parallelism increases and the added node selectors are removed") parallelism := ptr.Deref(job.Spec.WorkerGroupSpecs[0].Replicas, 1) newParallelism := parallelism + 1 createdJob.Spec.WorkerGroupSpecs[0].Replicas = &newParallelism @@ -632,6 +633,69 @@ var _ = ginkgo.Describe("RayCluster Job controller interacting with scheduler", util.ExpectPendingWorkloadsMetric(clusterQueue, 0, 0) util.ExpectReservingActiveWorkloadsMetric(clusterQueue, 1) }) + + gomega.Eventually(func() bool { + if err := features.SetEnable(features.DynamicallySizedJobs, true); err != nil { + return false + } + return features.Enabled(features.DynamicallySizedJobs) + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + + ginkgo.It("Should not suspend job when there's a scale down", func() { + ginkgo.By("creating localQueue") + localQueue = testing.MakeLocalQueue("local-queue", ns.Name).ClusterQueue(clusterQueue.Name).Obj() + gomega.Expect(k8sClient.Create(ctx, localQueue)).Should(gomega.Succeed()) + + ginkgo.By("checking a dev job starts") + job := testingraycluster.MakeCluster("dev-job", ns.Name).SetReplicaCount(4).Queue(localQueue.Name). + RequestHead(corev1.ResourceCPU, "1"). + RequestWorkerGroup(corev1.ResourceCPU, "1"). + Obj() + gomega.Expect(k8sClient.Create(ctx, job)).Should(gomega.Succeed()) + createdJob := &rayv1.RayCluster{} + gomega.Eventually(func() bool { + gomega.Expect(k8sClient.Get(ctx, types.NamespacedName{Name: job.Name, Namespace: job.Namespace}, createdJob)). + Should(gomega.Succeed()) + return *createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.BeFalse()) + gomega.Expect(createdJob.Spec.HeadGroupSpec.Template.Spec.NodeSelector[instanceKey]).Should(gomega.Equal(spotUntaintedFlavor.Name)) + gomega.Expect(createdJob.Spec.WorkerGroupSpecs[0].Template.Spec.NodeSelector[instanceKey]).Should(gomega.Equal(onDemandFlavor.Name)) + util.ExpectPendingWorkloadsMetric(clusterQueue, 0, 0) + util.ExpectReservingActiveWorkloadsMetric(clusterQueue, 1) + + ginkgo.By("reduce the number of replicas, check the job is not suspended") + replicaCount := ptr.Deref(job.Spec.WorkerGroupSpecs[0].Replicas, 1) + newReplicaCount := replicaCount - 2 + createdJob.Spec.WorkerGroupSpecs[0].Replicas = &newReplicaCount + gomega.Expect(k8sClient.Update(ctx, createdJob)).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + gomega.Expect(k8sClient.Get(ctx, types.NamespacedName{Name: job.Name, Namespace: job.Namespace}, createdJob)). + Should(gomega.Succeed()) + return *createdJob.Spec.Suspend + }, util.Timeout, util.Interval).Should(gomega.BeFalse()) + + ginkgo.By("checking the workload is updated with new count") + createdWorkload := &kueue.Workload{} + wlLookupKey := types.NamespacedName{Name: workloadraycluster.GetWorkloadNameForRayCluster(job.Name, job.UID), Namespace: ns.Name} + + gomega.Eventually(func() error { + return k8sClient.Get(ctx, wlLookupKey, createdWorkload) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { + return false + } + return createdWorkload.Spec.PodSets[1].Count == newReplicaCount + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + gomega.Eventually(func() bool { + if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil { + return false + } + return *createdWorkload.Status.Admission.PodSetAssignments[1].Count == newReplicaCount + }, util.Timeout, util.Interval).Should(gomega.BeTrue()) + + }) + }) var _ = ginkgo.Describe("Job controller with preemption enabled", ginkgo.Ordered, ginkgo.ContinueOnFailure, func() {