From 8d54c59e8a3e6d03d0ab13b2e8b8885143106de0 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Tue, 19 Aug 2025 15:06:38 +0800 Subject: [PATCH] fix qwen2.5 finetune precision with sdpa --- .../peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb | 27 +--- .../peft/lora/Qwen2.5-7B-Instruct-Lora.py | 120 ++---------------- mindnlp/core/_C/__init__.py | 88 ++++++++++++- mindnlp/core/__init__.py | 6 +- mindnlp/core/_tensor.py | 30 +++++ mindnlp/core/nn/functional.py | 20 ++- mindnlp/core/nn/parameter.py | 21 --- 7 files changed, 150 insertions(+), 162 deletions(-) diff --git a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb index 2841de612..5ffed5b32 100644 --- a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb +++ b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb @@ -18,9 +18,6 @@ "outputs": [], "source": [ "import mindnlp\n", - "import mindspore\n", - "\n", - "# mindspore.set_context(pynative_synchronize=True)\n", "from datasets import Dataset\n", "import pandas as pd\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig" @@ -167,8 +164,7 @@ "source": [ "import torch\n", "\n", - "model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-7B-Instruct', torch_dtype=torch.float16)\n", - "model = model.npu()" + "model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-7B-Instruct', torch_dtype=torch.float16, device_map='auto')" ] }, { @@ -250,22 +246,6 @@ "model.print_trainable_parameters()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "6b6aebd2", - "metadata": {}, - "outputs": [], - "source": [ - "# 待训练的lora参数需转成fp32\n", - "print_flag = True\n", - "for param in filter(lambda p: p.requires_grad, model.parameters()):\n", - " if print_flag:\n", - " print(param.data.dtype)\n", - " print_flag = False\n", - " param.data = param.data.to(torch.float32)" - ] - }, { "cell_type": "markdown", "id": "ca055683-837f-4865-9c57-9164ba60c00f", @@ -362,11 +342,14 @@ "tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)\n", "\n", "# 加载模型\n", - "model = AutoModelForCausalLM.from_pretrained(mode_path, device_map=\"auto\",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()\n", + "model = AutoModelForCausalLM.from_pretrained(mode_path, torch_dtype=torch.float16, trust_remote_code=True).eval()\n", "\n", "# 加载lora权重\n", "model = PeftModel.from_pretrained(model, model_id=lora_path)\n", "\n", + "# host to device\n", + "model = model.npu()\n", + "\n", "prompt = \"你是谁?\"\n", "inputs = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": \"假设你是皇帝身边的女人--甄嬛。\"},{\"role\": \"user\", \"content\": prompt}],\n", " add_generation_prompt=True,\n", diff --git a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py index 7a9cf71f6..a3d53f09f 100644 --- a/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py +++ b/examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.py @@ -1,11 +1,7 @@ -#!/usr/bin/env python +# !/usr/bin/env python # coding: utf-8 -# # 导入环境 - -# In[ ]: - - +# 导入环境 import mindnlp import mindspore @@ -14,30 +10,12 @@ import pandas as pd from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig - - # 将JSON文件转换为CSV文件 df = pd.read_json('/home/lvyufeng/lvyufeng/mindnlp/examples/transformers/peft/lora/huanhuan.json') ds = Dataset.from_pandas(df) - -# In[ ]: - - -ds[:3] - - -# # 处理数据集 - -# In[ ]: - - +# 处理数据集 tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-7B-Instruct', use_fast=False, trust_remote_code=True) -tokenizer - - -# In[ ]: - def process_func(example): MAX_LENGTH = 384 # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性 @@ -57,55 +35,14 @@ def process_func(example): "labels": labels } - -# In[ ]: - - tokenized_id = ds.map(process_func, remove_columns=ds.column_names) -print(len(tokenized_id)) - -# In[ ]: - - -tokenizer.decode(tokenized_id[0]['input_ids']) - - -# In[ ]: - - -tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[1]["labels"]))) - - -# # 创建模型 - -# In[ ]: - import torch - -model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-7B-Instruct', torch_dtype=torch.float16, attn_implementation='eager') -# model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-7B-Instruct', torch_dtype=torch.float16) -model = model.npu() - - -# In[ ]: - - +model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-7B-Instruct', torch_dtype=torch.float16, device_map=0) model.enable_input_require_grads() # 开启梯度检查点时,要执行该方法 - -# In[ ]: - - -model.dtype - - -# # lora - -# In[ ]: - - +# lora from peft import LoraConfig, TaskType, get_peft_model config = LoraConfig( @@ -116,30 +53,11 @@ def process_func(example): lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理 lora_dropout=0.1# Dropout 比例 ) -config - - -# In[ ]: - model = get_peft_model(model, config) -config - - -# In[ ]: - - model.print_trainable_parameters() - -# In[ ]: - - -# # 配置训练参数 - -# In[ ]: - - +# 配置训练参数 args = TrainingArguments( output_dir="./output/Qwen2.5_instruct_lora", per_device_train_batch_size=4, @@ -149,14 +67,9 @@ def process_func(example): save_steps=100, learning_rate=1e-4, save_on_each_node=True, - # fp16=True, # gradient_checkpointing=True ) - -# In[ ]: - - trainer = Trainer( model=model, args=args, @@ -164,24 +77,11 @@ def process_func(example): data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True), ) - -# In[ ]: - - -trainer.accelerator.state - - -# In[ ]: - - trainer.train() +# 合并加载模型 -# # 合并加载模型 - -# In[ ]: - - +import mindnlp from transformers import AutoModelForCausalLM, AutoTokenizer import torch from peft import PeftModel @@ -193,11 +93,13 @@ def process_func(example): tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True) # 加载模型 -model = AutoModelForCausalLM.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True).eval() +model = AutoModelForCausalLM.from_pretrained(mode_path, torch_dtype=torch.float16, trust_remote_code=True).eval() # 加载lora权重 model = PeftModel.from_pretrained(model, model_id=lora_path) +model = model.npu() + prompt = "你是谁?" inputs = tokenizer.apply_chat_template([{"role": "user", "content": "假设你是皇帝身边的女人--甄嬛。"},{"role": "user", "content": prompt}], add_generation_prompt=True, diff --git a/mindnlp/core/_C/__init__.py b/mindnlp/core/_C/__init__.py index c2f172428..944fd5047 100644 --- a/mindnlp/core/_C/__init__.py +++ b/mindnlp/core/_C/__init__.py @@ -1,6 +1,6 @@ from typing import Any -from mindspore import Generator as msGenerator import mindspore +from mindspore.ops.operations._inner_ops import Generator as GeneratorOp from mindnlp import core from . import _nn @@ -105,19 +105,101 @@ def __exit__(self, type: Any, value: Any, traceback: Any): device_ = device -class Generator(msGenerator): +STEP = 0 +SEED = 1 +GET_STATE = 2 +SET_STATE = 3 +MANUAL_SEED = 4 +INITIAL_SEED = 5 + +class Generator: def __init__(self, device='cpu'): - super().__init__() if device == 'cuda' and DEVICE_TARGET == 'Ascend': device = 'npu' self._device = device_(device) if isinstance(device, str) else device + self._seed = mindspore.Tensor(0) + self._offset = mindspore.Tensor(0) + self._generator = GeneratorOp().set_device("CPU") + self._generator.add_prim_attr("manual_seed", False) + + @property def device(self): if hasattr(self, '_device'): return self._device return device('cpu') + def set_state(self, state): + """ + Sets the generator state. + + Args: + state (tensor): target state of the generator. + """ + self._generator(SET_STATE, (self._seed, self._offset, state)) + + def get_state(self): + """ + Get the generator state. + + Returns: + Tensor, generator state. + """ + return self._generator(GET_STATE, (self._seed, self._offset))[2] + + def seed(self): # pylint: disable=redefined-outer-name + """ + Seed generator with random number. + + Returns: + Randomly generated seeds, the type is int. + """ + current_seed = self._generator( + SEED, (self._seed, self._offset))[0] + return current_seed.item() + + def manual_seed(self, seed): # pylint: disable=redefined-outer-name + """ + Set the generator seed. + + Args: + seed (int): Set the generator seed. + + Returns: + Generator, the generator instance. + """ + if not isinstance(seed, int): + raise TypeError("Seed must be an integer.") + seed = mindspore.Tensor(seed, mindspore.int64) + self._generator(MANUAL_SEED, (self._seed, self._offset, seed)) + self._generator.add_prim_attr("manual_seed", True) + return self + + def initial_seed(self): + """ + Return the initial seed of generator. + + Returns: + The initial seed of generator. + """ + current_seed = self._generator( + INITIAL_SEED, (self._seed, self._offset))[0] + return current_seed.item() + + + def _step(self, step): + """ + Return current seed and offset, and update offset for the next call. + + Args: + step (Tensor): Update offset by step. + + Returns: + Current seed and offset. + """ + return self._generator(STEP, (self._seed, self._offset, step,))[:2] + default_generator = Generator() class Tag: pass diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index 9ce4d98cd..cec68c687 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -43,12 +43,12 @@ -from ._C import * -from ._C.size import Size from ._dtype import * -from .ops import * from ._tensor import Tensor, tensor, is_tensor, \ LongTensor, FloatTensor, BoolTensor, HalfTensor, BFloat16Tensor, IntTensor +from ._C import * +from ._C.size import Size +from .ops import * from ._tensor import enable_mindspore_patch enable_mindspore_patch() diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 02ab6f70b..49f8544a1 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -5,6 +5,7 @@ from mindspore import Tensor from mindspore.common.tensor import _TensorMeta from mindspore._c_expression.typing import Type +from mindspore._c_expression import ParamInfo # pylint: disable=no-name-in-module from mindspore._c_expression import typing try: from mindspore.common._stub_tensor import StubTensor, _stub_method @@ -84,16 +85,19 @@ def tensor_meta_str(self): old_init = Tensor.__init__ def __init__(self, *args, **kwargs): + requires_grad = kwargs.pop('requires_grad', False) if len(args) > 1 and all([isinstance(arg, int) for arg in args]): tensor = Tensor_(shape=args, dtype=get_default_dtype()) old_init(self, tensor, internal=True) else: old_init(self, *args, **kwargs) + self.requires_grad_(requires_grad) Tensor.__init__ = __init__ origin_setitem = Tensor.__setitem__ Tensor._device = device_('cpu') +Tensor._requires_grad = False def tensor(data, *, dtype=None, device=None, requires_grad=False): if isinstance(data, Tensor): @@ -2517,6 +2521,32 @@ def char(self): def is_nested(self): return False + @property + def requires_grad(self): + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, value): + if not isinstance(value, bool): + raise TypeError("The 'requires_grad' attribute of parameter must be set as bool.") + self._requires_grad = value + if self.param_info is not None: + self.param_info.requires_grad = value + else: + self.param_info = ParamInfo() + self.param_info.requires_grad = value + + if value: + if not hasattr(self, 'handle'): + self.retain_grad() + else: + if hasattr(self, 'handle'): + self.handle.remove() + delattr(self, 'handle') + + def retain_grad(self): + pass + def enable_mindspore_patch(): fn_keys = list(TensorPlaceHolder.__dict__) fn_keys.remove('__doc__') diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index f1505e710..88309432d 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -1157,6 +1157,18 @@ def _in_projection_packed( b_q, b_k, b_v = b.chunk(3) return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) +def repeat_kv(hidden_states, n_rep: int): + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states.unsqueeze(2).expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False) -> core.Tensor: L, S = query.size(-2), key.size(-2) @@ -1180,14 +1192,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. attn_bias = attn_mask + attn_bias if enable_gqa: - key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) - value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + key = repeat_kv(key, query.size(-3) // key.size(-3)).contiguous() + value = repeat_kv(value, query.size(-3) // value.size(-3)).contiguous() attn_weight = query.float() @ key.transpose(-2, -1).float() * scale_factor attn_weight += attn_bias.float() - attn_weight = softmax(attn_weight, dim=-1, dtype=core.float32).to(query.dtype) + attn_weight = softmax(attn_weight, dim=-1) attn_weight = dropout(attn_weight, dropout_p, training=True) - return attn_weight @ value + return (attn_weight @ value.float()).to(query.dtype) def _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads): diff --git a/mindnlp/core/nn/parameter.py b/mindnlp/core/nn/parameter.py index bde35e7d3..9b8a9cd03 100644 --- a/mindnlp/core/nn/parameter.py +++ b/mindnlp/core/nn/parameter.py @@ -10,7 +10,6 @@ class Parameter(Tensor): grad = None - requires_grad = False _grad_fn = None def __init__(self, input_data=None, requires_grad=True, **kwargs): @@ -49,26 +48,6 @@ def name(self): # only for O2 """ return self.param_info.name - @property - def requires_grad(self): - return self._requires_grad - - @requires_grad.setter - def requires_grad(self, value): - if not isinstance(value, bool): - raise TypeError("The 'requires_grad' attribute of parameter must be set as bool.") - self.param_info.requires_grad = value - self._requires_grad = value - if value: - if not hasattr(self, 'handle'): - self.retain_grad() - else: - if hasattr(self, 'handle'): - self.handle.remove() - delattr(self, 'handle') - - def retain_grad(self): - pass class UninitializedTensorMixin: _allowed_methods = [