Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The output of nccl_all_to_all_scatter_async may be incomplete when num_local_experts>1. #172

Closed
Fragile-azalea opened this issue Jul 19, 2022 · 11 comments
Labels
wontfix This will not be worked on

Comments

@Fragile-azalea
Copy link

Describe the bug
The output of nccl_all_to_all_scatter_async may be incomplete.

To Reproduce
Steps to reproduce the behavior:

on host0(master): SKIP_EXPERT=1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr=host0 -m tutel.examples.helloworld --batch_size=4 --num_tokens=1 --model_dim=2 --hidden_size=2 --num_steps=1 --a2a_ffn_overlap_degree=1
on host1: SKIP_EXPERT=1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr=host0 -m tutel.examples.helloworld --batch_size=4 --num_tokens=1 --model_dim=2 --hidden_size=2 --num_steps=1 --a2a_ffn_overlap_degree=1

Log
The value of

y = y.view(self.world_size, -1, y.size(2))

tensor([[[ 1.5410, -0.2934],
[-1.0845, -1.3986]],
[[ 1.5410, -0.2934],
[ 0.4033, 0.8380]],
[[-2.1788, 0.5684],
[-1.0845, -1.3986]],
[[ 0.4033, 0.8380],
[-2.1788, 0.5684]]], device='cuda:0')

The value of

y = C.all_to_all(y, 0, 1, use_2dh=self.use_2dh, group=self.group)

tensor([[[ 1.5410, -0.2934],
[-1.0845, -1.3986]],
[[ 1.5410, -0.2934],
[ 0.4033, 0.8380]],
[[-2.1788, 0.5684],
[-1.0845, -1.3986]],
[[ 0.4033, 0.8380],
[-2.1788, 0.5684]]], device='cuda:0')

This is the result I expect. However, when
on host0(master): SKIP_EXPERT=1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr=host0 -m tutel.examples.helloworld --batch_size=4 --num_tokens=1 --model_dim=2 --hidden_size=2 --num_steps=1 --a2a_ffn_overlap_degree=2
on host1: SKIP_EXPERT=1 python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr=host0 -m tutel.examples.helloworld --batch_size=4 --num_tokens=1 --model_dim=2 --hidden_size=2 --num_steps=1 --a2a_ffn_overlap_degree=2

The value of

y = y.view(self.world_size, -1, y.size(2))

tensor([[[ 1.5410, -0.2934],
[-1.0845, -1.3986]],
[[ 1.5410, -0.2934],
[ 0.4033, 0.8380]],
[[-2.1788, 0.5684],
[-1.0845, -1.3986]],
[[ 0.4033, 0.8380],
[-2.1788, 0.5684]]], device='cuda:0')

The value of

y = a2a_ffn_overlap_forward(y, expert_fn=expert_fn, a2a_ffn_overlap_degree=a2a_ffn_overlap_degree, use_2dh=self.use_2dh, group=self.group)

tensor([[[ 0.0000, 0.0000],
[ 0.0000, 0.0000]],
[[ 1.5410, -0.2934],
[ 0.4033, 0.8380]],
[[ 0.0000, 0.0000],
[ 0.0000, 0.0000]],
[[ 0.4033, 0.8380],
[-2.1788, 0.5684]]], device='cuda:0')

It seems incomplete.

The possible code is:

CHECK_EQ(0, ncclGroupStart());
for (int j = 0; j < num_slices_per_split; j++) {
CHECK_EQ(0, ncclSend(
((char*)input.data_ptr()) + (j * num_split + calc_idx) * slice_size,
slice_size,
ncclInt8,
g_world_size * j / num_slices_per_split,
g_nccl_comm,
get_nccl_stream().stream()));
CHECK_EQ(0, ncclRecv(
((char*)output_list[calc_idx].data_ptr()) + j * slice_size,
slice_size,
ncclInt8,
g_world_size * j / num_slices_per_split,
g_nccl_comm,
get_nccl_stream().stream()));
}
CHECK_EQ(0, ncclGroupEnd());

It looks like the NCCL group keeps only the last send-recv pair in each peer.
There is no same problem when num_local_experts=1.

@ghostplant
Copy link
Contributor

@yzygitzh

@yzygitzh
Copy link
Member

Hi @Fragile-azalea , thanks for reporting this issue!

Currently I don't have 2 nodes, so I tried a 2-GPU-1-node run instead of 1-GPU-2-node run, and I didn't see the missing value phenomenon. Is the issue reproducible with 2-GPU-1-node setting?

BTW, what was your PyTorch version?

@Fragile-azalea
Copy link
Author

Thank you for your quick response. I don't have a node with two GPUs. Here is information about my platform:
Platform

  • Device: GeForce RTX 2080Ti
  • OS: Linux gpu9 4.4.0-142-generic 168-Ubuntu SMP Wed Jan 16 21:00:45 UTC 2019 x86_64 x86_64 x86_64 GNU/Linux
  • CUDA version: 10.2
  • NCCL version: 2.7.8-1
  • PyTorch version: 1.9.1
  • Python Version: 3.8

@Fragile-azalea
Copy link
Author

Fragile-azalea commented Jul 19, 2022

To verify my idea, I perform an extra experiment on https://github.com/NVIDIA/nccl-tests/blob/master/src/alltoall.cu.
replace https://github.com/NVIDIA/nccl-tests/blob/8274cb47b6dc70ce4411e7f114b77173d3892414/src/alltoall.cu#L71-L76
into

  NCCLCHECK(ncclGroupStart());
  for (int r=0; r<nRanks; r++) {
    NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset, count / 2, type, r, comm, stream));
    NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset, count / 2, type, r, comm, stream));
  }
  NCCLCHECK(ncclGroupEnd());
  NCCLCHECK(ncclGroupStart());
  for (int r=0; r<nRanks; r++) {
    NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset + rankOffset / 2, count / 2, type, r, comm, stream));
    NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset + rankOffset / 2, count / 2, type, r, comm, stream));
  }
  NCCLCHECK(ncclGroupEnd());

Run code with:

$ mpirun --hostfile ./servers  ./build/alltoall_perf -b 128M -e 128M -f 2 -g 1

The output:

# nThread 1 nGpus 1 minBytes 134217728 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 validation: 1
#
# Using devices
#   Rank  0 Pid 1366276 on     host12 device  0 [0x02] GeForce RTX 2070 SUPER
#   Rank  1 Pid 3620583 on     host13 device  0 [0x02] GeForce RTX 2070 SUPER
#
#                                                       out-of-place                       in-place
#       size         count      type   redop     time   algbw   busbw  error     time   algbw   busbw  error
#        (B)    (elements)                       (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)
   134217728      16777216     float           678590    0.20    0.10  0e+00   678399    0.20    0.10  0e+00
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0.0989085
#

replace https://github.com/NVIDIA/nccl-tests/blob/8274cb47b6dc70ce4411e7f114b77173d3892414/src/alltoall.cu#L71-L76
into

  NCCLCHECK(ncclGroupStart());
  for (int r=0; r<nRanks; r++) {
    NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset, count / 2, type, r, comm, stream));
    NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset, count / 2, type, r, comm, stream));
    NCCLCHECK(ncclSend(((char*)sendbuff)+r*rankOffset + rankOffset / 2, count / 2, type, r, comm, stream));
    NCCLCHECK(ncclRecv(((char*)recvbuff)+r*rankOffset + rankOffset / 2, count / 2, type, r, comm, stream));
  }
  NCCLCHECK(ncclGroupEnd());

Run code with:

$ mpirun --hostfile ./servers  ./build/alltoall_perf -b 128M -e 128M -f 2 -g 1

The output:

# nThread 1 nGpus 1 minBytes 134217728 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 validation: 1
#
# Using devices
#   Rank  0 Pid 1366729 on     host12 device  0 [0x02] GeForce RTX 2070 SUPER
#   Rank  1 Pid 3621476 on     host13 device  0 [0x02] GeForce RTX 2070 SUPER
#
#                                                       out-of-place                       in-place
#       size         count      type   redop     time   algbw   busbw  error     time   algbw   busbw  error
#        (B)    (elements)                       (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)
   134217728      16777216     float           337519    0.40    0.20  1e+00   337973    0.40    0.20  1e+00
# Out of bounds values : 2 FAILED
# Avg bus bandwidth    : 0.198696
#

@yzygitzh
Copy link
Member

I tried your modification on NCCL 2.7.8 and it caused nccl-tests to crash, but it works well on NCCL 2.10.3. Could you please try upgrade NCCL to 2.10.3?

@Fragile-azalea
Copy link
Author

I recompile the nccl-test by the following command:
make MPI=1 MPI_HOME=/home/xxx/local/openmpi-4.0.1 NCCL_HOME=/xxx/nccl_2.10.3-1+cuda10.2_x86_64
Here are my logs. They seems no change.

# nThread 1 nGpus 1 minBytes 134217728 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 validation: 1
#
# Using devices
#   Rank  0 Pid 1371961 on     host12 device  0 [0x02] GeForce RTX 2070 SUPER
#   Rank  1 Pid 3905986 on     host13 device  0 [0x02] GeForce RTX 2070 SUPER
#
#                                                       out-of-place                       in-place
#       size         count      type   redop     time   algbw   busbw  error     time   algbw   busbw  error
#        (B)    (elements)                       (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)
   134217728      16777216     float           677072    0.20    0.10  0e+00   678719    0.20    0.10  0e+00
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0.098996
#

# nThread 1 nGpus 1 minBytes 134217728 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 validation: 1
#
# Using devices
#   Rank  0 Pid 1371509 on     host12 device  0 [0x02] GeForce RTX 2070 SUPER
#   Rank  1 Pid 3904849 on     host13 device  0 [0x02] GeForce RTX 2070 SUPER
#
#                                                       out-of-place                       in-place
#       size         count      type   redop     time   algbw   busbw  error     time   algbw   busbw  error
#        (B)    (elements)                       (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)
   134217728      16777216     float           338925    0.40    0.20  1e+00   339634    0.40    0.20  1e+00
# Out of bounds values : 2 FAILED
# Avg bus bandwidth    : 0.197799
#
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[6242,1],1]
  Exit code:    1
--------------------------------------------------------------------------

@yzygitzh
Copy link
Member

yzygitzh commented Jul 20, 2022

Could you please add NCCL_DEBUG=VERSION when running nccl-tests to check the actual NCCL version you're using? Specifying NCCL_HOME during compile may not change the library used during runtime.

BTW, it's strange that both out-of-place and in-place results show no error. NCCL all-to-all should not support in-place operation.

@yzygitzh
Copy link
Member

yzygitzh commented Jul 20, 2022

Comparison from my side FYI:

Original nccl-tests:

[1,0]<stdout>:#                                                     out-of-place                       in-place
[1,0]<stdout>:#       size         count    type   redop     time   algbw   busbw  error     time   algbw   busbw  error
[1,0]<stdout>:#        (B)    (elements)                     (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)
[1,0]<stdout>:        1024            32   float            28.52    0.04    0.03  0e+00    27.43    0.04    0.03  1e+00
[1,0]<stdout>:        2048            64   float            27.57    0.07    0.07  0e+00    27.54    0.07    0.07  7e-03
[1,0]<stdout>:        4096           128   float            27.79    0.15    0.13  0e+00    27.41    0.15    0.13  1e+00
[1,0]<stdout>:        8192           256   float            27.72    0.30    0.26  0e+00    27.51    0.30    0.26  1e+00
[1,0]<stdout>:       16384           512   float            28.14    0.58    0.51  0e+00    27.68    0.59    0.52  1e+00
[1,0]<stdout>:       32768          1024   float            29.81    1.10    0.96  0e+00    29.58    1.11    0.97  1e+00
[1,0]<stdout>:       65536          2048   float            29.80    2.20    1.92  0e+00    29.38    2.23    1.95  1e+00
[1,0]<stdout>:      131072          4096   float            29.65    4.42    3.87  0e+00    29.09    4.51    3.94  1e+00
[1,0]<stdout>:      262144          8192   float            30.16    8.69    7.60  0e+00    30.66    8.55    7.48  1e+00
[1,0]<stdout>:      524288         16384   float            33.83   15.50   13.56  0e+00    32.78   15.99   13.99  0e+00
[1,0]<stdout>:     1048576         32768   float            40.29   26.02   22.77  0e+00    40.60   25.83   22.60  0e+00
[1,0]<stdout>:     2097152         65536   float            51.11   41.03   35.90  0e+00    51.78   40.50   35.44  0e+00
[1,0]<stdout>:     4194304        131072   float            69.77   60.12   52.60  0e+00    69.61   60.26   52.73  0e+00
[1,0]<stdout>:     8388608        262144   float            110.0   76.23   66.70  0e+00    106.1   79.05   69.17  0e+00
[1,0]<stdout>:    16777216        524288   float            191.4   87.63   76.68  0e+00    186.8   89.80   78.57  0e+00
[1,0]<stdout>:    33554432       1048576   float            218.5  153.57  134.37  0e+00    215.9  155.42  136.00  1e+00
[1,0]<stdout>:    67108864       2097152   float            318.1  210.98  184.61  0e+00    317.6  211.30  184.89  1e+00
[1,0]<stdout>:   134217728       4194304   float            585.6  229.21  200.56  0e+00    598.9  224.09  196.08  1e+00
[1,0]<stdout>:   268435456       8388608   float           1141.8  235.11  205.72  0e+00   1170.4  229.35  200.68  1e+00
[1,0]<stdout>:   536870912      16777216   float           2205.1  243.46  213.03  0e+00   2160.7  248.47  217.41  1e+00
[1,0]<stdout>:  1073741824      33554432   float           4251.5  252.56  220.99  0e+00   4126.4  260.21  227.68  1e+00
[1,0]<stdout>:  2147483648      67108864   float           8364.4  256.74  224.65  0e+00   8081.1  265.74  232.52  1e+00
[1,0]<stdout>:  4294967296     134217728   float            16630  258.27  225.99  0e+00    16013  268.22  234.69  1e+00
[1,0]<stdout>:  8589934592     268435456   float            33186  258.84  226.48  0e+00    31877  269.47  235.79  1e+00

Modified nccl-tests:

[1,0]<stdout>:#                                                       out-of-place                       in-place
[1,0]<stdout>:#       size         count      type   redop     time   algbw   busbw  error     time   algbw   busbw  error
[1,0]<stdout>:#        (B)    (elements)                       (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)
[1,0]<stdout>:        1024            32     float            43.59    0.02    0.02  0e+00    43.32    0.02    0.02  1e+00
[1,0]<stdout>:        2048            64     float            43.75    0.05    0.04  0e+00    45.71    0.04    0.04  1e+00
[1,0]<stdout>:        4096           128     float            45.70    0.09    0.08  0e+00    45.45    0.09    0.08  1e+00
[1,0]<stdout>:        8192           256     float            45.21    0.18    0.16  0e+00    44.55    0.18    0.16  1e+00
[1,0]<stdout>:       16384           512     float            44.84    0.37    0.32  0e+00    44.48    0.37    0.32  1e+00
[1,0]<stdout>:       32768          1024     float            45.16    0.73    0.63  0e+00    44.57    0.74    0.64  1e+00
[1,0]<stdout>:       65536          2048     float            44.33    1.48    1.29  0e+00    44.77    1.46    1.28  1e+00
[1,0]<stdout>:      131072          4096     float            45.12    2.91    2.54  0e+00    45.15    2.90    2.54  1e+00
[1,0]<stdout>:      262144          8192     float            45.62    5.75    5.03  0e+00    45.23    5.80    5.07  1e+00
[1,0]<stdout>:      524288         16384     float            48.53   10.80    9.45  0e+00    47.85   10.96    9.59  1e+00
[1,0]<stdout>:     1048576         32768     float            55.42   18.92   16.55  0e+00    55.21   18.99   16.62  1e+00
[1,0]<stdout>:     2097152         65536     float            68.69   30.53   26.71  0e+00    68.27   30.72   26.88  1e+00
[1,0]<stdout>:     4194304        131072     float            92.01   45.58   39.89  0e+00    90.83   46.18   40.40  1e+00
[1,0]<stdout>:     8388608        262144     float            136.5   61.48   53.79  0e+00    134.9   62.17   54.39  1e+00
[1,0]<stdout>:    16777216        524288     float            166.4  100.80   88.20  0e+00    164.7  101.87   89.14  1e+00
[1,0]<stdout>:    33554432       1048576     float            226.0  148.45  129.90  0e+00    226.4  148.18  129.66  1e+00
[1,0]<stdout>:    67108864       2097152     float            360.0  186.40  163.10  0e+00    361.0  185.87  162.64  1e+00
[1,0]<stdout>:   134217728       4194304     float            644.8  208.17  182.15  0e+00    646.1  207.73  181.77  1e+00
[1,0]<stdout>:   268435456       8388608     float           1198.7  223.94  195.95  0e+00   1194.5  224.73  196.64  1e+00
[1,0]<stdout>:   536870912      16777216     float           2316.6  231.75  202.78  0e+00   2318.0  231.61  202.66  1e+00
[1,0]<stdout>:  1073741824      33554432     float           4549.4  236.02  206.52  0e+00   4537.5  236.64  207.06  1e+00
[1,0]<stdout>:  2147483648      67108864     float           9008.3  238.39  208.59  0e+00   9003.4  238.52  208.70  1e+00
[1,0]<stdout>:  4294967296     134217728     float            17874  240.29  210.26  0e+00    17885  240.15  210.13  1e+00
[1,0]<stdout>:  8589934592     268435456     float            35562  241.55  211.36  0e+00    35564  241.53  211.34  1e+00

Overall all-to-all latency in the latter case is a little bit larger due to smaller packet size and more P2P operations.

@Fragile-azalea
Copy link
Author

By set LD_LIBRARY_PATH=/xxx/nccl_2.10.3-1+cuda10.2_x86_64/lib, it works now.

# nThread 1 nGpus 1 minBytes 134217728 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 validation: 1
#
# Using devices
#   Rank  0 Pid 1372214 on     host12 device  0 [0x02] GeForce RTX 2070 SUPER
#   Rank  1 Pid 3915732 on     host13 device  0 [0x02] GeForce RTX 2070 SUPER
NCCL version 2.10.3+cuda10.2
#
#                                                       out-of-place                       in-place
#       size         count      type   redop     time   algbw   busbw  error     time   algbw   busbw  error
#        (B)    (elements)                       (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)
   134217728      16777216     float           675213    0.20    0.10  0e+00   674851    0.20    0.10  0e+00
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0.0994159
#
# nThread 1 nGpus 1 minBytes 134217728 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 validation: 1
#
# Using devices
#   Rank  0 Pid 1372636 on     host12 device  0 [0x02] GeForce RTX 2070 SUPER
#   Rank  1 Pid 3916798 on     host13 device  0 [0x02] GeForce RTX 2070 SUPER
NCCL version 2.10.3+cuda10.2
#
#                                                       out-of-place                       in-place
#       size         count      type   redop     time   algbw   busbw  error     time   algbw   busbw  error
#        (B)    (elements)                       (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)
   134217728      16777216     float           677196    0.20    0.10  0e+00   677425    0.20    0.10  0e+00
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0.0990814
#

@ghostplant ghostplant added the wontfix This will not be worked on label Jul 20, 2022
@ghostplant
Copy link
Contributor

Thanks @yzygitzh. Seems like it is an old NCCL issue. I'll close this since it is solved by upgrading NCCL.

@Fragile-azalea
Copy link
Author

By set LD_LIBRARY_PATH=/xxx/nccl_2.10.3-1+cuda10.2_x86_64/lib, Tutul also works now.
Here is my log.

[Statistics] param count for MoE local_experts = 24, param count for MoE gate = 8.

ExampleModel(
  (_moe_layer): MOELayer(
    Top-K(s) = ['k=2, noise=0.0'], Total-Experts = 4 [managed by 2 device(s)],
    (experts): FusedExpertsNetwork(model_dim=2, hidden_size=2, output_dim=2, local_experts=2)
    (gates): ModuleList(
      (0): LinearTopKGate(
        (wg): Linear(in_features=2, out_features=4, bias=False)
      )
    )
  )
)
[Benchmark] world_size = 2, dtype = float32, model_dim = 2, hidden_size = 2, samples = 4, num_local_experts = 2, topK = 2, a2a_ffn_overlap_degree = 2, parallel_type = `auto`, device = `cuda:0`
tensor([[[ 1.5410, -0.2934],
         [-1.0845, -1.3986]],

        [[ 1.5410, -0.2934],
         [ 0.4033,  0.8380]],

        [[-2.1788,  0.5684],
         [-1.0845, -1.3986]],

        [[ 0.4033,  0.8380],
         [-2.1788,  0.5684]]], device='cuda:0')
NCCL version 2.7.8+cuda10.2
NCCL version 2.10.3+cuda10.2
tensor([[[ 1.5410, -0.2934],
         [-1.0845, -1.3986]],

        [[ 1.5410, -0.2934],
         [ 0.4033,  0.8380]],

        [[-2.1788,  0.5684],
         [-1.0845, -1.3986]],

        [[ 0.4033,  0.8380],
         [-2.1788,  0.5684]]], device='cuda:0')
STEP-0: loss = 0.00000, step_time = 3.229446 sec, perf = 0.00 tflops.

It's subtle that the log contains both NCCL version 2.7.8+cuda10.2 and NCCL version 2.10.3+cuda10.2, but it works well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants