Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Discussion] About 3D Parallelism #131

Closed
feifeibear opened this issue Jan 7, 2022 · 6 comments
Closed

[Discussion] About 3D Parallelism #131

feifeibear opened this issue Jan 7, 2022 · 6 comments
Labels
enhancement New feature or request

Comments

@feifeibear
Copy link
Contributor

feifeibear commented Jan 7, 2022

I read the paper Maximizing Parallelism in Distributed Training for Huge Neural Networks. The idea is elegant and does make sense to me. However, I just wonder about the compatibility of this method with gradient checkpointing (I mentioned it in #117, We call it GC afterward).

Using 3D parallelism, on the activations we have to conduct all-gather across (N/P^2) processors (it is a partial collective communication), where N is the number of GPU for 3-D linear. At least three times such partial collective communication has to be done, during forward, backward, and recomputing of activation during backward using GC. Therefore, it introduces more communication overhead compared with the model parallelism not splitting activations. Did you consider the overhead in the experiment section of the paper?

Also, the tensor of activations is in small size. If partition an activation tensor into N pieces, and send/recv in granularity of one piece of tensor. The bandwidth utilization will be extremely low? This is different from communication on parameters. We can pack a number of layers of parameter tensors and send/recv them in a larger volume to better utilize network bandwidth, but activations come one after another, you cannot treat them the same as the parameter tensors.

PS: a small typo in the arXiv paper. Page 5, 1st line, Bij = [lnp : lnp + np + 1]

@kurisusnowdeng
Copy link
Member

Hi, @feifeibear . Thank you for sharing the idea! In our opinion, this is basically a trade-off between memory cost and communication cost.

The current design of 3D Linear layer applies an all-gather on the input matrix A and a reduce-scatter on the output matrix C in the forward pass (all-gather on the gradients of C and reduce-scatter on the gradients of A in the backward pass), so that each activation can be 1/N of the original size.

An alternative design is to use an all-reduce on C in the forward pass as well as on the gradients of A in the backward pass, but the activations are 1/N^(2/3) of the original size.

Considering activation checkpointing, as the forward pass is recomputed, the first design applies 2 * all-gather + reduce-scatter on A and 2 * reduce-scatter + all-gather on C in total, while the second design applies 3 * all-reduce. Since all-reduce of the ring algorithm has similar cost to all-gather + reduce-scatter, the total communication costs of both designs seem to be similar.

However, we indeed concern that small tensors decrease the bandwidth utilization, and it is hard to fuse them up. To find the optimal performance, we are testing as much models and networking environments as possible, and let the results tell.

@feifeibear
Copy link
Contributor Author

I agree 3D parallel can shrink the peak activation footprint in one GPU at cost of more communication. The method definitely works in some special cases. Maybe a simple searching method can be derived to figure out which part of the DNN is suitable for 3D parallelism in the constraint of a limited memory budget.

@kurisusnowdeng
Copy link
Member

I agree 3D parallel can shrink the peak activation footprint in one GPU at cost of more communication. The method definitely works in some special cases. Maybe a simple searching method can be derived to figure out which part of the DNN is suitable for 3D parallelism in the constraint of a limited memory budget.

This can be a good idea. For example, self-attention blocks usually consume more than mlp (ffn) blocks.

@github-actions
Copy link
Contributor

This issue is stale because it has been open for 14 days with no activity.

@github-actions github-actions bot added the stale label Jan 26, 2022
@feifeibear
Copy link
Contributor Author

@1SAA communication profiling results may support some of my assumption iin discussion.

@binmakeswell binmakeswell added enhancement New feature or request and removed stale labels Apr 13, 2022
@binmakeswell
Copy link
Member

We have updated a lot. This issue was closed due to inactivity. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants