diff --git a/examples/elastic/tensorflow2/tensorflow2_mnist_elastic.py b/examples/elastic/tensorflow2/tensorflow2_mnist_elastic.py index dce0d31a22..33d45164e6 100644 --- a/examples/elastic/tensorflow2/tensorflow2_mnist_elastic.py +++ b/examples/elastic/tensorflow2/tensorflow2_mnist_elastic.py @@ -106,6 +106,9 @@ def on_state_reset(): if hvd.rank() == 0: checkpoint.save(checkpoint_dir) + # return some serializable data if needed + return hvd.rank() + if __name__ == '__main__': if len(sys.argv) == 5: @@ -115,8 +118,12 @@ def on_state_reset(): max_num_proc = int(sys.argv[3]) hosts = sys.argv[4] print('Running training through horovod.run') - horovod.run(main, num_proc=num_proc, min_num_proc=min_num_proc, max_num_proc=max_num_proc, - hosts=hosts, use_gloo=True, verbose=2) + ranks = horovod.run(main, num_proc=num_proc, min_num_proc=min_num_proc, max_num_proc=max_num_proc, + hosts=hosts, use_gloo=True, verbose=2) + + if ranks != range(min_num_proc): + raise RuntimeError(f'Expected training method to return ranks {",".join(range(min_num_proc))}, ' + f'got: {",".join(ranks)}') else: # this is running via horovodrun main()