Skip to content

Commit

Permalink
Improve pod support
Browse files Browse the repository at this point in the history
* Add UnretryableError error, doesn't require reconcile
  retry.

* Add integration/unit tests.

* Address PR comments.

* Add ValidateLabelAsCRDName call for the pod-group,
  make pod-group label immutable.
  • Loading branch information
achernevskii committed Nov 17, 2023
1 parent bf0b607 commit 5695e32
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 89 deletions.
22 changes: 22 additions & 0 deletions pkg/controller/jobframework/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package jobframework

import "errors"

// UnretryableError is an error that doesn't require reconcile retry
// and will not be returned by the JobReconciler.
func UnretryableError(msg string) error {
return &unretryableError{msg: msg}
}

type unretryableError struct {
msg string
}

func (e *unretryableError) Error() string {
return e.msg
}

func IsUnretryableError(e error) bool {
var unretryableError *unretryableError
return errors.As(e, &unretryableError)
}
6 changes: 3 additions & 3 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ type JobWithPriorityClass interface {
}

// ComposableJob interface should be implemented by generic jobs that
// has to be composed out of different API objects.
// are composed out of multiple API objects.
type ComposableJob interface {
// ConstructComposableWorkload returns a new Workload that's assembled out of all parts of a ComposableJob.
// ConstructComposableWorkload returns a new Workload that's assembled out of all members of the ComposableJob.
ConstructComposableWorkload(ctx context.Context, c client.Client, r record.EventRecorder) (*kueue.Workload, error)
// FindMatchingWorkloads returns all related workloads, workload that matches the ComposableJob and duplicates that has to be deleted.
FindMatchingWorkloads(ctx context.Context, c client.Client) (match *kueue.Workload, toDelete []*kueue.Workload, workloads *kueue.WorkloadList, err error)
FindMatchingWorkloads(ctx context.Context, c client.Client) (match *kueue.Workload, toDelete []*kueue.Workload, err error)
// IsComposableJobFinished fetches all the parts of ComposableJob and checks if the whole job has finished.
IsComposableJobFinished(ctx context.Context, c client.Client) (condition metav1.Condition, finished bool)
}
Expand Down
48 changes: 33 additions & 15 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,19 @@ func NewReconciler(
}
}

func (r *JobReconciler) ReconcileGenericJob(ctx context.Context, req ctrl.Request, job GenericJob) (ctrl.Result, error) {
object := job.Object()
func (r *JobReconciler) ReconcileGenericJobWrapper(ctx context.Context, req ctrl.Request, job GenericJob) (ctrl.Result, error) {
log := ctrl.LoggerFrom(ctx).WithValues("job", req.String(), "gvk", job.GVK())
ctx = ctrl.LoggerInto(ctx, log)

result, err := r.ReconcileGenericJob(ctx, req, job)

return result, r.ignoreUnretryableError(ctx, err)
}

func (r *JobReconciler) ReconcileGenericJob(ctx context.Context, req ctrl.Request, job GenericJob) (ctrl.Result, error) {
object := job.Object()
log := ctrl.LoggerFrom(ctx)

err := r.client.Get(ctx, req.NamespacedName, object)

if jws, implements := job.(JobWithSkip); implements {
Expand Down Expand Up @@ -398,17 +406,16 @@ func (r *JobReconciler) ensureOneWorkload(ctx context.Context, job GenericJob, o
// Find a matching workload first if there is one.
var toDelete []*kueue.Workload
var match *kueue.Workload
var workloads *kueue.WorkloadList
if cj, implements := job.(ComposableJob); implements {
var err error
match, toDelete, workloads, err = cj.FindMatchingWorkloads(ctx, r.client)
match, toDelete, err = cj.FindMatchingWorkloads(ctx, r.client)
if err != nil {
log.Error(err, "Composable job is unable to find matching workloads")
return nil, err
}
} else {
var err error
match, toDelete, workloads, err = FindMatchingWorkloads(ctx, r.client, job, object)
match, toDelete, err = FindMatchingWorkloads(ctx, r.client, job)
if err != nil {
log.Error(err, "Unable to list child workloads")
return nil, err
Expand All @@ -425,11 +432,11 @@ func (r *JobReconciler) ensureOneWorkload(ctx context.Context, job GenericJob, o
if match == nil && !job.IsSuspended() {
log.V(2).Info("job with no matching workload, suspending")
var w *kueue.Workload
if len(workloads.Items) == 1 {
if len(toDelete) == 1 {
// The job may have been modified and hence the existing workload
// doesn't match the job anymore. All bets are off if there are more
// than one workload...
w = &workloads.Items[0]
w = toDelete[0]
}

if _, finished := job.Finished(); finished {
Expand Down Expand Up @@ -477,31 +484,33 @@ func (r *JobReconciler) ensureOneWorkload(ctx context.Context, job GenericJob, o
return match, nil
}

func FindMatchingWorkloads(ctx context.Context, c client.Client, job GenericJob, object client.Object) (match *kueue.Workload, toDelete []*kueue.Workload, workloads *kueue.WorkloadList, err error) {
workloads = &kueue.WorkloadList{}
func FindMatchingWorkloads(ctx context.Context, c client.Client, job GenericJob) (match *kueue.Workload, toDelete []*kueue.Workload, err error) {
object := job.Object()

workloads := &kueue.WorkloadList{}
if err := c.List(ctx, workloads, client.InNamespace(object.GetNamespace()),
client.MatchingFields{getOwnerKey(job.GVK()): object.GetName()}); err != nil {
return nil, nil, nil, err
return nil, nil, err
}

for i := range workloads.Items {
w := &workloads.Items[i]
if match == nil && equivalentToWorkload(job, object, w) {
if match == nil && equivalentToWorkload(job, w) {
match = w
} else {
toDelete = append(toDelete, w)
}
}

return match, toDelete, workloads, nil
return match, toDelete, nil
}

// equivalentToWorkload checks if the job corresponds to the workload
func equivalentToWorkload(job GenericJob, object client.Object, wl *kueue.Workload) bool {
func equivalentToWorkload(job GenericJob, wl *kueue.Workload) bool {
owner := metav1.GetControllerOf(wl)
// Indexes don't work in unit tests, so we explicitly check for the
// owner here.
if owner.Name != object.GetName() {
if owner.Name != job.Object().GetName() {
return false
}

Expand Down Expand Up @@ -763,6 +772,15 @@ func (r *JobReconciler) handleJobWithNoWorkload(ctx context.Context, job Generic
return nil
}

func (r *JobReconciler) ignoreUnretryableError(ctx context.Context, err error) error {
if IsUnretryableError(err) {
log := ctrl.LoggerFrom(ctx)
log.V(2).Info("Received an unretryable error", "error", err)
return nil
}
return err
}

func generatePodsReadyCondition(job GenericJob, wl *kueue.Workload) metav1.Condition {
conditionStatus := metav1.ConditionFalse
message := "Not all pods are ready or succeeded"
Expand Down Expand Up @@ -814,7 +832,7 @@ type genericReconciler struct {
}

func (r *genericReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
return r.jr.ReconcileGenericJob(ctx, req, r.newJob())
return r.jr.ReconcileGenericJobWrapper(ctx, req, r.newJob())
}

func (r *genericReconciler) SetupWithManager(mgr ctrl.Manager) error {
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var (

func ValidateCreateForQueueName(job GenericJob) field.ErrorList {
var allErrs field.ErrorList
allErrs = append(allErrs, validateLabelAsCRDName(job, constants.QueueLabel)...)
allErrs = append(allErrs, ValidateLabelAsCRDName(job, constants.QueueLabel)...)
allErrs = append(allErrs, ValidateAnnotationAsCRDName(job, constants.QueueAnnotation)...)
return allErrs
}
Expand All @@ -48,7 +48,7 @@ func ValidateAnnotationAsCRDName(job GenericJob, crdNameAnnotation string) field
return allErrs
}

func validateLabelAsCRDName(job GenericJob, crdNameLabel string) field.ErrorList {
func ValidateLabelAsCRDName(job GenericJob, crdNameLabel string) field.ErrorList {
var allErrs field.ErrorList
if value, exists := job.Object().GetLabels()[crdNameLabel]; exists {
if errs := validation.IsDNS1123Subdomain(value); len(errs) > 0 {
Expand Down
68 changes: 40 additions & 28 deletions pkg/controller/jobs/pod/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@ limitations under the License.
package pod

import (
"cmp"
"context"
"fmt"
"slices"
"strconv"
"strings"
"time"

corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
apimeta "k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
Expand All @@ -48,8 +51,8 @@ const (
FrameworkName = "pod"
gateNotFound = -1
ConditionTypeTerminationTarget = "TerminationTarget"
WorkloadNameKey = "metadata.name"
errMsgIncorrectTotalGroupCount = "group total count is different from the actual number of pods in the cluster"
errMsgIncorrectGroupRoleCount = "pod group can't include more than 8 roles"
)

var (
Expand Down Expand Up @@ -118,8 +121,8 @@ func (p *Pod) Suspend() {

// RunWithPodSetsInfo will inject the node affinity and podSet counts extracting from workload to job and unsuspend it.
func (p *Pod) RunWithPodSetsInfo(podSetsInfo []podset.PodSetInfo) error {
if len(podSetsInfo) != 1 {
return fmt.Errorf("%w: expecting 1 got %d", podset.ErrInvalidPodsetInfo, len(podSetsInfo))
if p.groupName() == "" && len(podSetsInfo) != 1 {
return fmt.Errorf("%w: expecting 1 pod set got %d", podset.ErrInvalidPodsetInfo, len(podSetsInfo))
}
idx := p.gateIndex()
if idx != gateNotFound {
Expand Down Expand Up @@ -290,9 +293,7 @@ func (p *Pod) constructGroupPodSets(podsInGroup corev1.PodList) ([]kueue.PodSet,
return nil, err
}

var (
resultPodSets []kueue.PodSet
)
var resultPodSets []kueue.PodSet

for _, podInGroup := range podsInGroup.Items {
tc, err := strconv.Atoi(podInGroup.GetAnnotations()[GroupTotalCountAnnotation])
Expand All @@ -304,19 +305,17 @@ func (p *Pod) constructGroupPodSets(podsInGroup corev1.PodList) ([]kueue.PodSet,
}
if tc != groupTotalCount {
// ToDo: Add a warning event
return nil, fmt.Errorf("pods '%s' and '%s' has different '%s' values: %d!=%d",
return nil, jobframework.UnretryableError(fmt.Sprintf("pods '%s' and '%s' has different '%s' values: %d!=%d",
p.GetName(), podInGroup.GetName(),
GroupTotalCountAnnotation,
groupTotalCount, tc)
groupTotalCount, tc))
}

roleHash, ok := podInGroup.Annotations[RoleHashAnnotation]
if !ok {
roleHash, err = getRoleHash(fromObject(&podInGroup))
}

roleHash = roleHash[:8]

podRoleFound := false
for psi := range resultPodSets {
if resultPodSets[psi].Name == roleHash {
Expand All @@ -333,6 +332,10 @@ func (p *Pod) constructGroupPodSets(podsInGroup corev1.PodList) ([]kueue.PodSet,
}
}

slices.SortFunc(resultPodSets, func(a, b kueue.PodSet) int {
return cmp.Compare(a.Name, b.Name)
})

return resultPodSets, nil
}

Expand Down Expand Up @@ -361,10 +364,14 @@ func (p *Pod) ConstructComposableWorkload(ctx context.Context, c client.Client,

if len(podsInGroup.Items) != groupTotalCount {
r.Eventf(object, corev1.EventTypeWarning, "ErrWorkloadCompose", errMsgIncorrectTotalGroupCount)
return nil, fmt.Errorf(errMsgIncorrectTotalGroupCount)
return nil, jobframework.UnretryableError(errMsgIncorrectTotalGroupCount)
}

podSets, err = p.constructGroupPodSets(podsInGroup)
if len(podSets) > 8 {
return nil, jobframework.UnretryableError(errMsgIncorrectGroupRoleCount)
}

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -411,58 +418,59 @@ func (p *Pod) ConstructComposableWorkload(ctx context.Context, c client.Client,
return wl, nil
}

func (p *Pod) FindMatchingWorkloads(ctx context.Context, c client.Client) (*kueue.Workload, []*kueue.Workload, *kueue.WorkloadList, error) {
func (p *Pod) FindMatchingWorkloads(ctx context.Context, c client.Client) (*kueue.Workload, []*kueue.Workload, error) {
log := ctrl.LoggerFrom(ctx)
object := p.Object()
groupName := p.groupName()

if groupName == "" {
return jobframework.FindMatchingWorkloads(ctx, c, p, object)
return jobframework.FindMatchingWorkloads(ctx, c, p)
}

// Find a matching workload first if there is one.
workload := &kueue.Workload{}
if err := c.Get(ctx, types.NamespacedName{Name: p.groupName(), Namespace: p.GetNamespace()}, workload); err != nil {
if apierrors.IsNotFound(err) {
return nil, nil, nil, nil
return nil, nil, nil
}
log.Error(err, "Unable to list related workloads")
return nil, nil, nil, err
log.Error(err, "Unable to get related workload")
return nil, nil, err
}

var podsInGroup corev1.PodList
if err := c.List(ctx, &podsInGroup, client.MatchingLabels{
GroupNameLabel: p.groupName(),
}); err != nil {
return nil, nil, nil, err
return nil, nil, err
}

jobPodSets, err := p.constructGroupPodSets(podsInGroup)
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}

if p.equivalentToWorkload(workload, jobPodSets) {
return workload, []*kueue.Workload{}, &kueue.WorkloadList{Items: []kueue.Workload{*workload}}, nil
return workload, []*kueue.Workload{}, nil
} else {
return nil, []*kueue.Workload{workload}, &kueue.WorkloadList{Items: []kueue.Workload{*workload}}, nil
return nil, []*kueue.Workload{workload}, nil
}
}

func (p *Pod) equivalentToWorkload(wl *kueue.Workload, jobPodSets []kueue.PodSet) bool {
workloadFinished := apimeta.FindStatusCondition(wl.Status.Conditions, kueue.WorkloadFinished) != nil

if wl.GetName() != p.groupName() {
return false
}

if len(jobPodSets) != len(wl.Spec.PodSets) {
if !workloadFinished && len(jobPodSets) != len(wl.Spec.PodSets) {
return false
}

for i := range wl.Spec.PodSets {
if wl.Spec.PodSets[i].Count != jobPodSets[i].Count {
if !workloadFinished && wl.Spec.PodSets[i].Count != jobPodSets[i].Count {
return false
}
if wl.Spec.PodSets[i].Name != jobPodSets[i].Name {
if i < len(jobPodSets) && wl.Spec.PodSets[i].Name != jobPodSets[i].Name {
return false
}
}
Expand All @@ -484,11 +492,11 @@ func (p *Pod) IsComposableJobFinished(ctx context.Context, c client.Client) (met
return metav1.Condition{}, false
}

failedPodCount := 0
succeededPodCount := 0

for _, pod := range podsInGroup.Items {
if pod.Status.Phase == corev1.PodFailed {
failedPodCount++
if pod.Status.Phase == corev1.PodSucceeded {
succeededPodCount++
}

podJob := fromObject(&pod)
Expand All @@ -509,10 +517,14 @@ func (p *Pod) IsComposableJobFinished(ctx context.Context, c client.Client) (met
Reason: "JobFinished",
Message: fmt.Sprintf(
"Pod group has finished. Pods succeeded: %d/%d. Pods failed: %d/%d",
groupTotalCount-failedPodCount, groupTotalCount, failedPodCount, groupTotalCount,
succeededPodCount, groupTotalCount, groupTotalCount-succeededPodCount, groupTotalCount,
),
}

if succeededPodCount < groupTotalCount {
condition.Status = metav1.ConditionFalse
}

return condition, true
}

Expand Down
Loading

0 comments on commit 5695e32

Please sign in to comment.