Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate controller implementation to kubeflow/common fashion #1171

Merged
merged 1 commit into from Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
17 changes: 9 additions & 8 deletions cmd/tf-operator.v1/app/server.go
Expand Up @@ -20,15 +20,15 @@ import (
"os"
"time"

"github.com/kubeflow/common/pkg/util/signals"
"github.com/kubeflow/tf-operator/cmd/tf-operator.v1/app/options"
v1 "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"
tfjobclientset "github.com/kubeflow/tf-operator/pkg/client/clientset/versioned"
"github.com/kubeflow/tf-operator/pkg/client/clientset/versioned/scheme"
tfjobinformers "github.com/kubeflow/tf-operator/pkg/client/informers/externalversions"
controller "github.com/kubeflow/tf-operator/pkg/controller.v1/tensorflow"
"github.com/kubeflow/tf-operator/pkg/util/signals"
"github.com/kubeflow/tf-operator/pkg/version"
kubebatchclient "github.com/kubernetes-sigs/kube-batch/pkg/client/clientset/versioned"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
log "github.com/sirupsen/logrus"
Expand All @@ -43,6 +43,7 @@ import (
election "k8s.io/client-go/tools/leaderelection"
"k8s.io/client-go/tools/leaderelection/resourcelock"
"k8s.io/client-go/tools/record"
volcanoclient "volcano.sh/volcano/pkg/client/clientset/versioned"
)

const (
Expand Down Expand Up @@ -106,8 +107,7 @@ func Run(opt *options.ServerOption) error {
kcfg.Burst = opt.Burst

// Create clients.
kubeClientSet, leaderElectionClientSet, tfJobClientSet,
kubeBatchClientSet, err := createClientSets(kcfg)
kubeClientSet, leaderElectionClientSet, tfJobClientSet, volcanoClientSet, err := createClientSets(kcfg)
if err != nil {
return err
}
Expand All @@ -122,7 +122,7 @@ func Run(opt *options.ServerOption) error {
unstructuredInformer := controller.NewUnstructuredTFJobInformer(kcfg, opt.Namespace)

// Create tf controller.
tc := controller.NewTFController(unstructuredInformer, kubeClientSet, kubeBatchClientSet, tfJobClientSet, kubeInformerFactory, tfJobInformerFactory, *opt)
tc := controller.NewTFController(unstructuredInformer, kubeClientSet, volcanoClientSet, tfJobClientSet, kubeInformerFactory, tfJobInformerFactory, *opt)

// Start informer goroutines.
go kubeInformerFactory.Start(stopCh)
Expand Down Expand Up @@ -184,7 +184,7 @@ func Run(opt *options.ServerOption) error {
return nil
}

func createClientSets(config *restclientset.Config) (kubeclientset.Interface, kubeclientset.Interface, tfjobclientset.Interface, kubebatchclient.Interface, error) {
func createClientSets(config *restclientset.Config) (kubeclientset.Interface, kubeclientset.Interface, tfjobclientset.Interface, volcanoclient.Interface, error) {

kubeClientSet, err := kubeclientset.NewForConfig(restclientset.AddUserAgent(config, "tf-operator"))
if err != nil {
Expand All @@ -201,11 +201,12 @@ func createClientSets(config *restclientset.Config) (kubeclientset.Interface, ku
return nil, nil, nil, nil, err
}

kubeBatchClientSet, err := kubebatchclient.NewForConfig(restclientset.AddUserAgent(config, "kube-batch"))
volcanoClientSet, err := volcanoclient.NewForConfig(restclientset.AddUserAgent(config, "volcano"))
if err != nil {
return nil, nil, nil, nil, err
}
return kubeClientSet, leaderElectionClientSet, tfJobClientSet, kubeBatchClientSet, nil

return kubeClientSet, leaderElectionClientSet, tfJobClientSet, volcanoClientSet, nil
}

func checkCRDExists(clientset tfjobclientset.Interface, namespace string) bool {
Expand Down
Expand Up @@ -42,7 +42,7 @@ def scale(image, label):
def build_and_compile_cnn_model():
model = models.Sequential()
model.add(
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
Expand All @@ -63,10 +63,9 @@ def build_and_compile_cnn_model():
def decay(epoch):
if epoch < 3: #pylint: disable=no-else-return
return 1e-3
elif epoch >= 3 and epoch < 7:
if 3 <= epoch < 7:
return 1e-4
else:
return 1e-5
return 1e-5


def main(args):
Expand All @@ -75,7 +74,7 @@ def main(args):
# layers on each device across all workers
# if your GPUs don't support NCCL, replace "communication" with another
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=tf.distribute.experimental.CollectiveCommunication.NCCL)
communication=tf.distribute.experimental.CollectiveCommunication.NCCL)

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
Expand All @@ -84,7 +83,7 @@ def main(args):
ds_train = make_datasets_unbatched().batch(BATCH_SIZE).repeat()
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = \
tf.data.experimental.AutoShardPolicy.DATA
tf.data.experimental.AutoShardPolicy.DATA
ds_train = ds_train.with_options(options)
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_and_compile_cnn_model()
Expand All @@ -105,11 +104,11 @@ def on_epoch_end(self, epoch): #pylint: disable=no-self-use
epoch + 1, multi_worker_model.optimizer.lr.numpy()))

callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]

# Keras' `model.fit()` trains the model with specified number of epochs and
Expand Down
5 changes: 2 additions & 3 deletions go.mod
Expand Up @@ -4,17 +4,15 @@ go 1.14

require (
github.com/go-openapi/spec v0.19.2
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b
github.com/golang/protobuf v1.3.2
github.com/google/go-cmp v0.4.1 // indirect
github.com/grpc-ecosystem/grpc-gateway v1.5.0 // indirect
github.com/kubeflow/common v0.3.1
github.com/kubernetes-sigs/kube-batch v0.0.0-20200414051246-2e934d1c8860
github.com/kubernetes-sigs/kube-batch v0.0.0-20200414051246-2e934d1c8860 // indirect
github.com/onrik/logrus v0.2.2-0.20181225141908-a09d5cdcdc62
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_golang v1.5.1
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0
k8s.io/api v0.16.9
k8s.io/apiextensions-apiserver v0.16.9 // indirect
k8s.io/apimachinery v0.16.10-beta.0
Expand All @@ -25,6 +23,7 @@ require (
k8s.io/kube-openapi v0.0.0-20200410163147-594e756bea31
k8s.io/kubernetes v1.16.9
sigs.k8s.io/yaml v1.2.0 // indirect
volcano.sh/volcano v0.4.0
)

replace (
Expand Down
1 change: 1 addition & 0 deletions go.sum
Expand Up @@ -789,4 +789,5 @@ sigs.k8s.io/yaml v1.2.0 h1:kr/MCeFWJWTwyaHoR9c8EjH9OumOmoF9YGiZd7lFm/Q=
sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc=
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0=
vbom.ml/util v0.0.0-20160121211510-db5cfe13f5cc/go.mod h1:so/NYdZXCz+E3ZpW0uAoCj6uzU2+8OWDFv/HxUSs7kI=
volcano.sh/volcano v0.4.0 h1:B4ot28vzi9bH+hpyv6+qd/EFZFEcE37Lj27/QEI6ly0=
volcano.sh/volcano v0.4.0/go.mod h1:2sNJRhY/oNg0MYdBYORxozuDhvgZxoyeOvKJww/Tl8A=
2 changes: 1 addition & 1 deletion hack/update-codegen.sh
Expand Up @@ -42,5 +42,5 @@ ${GOPATH}/bin/defaulter-gen --input-dirs github.com/kubeflow/tf-operator/pkg/ap
cd - > /dev/null

echo "Generating OpenAPI specification for tensorflow/v1"
${GOPATH}/bin/openapi-gen --input-dirs github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1,k8s.io/api/core/v1,k8s.io/apimachinery/pkg/apis/meta/v1,k8s.io/apimachinery/pkg/api/resource,k8s.io/apimachinery/pkg/runtime,k8s.io/apimachinery/pkg/util/intstr,k8s.io/apimachinery/pkg/version --output-package github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1 --go-header-file hack/boilerplate/boilerplate.go.txt "$@"
${GOPATH}/bin/openapi-gen --report-filename=hack/violation_exception.list --input-dirs github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1,k8s.io/api/core/v1,k8s.io/apimachinery/pkg/apis/meta/v1,k8s.io/apimachinery/pkg/api/resource,k8s.io/apimachinery/pkg/runtime,k8s.io/apimachinery/pkg/util/intstr,k8s.io/apimachinery/pkg/version --output-package github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1 --go-header-file hack/boilerplate/boilerplate.go.txt "$@"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

em. Seems we assume openapi-gen is in the path. This will break for new users. I will create an issue to track to improve this later

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

cd - > /dev/null