Skip to content

FSDP2训练LoRA报错FSDP expects uniform original parameter dtype but got {torch.bfloat16, torch.float32} #182

@0hujun

Description

@0hujun

Checklist / 检查清单

  • I have searched existing issues, and this is a new bug report. / 我已经搜索过现有的 issues,确认这是一个新的 bug report。

Bug Description / Bug 描述

拉取twinkle的master分支,修改NPU相关参数,然后执行cd /home/ma-user/twinkle/cookbook/transformers && sh fsdp2_moe.sh 训练报错:

[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/ma-user/twinkle/cookbook/transformers/fsdp2_moe.py", line 89, in <module>
[rank2]:     train()
[rank2]:   File "/home/ma-user/twinkle/cookbook/transformers/fsdp2_moe.py", line 71, in train
[rank2]:     model.forward_backward(inputs=batch)
[rank2]:   File "/home/ma-user/twinkle/src/twinkle/infra/__init__.py", line 647, in wrapper
[rank2]:     return func(self, *args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/ma-user/twinkle/src/twinkle/model/transformers/transformers.py", line 553, in forward_backward
[rank2]:     outputs = self.forward(inputs=inputs, **kwargs)
[rank2]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/ma-user/twinkle/src/twinkle/infra/__init__.py", line 647, in wrapper
[rank2]:     return func(self, *args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/ma-user/twinkle/src/twinkle/model/transformers/transformers.py", line 373, in forward
[rank2]:     outputs = self.model(**inputs)
[rank2]:               ^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl
[rank2]:     return inner()
[rank2]:            ^^^^^^^
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in inner
[rank2]:     args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
[rank2]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 62, in fsdp_hook_wrapper
[rank2]:     return torch._dynamo.disable(func, recursive=True)(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 222, in _pre_forward
[rank2]:     args, kwargs = self._root_pre_forward(module, args, kwargs)
[rank2]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch_npu/distributed/fsdp/_add_fsdp_patch.py", line 95, in _patched_root_pre_forward
[rank2]:     self._lazy_init()
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 178, in _lazy_init
[rank2]:     state._fsdp_param_group.lazy_init()
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 248, in lazy_init
[rank2]:     self._init_mp_dtypes()
[rank2]:   File "/home/ma-user/.local/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 218, in _init_mp_dtypes
[rank2]:     raise AssertionError(
[rank2]: AssertionError: FSDP expects uniform original parameter dtype but got {torch.bfloat16, torch.float32}

How to Reproduce / 如何复现

拉取twinkle的master分支,修改NPU相关参数,然后执行cd /home/ma-user/twinkle/cookbook/transformers && sh fsdp2_moe.sh

Additional Information / 补充信息

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions