From a80f7cbae65c0188fe4b0dc55534cb9b4df219b1 Mon Sep 17 00:00:00 2001 From: LG Date: Tue, 22 Sep 2020 14:16:57 +0100 Subject: [PATCH] make API consistent with paper (#324) --- srcs/cpp/src/tensorflow/ops/cpu/elastic.cpp | 2 +- srcs/python/kungfu/tensorflow/ops/adapt.py | 13 +++++-------- tests/python/integration/test_tensorflow_resize.py | 8 ++++---- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/srcs/cpp/src/tensorflow/ops/cpu/elastic.cpp b/srcs/cpp/src/tensorflow/ops/cpu/elastic.cpp index 413d8a00f..1853e5773 100644 --- a/srcs/cpp/src/tensorflow/ops/cpu/elastic.cpp +++ b/srcs/cpp/src/tensorflow/ops/cpu/elastic.cpp @@ -59,7 +59,7 @@ class StepBasedSchedule : public OpKernel &cluster_size)); bool found = false; int result = default_; - for (const auto sch : schedule_) { // FIXME: use binary search + for (const auto &sch : schedule_) { // FIXME: use binary search const auto r = sch.first; if (r.first <= step && step < r.second) { result = sch.second; diff --git a/srcs/python/kungfu/tensorflow/ops/adapt.py b/srcs/python/kungfu/tensorflow/ops/adapt.py index d32e09f50..0c8075eaa 100644 --- a/srcs/python/kungfu/tensorflow/ops/adapt.py +++ b/srcs/python/kungfu/tensorflow/ops/adapt.py @@ -35,17 +35,14 @@ def resize(n): Inputs: n: A scalar tensor of uint32. Returns: - A pair of scalar tensors (changed, keep) of type bool, - {changed} indicates if the cluster has been changed, - {keep} indicates if the current peer is still in the new cluster, - the peer should quit if it is not in the new cluster. + A scalar tensor of bool, indicates if the cluster has been changed. """ - resize_op = _op_lib.kungfu_resize_cluster(n) + changed, keep = _op_lib.kungfu_resize_cluster(n) if hasattr(_op_lib, 'kungfu_reset_nccl_helper'): - changed, keep = resize_op - return _op_lib.kungfu_reset_nccl_helper(changed, keep) + changed, keep = _op_lib.kungfu_reset_nccl_helper(changed, keep) + return changed else: - return resize_op + return changed def set_tree(tree): diff --git a/tests/python/integration/test_tensorflow_resize.py b/tests/python/integration/test_tensorflow_resize.py index 698a18870..a404fa617 100644 --- a/tests/python/integration/test_tensorflow_resize.py +++ b/tests/python/integration/test_tensorflow_resize.py @@ -2,6 +2,7 @@ import kungfu import tensorflow as tf +from kungfu.python import detached from kungfu.tensorflow.ops import all_reduce, resize @@ -37,7 +38,6 @@ def main(): inc_gs = tf.assign_add(gs, 1) new_size = tf.placeholder(dtype=tf.uint32) resize_op = resize(new_size) - train_op = build_fake_train_op(args.use_nccl) init = tf.global_variables_initializer() @@ -57,11 +57,11 @@ def main(): # END train if step in fake_schedule: - changed, keep = sess.run( - resize_op, feed_dict={new_size: fake_schedule[step]}) + changed = sess.run(resize_op, + feed_dict={new_size: fake_schedule[step]}) if changed: need_sync = True - if not keep: + if detached(): break else: print('cluster not changed')