Skip to content

Commit

Permalink
make API consistent with paper (#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgarithm committed Sep 22, 2020
1 parent 2bc07d3 commit a80f7cb
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
2 changes: 1 addition & 1 deletion srcs/cpp/src/tensorflow/ops/cpu/elastic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class StepBasedSchedule : public OpKernel
&cluster_size));
bool found = false;
int result = default_;
for (const auto sch : schedule_) { // FIXME: use binary search
for (const auto &sch : schedule_) { // FIXME: use binary search
const auto r = sch.first;
if (r.first <= step && step < r.second) {
result = sch.second;
Expand Down
13 changes: 5 additions & 8 deletions srcs/python/kungfu/tensorflow/ops/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,14 @@ def resize(n):
Inputs:
n: A scalar tensor of uint32.
Returns:
A pair of scalar tensors (changed, keep) of type bool,
{changed} indicates if the cluster has been changed,
{keep} indicates if the current peer is still in the new cluster,
the peer should quit if it is not in the new cluster.
A scalar tensor of bool, indicates if the cluster has been changed.
"""
resize_op = _op_lib.kungfu_resize_cluster(n)
changed, keep = _op_lib.kungfu_resize_cluster(n)
if hasattr(_op_lib, 'kungfu_reset_nccl_helper'):
changed, keep = resize_op
return _op_lib.kungfu_reset_nccl_helper(changed, keep)
changed, keep = _op_lib.kungfu_reset_nccl_helper(changed, keep)
return changed
else:
return resize_op
return changed


def set_tree(tree):
Expand Down
8 changes: 4 additions & 4 deletions tests/python/integration/test_tensorflow_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import kungfu
import tensorflow as tf
from kungfu.python import detached
from kungfu.tensorflow.ops import all_reduce, resize


Expand Down Expand Up @@ -37,7 +38,6 @@ def main():
inc_gs = tf.assign_add(gs, 1)
new_size = tf.placeholder(dtype=tf.uint32)
resize_op = resize(new_size)

train_op = build_fake_train_op(args.use_nccl)
init = tf.global_variables_initializer()

Expand All @@ -57,11 +57,11 @@ def main():
# END train

if step in fake_schedule:
changed, keep = sess.run(
resize_op, feed_dict={new_size: fake_schedule[step]})
changed = sess.run(resize_op,
feed_dict={new_size: fake_schedule[step]})
if changed:
need_sync = True
if not keep:
if detached():
break
else:
print('cluster not changed')
Expand Down

0 comments on commit a80f7cb

Please sign in to comment.