Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No overlapping observed when enabling Smart Scheduling #168

Open
chenyu-jiang opened this issue Aug 4, 2023 · 8 comments
Open

No overlapping observed when enabling Smart Scheduling #168

chenyu-jiang opened this issue Aug 4, 2023 · 8 comments
Assignees

Comments

@chenyu-jiang
Copy link

Describe the bug
I am trying to create a minimal run-able example of Smart Scheduling proposed by the FasterMoE paper. However, when I profile the example using Nsight Systems, it seems that there is no overlapping between the all-to-all communication and expert computation.

Example of the profile result (one of the forward passes):
image

By looking at the CUDA API stack trace, it seems that it is indeed running the smart schedule code path:
Screenshot 2023-08-04 at 17 17 22

The code I used can be found below. Could you let me know if this is caused by my misusing FastMoE or other issues? Thanks.

To Reproduce
The test is done on 2 nodes, each with 8 V100 GPUs.
The code I used for the tests: (example.py)

import torch
import os

from fmoe import DistributedGroupedDataParallel as fmoeDDP
from fmoe.transformer import FMoETransformerMLP
from fmoe.gates import SwitchGate

class DummyMoEModel(torch.nn.Module):
    def __init__(self, world_size):
        super().__init__()
        self.non_moe = torch.nn.Sequential(
            torch.nn.Linear(1024, 8192),
            torch.nn.ReLU(),
            torch.nn.Linear(8192, 1024))
        self.moe = FMoETransformerMLP(
            num_expert=1,
            world_size=world_size,
            d_model=1024,
            d_hidden=4096,
            top_k = 1,
        )

    def forward(self, inp):
        torch.cuda.nvtx.range_push("Non-MoE")
        out = self.non_moe(inp)
        torch.cuda.nvtx.range_pop()
        torch.cuda.nvtx.range_push("FMoETransformerMLP")
        out = self.moe(out)
        torch.cuda.nvtx.range_pop()
        return torch.sum(out)

if __name__ == "__main__":
    torch.distributed.init_process_group(backend="nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = torch.distributed.get_world_size()
    torch.cuda.set_device(local_rank)
    model = DummyMoEModel(world_size).to(f"cuda:{local_rank}")
    model = fmoeDDP(model)
    opt = torch.optim.SGD(model.parameters(), lr=0.01)

    for i in range(20):
        inp = torch.randn(8192, 1024).to(f"cuda:{local_rank}")
        opt.zero_grad()
        if i == 10:
            torch.cuda.cudart().cudaProfilerStart()
        out = model(inp)
        if i == 15:
            torch.cuda.cudart().cudaProfilerStop()
        out.backward()

Steps to reproduce the behavior:

  1. Setup the docker environment using image pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel, install FastMoE.
  2. Run the above code with command
    FMOE_FASTER_SCHEDULE_ENABLE=1 torchrun --nnodes=2 --nproc-per-node=8 --rdzv-id=0 --rdzv-backend=c10d --rdzv-endpoint=xxx.xxx.xx.xx example.py

Expected behavior
Overlapping expert computation and all-to-all.

Logs
N/A

Platform

  • Device: NVIDIA V100
  • OS: Ubuntu 20.04.5 LTS
  • CUDA version: 11.7
  • NCCL version: 2.14.3-1
  • PyTorch version: 2.0.1

Additional context
N/A

@chenyu-jiang chenyu-jiang changed the title No overlapping observsed when enabling Smart Scheduling No overlapping observed when enabling Smart Scheduling Aug 4, 2023
@zms1999
Copy link
Collaborator

zms1999 commented Aug 4, 2023

I will check it out. However, it looks like you missed to set the FMOE_FASTER_GROUP_SIZE variable.

@chenyu-jiang
Copy link
Author

Thanks for the fast reply! I tried to set FMOE_FASTER_GROUP_SIZE=4, but still not seeing any overlap:
image

@laekov
Copy link
Owner

laekov commented Aug 17, 2023

This issue is found to be caused by using default cuda stream which synchronizes all other streams. Simply using another stream in smgr for nccl can solve the problem. Credits to @Harry-Chen for finding the point. Looking forward to a pull request.

@zms1999
Copy link
Collaborator

zms1999 commented Aug 25, 2023

Hi @chenyu-jiang , I finally found some bugs. I've fixed them in this branch; maybe you can retrace your program on it?

@chenyu-jiang
Copy link
Author

Hi @zms1999, extremely sorry for the (very) delayed response.. After the fix, now I can see overlapping in the example program. Thanks a lot for the fix! It is tremendously helpful.

@chenyu-jiang
Copy link
Author

Sorry for bothering again, but I am still running into problems when running the above example code with SwitchGate (i.e., add gate=SwitchGate when initializing FMoETransformerMLP.

The error message is:

Traceback (most recent call last):
  File "example.py", line 47, in <module>
    out = model(inp)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/fastmoe-1.0.2-py3.8-linux-x86_64.egg/fmoe/distributed.py", line 114, in forward
    return self.module(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "example.py", line 29, in forward
    out = self.moe(out)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/fastmoe-1.0.2-py3.8-linux-x86_64.egg/fmoe/transformer.py", line 65, in forward
    output = super().forward(inp)
  File "/root/.local/lib/python3.8/site-packages/fastmoe-1.0.2-py3.8-linux-x86_64.egg/fmoe/layers.py", line 228, in forward
    gate_top_k_idx, gate_score = self.gate(moe_inp)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/fastmoe-1.0.2-py3.8-linux-x86_64.egg/fmoe/gates/switch_gate.py", line 49, in forward
    valid_idx = top1_idx[top1_idx > -1]
RuntimeError: CUDA error: an illegal memory access was encountered

While if the code is run with CUDA_LAUNCH_BLOCKING=1, the error is gone. Could there still be some issue with synchronization?

@chenyu-jiang chenyu-jiang reopened this Sep 10, 2023
@zms1999
Copy link
Collaborator

zms1999 commented Sep 10, 2023

I guess that you're right because I've already fixed some synchronization bugs. There could be more, I will check next week.

@laekov
Copy link
Owner

laekov commented Sep 11, 2023

The switch gate problem seems to be caused by using then old problematic stream manager in the expert counting and balancing kernels. I put torch stream into smgr and replace the smgr streams in the other places in PR #173 . @zms1999 can u plz have a look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants