Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@
"from peft import PeftModel\n",
"\n",
"mode_path = 'Qwen/Qwen2.5-7B-Instruct'\n",
"lora_path = './output/Qwen2.5_instruct_lora/checkpoint-747' # 这里改称你的 lora 输出对应 checkpoint 地址\n",
"lora_path = './output/Qwen2.5_instruct_lora/checkpoint-1100' # 这里改称你的 lora 输出对应 checkpoint 地址\n",
"\n",
"# 加载tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def process_func(example):
from peft import PeftModel

mode_path = 'Qwen/Qwen2.5-7B-Instruct'
lora_path = './output/Qwen2.5_instruct_lora/checkpoint-747' # 这里改称你的 lora 输出对应 checkpoint 地址
lora_path = './output/Qwen2.5_instruct_lora/checkpoint-1100' # 这里改称你的 lora 输出对应 checkpoint 地址

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)
Expand Down
4 changes: 4 additions & 0 deletions mindnlp/core/_prims/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,7 @@ def one_hot_ext(tensor, num_classes):

__all__.append('one_hot_ext')

def flash_attention_score(*args, **kwargs):
return pyboost_inner_prim.flash_attention_score_impl(*args, **kwargs)

__all__.append('flash_attention_score')
1 change: 1 addition & 0 deletions mindnlp/core/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DEVICE_TARGET = mindspore.get_context('device_target')
SUPPORT_BF16 = DEVICE_TARGET == 'Ascend' and SOC not in ['ascend910', 'ascend310b']
ON_A1 = SOC == 'ascend910'
ON_A2 = SOC in ['ascend910b', 'ascend910_93']
ON_ORANGE_PI = '310b' in SOC
USE_PYBOOST = DEVICE_TARGET == 'Ascend'
DEFAULT_DTYPE = mindspore.float32
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def dispatch(self, func_name, *args, **kwargs):
raise RuntimeError(
f"No implementation for function: {func_name} on {device_type}."
)
return func(*args), device
return func(*args, **kwargs), device


dispatcher = Dispatcher()
Expand Down
3 changes: 3 additions & 0 deletions mindnlp/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .dispatcher import dispatcher

def execute(func_name, *args, **kwargs):
requires_grad = kwargs.pop('requires_grad', False)
user_created = kwargs.pop('user_created', False)

out, device = dispatcher.dispatch(func_name, *args, **kwargs)
if not isinstance(out, (tuple, list)):
out._device = device
Expand Down
26 changes: 20 additions & 6 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mindnlp import core
from mindnlp.core.executor import execute

from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1
from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1, ON_A2

generator_step_ = 12

Expand Down Expand Up @@ -1171,9 +1171,27 @@ def repeat_kv(hidden_states, n_rep: int):

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> core.Tensor:

L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale

if enable_gqa:
key = repeat_kv(key, query.size(-3) // key.size(-3)).contiguous()
value = repeat_kv(value, query.size(-3) // value.size(-3)).contiguous()

if query.device.type == 'npu' and ON_A2:
if attn_mask is not None:
attn_mask = ~attn_mask

head_num = query.shape[1]
output = execute('flash_attention_score', query, key, value, head_num=head_num, input_layout='BNSD', real_shift=None, padding_mask=None, attn_mask=attn_mask,
scale_value=scale_factor, keep_prob=1 - dropout_p, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0,
drop_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0)

sfm_max, sfm_sum, sfm_out, atten_out = output

return atten_out

attn_bias_shape = (L, S) if attn_mask is None else attn_mask.shape
attn_bias = core.zeros(attn_bias_shape, dtype=query.dtype, device=query.device)

Expand All @@ -1190,11 +1208,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), core.finfo(attn_bias.dtype).min)
else:
attn_bias = attn_mask + attn_bias

if enable_gqa:
key = repeat_kv(key, query.size(-3) // key.size(-3)).contiguous()
value = repeat_kv(value, query.size(-3) // value.size(-3)).contiguous()


attn_weight = query.float() @ key.transpose(-2, -1).float() * scale_factor
attn_weight += attn_bias.float()
attn_weight = softmax(attn_weight, dim=-1, dtype=core.float32).to(query.dtype)
Expand Down
Loading