diff --git a/go.mod b/go.mod index 75e519e2e6..e90df6e37d 100644 --- a/go.mod +++ b/go.mod @@ -74,12 +74,15 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.11.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/cobra v1.7.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/testify v1.8.4 // indirect go.etcd.io/etcd/api/v3 v3.5.9 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.9 // indirect go.etcd.io/etcd/client/v3 v3.5.9 // indirect diff --git a/go.sum b/go.sum index 1853395222..a2b37d63f3 100644 --- a/go.sum +++ b/go.sum @@ -301,6 +301,7 @@ github.com/stoewer/go-strcase v1.3.0 h1:g0eASXYtp+yvN9fK8sH94oCIk0fau9uV1/ZdJ0AV github.com/stoewer/go-strcase v1.3.0/go.mod h1:fAH5hQ5pehh+j3nZfvwdk2RgEgQjAoM8wodgtPmh1xo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= diff --git a/pkg/controller/core/workload_controller.go b/pkg/controller/core/workload_controller.go index b6df543523..c145e70d6d 100644 --- a/pkg/controller/core/workload_controller.go +++ b/pkg/controller/core/workload_controller.go @@ -34,6 +34,7 @@ import ( "k8s.io/utils/ptr" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/reconcile" @@ -131,7 +132,7 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c log.V(2).Info("Reconciling Workload") if apimeta.IsStatusConditionTrue(wl.Status.Conditions, kueue.WorkloadFinished) { - return ctrl.Result{}, nil + return ctrl.Result{}, r.removeFinalizer(ctx, &wl) } if rejectedChecks := workload.GetRejectedChecks(&wl); len(rejectedChecks) > 0 { @@ -379,6 +380,13 @@ func (r *WorkloadReconciler) Delete(e event.DeleteEvent) bool { return true } +func (r *WorkloadReconciler) removeFinalizer(ctx context.Context, wl *kueue.Workload) error { + if controllerutil.RemoveFinalizer(wl, kueue.ResourceInUseFinalizerName) { + return r.client.Update(ctx, wl) + } + return nil +} + func (r *WorkloadReconciler) Update(e event.UpdateEvent) bool { oldWl, isWorkload := e.ObjectOld.(*kueue.Workload) if !isWorkload { diff --git a/pkg/controller/core/workload_controller_test.go b/pkg/controller/core/workload_controller_test.go index de8e18b70c..91d03912c2 100644 --- a/pkg/controller/core/workload_controller_test.go +++ b/pkg/controller/core/workload_controller_test.go @@ -17,16 +17,22 @@ limitations under the License. package core import ( + "context" + "errors" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/mock" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" testingclock "k8s.io/utils/clock/testing" "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/test/util" ) func TestAdmittedNotReadyWorkload(t *testing.T) { @@ -290,3 +296,106 @@ func TestSyncCheckStates(t *testing.T) { }) } } + +func TestFinalizer(t *testing.T) { + now := time.Now() + + testCases := map[string]struct { + workload kueue.Workload + expectRequeue bool + expectError bool + setupMock func(mockClient *util.MockClient) + assert func(t *testing.T, mockClient *util.MockClient) + }{ + "Workload with WorkloadFinished=True and existing Finalizer; remove Finalizer": { + workload: kueue.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Finalizers: []string{kueue.ResourceInUseFinalizerName}, + }, + Status: kueue.WorkloadStatus{ + Conditions: []metav1.Condition{ + { + Type: kueue.WorkloadFinished, + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.NewTime(now), + }, + }, + }, + }, + expectRequeue: false, + expectError: false, + setupMock: func(mockClient *util.MockClient) { + mockClient.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(nil) + }, + assert: func(t *testing.T, mockClient *util.MockClient) { + mockClient.AssertCalled(t, "Update", mock.Anything, mock.MatchedBy(func(workload *kueue.Workload) bool { + return len(workload.ObjectMeta.Finalizers) == 0 + }), mock.Anything) + }, + }, + "Workload with WorkloadFinished=True and existing Finalizer, removing fails, retry; remove Finalizer": { + workload: kueue.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Finalizers: []string{kueue.ResourceInUseFinalizerName}, + }, + Status: kueue.WorkloadStatus{ + Conditions: []metav1.Condition{ + { + Type: kueue.WorkloadFinished, + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.NewTime(now), + }, + }, + }, + }, + expectRequeue: false, + expectError: true, + setupMock: func(mockClient *util.MockClient) { + mockClient.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("test")) + }, + assert: func(t *testing.T, mockClient *util.MockClient) { + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + mockClient := new(util.MockClient) + + mockClient.On("Get", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + workload := args.Get(2).(*kueue.Workload) + *workload = tc.workload + }).Return(nil) + + tc.setupMock(mockClient) + + wRec := WorkloadReconciler{client: mockClient} + + req := ctrl.Request{ + NamespacedName: types.NamespacedName{ + Name: "test", + Namespace: "test", + }, + } + + res, err := wRec.Reconcile(context.TODO(), req) + if tc.expectError && err == nil { + t.Errorf("Expected error, but got none") + } + + if !tc.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if tc.expectRequeue && !res.Requeue { + t.Errorf("Expected requeue, but got none") + } + + if !tc.expectRequeue && res.Requeue { + t.Errorf("Expected no requeue, but got one") + } + + tc.assert(t, mockClient) + }) + } +} diff --git a/pkg/controller/jobframework/reconciler.go b/pkg/controller/jobframework/reconciler.go index 81ddba14fd..3d64a2f071 100644 --- a/pkg/controller/jobframework/reconciler.go +++ b/pkg/controller/jobframework/reconciler.go @@ -255,16 +255,10 @@ func (r *JobReconciler) ReconcileGenericJob(ctx context.Context, req ctrl.Reques log.V(2).Info("The workload is marked for deletion") err := r.stopJob(ctx, job, wl, StopReasonWorkloadDeleted, "Workload is deleted") if err != nil { - if !apierrors.IsNotFound(err) { - return ctrl.Result{}, err - } - log.Error(err, "Suspending job with deleted workload") - } - if wl != nil { - err := r.removeFinalizer(ctx, wl) - if err != nil { - return ctrl.Result{}, err + if apierrors.IsNotFound(err) { + log.Error(err, "Suspending job with deleted workload") } + return ctrl.Result{}, err } return ctrl.Result{}, nil } diff --git a/test/util/mock.go b/test/util/mock.go new file mode 100644 index 0000000000..c781926105 --- /dev/null +++ b/test/util/mock.go @@ -0,0 +1,80 @@ +package util + +import ( + "context" + + "github.com/stretchr/testify/mock" + "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type MockClient struct { + mock.Mock +} + +func (m *MockClient) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + args := m.Called(ctx, key, obj, opts) + return args.Error(0) +} + +func (m *MockClient) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *MockClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *MockClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *MockClient) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + args := m.Called(ctx, list, opts) + return args.Error(0) +} + +func (m *MockClient) Patch(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.PatchOption) error { + args := m.Called(ctx, obj, patch, opts) + return args.Error(0) +} + +func (m *MockClient) DeleteAllOf(ctx context.Context, obj client.Object, opts ...client.DeleteAllOfOption) error { + args := m.Called(ctx, obj, opts) + return args.Error(0) +} + +func (m *MockClient) Status() client.StatusWriter { + args := m.Called() + return args.Get(0).(client.StatusWriter) +} + +func (m *MockClient) Scheme() *runtime.Scheme { + args := m.Called() + return args.Get(0).(*runtime.Scheme) +} + +func (m *MockClient) RESTMapper() meta.RESTMapper { + args := m.Called() + return args.Get(0).(meta.RESTMapper) +} + +func (m *MockClient) GroupVersionKindFor(obj runtime.Object) (gvk schema.GroupVersionKind, err error) { + args := m.Called(obj) + return args.Get(0).(schema.GroupVersionKind), args.Error(1) +} + +func (m *MockClient) IsObjectNamespaced(obj runtime.Object) (isNamespaced bool, err error) { + args := m.Called(obj) + return args.Get(0).(bool), args.Error(1) +} + +func (m *MockClient) SubResource(subResource string) client.SubResourceClient { + args := m.Called(subResource) + return args.Get(0).(client.SubResourceClient) +}