Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The program hang at the forward function when use model parallel in Megatron-LM #58

Closed
seanM29 opened this issue Jul 7, 2021 · 14 comments

Comments

@seanM29
Copy link

seanM29 commented Jul 7, 2021

thanks for your work!!! I love it very much!!
I met a problem, hope you can help me. Thx a lot !

Platform

  • v100 , single node ,8gpu
  • pytorch:1.8.0
  • cuda11.1
  • cudnn8

update

  • if set pipeline-model-parallel-size=1, the program can run well (tensor-model-parallel-size>1)
  • if set pipeline-model-parallel-size > 1, the program will hang
@ymjiang
Copy link
Contributor

ymjiang commented Jul 7, 2021

Is it possible that two nccl calls happen concurrently when enabling pipeline for megatron? @laekov

(i.e., megatron calls send/recv for pipelining, and fastmoe calls send/recv for alltoall)

@xptree
Copy link
Collaborator

xptree commented Jul 7, 2021

@ymjiang We haven't officially test fmoe with pipeline-model-parallel-size > 1.

The communication of pipeline mp happens between transfomer layers while the communication of fmoe happens within each transformer layer, so it is a little bit wired to me that they will call nccl concurrently.

There could be other possible reasons, e.g., cuda version. @laekov have we test fmoe in cuda11.

@ymjiang
Copy link
Contributor

ymjiang commented Jul 7, 2021

The communication of pipeline mp happens between transfomer layers while the communication of fmoe happens within each transformer layer, so it is a little bit wired to me that they will call nccl concurrently.

I suppose a case like this: Node-1 just received the inputs (using NCCL) from its upstream Node-0, and begins the fmoe's NCCL call. Meanwhile, Node-1 continuously waits for new inputs from Node-0 (still using NCCL). I think this is likely to happen, given that megatron v2.2 already enables pipelined parallel by breaking data into micro-batches.

image

That said, I haven't dived into the code further. So it is just a guess.

@seanM29
Copy link
Author

seanM29 commented Jul 7, 2021

Have you tested the model parameters of each machine when model parallelization is enabled? @xptree @ymjiang @laekov

In actual use, I found some strange places:

  • When fastmoe is not used, a single gpu occupies 300 million parameters,

  • and then set pipeline-model-parallel-size to 1, tensor-model-parallel-size to 8. On a single node with 8 gpu, each gpu occupies 45million parameter

But after fastmoe is enabled (the number of experts is 12),

  • the single-machine single gpu occupies 1.3 billion parameters,
  • and then set the pipeline-model-parallel-size to 1, tensor-model-parallel-size to 8. On a single node with 8 gpu, each gpu occupies 1.2 billion parameters.
  • It seems that fastmoe's expert has no model parallelism at all? Or am I using the wrong way?

@laekov
Copy link
Owner

laekov commented Jul 7, 2021

I am running fastmoe with cuda@11 and nccl@2.9.9, so it should not be a cuda issue. @xptree

FastMoE currently does not support pipeline parallelism in Megatron. I am not sure about its behavior. I will inspect it now and see if we can support it with less burden.@ymjiang @seanM29

For tensor model parallel, the experts are not divided into pieces like what Megatron does. On each GPU locates a different expert. However, the attention layer is partitioned the same as Megatron does. So, in your observation, I suppose 1.3 billion = attention + 1 expert, 1.2 billion = attention / 8 + 1 expert. @seanM29

@laekov
Copy link
Owner

laekov commented Jul 7, 2021

I think I have the idea of why it gets stuck. Given that we have data parallel (DP), tensor model parallel (MP) and pipeline parallel (PP), world = DP x MP x PP. For an MLP layer that is fmoefied, it should expand across DP x MP in our assumption, which is prependicular to PP. However, in FastMoE's current implementation, we assume that PP = 1 and let the MoE layer expand across the world, which leads to this issue.

@laekov
Copy link
Owner

laekov commented Jul 7, 2021

Should be resolved in #59

@seanM29
Copy link
Author

seanM29 commented Jul 7, 2021

I am running fastmoe with cuda@11 and nccl@2.9.9, so it should not be a cuda issue. @xptree

FastMoE currently does not support pipeline parallelism in Megatron. I am not sure about its behavior. I will inspect it now and see if we can support it with less burden.@ymjiang @seanM29

For tensor model parallel, the experts are not divided into pieces like what Megatron does. On each GPU locates a different expert. However, the attention layer is partitioned the same as Megatron does. So, in your observation, I suppose 1.3 billion = attention + 1 expert, 1.2 billion = attention / 8 + 1 expert. @seanM29

Thank you very much for your reply, but I look at the pictures in the readme, fastmoe supports different experts and puts them on different machines?

If I want to put the experts nto pieces like what Megatron does, do you have any suggestions?

@TiagoMAntunes
Copy link
Collaborator

Thank you very much for your reply, but I look at the pictures in the readme, fastmoe supports different experts and puts them on different machines?

In the current version, you have 3 ways to give experts to FastMoE: you give it a single expert class, you give it a list of N experts, or you have a fused expert (see fmoe/transformer.py). The same experts will be replicated across all nodes, so you cannot customize a per node expert type for now.

If I want to put the experts nto pieces like what Megatron does, do you have any suggestions?

I am not sure, but most likely you'll need to use a customized fused expert. I think it would cause some issues though since you might need to add some extra communication between the different shards. Someone with more experience with Megatron might be able to help you better

@laekov
Copy link
Owner

laekov commented Jul 8, 2021

If I want to put the experts nto pieces like what Megatron does, do you have any suggestions?

Megatron-LM uses column and row partition for two layers of MLP, which equals to activating a group of MLPs (or experts) at the same time. The easist way of doing this is that you can develop a new gate with fewer number of logical experts and activate a group of experts instead of some specific expert by repeat_interleave.

@seanM29
Copy link
Author

seanM29 commented Jul 8, 2021

I saw some public reports that Wenlan used fastmoe to reach 1.75 trillion parameters, @laekov @TiagoMAntunes
If you put all the experts on each machine, how can you reach 1.75 trillion parameters?

In my experiment, on a V100, with 345 million bert and 12 experts, the gpu memory is basically full. And it only reached 1.3 billion parameters

@laekov
Copy link
Owner

laekov commented Jul 8, 2021

If you put all the experts on each machine, how can you reach 1.75 trillion parameters?

They are using a private version of FastMoE and Megatron on SunWay platform, which is quite different from NVIDIA's stuff.

@seanM29
Copy link
Author

seanM29 commented Jul 8, 2021

If you put all the experts on each machine, how can you reach 1.75 trillion parameters?

They are using a private version of FastMoE and Megatron on SunWay platform, which is quite different from NVIDIA's stuff.

Thank you very much for your reply @laekov
I found the code num_expert stands for the number of experts on each worker., and the gate seem to use world_size * num_expert different experts

So sorry to disturb you again. I want to confirm that when setting num_expert=4 and using 8 GPUs, in BaseGate, self.tot_expert should be 8*4=32, it seems use 32 different experts?

does the whole network have 4*8=32 different experts? Are there still only 4 experts?

@laekov
Copy link
Owner

laekov commented Jul 8, 2021

If you put all the experts on each machine, how can you reach 1.75 trillion parameters?

They are using a private version of FastMoE and Megatron on SunWay platform, which is quite different from NVIDIA's stuff.

Thank you very much for your reply @laekov
I found the code num_expert stands for the number of experts on each worker., and the gate seem to use world_size * num_expert different experts

exactly

So sorry to disturb you again. I want to confirm that when setting num_expert=4 and using 8 GPUs, in BaseGate, self.tot_expert should be 8*4=32, it seems use 32 different experts?

does the whole network have 4*8=32 different experts? Are there still only 4 experts?

yup, there are 32 experts in total

@laekov laekov closed this as completed Jul 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants