From a2e3abf3276d837103bfe5affb94ab4505d5f736 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Tue, 19 Aug 2025 17:58:02 +0800 Subject: [PATCH] fix bfloat16 not found for numpy op --- .../transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py | 4 ++-- mindnlp/core/_dtype.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py index 614aaff41..7711cf1c2 100644 --- a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py +++ b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py @@ -64,11 +64,11 @@ def process_func(example): # 配置训练参数 args = TrainingArguments( output_dir="./output/Qwen2.5_instruct_lora", - per_device_train_batch_size=3, + per_device_train_batch_size=2, gradient_accumulation_steps=5, logging_steps=10, num_train_epochs=3, - save_steps=100, + save_steps=100, learning_rate=1e-4, save_on_each_node=True, # gradient_checkpointing=True diff --git a/mindnlp/core/_dtype.py b/mindnlp/core/_dtype.py index 66a4535ae..d60f3cdef 100644 --- a/mindnlp/core/_dtype.py +++ b/mindnlp/core/_dtype.py @@ -4,7 +4,12 @@ from mindspore._c_expression import typing from mindspore._c_expression.typing import Type -from .configs import ON_A1 +from .configs import ON_A1, SUPPORT_BF16 + +if SUPPORT_BF16: + from mindspore.common.np_dtype import bfloat16 as np_bfloat16# pylint: disable=import-error +else: + from ml_dtypes import bfloat16 as np_bfloat16 bool_alias = bool @@ -107,6 +112,9 @@ def __gt__(self, other): float64 : np.float64, } +if not ON_A1: + dtype2np[bfloat16] = np_bfloat16 + py2dtype = { bool_alias: bool }