Skip to content

thd format is not supported with hierarchical CP implementation #6001

@stormchasingg

Description

@stormchasingg

Describe the bug
What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)
add parameters "hierarchical_context_parallel_sizes":[2,2], "cp_comm_type": "a2a+p2p", error report:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/ms-swift/swift/cli/_megatron/sft.py", line 5, in <module>
[rank0]:     megatron_sft_main()
[rank0]:   File "/ms-swift/swift/megatron/train/sft.py", line 79, in megatron_sft_main
[rank0]:     return MegatronSft(args).main()
[rank0]:   File "/ms-swift/swift/llm/base.py", line 49, in main
[rank0]:     result = self.run()
[rank0]:   File "/ms-swift/swift/megatron/train/sft.py", line 69, in run
[rank0]:     self.trainer.train(train_dataset, val_dataset, data_collator)
[rank0]:   File "/ms-swift/swift/megatron/trainers/base.py", line 749, in train
[rank0]:     pretrain(
[rank0]:   File "/Megatron-LM/megatron/training/training.py", line 864, in pretrain
[rank0]:     iteration, num_floating_point_operations_so_far = train(
[rank0]:   File "/Megatron-LM/megatron/training/training.py", line 2279, in train
[rank0]:     ) = train_step(
[rank0]:   File "/ms-swift/swift/megatron/trainers/base.py", line 302, in train_step
[rank0]:     return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
[rank0]:   File "/Megatron-LM/megatron/training/training.py", line 1395, in train_step
[rank0]:     losses_reduced = forward_backward_func(
[rank0]:   File "/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 1889, in forward_backward_pipelining_without_interleaving
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:   File "/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 289, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:   File "/ms-swift/swift/megatron/trainers/trainer.py", line 116, in forward_step
[rank0]:     output_tensor = model(**data)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/Megatron-LM/megatron/core/distributed/data_parallel_base.py", line 22, in forward
[rank0]:     return self.module(*inputs, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/Megatron-LM/megatron/core/transformer/module.py", line 237, in forward
[rank0]:     outputs = self.module(*inputs, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/ms-swift/swift/megatron/model/gpt_model.py", line 196, in forward
[rank0]:     hidden_states = self.decoder(
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/Megatron-LM/megatron/core/transformer/transformer_block.py", line 563, in forward
[rank0]:     hidden_states = self._checkpointed_forward(
[rank0]:   File "/Megatron-LM/megatron/core/transformer/transformer_block.py", line 454, in _checkpointed_forward
[rank0]:     hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1))
[rank0]:   File "/Megatron-LM/megatron/core/transformer/transformer_block.py", line 417, in checkpoint_handler
[rank0]:     return tensor_parallel.checkpoint(
[rank0]:   File "/Megatron-LM/megatron/core/tensor_parallel/random.py", line 477, in checkpoint
[rank0]:     return CheckpointFunction.apply(function, distribute_saved_activations, *args)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/Megatron-LM/megatron/core/tensor_parallel/random.py", line 423, in forward
[rank0]:     outputs = run_function(*args)
[rank0]:   File "/Megatron-LM/megatron/core/transformer/transformer_block.py", line 388, in custom_forward
[rank0]:     hidden_states, context = layer(
[rank0]:   File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 875, in __call__
[rank0]:     return super(MegatronModule, self).__call__(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/ms-swift/swift/megatron/init.py", line 434, in forward
[rank0]:     hidden_states, context = self._forward_attention(*_args, **kwargs)
[rank0]:   File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 501, in _forward_attention
[rank0]:     attention_output_with_bias = self.self_attention(
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/Megatron-LM/megatron/core/transformer/attention.py", line 737, in forward
[rank0]:     core_attn_out = self.core_attention(
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/Megatron-LM/megatron/core/extensions/transformer_engine.py", line 931, in forward
[rank0]:     core_attn_out = super().forward(
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py", line 1026, in forward
[rank0]:     return self.flash_attention(
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/transformer_engine/pytorch/attention/dot_product_attention/backends.py", line 659, in forward
[rank0]:     output = attn_forward_func_with_cp(
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py", line 3619, in attn_forward_func_with_cp
[rank0]:     out = AttnFuncWithCPAndKVP2P.apply(*args)
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/opt/conda/lib/python3.10/site-packages/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py", line 469, in forward
[rank0]:     qkv_format != "thd"
[rank0]: AssertionError: thd format is not supported with hierarchical CP implementation yet!

Your hardware and system info
Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等)
platform H800
pytorch 2.7
megatron-lm branch core_r0.13.0
transformer_engine 2.4.0

Additional context
Add any other context about the problem here(在这里补充其他信息)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions