Skip to content

Conversation

@S1ro1
Copy link
Contributor

@S1ro1 S1ro1 commented Sep 22, 2025

class Ep2DpParallel(ParallelStyle):
    def __init__(self):
        super().__init__()
        self._num_tokens_to_send = None
        self._num_tokens_to_recv = None
        self._reshuffle_indices = None
        self._reshuffled_counts = None

    def _token_dispatch(self, mod, inputs, device_mesh):
        routed_input, num_tokens_per_expert = inputs
        ep_size = device_mesh.shape[0]

        # num_tokens_per_expert is of shape (num_experts, ), where each element holds the amount of tokens for
        # the corresponding expert from the local rank
        with torch.no_grad():
            # we transpose num_tokens_per_expert on device_mesh ep axis, to get the number of tokens for the local rank
            # think of all2all as a transpose operation on the device mesh
            # grouped_tokens_per_rank is of shape (ep_size * num_experts_per_rank,)
            # such as:
            # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., # tokens for local expert n from EP rank 0, ...]
            grouped_tokens_per_rank = all_to_all_single(
                num_tokens_per_expert,
                None,
                None,
                group=device_mesh.get_group(),
            )

            # this is of shape (ep_size, )
            # [#tokens for rank 0, #tokens for rank 1, ...]
            num_tokens_to_send = (
                num_tokens_per_expert.view(ep_size, -1)
                .sum(dim=1)
                .to(torch.device("cpu"), non_blocking=True)
            )

            # this is of shape (ep_size, )
            # [#tokens from rank 0, #tokens from rank 1, ...]
            num_tokens_to_recv = (
                grouped_tokens_per_rank.view(ep_size, -1)
                .sum(dim=1)
                .to(torch.device("cpu"), non_blocking=False)
            )
            self._num_tokens_to_send = num_tokens_to_send.tolist()
            self._num_tokens_to_recv = num_tokens_to_recv.tolist()

        # perform all-to-all to send the tokens to the right ranks
        routed_input = all_to_all_single_autograd(
            routed_input,
            self._num_tokens_to_recv,
            self._num_tokens_to_send,
            device_mesh.get_group(),
        )

        # routed input is not sorted by expert anymore, rather looks like:
        # [tokens for local expert 0 from EP rank 0, tokens for local expert 0 from EP rank 1, ..., tokens for local expert 0 from EP rank n, ...]
        # this needs to be reshuffled back
        # same applies to grouped_tokens_per_rank
        # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 0 from EP rank 1, ..., # tokens for local expert 0 from EP rank n, ...]
        return routed_input, grouped_tokens_per_rank

    @staticmethod
    def _partition_fn(name, mod, device_mesh):
        # shard on the expert dimension
        for name, param in mod.named_parameters(recurse=False):
            dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)]))
            mod.register_parameter(name, dist_param)

    def _token_combine(self, mod, routed_output, device_mesh):
        # reverse all-to-all from dispatch
        routed_output = all_to_all_single_autograd(
            routed_output,
            self._num_tokens_to_send,
            self._num_tokens_to_recv,
            device_mesh.get_group(),
        )
        return routed_output

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        return distribute_module(
            module,
            device_mesh,
            partition_fn=Ep2DpParallel._partition_fn,
            input_fn=self._token_dispatch,
            output_fn=self._token_combine,
        )


        
        ```

@Liuweixiong0118
Copy link

hello~ Recently, I have been fine-tuning qwen3 moe using llama-factory. I found that the training of qwen3 moe is very slow and the GPU utilization is very low. Training the same data with 30B-A3B takes five times longer than the 32B dense model

May I ask if this PR is for solving this problem? May I also ask if the current changes are available and when they will be incorporated?
thanks~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants