Skip to content

Commit

Permalink
Migrate controller implementation to kubeflow/common fashion (#1171)
Browse files Browse the repository at this point in the history
Signed-off-by: ChanYiLin <j5111261112@gmail.com>
  • Loading branch information
ChanYiLin committed Aug 18, 2020
1 parent da22601 commit 984adc2
Show file tree
Hide file tree
Showing 557 changed files with 6,037 additions and 64,331 deletions.
17 changes: 9 additions & 8 deletions cmd/tf-operator.v1/app/server.go
Expand Up @@ -20,15 +20,15 @@ import (
"os"
"time"

"github.com/kubeflow/common/pkg/util/signals"
"github.com/kubeflow/tf-operator/cmd/tf-operator.v1/app/options"
v1 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"
tfjobclientset "github.com/kubeflow/tf-operator/pkg/client/clientset/versioned"
"github.com/kubeflow/tf-operator/pkg/client/clientset/versioned/scheme"
tfjobinformers "github.com/kubeflow/tf-operator/pkg/client/informers/externalversions"
controller "github.com/kubeflow/tf-operator/pkg/controller.v1/tensorflow"
"github.com/kubeflow/tf-operator/pkg/util/signals"
"github.com/kubeflow/tf-operator/pkg/version"
kubebatchclient "github.com/kubernetes-sigs/kube-batch/pkg/client/clientset/versioned"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
log "github.com/sirupsen/logrus"
Expand All @@ -43,6 +43,7 @@ import (
election "k8s.io/client-go/tools/leaderelection"
"k8s.io/client-go/tools/leaderelection/resourcelock"
"k8s.io/client-go/tools/record"
volcanoclient "volcano.sh/volcano/pkg/client/clientset/versioned"
)

const (
Expand Down Expand Up @@ -106,8 +107,7 @@ func Run(opt *options.ServerOption) error {
kcfg.Burst = opt.Burst

// Create clients.
kubeClientSet, leaderElectionClientSet, tfJobClientSet,
kubeBatchClientSet, err := createClientSets(kcfg)
kubeClientSet, leaderElectionClientSet, tfJobClientSet, volcanoClientSet, err := createClientSets(kcfg)
if err != nil {
return err
}
Expand All @@ -122,7 +122,7 @@ func Run(opt *options.ServerOption) error {
unstructuredInformer := controller.NewUnstructuredTFJobInformer(kcfg, opt.Namespace)

// Create tf controller.
tc := controller.NewTFController(unstructuredInformer, kubeClientSet, kubeBatchClientSet, tfJobClientSet, kubeInformerFactory, tfJobInformerFactory, *opt)
tc := controller.NewTFController(unstructuredInformer, kubeClientSet, volcanoClientSet, tfJobClientSet, kubeInformerFactory, tfJobInformerFactory, *opt)

// Start informer goroutines.
go kubeInformerFactory.Start(stopCh)
Expand Down Expand Up @@ -184,7 +184,7 @@ func Run(opt *options.ServerOption) error {
return nil
}

func createClientSets(config *restclientset.Config) (kubeclientset.Interface, kubeclientset.Interface, tfjobclientset.Interface, kubebatchclient.Interface, error) {
func createClientSets(config *restclientset.Config) (kubeclientset.Interface, kubeclientset.Interface, tfjobclientset.Interface, volcanoclient.Interface, error) {

kubeClientSet, err := kubeclientset.NewForConfig(restclientset.AddUserAgent(config, "tf-operator"))
if err != nil {
Expand All @@ -201,11 +201,12 @@ func createClientSets(config *restclientset.Config) (kubeclientset.Interface, ku
return nil, nil, nil, nil, err
}

kubeBatchClientSet, err := kubebatchclient.NewForConfig(restclientset.AddUserAgent(config, "kube-batch"))
volcanoClientSet, err := volcanoclient.NewForConfig(restclientset.AddUserAgent(config, "volcano"))
if err != nil {
return nil, nil, nil, nil, err
}
return kubeClientSet, leaderElectionClientSet, tfJobClientSet, kubeBatchClientSet, nil

return kubeClientSet, leaderElectionClientSet, tfJobClientSet, volcanoClientSet, nil
}

func checkCRDExists(clientset tfjobclientset.Interface, namespace string) bool {
Expand Down
Expand Up @@ -42,7 +42,7 @@ def scale(image, label):
def build_and_compile_cnn_model():
model = models.Sequential()
model.add(
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
Expand All @@ -63,10 +63,9 @@ def build_and_compile_cnn_model():
def decay(epoch):
if epoch < 3: #pylint: disable=no-else-return
return 1e-3
elif epoch >= 3 and epoch < 7:
if 3 <= epoch < 7:
return 1e-4
else:
return 1e-5
return 1e-5


def main(args):
Expand All @@ -75,7 +74,7 @@ def main(args):
# layers on each device across all workers
# if your GPUs don't support NCCL, replace "communication" with another
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=tf.distribute.experimental.CollectiveCommunication.NCCL)
communication=tf.distribute.experimental.CollectiveCommunication.NCCL)

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
Expand All @@ -84,7 +83,7 @@ def main(args):
ds_train = make_datasets_unbatched().batch(BATCH_SIZE).repeat()
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = \
tf.data.experimental.AutoShardPolicy.DATA
tf.data.experimental.AutoShardPolicy.DATA
ds_train = ds_train.with_options(options)
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_and_compile_cnn_model()
Expand All @@ -105,11 +104,11 @@ def on_epoch_end(self, epoch): #pylint: disable=no-self-use
epoch + 1, multi_worker_model.optimizer.lr.numpy()))

callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]

# Keras' `model.fit()` trains the model with specified number of epochs and
Expand Down
5 changes: 2 additions & 3 deletions go.mod
Expand Up @@ -4,17 +4,15 @@ go 1.14

require (
github.com/go-openapi/spec v0.19.2
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b
github.com/golang/protobuf v1.3.2
github.com/google/go-cmp v0.4.1 // indirect
github.com/grpc-ecosystem/grpc-gateway v1.5.0 // indirect
github.com/kubeflow/common v0.3.1
github.com/kubernetes-sigs/kube-batch v0.0.0-20200414051246-2e934d1c8860
github.com/kubernetes-sigs/kube-batch v0.0.0-20200414051246-2e934d1c8860 // indirect
github.com/onrik/logrus v0.2.2-0.20181225141908-a09d5cdcdc62
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_golang v1.5.1
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0
k8s.io/api v0.16.9
k8s.io/apiextensions-apiserver v0.16.9 // indirect
k8s.io/apimachinery v0.16.10-beta.0
Expand All @@ -25,6 +23,7 @@ require (
k8s.io/kube-openapi v0.0.0-20200410163147-594e756bea31
k8s.io/kubernetes v1.16.9
sigs.k8s.io/yaml v1.2.0 // indirect
volcano.sh/volcano v0.4.0
)

replace (
Expand Down
1 change: 1 addition & 0 deletions go.sum
Expand Up @@ -789,4 +789,5 @@ sigs.k8s.io/yaml v1.2.0 h1:kr/MCeFWJWTwyaHoR9c8EjH9OumOmoF9YGiZd7lFm/Q=
sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc=
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0=
vbom.ml/util v0.0.0-20160121211510-db5cfe13f5cc/go.mod h1:so/NYdZXCz+E3ZpW0uAoCj6uzU2+8OWDFv/HxUSs7kI=
volcano.sh/volcano v0.4.0 h1:B4ot28vzi9bH+hpyv6+qd/EFZFEcE37Lj27/QEI6ly0=
volcano.sh/volcano v0.4.0/go.mod h1:2sNJRhY/oNg0MYdBYORxozuDhvgZxoyeOvKJww/Tl8A=
2 changes: 1 addition & 1 deletion hack/update-codegen.sh
Expand Up @@ -42,5 +42,5 @@ ${GOPATH}/bin/defaulter-gen --input-dirs github.com/kubeflow/tf-operator/pkg/ap
cd - > /dev/null

echo "Generating OpenAPI specification for tensorflow/v1"
${GOPATH}/bin/openapi-gen --input-dirs github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1,k8s.io/api/core/v1,k8s.io/apimachinery/pkg/apis/meta/v1,k8s.io/apimachinery/pkg/api/resource,k8s.io/apimachinery/pkg/runtime,k8s.io/apimachinery/pkg/util/intstr,k8s.io/apimachinery/pkg/version --output-package github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1 --go-header-file hack/boilerplate/boilerplate.go.txt "$@"
${GOPATH}/bin/openapi-gen --report-filename=hack/violation_exception.list --input-dirs github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1,k8s.io/api/core/v1,k8s.io/apimachinery/pkg/apis/meta/v1,k8s.io/apimachinery/pkg/api/resource,k8s.io/apimachinery/pkg/runtime,k8s.io/apimachinery/pkg/util/intstr,k8s.io/apimachinery/pkg/version --output-package github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1 --go-header-file hack/boilerplate/boilerplate.go.txt "$@"
cd - > /dev/null

0 comments on commit 984adc2

Please sign in to comment.