Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Reducescatter operator #1496

Closed
wants to merge 5 commits into from

Conversation

jessebenson
Copy link
Contributor

@jessebenson jessebenson commented Nov 4, 2019

Add Reducescatter operator. From Nvidia documentation:
image

  1. Implement Reducescatter operator on CPU and GPU (with MPI+CUDA)
  2. Support tensor fusion with Reducescatter
  3. Expose Reducescatter in Python through pytorch, tensorflow, keras, mxnet
  4. Add unit tests for pytorch/tensorflow (similar set as covered by allreduce)
  5. Updated concepts.rst documentation to describe operator.

@jessebenson
Copy link
Contributor Author

jessebenson commented Nov 5, 2019

Many (but not all) of the "Run PyTests" runs are timing out when hitting the Reducescatter unit tests. I am trying to understand what's causing this.
Update 1: the Reducescatter unit tests only fail if the 'Join' unit tests are run first. Join is not enabled in Pytorch v1, so that's why it passes in some. Investigating ...
Update 2: disabled the two 'Join' unit tests causing the issue for now (discussed with the author). It is not related to Reducescatter - if you cause the Join unit tests to run first (prefix 'aaa') then it causes the same issue.

@kit1980
Copy link
Contributor

kit1980 commented Nov 5, 2019

The two problematic Join tests are the ones that test "not implemented" failures for allgather and broadcast. I think it's OK to disable them for this PR, I'll work on a proper fix separately.

@jessebenson
Copy link
Contributor Author

Looks like there may be a breaking change in tfhead. The build image steps in unit tests are failing now:

ModuleNotFoundError: No module named 'tensorflow_core.keras'

@tgaddair
Copy link
Collaborator

tgaddair commented Nov 6, 2019

@jessebenson just triggered a rebuild. TensorFlow's last nightly had a bug in it they've since rolled-back, so should be working now.

@tgaddair
Copy link
Collaborator

Hey @jessebenson, there was another breaking change by TensorFlow that required a fix in #1515. Can you rebase again?

@jessebenson
Copy link
Contributor Author

@tgaddair - will do.

I don't currently have ReduceScatter for MLSL, NCCL, or GLOO. Those will take a bit longer. MPI and GLOO ReduceScatter allow different receive counts per rank, while MLSL and NCCL require all ranks have same receive count. I was planning to do a partial Reduce to last rank (if tensor doesn't evenly divide) to solve this, similar in principle to how Horovod currently does hierarchical APIs.

GLOO has ReduceScatter, but they don't expose an API to call - compare:
https://github.com/facebookincubator/gloo/blob/master/gloo/reduce_scatter.h
https://github.com/facebookincubator/gloo/blob/master/gloo/allgather.h#L71

@jessebenson
Copy link
Contributor Author

I think this pull request is in a good state to review.

The future work would be Reducescatter implementation for NCCL, GLOO, and the new Intel CCL. MLSL is being removed. However, GLOO doesn't have a proper public API for Reducescatter, CCL doesn't have Reducescatter at all - so that leaves NCCL.

Copy link
Collaborator

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jessebenson, and apologies for the delay. Looks good, just one question regarding API alignment before we land.

@@ -118,6 +118,37 @@ def allreduce(tensor, average=None, device_dense='', device_sparse='',
return new_tensor


def reducescatter(tensor, average=True, device_dense='', compression=Compression.none):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the Adasum PR deprecated the average param in favor of op, I'm wondering if we should do the same here. Might we want to support other reduction ops (min, max, product) in the future?

My understanding is that NCCL reduceScatter supports other reductions: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/api/types.html#c.ncclRedOp_t

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this seems perfectly reasonable.

horovod/tensorflow/mpi_ops.cc Outdated Show resolved Hide resolved
@tgaddair
Copy link
Collaborator

Hey @jessebenson, any updates on the API consolidation?

@jessebenson
Copy link
Contributor Author

The Allreduce sort of supports both average and op now, and has some legacy logic to decide between them. I'm wondering - should I support just op parameter?

@tgaddair
Copy link
Collaborator

I would just support the new op param since this is new functionality. Thanks for checking!

@jessebenson jessebenson force-pushed the reducescatter branch 2 times, most recently from 2cc1ee5 to ac2178b Compare December 17, 2019 19:35
@jessebenson
Copy link
Contributor Author

It became a bit more involved to update all of mxnet/keras/tensorflow/pytorch. Hopefully I was thorough and careful enough. I also added a unit test to test_torch.py to verify passing op=hvd.Adasum will give an error.

@@ -201,7 +201,7 @@ def _broadcast_grad(op, grad):
return grad_reduced


def _reducescatter(tensor, name=None):
def _reducescatter(tensor, name=None, op=Sum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like everywhere else the default is Average but here it's Sum. Is there a reason, or should we make it Average here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, Average needs to be handled in the higher-level reducescatter() in __init__.py (which is where the division happens). I was trying to match the behavior of Allreduce, where allreduce() handles Average and passes Sum to _allreduce().

_allreduce() and _reducescatter() defaults are Sum
allreduce() and reducescatter() defaults are Average

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me make a small tweak to reducescatter() though. Since the underlying _reducescatter() only supports Sum at the moment, I never actually pass the reduce op through to it. However, I should pass something to it since any errors will get caught in EnqueueTensorReduceScatter anyway so it's "safe" to pass the op. That way future people aren't confused why they pass Max and it doesn't show up in the C++ API.

Copy link
Collaborator

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just one small question about default param before we land.

@jessebenson
Copy link
Contributor Author

Some tests are failing now - likely a bad interaction with my changes and a recent change from master. I'll have to investigate.

@kit1980
Copy link
Contributor

kit1980 commented Dec 21, 2019

Some tests are failing now - likely a bad interaction with my changes and a recent change from master. I'll have to investigate.

After #1594 you need to set tensor sizes and data type for requests, like here: https://github.com/horovod/horovod/blob/master/horovod/common/controller.cc#L581

Also there is ongoing #1604

@jessebenson
Copy link
Contributor Author

jessebenson commented Dec 22, 2019

@kit1980 - thanks, that is useful to know. What are these tensor sizes used for? Should it correspond to the fusion buffer size, or response size, or something else?

For Allreduce/Adasum, the input/output tensors are the same size.
For AllGather, the input (per rank) is size ~T and output is ~T*N.
For ReduceScatter, the input (per rank) is size T and output is ~T/N.
For Broadcast, it doesn't seem to set the tensor sizes (is it not needed?)

@kit1980
Copy link
Contributor

kit1980 commented Dec 22, 2019

@jessebenson, currently the sizes are used in this way only for AllReduce and AdaSum.

The tensor sizes are used in https://github.com/horovod/horovod/blob/master/horovod/common/controller.cc#L642 and https://github.com/horovod/horovod/blob/master/horovod/common/controller.cc#L651

The sizes set in response in this case mean number of elements in the tensor. Not sure how this will work when input/output sizes are different.

Basically, the recent change was that FuseResponses used to call tensor_queue_.GetTensorSizeAndType(response.tensor_names()[0], tensor_size, dtype), and now tensor_size and dtype are from response directly (with a caveat that tensor_sizes in response are in the number of elements, so need to multiply on element size to get bytes).

Signed-off-by: Jesse Benson (AI) <jesseb@microsoft.com>
Signed-off-by: Jesse Benson (AI) <jesseb@microsoft.com>
jessebenson and others added 3 commits January 30, 2020 20:26
Signed-off-by: Jesse Benson (AI) <jesseb@microsoft.com>
…atter() API.

Signed-off-by: Jesse Benson (AI) <jesseb@microsoft.com>
Signed-off-by: Jesse Benson <jesse.benson@microsoft.com>
@legatoo
Copy link

legatoo commented Aug 18, 2020

Any updates of this PR?

@stale
Copy link

stale bot commented Nov 6, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Nov 6, 2020
@stale stale bot closed this Nov 13, 2020
@maxhgerlach
Copy link
Collaborator

@legatoo, @ducviet00, and anybody else who's interested: I've revived this PR in #3299 and any feedback would be appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

Successfully merging this pull request may close these issues.

None yet

5 participants