Skip to content

Commit

Permalink
add unit test for xgboostjob controller
Browse files Browse the repository at this point in the history
  • Loading branch information
zw0610 committed Jan 6, 2022
1 parent 487f961 commit 30f053e
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 49 deletions.
81 changes: 81 additions & 0 deletions pkg/common/util/v1/testutil/xgboostjob.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright 2022 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package testutil

import (
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

xgboostjobv1 "github.com/kubeflow/training-operator/pkg/apis/xgboost/v1"
)

func NewXGBoostJobWithMaster(worker int) *xgboostjobv1.XGBoostJob {
job := NewXGoostJob(worker)
master := int32(1)
masterReplicaSpec := &commonv1.ReplicaSpec{
Replicas: &master,
Template: NewXGBoostReplicaSpecTemplate(),
}
job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(xgboostjobv1.XGBoostReplicaTypeMaster)] = masterReplicaSpec
return job
}

func NewXGoostJob(worker int) *xgboostjobv1.XGBoostJob {

job := &xgboostjobv1.XGBoostJob{
TypeMeta: metav1.TypeMeta{
Kind: xgboostjobv1.Kind,
},
ObjectMeta: metav1.ObjectMeta{
Name: "test-xgboostjob",
Namespace: metav1.NamespaceDefault,
},
Spec: xgboostjobv1.XGBoostJobSpec{
XGBReplicaSpecs: make(map[commonv1.ReplicaType]*commonv1.ReplicaSpec),
},
}

if worker > 0 {
worker := int32(worker)
workerReplicaSpec := &commonv1.ReplicaSpec{
Replicas: &worker,
Template: NewXGBoostReplicaSpecTemplate(),
}
job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(xgboostjobv1.XGBoostReplicaTypeWorker)] = workerReplicaSpec
}

return job
}

func NewXGBoostReplicaSpecTemplate() corev1.PodTemplateSpec {
return corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
corev1.Container{
Name: xgboostjobv1.DefaultContainerName,
Image: "test-image-for-kubeflow-xgboost-operator:latest",
Args: []string{"Fake", "Fake"},
Ports: []corev1.ContainerPort{
corev1.ContainerPort{
Name: xgboostjobv1.DefaultPortName,
ContainerPort: xgboostjobv1.DefaultPort,
},
},
},
},
},
}
}
79 changes: 79 additions & 0 deletions pkg/controller.v1/xgboost/pod_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2021 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package xgboost

import (
"testing"

commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"

xgboostv1 "github.com/kubeflow/training-operator/pkg/apis/xgboost/v1"
"github.com/kubeflow/training-operator/pkg/common/util/v1/testutil"
)

func TestClusterSpec(t *testing.T) {
type tc struct {
job *xgboostv1.XGBoostJob
rt commonv1.ReplicaType
index string
expectedClusterSpec map[string]string
}
testCase := []tc{
tc{
job: testutil.NewXGBoostJobWithMaster(0),
rt: xgboostv1.XGBoostReplicaTypeMaster,
index: "0",
expectedClusterSpec: map[string]string{"WORLD_SIZE": "1", "MASTER_PORT": "9999", "RANK": "0", "MASTER_ADDR": "test-xgboostjob-master-0"},
},
tc{
job: testutil.NewXGBoostJobWithMaster(1),
rt: xgboostv1.XGBoostReplicaTypeMaster,
index: "1",
expectedClusterSpec: map[string]string{"WORLD_SIZE": "2", "MASTER_PORT": "9999", "RANK": "1", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0"},
},
tc{
job: testutil.NewXGBoostJobWithMaster(2),
rt: xgboostv1.XGBoostReplicaTypeMaster,
index: "0",
expectedClusterSpec: map[string]string{"WORLD_SIZE": "3", "MASTER_PORT": "9999", "RANK": "0", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0,test-xgboostjob-worker-1"},
},
tc{
job: testutil.NewXGBoostJobWithMaster(2),
rt: xgboostv1.XGBoostReplicaTypeWorker,
index: "0",
expectedClusterSpec: map[string]string{"WORLD_SIZE": "3", "MASTER_PORT": "9999", "RANK": "1", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0,test-xgboostjob-worker-1"},
},
tc{
job: testutil.NewXGBoostJobWithMaster(2),
rt: xgboostv1.XGBoostReplicaTypeWorker,
index: "1",
expectedClusterSpec: map[string]string{"WORLD_SIZE": "3", "MASTER_PORT": "9999", "RANK": "2", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0,test-xgboostjob-worker-1"},
},
}
for _, c := range testCase {
demoTemplateSpec := c.job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(c.rt)].Template
if err := SetPodEnv(c.job, &demoTemplateSpec, string(c.rt), c.index); err != nil {
t.Errorf("Failed to set cluster spec: %v", err)
}
actual := demoTemplateSpec.Spec.Containers[0].Env
for _, env := range actual {
if val, ok := c.expectedClusterSpec[env.Name]; ok {
if val != env.Value {
t.Errorf("For name %s Got %s. Expected %s ", env.Name, env.Value, c.expectedClusterSpec[env.Name])
}
}
}
}
}
52 changes: 48 additions & 4 deletions pkg/controller.v1/xgboost/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
package xgboost

import (
"context"
"fmt"
"path/filepath"
"testing"
"time"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes/scheme"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/envtest"
"sigs.k8s.io/controller-runtime/pkg/envtest/printer"
Expand All @@ -34,8 +39,13 @@ import (
// These tests use Ginkgo (BDD-style Go testing framework). Refer to
// http://onsi.github.io/ginkgo/ to learn more about Ginkgo.

var k8sClient client.Client
var testEnv *envtest.Environment
var (
testK8sClient client.Client
testEnv *envtest.Environment
testCtx context.Context
testCancel context.CancelFunc
reconciler *XGBoostJobReconciler
)

func TestAPIs(t *testing.T) {
RegisterFailHandler(Fail)
Expand All @@ -46,8 +56,14 @@ func TestAPIs(t *testing.T) {
}

var _ = BeforeSuite(func() {
const (
timeout = 10 * time.Second
interval = 1000 * time.Millisecond
)
logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true)))

testCtx, testCancel = context.WithCancel(context.TODO())

By("bootstrapping test environment")
testEnv = &envtest.Environment{
CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")},
Expand All @@ -63,14 +79,42 @@ var _ = BeforeSuite(func() {

//+kubebuilder:scaffold:scheme

k8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme})
testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme})
Expect(err).NotTo(HaveOccurred())
Expect(k8sClient).NotTo(BeNil())
Expect(testK8sClient).NotTo(BeNil())

mgr, err := ctrl.NewManager(cfg, ctrl.Options{
MetricsBindAddress: "0",
})
Expect(err).NotTo(HaveOccurred())

reconciler = NewReconciler(mgr, false)
Expect(reconciler.SetupWithManager(mgr)).NotTo(HaveOccurred())

go func() {
defer GinkgoRecover()
err = mgr.Start(testCtx)
Expect(err).ToNot(HaveOccurred(), "failed to run manager")
}()

// This step is introduced to make sure cache starts before running any tests
Eventually(func() error {
nsList := &corev1.NamespaceList{}
if err := testK8sClient.List(context.Background(), nsList); err != nil {
return err
} else if len(nsList.Items) < 1 {
return fmt.Errorf("cannot get at lease one namespace, got %d", len(nsList.Items))
}
return nil
}, timeout, interval).Should(BeNil())

}, 60)

var _ = AfterSuite(func() {
By("tearing down the test environment")
testCancel()
// Give 5 seconds to stop all tests
time.Sleep(5 * time.Second)
err := testEnv.Stop()
Expect(err).NotTo(HaveOccurred())
})
45 changes: 0 additions & 45 deletions pkg/version/version.go

This file was deleted.

0 comments on commit 30f053e

Please sign in to comment.