From 08092778ca474623d870ad9e646dce34664a294c Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Tue, 19 Aug 2025 23:20:45 +0800 Subject: [PATCH] sdpa use flash attention for 910B --- .../peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb | 2 +- .../peft/lora/Qwen2.5-7B-Instruct-Lora.py | 2 +- mindnlp/core/_prims/ascend.py | 4 +++ mindnlp/core/configs.py | 1 + mindnlp/core/dispatcher.py | 2 +- mindnlp/core/executor.py | 3 +++ mindnlp/core/nn/functional.py | 26 ++++++++++++++----- 7 files changed, 31 insertions(+), 9 deletions(-) diff --git a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb index 359f7e9d4..3845a06da 100644 --- a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb +++ b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb @@ -336,7 +336,7 @@ "from peft import PeftModel\n", "\n", "mode_path = 'Qwen/Qwen2.5-7B-Instruct'\n", - "lora_path = './output/Qwen2.5_instruct_lora/checkpoint-747' # 这里改称你的 lora 输出对应 checkpoint 地址\n", + "lora_path = './output/Qwen2.5_instruct_lora/checkpoint-1100' # 这里改称你的 lora 输出对应 checkpoint 地址\n", "\n", "# 加载tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)\n", diff --git a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py index 7711cf1c2..6826f83ff 100644 --- a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py +++ b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py @@ -91,7 +91,7 @@ def process_func(example): from peft import PeftModel mode_path = 'Qwen/Qwen2.5-7B-Instruct' -lora_path = './output/Qwen2.5_instruct_lora/checkpoint-747' # 这里改称你的 lora 输出对应 checkpoint 地址 +lora_path = './output/Qwen2.5_instruct_lora/checkpoint-1100' # 这里改称你的 lora 输出对应 checkpoint 地址 # 加载tokenizer tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True) diff --git a/mindnlp/core/_prims/ascend.py b/mindnlp/core/_prims/ascend.py index ca6f7c4d3..c82880ffb 100644 --- a/mindnlp/core/_prims/ascend.py +++ b/mindnlp/core/_prims/ascend.py @@ -239,3 +239,7 @@ def one_hot_ext(tensor, num_classes): __all__.append('one_hot_ext') +def flash_attention_score(*args, **kwargs): + return pyboost_inner_prim.flash_attention_score_impl(*args, **kwargs) + +__all__.append('flash_attention_score') diff --git a/mindnlp/core/configs.py b/mindnlp/core/configs.py index c66fb2ae0..e1cc192ab 100644 --- a/mindnlp/core/configs.py +++ b/mindnlp/core/configs.py @@ -6,6 +6,7 @@ DEVICE_TARGET = mindspore.get_context('device_target') SUPPORT_BF16 = DEVICE_TARGET == 'Ascend' and SOC not in ['ascend910', 'ascend310b'] ON_A1 = SOC == 'ascend910' +ON_A2 = SOC in ['ascend910b', 'ascend910_93'] ON_ORANGE_PI = '310b' in SOC USE_PYBOOST = DEVICE_TARGET == 'Ascend' DEFAULT_DTYPE = mindspore.float32 diff --git a/mindnlp/core/dispatcher.py b/mindnlp/core/dispatcher.py index ac1743c77..a3b0a736c 100644 --- a/mindnlp/core/dispatcher.py +++ b/mindnlp/core/dispatcher.py @@ -154,7 +154,7 @@ def dispatch(self, func_name, *args, **kwargs): raise RuntimeError( f"No implementation for function: {func_name} on {device_type}." ) - return func(*args), device + return func(*args, **kwargs), device dispatcher = Dispatcher() diff --git a/mindnlp/core/executor.py b/mindnlp/core/executor.py index 2762a7394..94bed3830 100644 --- a/mindnlp/core/executor.py +++ b/mindnlp/core/executor.py @@ -2,6 +2,9 @@ from .dispatcher import dispatcher def execute(func_name, *args, **kwargs): + requires_grad = kwargs.pop('requires_grad', False) + user_created = kwargs.pop('user_created', False) + out, device = dispatcher.dispatch(func_name, *args, **kwargs) if not isinstance(out, (tuple, list)): out._device = device diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index eaa23fad0..e4b03d3c5 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -7,7 +7,7 @@ from mindnlp import core from mindnlp.core.executor import execute -from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1 +from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1, ON_A2 generator_step_ = 12 @@ -1171,9 +1171,27 @@ def repeat_kv(hidden_states, n_rep: int): def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False) -> core.Tensor: + L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + if enable_gqa: + key = repeat_kv(key, query.size(-3) // key.size(-3)).contiguous() + value = repeat_kv(value, query.size(-3) // value.size(-3)).contiguous() + + if query.device.type == 'npu' and ON_A2: + if attn_mask is not None: + attn_mask = ~attn_mask + + head_num = query.shape[1] + output = execute('flash_attention_score', query, key, value, head_num=head_num, input_layout='BNSD', real_shift=None, padding_mask=None, attn_mask=attn_mask, + scale_value=scale_factor, keep_prob=1 - dropout_p, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0, + drop_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0) + + sfm_max, sfm_sum, sfm_out, atten_out = output + + return atten_out + attn_bias_shape = (L, S) if attn_mask is None else attn_mask.shape attn_bias = core.zeros(attn_bias_shape, dtype=query.dtype, device=query.device) @@ -1190,11 +1208,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), core.finfo(attn_bias.dtype).min) else: attn_bias = attn_mask + attn_bias - - if enable_gqa: - key = repeat_kv(key, query.size(-3) // key.size(-3)).contiguous() - value = repeat_kv(value, query.size(-3) // value.size(-3)).contiguous() - + attn_weight = query.float() @ key.transpose(-2, -1).float() * scale_factor attn_weight += attn_bias.float() attn_weight = softmax(attn_weight, dim=-1, dtype=core.float32).to(query.dtype)