diff --git a/cmd/training-operator.v1/main.go b/cmd/training-operator.v1/main.go index 0d5cd145eb..87b2c00550 100644 --- a/cmd/training-operator.v1/main.go +++ b/cmd/training-operator.v1/main.go @@ -29,6 +29,7 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/healthz" "sigs.k8s.io/controller-runtime/pkg/log/zap" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" commonutil "github.com/kubeflow/common/pkg/util" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -45,6 +46,7 @@ var ( func init() { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) utilruntime.Must(kubeflowv1.AddToScheme(scheme)) + utilruntime.Must(v1beta1.AddToScheme(scheme)) //+kubebuilder:scaffold:scheme } diff --git a/pkg/controller.v1/mpi/mpijob_controller.go b/pkg/controller.v1/mpi/mpijob_controller.go index 8463543e4a..d36d5fa609 100644 --- a/pkg/controller.v1/mpi/mpijob_controller.go +++ b/pkg/controller.v1/mpi/mpijob_controller.go @@ -53,6 +53,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -242,6 +243,19 @@ func (jc *MPIJobReconciler) SetupWithManager(mgr ctrl.Manager) error { return err } + // skip watching podgroup if PodGroup is not installed + _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.SchemeGroupVersion.Group, Kind: "PodGroup"}, + v1beta1.SchemeGroupVersion.Version) + if err == nil { + // inject watching for job related PodGroup + if err = c.Watch(&source.Kind{Type: &v1beta1.PodGroup{}}, &handler.EnqueueRequestForOwner{ + IsController: true, + OwnerType: &kubeflowv1.MPIJob{}, + }, predicates); err != nil { + return err + } + } + return nil } diff --git a/pkg/controller.v1/mpi/suite_test.go b/pkg/controller.v1/mpi/suite_test.go index d814c32692..e821208644 100644 --- a/pkg/controller.v1/mpi/suite_test.go +++ b/pkg/controller.v1/mpi/suite_test.go @@ -25,11 +25,13 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" + v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -65,6 +67,8 @@ var _ = BeforeSuite(func() { Expect(err).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) + err = v1beta1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) err = kubeflowv1.AddToScheme(scheme.Scheme) Expect(err).NotTo(HaveOccurred()) diff --git a/pkg/controller.v1/mxnet/mxjob_controller.go b/pkg/controller.v1/mxnet/mxjob_controller.go index 615f878c18..c0d977c488 100644 --- a/pkg/controller.v1/mxnet/mxjob_controller.go +++ b/pkg/controller.v1/mxnet/mxjob_controller.go @@ -47,6 +47,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -216,6 +217,23 @@ func (r *MXJobReconciler) SetupWithManager(mgr ctrl.Manager) error { return err } + // skip watching podgroup if podgroup is not installed + _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.SchemeGroupVersion.Group, Kind: "PodGroup"}, + v1beta1.SchemeGroupVersion.Version) + if err == nil { + // inject watching for job related podgroup + if err = c.Watch(&source.Kind{Type: &v1beta1.PodGroup{}}, &handler.EnqueueRequestForOwner{ + IsController: true, + OwnerType: &kubeflowv1.MXJob{}, + }, predicate.Funcs{ + CreateFunc: util.OnDependentCreateFunc(r.Expectations), + UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), + DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), + }); err != nil { + return err + } + } + return nil } diff --git a/pkg/controller.v1/mxnet/suite_test.go b/pkg/controller.v1/mxnet/suite_test.go index 1e15cb2911..e78479fe7d 100644 --- a/pkg/controller.v1/mxnet/suite_test.go +++ b/pkg/controller.v1/mxnet/suite_test.go @@ -27,6 +27,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" + v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -55,6 +56,8 @@ var _ = BeforeSuite(func() { Expect(err).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) + err = v1beta1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) err = kubeflowv1.AddToScheme(scheme.Scheme) Expect(err).NotTo(HaveOccurred()) diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller.go b/pkg/controller.v1/pytorch/pytorchjob_controller.go index 717f6c259d..18aa303a1a 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller.go @@ -47,6 +47,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -214,6 +215,22 @@ func (r *PyTorchJobReconciler) SetupWithManager(mgr ctrl.Manager) error { }); err != nil { return err } + // skip watching podgroup if podgroup is not installed + _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.SchemeGroupVersion.Group, Kind: "PodGroup"}, + v1beta1.SchemeGroupVersion.Version) + if err == nil { + // inject watching for job related podgroup + if err = c.Watch(&source.Kind{Type: &v1beta1.PodGroup{}}, &handler.EnqueueRequestForOwner{ + IsController: true, + OwnerType: &kubeflowv1.PyTorchJob{}, + }, predicate.Funcs{ + CreateFunc: util.OnDependentCreateFunc(r.Expectations), + UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), + DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), + }); err != nil { + return err + } + } return nil } diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go index a0e2fb0497..ea7816966c 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go @@ -31,6 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" + v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" //+kubebuilder:scaffold:imports ) @@ -65,6 +66,8 @@ var _ = BeforeSuite(func() { Expect(err).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) + err = v1beta1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) err = kubeflowv1.AddToScheme(scheme.Scheme) Expect(err).NotTo(HaveOccurred()) diff --git a/pkg/controller.v1/tensorflow/suite_test.go b/pkg/controller.v1/tensorflow/suite_test.go index 160279e60a..667e31c9e0 100644 --- a/pkg/controller.v1/tensorflow/suite_test.go +++ b/pkg/controller.v1/tensorflow/suite_test.go @@ -30,6 +30,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" + v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" //+kubebuilder:scaffold:imports @@ -71,6 +72,8 @@ var _ = BeforeSuite(func() { Expect(err).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) + err = v1beta1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) err = kubeflowv1.AddToScheme(scheme.Scheme) Expect(err).NotTo(HaveOccurred()) diff --git a/pkg/controller.v1/tensorflow/tfjob_controller.go b/pkg/controller.v1/tensorflow/tfjob_controller.go index 56ca4fe561..5f29bf5141 100644 --- a/pkg/controller.v1/tensorflow/tfjob_controller.go +++ b/pkg/controller.v1/tensorflow/tfjob_controller.go @@ -49,6 +49,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/source" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -235,6 +236,22 @@ func (r *TFJobReconciler) SetupWithManager(mgr ctrl.Manager) error { }); err != nil { return err } + // skip watching podgroup if podgroup is not installed + _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.SchemeGroupVersion.Group, Kind: "PodGroup"}, + v1beta1.SchemeGroupVersion.Version) + if err == nil { + // inject watching for job related podgroup + if err = c.Watch(&source.Kind{Type: &v1beta1.PodGroup{}}, &handler.EnqueueRequestForOwner{ + IsController: true, + OwnerType: &kubeflowv1.TFJob{}, + }, predicate.Funcs{ + CreateFunc: util.OnDependentCreateFunc(r.Expectations), + UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), + DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), + }); err != nil { + return err + } + } return nil } diff --git a/pkg/controller.v1/xgboost/suite_test.go b/pkg/controller.v1/xgboost/suite_test.go index f45bfc97d8..3b0d789c43 100644 --- a/pkg/controller.v1/xgboost/suite_test.go +++ b/pkg/controller.v1/xgboost/suite_test.go @@ -25,6 +25,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" + v1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" //+kubebuilder:scaffold:imports @@ -55,6 +56,8 @@ var _ = BeforeSuite(func() { Expect(err).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) + err = v1beta1.AddToScheme(scheme.Scheme) + Expect(err).NotTo(HaveOccurred()) err = kubeflowv1.AddToScheme(scheme.Scheme) Expect(err).NotTo(HaveOccurred()) diff --git a/pkg/controller.v1/xgboost/xgboostjob_controller.go b/pkg/controller.v1/xgboost/xgboostjob_controller.go index 2d6da31143..2b9295ec54 100644 --- a/pkg/controller.v1/xgboost/xgboostjob_controller.go +++ b/pkg/controller.v1/xgboost/xgboostjob_controller.go @@ -49,6 +49,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/controller-runtime/pkg/source" + "volcano.sh/apis/pkg/apis/scheduling/v1beta1" volcanoclient "volcano.sh/apis/pkg/client/clientset/versioned" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -221,6 +222,22 @@ func (r *XGBoostJobReconciler) SetupWithManager(mgr ctrl.Manager) error { }); err != nil { return err } + // skip watching podgroup if podgroup is not installed + _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.SchemeGroupVersion.Group, Kind: "PodGroup"}, + v1beta1.SchemeGroupVersion.Version) + if err == nil { + // inject watching for job related podgroup + if err = c.Watch(&source.Kind{Type: &v1beta1.PodGroup{}}, &handler.EnqueueRequestForOwner{ + IsController: true, + OwnerType: &kubeflowv1.XGBoostJob{}, + }, predicate.Funcs{ + CreateFunc: util.OnDependentCreateFunc(r.Expectations), + UpdateFunc: util.OnDependentUpdateFunc(&r.JobController), + DeleteFunc: util.OnDependentDeleteFunc(r.Expectations), + }); err != nil { + return err + } + } return nil }