diff --git a/README.md b/README.md index b46430d5d0..c29b39193b 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,17 @@ Currently supported approches (and counting): 1. LoRA: [LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/abs/2106.09685) 2. Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751) 3. Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119) -4. All tuners offered on [Peft](https://github.com/huggingface/peft). +4. Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503) +5. ResTuning-Bypass +7. All tuners offered on [Peft](https://github.com/huggingface/peft) Key features: 1. By integrating the ModelScope library, models can be readily obatined via a model-id. 2. Tuners provided by SWIFT be combined together to allow exploration of multiple tuners on a model for best result. +3. Support calling `activate_adapter`或`deactivate_adapter` to activate/deactivate a single tuner. User can use one model with multiple tuners in different threads. + +Users can check the [documentation of Swift](./docs/Get Started/1.Introduction.md) to get detail tutorials. ## LLM SFT Example [code link](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm) diff --git a/README_CN.md b/README_CN.md index f4f206840e..ac9b4ce191 100644 --- a/README_CN.md +++ b/README_CN.md @@ -20,11 +20,16 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展 1. LoRA:[LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/abs/2106.09685) 2. Adapter:[Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751) 3. Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119) -4. 所有在[Peft](https://github.com/huggingface/peft)上提供的tuners。 +4. Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503) +5. ResTuning-Bypass +6. 所有在[Peft](https://github.com/huggingface/peft)上提供的tuners 关键特点: 1. 通过集成ModelScope库,可以通过model id轻松获取模型。 2. SWIFT提供的tuners可以组合在一起,以便在模型上探索多个tuners,以获得最佳结果。 +3. 支持调用`activate_adapter`或`deactivate_adapter`来使tuner激活或失活,用户可以在推理时用一个模型在不同线程中使用多种tuners而互不干扰。 + +用户可以查看 [Swift官方文档](./docs/Get Started/1.Introduction.md) 来了解详细信息。 ## 大模型微调的例子 [code link](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm) diff --git a/docs/Get Started/1.Introduction.md b/docs/Get Started/1.Introduction.md new file mode 100644 index 0000000000..14f68b2d0c --- /dev/null +++ b/docs/Get Started/1.Introduction.md @@ -0,0 +1,103 @@ +# 介绍 + +Swift是一个提供LLM模型轻量级训练和推理的开源框架。Swift提供的主要能力是`efficient tuners`,tuners是运行时动态加载到模型上的额外结构,在训练时将原模型的参数冻结,只训练tuner部分,这样可以达到快速训练、降低显存使用的目的。比如,最常用的tuner是LoRA。 + +总之,在这个框架中提供了以下特性: + +- **具备SOTA特性的Efficient Tuners**:用于结合大模型实现轻量级(在商业级显卡上)训练和推理,并取得较好效果 +- **使用ModelScope Hub的Trainer**:基于`transformers trainer`提供,支持LLM模型的训练,并支持将训练后的模型上传到[ModelScope Hub](https://www.modelscope.cn/models)中 +- **可运行的模型Examples**:针对热门大模型提供的训练脚本和推理脚本,并针对热门开源数据集提供了预处理逻辑,可直接运行使用 + +# 快速开始 + +在本章节会介绍如何快速安装swift并设定好运行环境,并跑通一个用例。 + +安装swift的方式非常简单,用户只需要在python>=3.8环境中运行: + +```shell +pip install ms-swift +``` + +下面的代码使用LoRA在分类任务上训练了`bert-base-uncased`模型: + +**运行下面的代码前请额外安装modelscope: ** + +```shell +pip install modelscope>=1.9.0 +``` + +```python +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +from modelscope import AutoModelForSequenceClassification, AutoTokenizer, MsDataset +from transformers import default_data_collator + +from swift import Trainer, LoRAConfig, Swift, TrainingArguments + + +model = AutoModelForSequenceClassification.from_pretrained( + 'AI-ModelScope/bert-base-uncased', revision='v1.0.0') +tokenizer = AutoTokenizer.from_pretrained( + 'AI-ModelScope/bert-base-uncased', revision='v1.0.0') +lora_config = LoRAConfig(target_modules=['query', 'key', 'value']) +model = Swift.prepare_model(model, config=lora_config) + +train_dataset = MsDataset.load('clue', subset_name='afqmc', split='train').to_hf_dataset().select(range(100)) +val_dataset = MsDataset.load('clue', subset_name='afqmc', split='validation').to_hf_dataset().select(range(100)) + + +def tokenize_function(examples): + return tokenizer(examples["sentence1"], examples["sentence2"], + padding="max_length", truncation=True, max_length=128) + + +train_dataset = train_dataset.map(tokenize_function) +val_dataset = val_dataset.map(tokenize_function) + +arguments = TrainingArguments( + output_dir='./outputs', + per_device_train_batch_size=16, +) + +trainer = Trainer(model, arguments, train_dataset=train_dataset, + eval_dataset=val_dataset, + data_collator=default_data_collator,) + +trainer.train() +``` + +在上面的例子中,我们使用了`bert-base-uncased`作为基模型,将LoRA模块patch到了['query', 'key', 'value']三个Linear上,进行了一次训练。 + +训练结束后可以看到outputs文件夹,它的文件结构如下: + +> outputs +> +> ​ |-- checkpoint-xx +> +> ​ |-- configuration.json +> +> ​ |-- default +> +> ​ |-- adapter_config.json +> +> ​ |-- adapter_model.bin +> +> ​ |-- ... + +可以使用该文件夹执行推理: + +```python +from modelscope import AutoModelForSequenceClassification, AutoTokenizer +from swift import Trainer, LoRAConfig, Swift + + +model = AutoModelForSequenceClassification.from_pretrained( + 'AI-ModelScope/bert-base-uncased', revision='v1.0.0') +tokenizer = AutoTokenizer.from_pretrained( + 'AI-ModelScope/bert-base-uncased', revision='v1.0.0') +lora_config = LoRAConfig(target_modules=['query', 'key', 'value']) +model = Swift.from_pretrained(model, model_id='./outputs/checkpoint-21') + +print(model(**tokenizer('this is a test', return_tensors='pt'))) +``` diff --git a/docs/Get Started/2.Installation.md b/docs/Get Started/2.Installation.md new file mode 100644 index 0000000000..7bc620c51d --- /dev/null +++ b/docs/Get Started/2.Installation.md @@ -0,0 +1,25 @@ +# 安装和使用 + +## Wheel包安装 + +可以使用pip进行安装: + +```shell +pip install ms-swift +``` + +## 源代码安装 + +```shell +git clone https://github.com/modelscope/swift.git +cd swift +pip install -e . +``` + +## Notebook环境 + +Swift支持训练的绝大多数模型都可以在`A10`显卡上使用,用户可以使用ModelScope官方提供的免费显卡资源: + +1. 进入[ModelScope](https://www.modelscope.cn)官方网站并登录 +2. 点击左侧的`我的Notebook`并开启一个免费GPU实例 +3. 愉快地薅A10显卡羊毛 diff --git a/docs/Get Started/3.Use in train and infer.md b/docs/Get Started/3.Use in train and infer.md new file mode 100644 index 0000000000..2209cecfc6 --- /dev/null +++ b/docs/Get Started/3.Use in train and infer.md @@ -0,0 +1,123 @@ +# Swift API + +## 在训练中使用Swift + +调用`Swift.prepare_model()`来将tuners添加到模型上: + +```python +from modelscope import Model +from swift import Swift, LoRAConfig +import torch +model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto') +lora_config = LoRAConfig( + r=16, + target_modules=['query_key_value'], + lora_alpha=32, + lora_dropout=0.) +model = Swift.prepare_model(model, lora_config) +# use model to do other things +``` + +也可以同时使用多个tuners: + +```python +from modelscope import Model +from swift import Swift, LoRAConfig, AdapterConfig +import torch +model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto') +lora_config = LoRAConfig( + r=16, + target_modules=['query_key_value'], + lora_alpha=32, + lora_dropout=0.) +adapter_config = AdapterConfig( + dim=model.config.hidden_size, + target_modules=['mlp'], + method_name='forward', + hidden_pos=0, + adapter_length=32, + ) +model = Swift.prepare_model(model, {'first_tuner': lora_config, 'second_tuner': adapter_config}) +# use model to do other things +``` + +在使用多个tuners时,传入的第二个参数需要是Dict,key是tuner名字,value是tuner配置。 + +训练后可以调用: + +```python +model.save_pretrained(save_directory='./output') +``` + +来存储模型checkpoint。模型的checkpoint文件只会包括tuners的权重,不会包含模型本身的权重。存储后的结构如下: + +> outputs +> +> ​ |-- configuration.json +> +> ​ |-- first_tuner +> +> ​ |-- adapter_config.json +> +> ​ |-- adapter_model.bin +> +> ​ |-- second_tuner +> +> ​ |-- adapter_config.json +> +> ​ |-- adapter_model.bin +> +> ​ |-- ... + +如果只传入单独的config,则会使用默认的名称`default`: + +> outputs +> +> ​ |-- configuration.json +> +> ​ |-- default +> +> ​ |-- adapter_config.json +> +> ​ |-- adapter_model.bin +> +> ​ |-- ... + +## 在推理时使用Swift + +使用`Swift.from_pretrained()`来拉起训练后存储的checkpoint: + +```python +from modelscope import Model +from swift import Swift +import torch +model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto') +model = Swift.from_pretrained(model, './output') +``` + +## 加载多个tuners并在不同线程中并行使用 + +在模型提供服务时,很可能出现一个模型同时服务多个http线程的情况,其中每个线程代表了一类用户请求。Swift支持在不同线程中激活不同tuners: + +```python +from modelscope import Model +from swift import Swift +import torch +model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto') +# 假设output中存在训练完成的a、b、c、d是个tuners +model = Swift.from_pretrained(model, './output') + +# 假设两类请求,一类使用a、b两个tuner,一类使用c、d两个tuner +type_1 = ['a', 'b', 'c'] +type_2 = ['a', 'c', 'd'] + +def request(_input, _type): + if _type == 'type_1': + model.set_active_adapters(type_1) + elif _type == 'type_2': + model.set_active_adapters(type_2) + return model(**_input) + +``` + +在不同线程中使用同一个tuner是安全的。 diff --git a/docs/Get Started/4.examples.md b/docs/Get Started/4.examples.md new file mode 100644 index 0000000000..80240e2679 --- /dev/null +++ b/docs/Get Started/4.examples.md @@ -0,0 +1,3 @@ +# LLM训练方案 + +Swift提供了完整的LLM训练方案,可以查看[Examples的README](../../examples/pytorch/llm/README_CN.md). diff --git a/docs/Modules/1.swift.md b/docs/Modules/1.swift.md new file mode 100644 index 0000000000..0d5b35c9ab --- /dev/null +++ b/docs/Modules/1.swift.md @@ -0,0 +1,69 @@ +# 接口介绍 + +## Swift + +##### Swift.prepare_model(model: Union[nn.Module, 'SwiftModel'], config: Union[SwiftConfig, PeftConfig, Dict[str, SwiftConfig]], **kwargs) + +>该静态方法随机初始化指定类型的tuners +> +>model: 需要加载tuner的模型,可以是SwiftModel,后添加的tuners会和前面SwiftModel中的一起生效 +> +>config:加载的tuner的config,可以是SwiftConfig或PeftConfig,或者带有名称的config的dict。如果不传递名称则名称默认为`default` +> +>kwargs: +> +>​ extra_state_keys: List[str] 需要被额外存储到文件的原始模型weights的key +> +>​ inference_mode: bool 是否以推理模式启动 + +SwiftConfig的具体参数可以查看每个tuner的文档。 + +##### Swift.from_pretrained(model: Union[nn.Module, 'SwiftModel'], model_id: str = None, adapter_name: Union[str, List[str]] = None, revision: str = None, **kwargs) + +> 该静态方法拉起之前存储过的tuners的checkpoint +> +> model: 需要加载tuner的模型,可以是SwiftModel,后添加的tuners会和前面SwiftModel中的一起生效 +> +> model_id:已存储的tuners的本地目录或modelscope hub id。 +> +> adapter_name:需要被拉起的adapter名称,默认为None代表全部拉起 +> +> kwargs: +> +> ​ inference_mode: bool 是否以推理模式启动 +> +> ​ revision: model_id的revision +> +> ​ extra_state_keys: 下次save_pretrained时额外存储的weights + +## SwiftModel + +在`Swift.prepare_model`或`Swift.from_pretrained`拉起后,都会返回一个`SwiftModel`类型的实例。该实例包装了实际传入的模型。 + +##### save_pretrained(self, save_directory: str, safe_serialization: bool = False, adapter_name: Union[str, List[str]] = None, **kwargs) + +> 实例方法,将模型存储到本地磁盘中,可直接被Swift.from_pretrained拉起 +> +> save_directory:存储的目录 +> +> safe_serialization: 是否存储safe_tensors +> +> adapter_name:待存储的adapter名称,默认为None代表全部存储 + +##### set_active_adapters(self, adapter_names: List[str]) + +> 实例方法,设置模型在当前线程中生效的所有adapter。如果将环境变量`USE_UNIQUE_THREAD`设置为'0',则设置对所有线程同时生效。 +> +> adapter_names:adapter名称列表 + +##### activate_adapter(self, adapter_name) + +> 实例方法,在当前线程中单独激活某个adapter,如果将环境变量`USE_UNIQUE_THREAD`设置为'0',则设置对所有线程同时生效。 +> +> adapter_name:adapter名称 + +##### deactivate_adapter(self, adapter_name) + +> 实例方法,在当前线程中单独激活某个adapter,如果将环境变量`USE_UNIQUE_THREAD`设置为'0',则设置对所有线程同时生效。 +> +> adapter_name:adapter名称 diff --git a/docs/Modules/2.lora.md b/docs/Modules/2.lora.md new file mode 100644 index 0000000000..013c4da7ee --- /dev/null +++ b/docs/Modules/2.lora.md @@ -0,0 +1,32 @@ +# LoRA + +LoRA是[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) 论文提供的轻量级训练组件。LoRA可以添加到Linear、Embedding、Conv2d等算子上生效。 + +>```python +>LoRAConfig ( +> r: int LoRA结构的秩 +> target_modules: Union[List[str], str] MLP结构的module_key,如果是str类型则进行full_match统配查找,如果是List,则进行末尾匹配 +> lora_alpha: int LoRA结构的权重比例,lora_alpha/r的值是lora结构的权重 +> lora_dropout: float LoRA结构的dropout比例 +> merge_weights: bool 在推理时是否将loRA权重合并到原始weights上 +> use_merged_linear: bool 是否是merged linear结构 +> enable_lora: List[bool]: 如果是use_merged_linear,哪些module需要添加LoRA结构 +> bias: str 偏置是否参与训练和存储,可以为`none`:所有偏置不参与训练, `all`:所有模块的偏置均参与训练, `lora_only`:仅loRA结构的偏置参与训练 +>) +>``` + +一个使用LoRA的例子如下: + +```python +from modelscope import Model +from swift import Swift, LoRAConfig +import torch +model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto') +lora_config = LoRAConfig( + r=16, + target_modules=['query_key_value'], + lora_alpha=32, + lora_dropout=0.) +model = Swift.prepare_model(model, lora_config) +# use model to do other things +``` diff --git a/docs/Modules/3.Restuning.md b/docs/Modules/3.Restuning.md new file mode 100644 index 0000000000..4beb11a022 --- /dev/null +++ b/docs/Modules/3.Restuning.md @@ -0,0 +1,41 @@ +# Restuning + +Restuning是[Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone]()论文提供的轻量级训练组件。Restuning工作在深度学习模型多层结构的layer上。 + +>```python +>ResTuningConfig ( +>dims: Union[List[int], int] layers输出的hidden_state的维度,可以传入List以适配上采样或下采样 +>root_modules: str 提供root hidden_state的模块的正则表达式 +>root_modules_hook: str 可以为`input`或`output`,表示hidden_state从root_module的输入或输出中取到 +>stem_modules: Union[List[str], str 提供root hidden_state的模块的正则表达式(str)或完整module路径(List) +>stem_modules_hook: str 可以为`input`或`output`,表示hidden_state从stem_module的输入或输出中取到 +>target_modules: str target module的正则表达式 +>target_modules_hook: str 可以为`input`或`output` hidden_state从target_module的输入或输出中取到 +>target_hidden_pos: Union[int, str] target_module forward输入或输出中hidden_state的index +>tuner_cfg: restuning模块中子tuner的配置,可以传入str或dict +>use_upsample: bool 是否加入上采样模块 +>upsample_out_channels: List[int] 如果进行上采样,上采样的通道数 +>zero_init_last: bool 是否对tuner的最后一层Linear进行全零初始化 +>) +>``` + +一个使用Restuning的例子如下: + +```python +from swift import (ResTuningConfig, Swift, snapshot_download) + +model_dir = snapshot_download('AI-ModelScope/vit-base-patch16-224') +from transformers import AutoModelForImageClassification + +model = AutoModelForImageClassification.from_pretrained(model_dir) +restuning_config_1 = ResTuningConfig( + dims=768, + root_modules=r'.*vit.encoder.layer.0$', + stem_modules=r'.*vit.encoder.layer\.\d+$', + target_modules=r'.*vit.layernorm', + target_modules_hook='input', + tuner_cfg='res_adapter', +) +model = Swift.prepare_model(model, config=restuning_config_1) +# use model to do other things +``` diff --git a/docs/Modules/4.adapter.md b/docs/Modules/4.adapter.md new file mode 100644 index 0000000000..10ab21c665 --- /dev/null +++ b/docs/Modules/4.adapter.md @@ -0,0 +1,31 @@ +# Adapter + +Adapter是[Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751) 论文提供的轻量级训练组件。一般添加到MLP结构之后生效。 + +>```python +>AdapterConfig ( +> dim: int MLP结构输出中hidden_state的dim,一般等于模型的hidden_size +> target_modules: Union[List[str], str] MLP结构的module_key,如果是str类型则进行full_match统配查找,如果是List,则进行末尾匹配 +> hidden_pos: Union[str, int] MLP输出结构中hidden_state的位置,如果是tuple/list则传入int,如果是dict则传入str类型的key +> method_name: str MLP结构的前向方法,Adapter默认会patch到该方法上,在forward调用后使用其hidden_state输入tuner,默认是forward。 +> adapter_length: int adapter结构中间层长度,默认为128 +> act_layer: str 激活算子,默认为gelu +>) +>``` + +一个使用adapter的例子如下: + +```python +from modelscope import Model +from swift import Swift, LoRAConfig +import torch +model = Model.from_pretrained('ZhipuAI/chatglm2-6b', torch_dtype=torch.bfloat16, device_map='auto') +adapter_config = AdapterConfig( + dim=model.config.hidden_size, + target_modules=['mlp']), + method_name='forward', + hidden_pos=0, + ) +model = Swift.prepare_model(model, adapter_config) +# use model to do other things +``` diff --git a/docs/Modules/5.side.md b/docs/Modules/5.side.md new file mode 100644 index 0000000000..6c49e2fad3 --- /dev/null +++ b/docs/Modules/5.side.md @@ -0,0 +1,30 @@ +# Side + +Side是[Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503) 论文提供的轻量级训练组件。Side可以添加到MLP结构上。 + +>```python +>SideConfig ( +>dim: int hidden_state的维度 +>target_modules: str 需要嵌入的位置的正则表达式 +>side_module_name: str side module的名字,可以是fcn4,mlp,alexnet +>hidden_pos: Union[str, int] hidden_state在MLP结构中的位置,如果MLP输出为tuple/list,则hidden_pos需要是一个int,否则需要是一个str +>) +>``` + +一个使用Side的例子如下: + +```python +from modelscope import Model + +from swift import (SideConfig, Swift) + +model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') +side_config = SideConfig( + dim=model.config.hidden_size, + target_modules=r'.*encoder.encoder', + side_module_name='mlp', + hidden_pos='last_hidden_state') +model = Swift.prepare_model(model, side_config) +# use model to do other things +``` diff --git a/docs/Modules/6.prompt.md b/docs/Modules/6.prompt.md new file mode 100644 index 0000000000..a9578911d5 --- /dev/null +++ b/docs/Modules/6.prompt.md @@ -0,0 +1,34 @@ +# Prompt + +Prompt是[Visual Prompt Tuning](https://arxiv.org/abs/2106.09685) 论文提供的轻量级训练组件。Prompt可以添加到每个layer的输入上,为hidden_state添加prompt embedding。 + +>```python +>PromptConfig ( +> dim: int layer输入参数中hidden_state的维度 +> target_modules: Union[str, List[str]]:可以是需要嵌入prompt的layer的正则表达式(字符串类型),如果是List,则匹配这些layers名称的末尾 +> embedding_pos: Union[str, int] layer输入参数中hidden_state的位置,如果是tuple/list则是int类型,如果是dict则是str类型 +> attention_mask_pos: Union[str, int] layer输入参数中attention_mask的位置,如果是tuple/list则是int类型,如果是dict则是str类型 +> attention_mask_value: Union[float, int, bool] prompt部分的attention值,默认为0.0 +> prompt_length: int prompt的长度 +> attach_front: bool prompt和hidden_state组合的方式,True代表将prompt concat到hidden_state的前面,反之则concat到后面 +> extract_embedding: bool 是否在最后的layer结束后将hidden_state中的prompt部分移除 +>) +>``` + +一个使用Prompt的例子如下: + +```python +from modelscope import Model + +from swift import (PromptConfig, Swift) + +model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') +prompt_config = PromptConfig( + dim=model.config.hidden_size, + target_modules=r'.*layer\.\d+$', + embedding_pos=0, + attention_mask_pos=1) +model = Swift.prepare_model(model, config=prompt_config) +# use model to do other things +``` diff --git a/docs/Modules/7.peft.md b/docs/Modules/7.peft.md new file mode 100644 index 0000000000..aadfa08023 --- /dev/null +++ b/docs/Modules/7.peft.md @@ -0,0 +1,38 @@ +# 对Peft的兼容性 + +为了支持习惯Peft的用户,Swift提供了对于Peft的兼容性。用户可以从swift中import peft组件: + +>PeftModel +>PeftConfig +>PeftModelForSeq2SeqLM +>PeftModelForSequenceClassification +>PeftModelForTokenClassification +>PeftModelForCausalLM +>PromptEncoderConfig +>PromptTuningConfig +>PrefixTuningConfig +>PromptLearningConfig +>LoraConfig +>get_peft_config +>get_peft_model_state_dict +>get_peft_model + +以上组件均可以从swift中import: + +```python +from swift import PeftModel, PeftConfig +``` + +Swift类也支持初始化Peft的tuner: + +```python +from modelscope.models.nlp import SbertForSequenceClassification +from modelscope.models.nlp.structbert import SbertConfig + +from swift import LoraConfig, Swift +model = SbertForSequenceClassification(SbertConfig()) +lora_config = LoraConfig(target_modules=['query', 'key', 'value']) +model = Swift.prepare_model(model, lora_config) +``` + +Swift对Peft进行了浅封装,使Peft可以在from_pretrained时使用modelscope hub中的模型。 diff --git a/examples/pytorch/llm/scripts/baichuan2_7b_chat/lora_ddp/infer.sh b/examples/pytorch/llm/scripts/baichuan2_7b_chat/lora_ddp/infer.sh index ce54c3ffaa..6988d4a37d 100644 --- a/examples/pytorch/llm/scripts/baichuan2_7b_chat/lora_ddp/infer.sh +++ b/examples/pytorch/llm/scripts/baichuan2_7b_chat/lora_ddp/infer.sh @@ -8,7 +8,7 @@ python src/llm_infer.py \ --eval_human false \ --dataset damo-agent-mini-zh \ --max_length 4096 \ - --max_new_tokens 1024 \ + --max_new_tokens 2048 \ --temperature 0.9 \ --top_k 50 \ --top_p 0.9 \ diff --git a/examples/pytorch/llm/scripts/chatglm2_6b/lora_ddp/sft.sh b/examples/pytorch/llm/scripts/chatglm2_6b/lora_ddp/sft.sh index 06ae8c240a..f0eee33da8 100644 --- a/examples/pytorch/llm/scripts/chatglm2_6b/lora_ddp/sft.sh +++ b/examples/pytorch/llm/scripts/chatglm2_6b/lora_ddp/sft.sh @@ -1,5 +1,5 @@ -# Experimental environment: A100 -# 50GB GPU memory +# Experimental environment: 2 * A100 +# 2 * 50GB GPU memory nproc_per_node=2 CUDA_VISIBLE_DEVICES=0,1 \ torchrun \ diff --git a/examples/pytorch/llm/scripts/llama2_70b_chat/qlora/infer.sh b/examples/pytorch/llm/scripts/llama2_70b_chat/qlora/infer.sh index b47ece0d8c..46ad5c849f 100644 --- a/examples/pytorch/llm/scripts/llama2_70b_chat/qlora/infer.sh +++ b/examples/pytorch/llm/scripts/llama2_70b_chat/qlora/infer.sh @@ -1,6 +1,6 @@ CUDA_VISIBLE_DEVICES=0,1 \ python src/llm_infer.py \ - --model_type llama2-7b-chat \ + --model_type llama2-70b-chat \ --sft_type lora \ --ckpt_dir "runs/llama2-70b-chat/vx_xxx/checkpoint-xxx" \ --eval_human true \ diff --git a/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp/infer.sh b/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp/infer.sh index 5d280cf86f..d28d312972 100644 --- a/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp/infer.sh +++ b/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp/infer.sh @@ -7,9 +7,9 @@ python src/llm_infer.py \ --ckpt_dir "runs/qwen-7b-chat/vx_xxx/checkpoint-xxx" \ --eval_human false \ --dataset damo-agent-zh \ - --max_length 8192 \ + --max_length 6144 \ --use_flash_attn true \ - --max_new_tokens 1024 \ + --max_new_tokens 2048 \ --temperature 0.9 \ --top_k 50 \ --top_p 0.9 \ diff --git a/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp_ddp/infer.sh b/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp_ddp/infer.sh index d02ca2471f..41ba146157 100644 --- a/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp_ddp/infer.sh +++ b/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp_ddp/infer.sh @@ -7,9 +7,9 @@ python src/llm_infer.py \ --ckpt_dir "runs/qwen-7b-chat/vx_xxx/checkpoint-xxx" \ --eval_human false \ --dataset medical-en,medical-zh \ - --max_length 8192 \ + --max_length 6144 \ --use_flash_attn true \ - --max_new_tokens 1024 \ + --max_new_tokens 2048 \ --temperature 0.9 \ --top_k 50 \ --top_p 0.9 \ diff --git a/examples/pytorch/llm/src/llm_infer.py b/examples/pytorch/llm/src/llm_infer.py index 7f852f0056..bdd8b3152c 100644 --- a/examples/pytorch/llm/src/llm_infer.py +++ b/examples/pytorch/llm/src/llm_infer.py @@ -42,8 +42,7 @@ class InferArguments: system: str = 'you are a helpful assistant!' max_length: Optional[int] = 2048 - quantization_bit: Optional[int] = field( - default=None, metadata={'choices': {4, 8}}) + quantization_bit: int = field(default=0, metadata={'choices': {0, 4, 8}}) bnb_4bit_comp_dtype: str = field( default=None, metadata={'choices': {'fp16', 'bf16', 'fp32'}}) bnb_4bit_quant_type: str = field( @@ -110,7 +109,8 @@ def llm_infer(args: InferArguments) -> None: # ### Preparing lora if args.sft_type == 'lora': - model = Swift.from_pretrained(model, args.ckpt_dir) + model = Swift.from_pretrained( + model, args.ckpt_dir, inference_mode=True) show_layers(model) print_model_info(model) diff --git a/examples/pytorch/llm/src/llm_sft.py b/examples/pytorch/llm/src/llm_sft.py index 2214f1d6ff..5d484e423b 100644 --- a/examples/pytorch/llm/src/llm_sft.py +++ b/examples/pytorch/llm/src/llm_sft.py @@ -10,16 +10,17 @@ import numpy as np import torch import torch.distributed as dist -from transformers import BitsAndBytesConfig +from transformers import BitsAndBytesConfig, GenerationConfig from utils import (DATASET_MAPPING, MODEL_MAPPING, TEMPLATE_MAPPING, - broadcast_string, check_json_format, dataset_map, - find_all_linear_for_lora, get_dataset, get_dist_setting, - get_model_tokenizer, get_preprocess, is_ddp_plus_mp, - is_dist, is_master, plot_images, select_bnb, select_dtype, - show_layers, sort_by_max_length) - -from swift import (HubStrategy, LoraConfig, Seq2SeqTrainer, - Seq2SeqTrainingArguments, Swift, get_logger) + broadcast_string, check_json_format, compute_nlg_metrics, + dataset_map, find_all_linear_for_lora, get_dataset, + get_dist_setting, get_model_tokenizer, get_preprocess, + is_ddp_plus_mp, is_dist, is_master, plot_images, + prepare_model, select_bnb, select_dtype, show_layers, + sort_by_max_length) + +from swift import (HubStrategy, Seq2SeqTrainer, Seq2SeqTrainingArguments, + Swift, get_logger) from swift.hub import HubApi, ModelScopeConfig from swift.utils import (add_version_to_work_dir, parse_args, print_model_info, seed_everything) @@ -73,6 +74,7 @@ class SftArguments: gradient_checkpointing: bool = False batch_size: int = 1 + eval_batch_size: Optional[int] = None num_train_epochs: int = 1 # if max_steps >= 0, override num_train_epochs max_steps: int = -1 @@ -81,6 +83,7 @@ class SftArguments: weight_decay: float = 0.01 gradient_accumulation_steps: int = 16 max_grad_norm: float = 1. + predict_with_generate: bool = False lr_scheduler_type: str = 'cosine' warmup_ratio: float = 0.05 @@ -119,6 +122,13 @@ class SftArguments: "This parameter is used only when model_type.startswith('qwen')" }) + # generation config, only useful when `predict_with_generate=True` + max_new_tokens: int = 1024 + do_sample: bool = True + temperature: float = 0.9 + top_k: int = 50 + top_p: float = 0.9 + def __post_init__(self): if is_dist(): rank, local_rank, _, _ = get_dist_setting() @@ -181,6 +191,11 @@ def __post_init__(self): if self.use_flash_attn is None: self.use_flash_attn = 'auto' self.train_sampler_random = not self.test_oom_error + if self.eval_batch_size is None: + if self.predict_with_generate: + self.eval_batch_size = 1 + else: + self.eval_batch_size = self.batch_size def llm_sft(args: SftArguments) -> None: @@ -211,37 +226,30 @@ def llm_sft(args: SftArguments) -> None: model, tokenizer = get_model_tokenizer( args.model_type, torch_dtype=args.torch_dtype, **kwargs) - # ### Preparing lora - if args.sft_type == 'lora': - if 'ALL' in args.lora_target_modules: - assert len(args.lora_target_modules) == 1 - args.lora_target_modules = find_all_linear_for_lora( - model, args.quantization_bit, args.model_type) - logger.info( - f'Setting lora_target_modules: {args.lora_target_modules}') - if args.resume_from_ckpt is None: - lora_config = LoraConfig( - r=args.lora_rank, - target_modules=args.lora_target_modules, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout_p, - task_type='CAUSAL_LM') - logger.info(f'lora_config: {lora_config}') - model = Swift.prepare_model(model, lora_config) - else: - model = Swift.from_pretrained( - model, args.resume_from_ckpt, is_trainable=True) + if args.resume_from_ckpt is None: + if args.sft_type != 'full': + # lora + model = prepare_model(model, args) + else: + model = Swift.from_pretrained( + model, args.resume_from_ckpt, is_trainable=True) show_layers(model) print_model_info(model) logger.info(model) # ### Loading Dataset + generation_config = GenerationConfig( + do_sample=args.do_sample, + max_length=None, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + ) train_dataset, val_dataset = get_dataset( args.dataset.split(','), args.dataset_test_ratio, args.dataset_split_seed) - preprocess_func = get_preprocess(args.template_type, tokenizer, - args.system, args.max_length) if args.train_dataset_sample >= 0: val_dataset_sample = int(args.train_dataset_sample * args.dataset_test_ratio) @@ -252,8 +260,20 @@ def llm_sft(args: SftArguments) -> None: val_dataset = val_dataset.select(val_idxs) logger.info(f'train_dataset: {train_dataset}') logger.info(f'val_dataset: {val_dataset}') - train_dataset = dataset_map(train_dataset, preprocess_func) - val_dataset = dataset_map(val_dataset, preprocess_func) + preprocess_func_train = get_preprocess( + args.template_type, + tokenizer, + args.system, + args.max_length, + validate_generation=False) + preprocess_func_eval = get_preprocess( + args.template_type, + tokenizer, + args.system, + args.max_length, + validate_generation=args.predict_with_generate) + train_dataset = dataset_map(train_dataset, preprocess_func_train) + val_dataset = dataset_map(val_dataset, preprocess_func_eval) if args.test_oom_error: train_dataset = sort_by_max_length(train_dataset, 20000) # Data analysis @@ -287,7 +307,7 @@ def llm_sft(args: SftArguments) -> None: do_eval=True, evaluation_strategy='steps', per_device_train_batch_size=args.batch_size, - per_device_eval_batch_size=args.batch_size, + per_device_eval_batch_size=args.eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, @@ -305,8 +325,9 @@ def llm_sft(args: SftArguments) -> None: eval_steps=args.eval_steps, dataloader_num_workers=args.dataloader_num_workers, load_best_model_at_end=True, - metric_for_best_model='loss', - greater_is_better=False, + metric_for_best_model='rouge-l' + if args.predict_with_generate else 'loss', + greater_is_better=args.predict_with_generate, sortish_sampler=True, optim=args.optim, hub_model_id=args.hub_model_id, @@ -317,11 +338,12 @@ def llm_sft(args: SftArguments) -> None: resume_from_checkpoint=args.resume_from_ckpt, ddp_backend=args.ddp_backend, gradient_checkpointing=args.gradient_checkpointing, + predict_with_generate=args.predict_with_generate, + generation_config=generation_config, local_rank=local_rank, **kwargs) if args.gradient_checkpointing: - model.config.use_cache = False model.enable_input_require_grads() if is_dist(): # Compatible with https://github.com/huggingface/transformers/pull/25903 @@ -342,6 +364,8 @@ def llm_sft(args: SftArguments) -> None: train_dataset=train_dataset, eval_dataset=val_dataset, tokenizer=tokenizer, + compute_metrics=partial(compute_nlg_metrics, tokenizer=tokenizer) + if args.predict_with_generate else None, ) if is_master(): for args_obj, fname in zip([args, training_args], @@ -354,6 +378,7 @@ def llm_sft(args: SftArguments) -> None: ensure_ascii=False, indent=2) trainer.train(training_args.resume_from_checkpoint) + logger.info(trainer.perf) # ### Visualization if is_master(): diff --git a/examples/pytorch/llm/src/utils/__init__.py b/examples/pytorch/llm/src/utils/__init__.py index 10ace8ba3c..341293902d 100644 --- a/examples/pytorch/llm/src/utils/__init__.py +++ b/examples/pytorch/llm/src/utils/__init__.py @@ -1,6 +1,8 @@ from .dataset import DATASET_MAPPING, get_dataset +from .metric_utils import compute_nlg_metrics from .model import MODEL_MAPPING, get_model_tokenizer from .preprocess import TEMPLATE_MAPPING, get_preprocess +from .swift_utils import prepare_model from .utils import (broadcast_string, check_json_format, dataset_map, download_dataset, find_all_linear_for_lora, get_dist_setting, inference, is_ddp_plus_mp, is_dist, diff --git a/examples/pytorch/llm/src/utils/metric_utils.py b/examples/pytorch/llm/src/utils/metric_utils.py new file mode 100644 index 0000000000..2e8df7d53d --- /dev/null +++ b/examples/pytorch/llm/src/utils/metric_utils.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numpy as np + +from swift import get_logger + +logger = get_logger() + + +def compute_nlg_metrics(prediction, tokenizer): + import jieba + from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + from rouge.rouge import Rouge + preds, labels = prediction[0], prediction[1] + + score_dict = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []} + + def _decode(tokens, ignore_pad_token_for_loss=False): + if ignore_pad_token_for_loss: + tokens = np.where(tokens != -100, tokens, tokenizer.pad_token_id) + tokens = np.where(tokens < tokenizer.vocab_size, tokens, + tokenizer.pad_token_id) + return [ + t + for t in tokenizer.batch_decode(tokens, skip_special_tokens=True) + ] + + for pred, label in zip(preds, labels): + pred = ''.join(_decode(pred, False)) + label = ''.join(_decode(label, True)) + hypothesis = list(jieba.cut(pred)) + if len(hypothesis) == 0 or ''.join(hypothesis) == '.': + hypothesis = [tokenizer.decode(tokenizer.eos_token_id)] + reference = list(jieba.cut(label)) + try: + rouge = Rouge() + scores = rouge.get_scores(' '.join(hypothesis), + ' '.join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v['f'] * 100, 4)) + bleu_score = sentence_bleu( + [list(label)], + list(pred), + smoothing_function=SmoothingFunction().method3) + score_dict['bleu-4'].append(round(bleu_score * 100, 4)) + except Exception as e: + logger.error(e) + logger.error(f'eval error {hypothesis}, {reference}') + + for k, v in score_dict.items(): + score_dict[k] = float(np.mean(v)) + return score_dict diff --git a/examples/pytorch/llm/src/utils/model.py b/examples/pytorch/llm/src/utils/model.py index 3c02aedc5c..0a8cbcbb8c 100644 --- a/examples/pytorch/llm/src/utils/model.py +++ b/examples/pytorch/llm/src/utils/model.py @@ -175,6 +175,54 @@ class LoRATM(NamedTuple): internlm = ['q_proj', 'k_proj', 'v_proj'] +class AdapterTM(NamedTuple): + # default adapter target modules. + baichuan = ['mlp'] + chatglm2 = ['mlp'] + llama2 = ['mlp'] + qwen = ['mlp'] + polylm = ['mlp'] + + +class ResTunerTM(NamedTuple): + # default res-tuning config. + baichuan = { + 'root_modules': r'.*layers.0$', + 'stem_modules': r'.*layers\.\d+$', + 'target_modules': r'.*model.norm', + 'target_modules_hook': 'input', + 'tuner_cfg': 'res_adapter', + } + chatglm2 = { + 'root_modules': r'.*layers.0$', + 'stem_modules': r'.*layers\.\d+$', + 'target_modules': r'.*final_layernorm', + 'target_modules_hook': 'input', + 'tuner_cfg': 'res_adapter', + } + llama2 = { + 'root_modules': r'.*layers.0$', + 'stem_modules': r'.*layers\.\d+$', + 'target_modules': r'.*model.norm', + 'target_modules_hook': 'input', + 'tuner_cfg': 'res_adapter', + } + qwen = { + 'root_modules': r'.*transformer.h.0$', + 'stem_modules': r'.*transformer.h\.\d+$', + 'target_modules': r'.*transformer.ln_f', + 'target_modules_hook': 'input', + 'tuner_cfg': 'res_adapter', + } + polylm = { + 'root_modules': r'.*transformer.h.0$', + 'stem_modules': r'.*transformer.h\.\d+$', + 'target_modules': r'.*transformer.ln_f', + 'target_modules_hook': 'input', + 'tuner_cfg': 'res_adapter', + } + + # Model Home: 'https://modelscope.cn/models/{model_id}/summary' # keys: 'model_id', 'revision', 'get_function', 'template', # 'ignore_file_pattern', 'lora_TM' @@ -184,6 +232,8 @@ class LoRATM(NamedTuple): 'revision': 'v1.0.0', 'get_function': get_model_tokenizer_qwen, 'lora_TM': LoRATM.qwen, + 'adapter_TM': AdapterTM.qwen, + 'restuner_TM': ResTunerTM.qwen, }, 'qwen-7b-chat': { 'model_id': 'ccyh123/Qwen-7B-Chat', @@ -191,12 +241,16 @@ class LoRATM(NamedTuple): 'get_function': get_model_tokenizer_qwen, 'template': 'chatml', 'lora_TM': LoRATM.qwen, + 'adapter_TM': AdapterTM.qwen, + 'restuner_TM': ResTunerTM.qwen, }, 'qwen-vl': { 'model_id': 'ccyh123/Qwen-VL', 'revision': 'v1.0.0', 'get_function': get_model_tokenizer_qwen_vl, 'lora_TM': LoRATM.qwen, + 'adapter_TM': AdapterTM.qwen, + 'restuner_TM': ResTunerTM.qwen, }, 'qwen-vl-chat': { 'model_id': 'ccyh123/Qwen-VL-Chat', @@ -204,23 +258,31 @@ class LoRATM(NamedTuple): 'get_function': get_model_tokenizer_qwen_vl, 'template': 'chatml', 'lora_TM': LoRATM.qwen, + 'adapter_TM': AdapterTM.qwen, + 'restuner_TM': ResTunerTM.qwen, }, 'baichuan-7b': { 'model_id': 'baichuan-inc/baichuan-7B', 'revision': 'v1.0.7', 'lora_TM': LoRATM.baichuan, + 'adapter_TM': AdapterTM.baichuan, + 'restuner_TM': ResTunerTM.baichuan, }, 'baichuan-13b': { 'model_id': 'baichuan-inc/Baichuan-13B-Base', 'revision': 'v1.0.5', 'get_function': get_model_tokenizer_baichuan13b, 'lora_TM': LoRATM.baichuan, + 'adapter_TM': AdapterTM.baichuan, + 'restuner_TM': ResTunerTM.baichuan, }, 'baichuan-13b-chat': { 'model_id': 'baichuan-inc/Baichuan-13B-Chat', 'revision': 'v1.0.8', 'template': 'baichuan', 'lora_TM': LoRATM.baichuan, + 'adapter_TM': AdapterTM.baichuan, + 'restuner_TM': ResTunerTM.baichuan, }, 'chatglm2-6b': { 'model_id': 'ZhipuAI/chatglm2-6b', @@ -228,18 +290,24 @@ class LoRATM(NamedTuple): 'get_function': get_model_tokenizer_chatglm2, 'template': 'chatglm2', 'lora_TM': LoRATM.chatglm2, + 'adapter_TM': AdapterTM.chatglm2, + 'restuner_TM': ResTunerTM.chatglm2, }, 'chatglm2-6b-32k': { 'model_id': 'ZhipuAI/chatglm2-6b-32k', 'revision': 'v1.0.1', 'template': 'chatglm2', 'lora_TM': LoRATM.chatglm2, + 'adapter_TM': AdapterTM.chatglm2, + 'restuner_TM': ResTunerTM.chatglm2, }, 'llama2-7b': { 'model_id': 'modelscope/Llama-2-7b-ms', 'revision': 'v1.0.2', 'ignore_file_pattern': [r'.+\.bin$'], # use safetensors 'lora_TM': LoRATM.llama2, + 'adapter_TM': AdapterTM.llama2, + 'restuner_TM': ResTunerTM.llama2, }, 'llama2-13b': { 'model_id': 'modelscope/Llama-2-13b-ms', @@ -247,12 +315,16 @@ class LoRATM(NamedTuple): 'get_function': get_model_tokenizer_llama2, 'ignore_file_pattern': [r'.+\.bin$'], 'lora_TM': LoRATM.llama2, + 'adapter_TM': AdapterTM.llama2, + 'restuner_TM': ResTunerTM.llama2, }, 'llama2-70b': { 'model_id': 'modelscope/Llama-2-70b-ms', 'revision': 'v1.0.0', 'ignore_file_pattern': [r'.+\.bin$'], 'lora_TM': LoRATM.llama2, + 'adapter_TM': AdapterTM.llama2, + 'restuner_TM': ResTunerTM.llama2, }, 'llama2-7b-chat': { 'model_id': 'modelscope/Llama-2-7b-chat-ms', @@ -260,6 +332,8 @@ class LoRATM(NamedTuple): 'template': 'llama', 'ignore_file_pattern': [r'.+\.bin$'], # use safetensors 'lora_TM': LoRATM.llama2, + 'adapter_TM': AdapterTM.llama2, + 'restuner_TM': ResTunerTM.llama2, }, 'llama2-13b-chat': { 'model_id': 'modelscope/Llama-2-13b-chat-ms', @@ -268,6 +342,8 @@ class LoRATM(NamedTuple): 'template': 'llama', 'ignore_file_pattern': [r'.+\.bin$'], 'lora_TM': LoRATM.llama2, + 'adapter_TM': AdapterTM.llama2, + 'restuner_TM': ResTunerTM.llama2, }, 'llama2-70b-chat': { 'model_id': 'modelscope/Llama-2-70b-chat-ms', @@ -276,12 +352,16 @@ class LoRATM(NamedTuple): 'template': 'llama', 'ignore_file_pattern': [r'.+\.bin$'], 'lora_TM': LoRATM.llama2, + 'adapter_TM': AdapterTM.llama2, + 'restuner_TM': ResTunerTM.llama2, }, 'openbuddy-llama2-13b': { 'model_id': 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16', 'revision': 'v1.0.0', 'template': 'openbuddy-llama', 'lora_TM': LoRATM.llama2, + 'adapter_TM': AdapterTM.llama2, + 'restuner_TM': ResTunerTM.llama2, }, 'openbuddy-llama-65b': { 'model_id': 'OpenBuddy/openbuddy-llama-65b-v8-bf16', @@ -294,12 +374,16 @@ class LoRATM(NamedTuple): 'revision': 'v1.0.0', 'template': 'openbuddy-llama', 'lora_TM': LoRATM.llama2, + 'adapter_TM': AdapterTM.llama2, + 'restuner_TM': ResTunerTM.llama2, }, 'polylm-13b': { 'model_id': 'damo/nlp_polylm_13b_text_generation', 'revision': 'v1.0.3', 'get_function': get_model_tokenizer_polylm, 'lora_TM': LoRATM.polylm, + 'adapter_TM': AdapterTM.polylm, + 'restuner_TM': ResTunerTM.polylm, }, 'baichuan2-7b': { 'model_id': 'baichuan-inc/Baichuan2-7B-Base', diff --git a/examples/pytorch/llm/src/utils/preprocess.py b/examples/pytorch/llm/src/utils/preprocess.py index f3a08a7ab5..2115c83aa4 100644 --- a/examples/pytorch/llm/src/utils/preprocess.py +++ b/examples/pytorch/llm/src/utils/preprocess.py @@ -37,6 +37,11 @@ 'chat_sep': ['\n\n'], 'suffix': [['eos_token_id']], }, + 'chatglm2-generation': { + 'prefix': [[64790, 64792]], + 'prompt': ['{{query}}'], + 'suffix': [['eos_token_id']], + }, 'llama': { 'prefix': [['bos_token_id'], '[INST] <>\n{{SYSTEM}}\n<>\n\n'], @@ -119,13 +124,15 @@ def _encode(tokenizer: PreTrainedTokenizer, context_list: List[Context], def _preprocess( - template_type: str, - tokenizer: PreTrainedTokenizer, - query: str, - response: Optional[str] = None, - history: Optional[History] = None, - system: Optional[str] = None, - max_length: Optional[int] = None, + template_type: str, + tokenizer: PreTrainedTokenizer, + query: str, + response: Optional[str] = None, + history: Optional[History] = None, + system: Optional[str] = None, + max_length: Optional[int] = None, + validate_generation: Optional[ + bool] = True, # do cross-validation with `model.generate()` ) -> Dict[str, List[int]]: if history is None: history = [] @@ -159,11 +166,15 @@ def _preprocess( labels = None if response is not None: - labels = [-100] * len(input_ids) tgt_input_ids = _encode(tokenizer, [response], []) tgt_input_ids += _encode(tokenizer, template_config['suffix'], []) - input_ids += tgt_input_ids - labels += tgt_input_ids + if not validate_generation: + # train, or validate with `loss` + labels = [-100] * len(input_ids) + tgt_input_ids + input_ids += tgt_input_ids + else: + # validate with `model.generate()` + labels = tgt_input_ids if max_length is not None: input_ids = input_ids[-max_length:] @@ -178,6 +189,7 @@ def get_preprocess( tokenizer: PreTrainedTokenizer, system: Optional[str] = None, max_length: Optional[int] = None, + validate_generation: Optional[bool] = False, ) -> Callable[[Dict[str, Any]], Dict[str, List[int]]]: def preprocess(example: Dict[str, Any]) -> Dict[str, List[int]]: @@ -186,6 +198,6 @@ def preprocess(example: Dict[str, Any]) -> Dict[str, List[int]]: response: str = example.get('response', None) custom_system = example.get('system', system) return _preprocess(template_type, tokenizer, query, response, history, - custom_system, max_length) + custom_system, max_length, validate_generation) return preprocess diff --git a/examples/pytorch/llm/src/utils/swift_utils.py b/examples/pytorch/llm/src/utils/swift_utils.py new file mode 100644 index 0000000000..ee8ef3b489 --- /dev/null +++ b/examples/pytorch/llm/src/utils/swift_utils.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from torch.nn import Module + +from swift import (AdapterConfig, LoRAConfig, ResTuningConfig, Swift, + SwiftConfig, SwiftTuners, get_logger) +from .model import MODEL_MAPPING +from .utils import find_all_linear_for_lora + +logger = get_logger() + + +def prepare_model(model: Module, args) -> Module: + swift_config: Dict[str, SwiftConfig] = dict() + for sft_type in [_type.strip() for _type in args.sft_type.split(',')]: + if sft_type.lower() == SwiftTuners.LORA.lower(): + if 'ALL' in args.lora_target_modules: + assert len(args.lora_target_modules) == 1 + args.lora_target_modules = find_all_linear_for_lora( + model, args.quantization_bit, args.model_type) + logger.info( + f'Setting lora_target_modules: {args.lora_target_modules}') + + lora_config = LoRAConfig( + r=args.lora_rank, + target_modules=args.lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout_p) + logger.debug(f'lora_config: {lora_config}') + swift_config['lora'] = lora_config + elif sft_type.lower() == SwiftTuners.ADAPTER.lower(): + adapter_config = AdapterConfig( + dim=model.config.hidden_size, + target_modules=MODEL_MAPPING[args.model_type].get( + 'adapter_TM', ['mlp']), + method_name='forward', + hidden_pos=0, + adapter_length=args.adapter_length, + ) + logger.debug(f'adapter_config: {adapter_config}') + swift_config['adapter'] = adapter_config + elif sft_type.lower() == SwiftTuners.RESTUNING.lower(): + restuner_config = ResTuningConfig( + dims=model.config.hidden_size, + **MODEL_MAPPING[args.model_type]['restuner_TM']) + logger.debug(f'restuner_config: {restuner_config}') + swift_config['restuner'] = restuner_config + return Swift.prepare_model(model, swift_config) diff --git a/requirements/framework.txt b/requirements/framework.txt index 4247a138db..c4ecc554c0 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -3,7 +3,7 @@ datasets diffusers>=0.18.0 numpy pandas -peft +peft>=0.5.0 requests safetensors tensorboard diff --git a/swift/__init__.py b/swift/__init__.py index d4ab2b8c64..9049f2e70d 100644 --- a/swift/__init__.py +++ b/swift/__init__.py @@ -8,11 +8,12 @@ from .tuners import ( Adapter, AdapterConfig, AdapterModule, SwiftModel, LoRA, LoRAConfig, SWIFT_MAPPING, LoraConfig, PeftConfig, PeftModel, PeftModelForCausalLM, - PeftModelForSeq2SeqLM, PeftModelForSequenceClassification, - PeftModelForTokenClassification, PrefixTuningConfig, - PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, - get_peft_config, get_peft_model, get_peft_model_state_dict, Prompt, - PromptConfig, PromptModule, SwiftConfig, SwiftOutput, Swift) + ResTuningConfig, SideConfig, PeftModelForSeq2SeqLM, + PeftModelForSequenceClassification, PeftModelForTokenClassification, + PrefixTuningConfig, PromptEncoderConfig, PromptLearningConfig, + PromptTuningConfig, get_peft_config, get_peft_model, + get_peft_model_state_dict, Prompt, PromptConfig, PromptModule, + SwiftConfig, SwiftOutput, Swift, SwiftTuners) from .hub import snapshot_download, push_to_hub, push_to_hub_async, push_to_hub_in_queue from .trainers import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy, SchedulerType, @@ -29,13 +30,15 @@ 'tuners': [ 'Adapter', 'AdapterConfig', 'AdapterModule', 'SwiftModel', 'LoRA', 'LoRAConfig', 'SWIFT_MAPPING', 'LoraConfig', 'PeftConfig', - 'PeftModel', 'PeftModelForCausalLM', 'PeftModelForSeq2SeqLM', + 'ResTuningConfig', 'SideConfig', 'PeftModel', + 'PeftModelForCausalLM', 'PeftModelForSeq2SeqLM', 'PeftModelForSequenceClassification', 'PeftModelForTokenClassification', 'PrefixTuningConfig', 'PromptEncoderConfig', 'PromptLearningConfig', 'PromptTuningConfig', 'get_peft_config', 'get_peft_model', 'get_peft_model_state_dict', 'Prompt', 'PromptConfig', - 'PromptModule', 'SwiftConfig', 'SwiftOutput', 'Swift' + 'PromptModule', 'SwiftConfig', 'SwiftOutput', 'Swift', + 'SwiftTuners' ], 'trainers': [ 'EvaluationStrategy', 'FSDPOption', 'HPSearchBackend', diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 15e1445aa9..c51eae8841 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -1,5 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import time +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn from transformers import Seq2SeqTrainer as HfSeq2SeqTrainer from transformers import Trainer as HfTrainer from transformers import trainer @@ -7,13 +12,143 @@ from .callback import DefaultFlowCallbackNew, ProgressCallbackNew from .mixin import PushToMsHubMixin, SwiftMixin +try: + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +except ImportError: + from transformers.deepspeed import is_deepspeed_zero3_enabled + class Trainer(PushToMsHubMixin, SwiftMixin, HfTrainer): pass class Seq2SeqTrainer(PushToMsHubMixin, SwiftMixin, HfSeq2SeqTrainer): - pass + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # performance + self.perf: Dict[str, Any] = { + 'gen_time': + 0., + 'gen_len': + 0, + 'memory': {}, + 'train_time': + 0., + 'model': + self.model.get_trainable_parameters() if hasattr( + self.model, 'get_trainable_parameters') else None, + } + self._iter_perf = 0 + + def training_step(self, *args, **kwargs) -> torch.Tensor: + train_time = time.time() + training_output = super().training_step(*args, **kwargs) + train_time = time.time() - train_time + self.perf['train_time'] = self.perf['train_time'] + train_time + self._iter_perf += 1 + if self._iter_perf > 20 and not self.perf[ + 'memory'] and torch.cuda.device_count() > 0: + for i in range(torch.cuda.device_count()): + self.perf['memory'][ + f'device:{i}'] = torch.cuda.memory_reserved(i) + return training_output + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + **gen_kwargs, + ) -> Tuple[Optional[float], Optional[torch.Tensor], + Optional[torch.Tensor]]: + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model, + inputs, + prediction_loss_only=prediction_loss_only, + ignore_keys=ignore_keys) + + has_labels = 'labels' in inputs + inputs = self._prepare_inputs(inputs) + + # XXX: adapt synced_gpus for fairscale as well + # Priority (handled in generate): + # gen_kwargs > model.generation_config > default GenerationConfig() + + if len(gen_kwargs) == 0 and hasattr(self, '_gen_kwargs'): + gen_kwargs = self._gen_kwargs.copy() + if hasattr(self.model, 'generation_config'): + gen_kwargs.update(self.model.generation_config.to_dict()) + + if gen_kwargs.get('max_length') is None and gen_kwargs.get( + 'max_new_tokens') is None: + gen_kwargs['max_length'] = self.model.config.max_length + gen_kwargs['num_beams'] = ( + gen_kwargs['num_beams'] if gen_kwargs.get('num_beams') is not None + else self.model.config.num_beams) + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + gen_kwargs['synced_gpus'] = ( + gen_kwargs['synced_gpus'] if gen_kwargs.get('synced_gpus') + is not None else default_synced_gpus) + + # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate + # (otherwise, it would continue generating from the padded `decoder_input_ids`) + if ('labels' in inputs and 'decoder_input_ids' in inputs and + inputs['labels'].shape == inputs['decoder_input_ids'].shape): + inputs = { + k: v + for k, v in inputs.items() if k != 'decoder_input_ids' + } + + gen_kwargs['pad_token_id'] = self.tokenizer.pad_token_id + gen_kwargs['eos_token_id'] = self.tokenizer.eos_token_id + gen_time = time.time() + generated_tokens = self.model.generate(**inputs, **gen_kwargs) + gen_time = time.time() - gen_time + + if hasattr( + self.model, 'encoder' + ) and self.model.encoder.main_input_name != self.model.main_input_name: + generation_inputs = inputs[self.model.encoder.main_input_name] + else: + generation_inputs = inputs[self.model.main_input_name] + + generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] + gen_len = len(generated_tokens[0]) + self.perf['gen_time'] = self.perf['gen_time'] + gen_time + self.perf['gen_len'] = self.perf['gen_len'] + gen_len + + # in case the batch is shorter than max length, the output should be padded + if gen_kwargs.get('max_length') is not None and generated_tokens.shape[ + -1] < gen_kwargs['max_length']: + generated_tokens = self._pad_tensors_to_max_len( + generated_tokens, gen_kwargs['max_length']) + elif gen_kwargs.get('max_new_tokens' + ) is not None and generated_tokens.shape[-1] < ( + gen_kwargs['max_new_tokens'] + 1): + generated_tokens = self._pad_tensors_to_max_len( + generated_tokens, gen_kwargs['max_new_tokens'] + 1) + + if self.args.prediction_loss_only: + return None, None, None + + if has_labels: + labels = inputs['labels'] + if gen_kwargs.get('max_length') is not None and labels.shape[ + -1] < gen_kwargs['max_length']: + labels = self._pad_tensors_to_max_len(labels, + gen_kwargs['max_length']) + elif gen_kwargs.get( + 'max_new_tokens') is not None and labels.shape[-1] < ( + gen_kwargs['max_new_tokens'] + 1): + labels = self._pad_tensors_to_max_len( + labels, (gen_kwargs['max_new_tokens'] + 1)) + else: + labels = None + + return None, generated_tokens, labels # monkey patching diff --git a/swift/tuners/__init__.py b/swift/tuners/__init__.py index bed8803d70..1ecc496850 100644 --- a/swift/tuners/__init__.py +++ b/swift/tuners/__init__.py @@ -7,7 +7,9 @@ from .adapter import Adapter, AdapterConfig, AdapterModule from .base import SwiftModel, Swift from .lora import LoRA, LoRAConfig - from .mapping import SWIFT_MAPPING + from .mapping import SWIFT_MAPPING, SwiftTuners + from .side import Side, SideConfig, SideModule + from .restuning import ResTuning, ResTuningConfig, ResTuningBypassModule from .peft import (LoraConfig, PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM, PeftModelForSequenceClassification, @@ -22,7 +24,9 @@ 'adapter': ['Adapter', 'AdapterConfig', 'AdapterModule'], 'base': ['SwiftModel', 'Swift'], 'lora': ['LoRA', 'LoRAConfig'], - 'mapping': ['SWIFT_MAPPING'], + 'mapping': ['SWIFT_MAPPING', 'SwiftTuners'], + 'side': ['Side', 'SideConfig', 'SideModule'], + 'restuning': ['ResTuning', 'ResTuningConfig', 'ResTuningBypassModule'], 'peft': [ 'LoraConfig', 'PeftConfig', 'PeftModel', 'PeftModelForCausalLM', 'PeftModelForSeq2SeqLM', 'PeftModelForSequenceClassification', diff --git a/swift/tuners/adapter.py b/swift/tuners/adapter.py index 19233e60eb..98f829525a 100644 --- a/swift/tuners/adapter.py +++ b/swift/tuners/adapter.py @@ -3,13 +3,17 @@ import re import types from dataclasses import dataclass, field -from typing import Union +from typing import List, Union import torch from torch import nn from transformers.activations import ACT2CLS -from .utils import SwiftConfig, SwiftOutput +from swift import get_logger +from swift.utils.torch_utils import find_sub_module +from .utils import ActivationMixin, SwiftConfig, SwiftOutput + +logger = get_logger() @dataclass @@ -22,10 +26,12 @@ class AdapterConfig(SwiftConfig): See http://arxiv.org/abs/1902.00751 Args: - dim: The dimension of the hidden states - target_modules: The feedforward module to be replaced, in regex format - hidden_pos: The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs) - method_name: The method to be replaced, default to replace the forward method + dim(`int`): The dimension of the hidden states + target_modules(`Union[str, List[str]]`): The feedforward module to be replaced. + in regex format if this argument is str, else will match with `end with` if List[str]. + hidden_pos(`Union[str, int]`): The position of the hidden state to be passed into the adapter, + can be int (args) or str (kwargs) + method_name(`str`): The method to be replaced, default is `forward` adapter_length: The length of the adapter length (intermediate length) act_layer: The activation layer of the adapter """ @@ -33,25 +39,24 @@ class AdapterConfig(SwiftConfig): dim: int = field( default=None, metadata={'help': 'The dimension of the hidden states'}) - target_modules: str = field( + target_modules: Union[str, List[str]] = field( default=None, metadata={ - 'help': 'The feedforward module to be replaced, in regex format' + 'help': + 'The feedforward module to be replaced. in regex format if this argument is str, ' + 'else will match with `end with` if List[str].' }) hidden_pos: Union[str, int] = field( default=None, metadata={ 'help': - 'The position of the hidden state to passed into the adapter, can be int (args) or str (kwargs)' + 'The position of the hidden state to be passed into the adapter, can be int (args) or str (kwargs)' }) method_name: str = field( default='forward', - metadata={ - 'help': - 'The method to be replaced, default to replace the forward method' - }) + metadata={'help': 'The method to be replaced, default is `forward`'}) adapter_length: int = field( default=128, @@ -71,35 +76,54 @@ def __post_init__(self): class Adapter: @staticmethod - def prepare_model(model: nn.Module, config: AdapterConfig) -> SwiftOutput: + def prepare_model(model: nn.Module, config: AdapterConfig, + adapter_name: str) -> SwiftOutput: """Prepare a model with `AdapterConfig`""" module_keys = [key for key, _ in model.named_modules()] for module_key in module_keys: - if re.fullmatch(config.target_modules, module_key): # noqa + if isinstance(config.target_modules, str): + target_module_found = re.fullmatch(config.target_modules, + module_key) + else: + target_module_found = any( + module_key.endswith(target_key) + for target_key in config.target_modules) + + if target_module_found: # noqa module = model.get_submodule(module_key) def _forward(self, *args, **kwargs): - args = self.forward_origin(*args, **kwargs) + args = getattr(self, + f'forward_origin_{adapter_name}')(*args, + **kwargs) if isinstance(args, (tuple, list, dict)): if isinstance(config.hidden_pos, int): - return args[0:config.hidden_pos] + args[ - config.hidden_pos] + getattr(self, 'adapter')(args[config.hidden_pos]) \ - + args[config.hidden_pos + 1:] # noqa + _type = type(args) + args = list(args) + args[config.hidden_pos] = getattr( + self, f'adapter_{adapter_name}')( + args[config.hidden_pos]) + args = _type(args) else: - kwargs[config.hidden_pos] = args[ - config.hidden_pos] + getattr(self, 'adapter')( + args[config.hidden_pos] = getattr( + self, f'adapter_{adapter_name}')( args[config.hidden_pos]) elif isinstance(args, torch.Tensor): - args = getattr(self, 'adapter')(args) + args = getattr(self, f'adapter_{adapter_name}')(args) return args def _feed_forward_chunk(self, attention_output): return _forward(self, attention_output) - module.forward_origin = getattr(module, config.method_name) + # TODO The `config.method_name` method should not be replaced twice. + + setattr(module, f'forward_origin_{adapter_name}', + getattr(module, config.method_name)) num_args_in_forward_chunk_fn = len( - inspect.signature(module.forward_origin).parameters) + inspect.signature( + getattr(module, + f'forward_origin_{adapter_name}')).parameters) if config.method_name == 'feed_forward_chunk' and num_args_in_forward_chunk_fn == 1: setattr(module, config.method_name, types.MethodType(_feed_forward_chunk, module)) @@ -109,12 +133,16 @@ def _feed_forward_chunk(self, attention_output): adapter_module = AdapterModule(config.dim, config.adapter_length, ACT2CLS[config.act_layer]) - setattr(module, 'adapter', adapter_module) + setattr(module, f'adapter_{adapter_name}', adapter_module) + logger.info( + f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}' + ) - def state_dict_callback(state_dict): + def state_dict_callback(state_dict, adapter_name: str): return { key: value - for key, value in state_dict.items() if 'adapter' in key + for key, value in state_dict.items() + if f'adapter_{adapter_name}' in key } def mark_trainable_callback(model): @@ -123,8 +151,17 @@ def mark_trainable_callback(model): return SwiftOutput(config, state_dict_callback, mark_trainable_callback) + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, + activate: bool): + modules: List[torch.nn.Module] = find_sub_module( + module, f'adapter_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module.set_activation(activate) + -class AdapterModule(nn.Module): +class AdapterModule(nn.Module, ActivationMixin): """The implementation of adapter tuning method. Adapters project input tokens by an MLP layer. @@ -143,13 +180,14 @@ def __init__( act_layer=nn.GELU, ): super(AdapterModule, self).__init__() + super(nn.Module, self).__init__() self.dim = dim self.adapter_length = adapter_length - # self.adapter_type = adapter_type - self.ln1 = nn.Linear(dim, adapter_length) - self.activate = act_layer() - self.ln2 = nn.Linear(adapter_length, dim) + self.linear1 = nn.Linear(dim, adapter_length) + self.act = act_layer() + self.linear2 = nn.Linear(adapter_length, dim) self.init_weights() + self._prepared = False def init_weights(self): @@ -161,8 +199,19 @@ def _init_weights(m): self.apply(_init_weights) def forward(self, x, identity=None): - out = self.ln2(self.activate(self.ln1(x))) + if not self.is_activated(): + return x + if not self._prepared: + self.linear1.to(x.device) + self.act.to(x.device) + self.linear2.to(x.device) + self._prepared = True + + x_dtype = x.dtype + x = x.to(self.linear1.weight.dtype) + out = self.linear2(self.act(self.linear1(x))) if identity is None: identity = x + identity = identity.to(out.dtype) out = identity + out - return out + return out.to(x_dtype) diff --git a/swift/tuners/base.py b/swift/tuners/base.py index b6f4d1c3db..8ad9807e09 100644 --- a/swift/tuners/base.py +++ b/swift/tuners/base.py @@ -28,20 +28,27 @@ class SwiftModel(nn.Module): """The Swift wrapper model. Args: - model (`torch.nn.Module`) A module to be tuned by Swift. - config (`Union[SwiftConfig, Dict[str, SwiftConfig]]`) A config or a dict of adapter_name: SwiftConfig. + model (`Union[nn.Module, 'SwiftModel']`) A module to be tuned by Swift. + config (`Union[SwiftConfig, Dict[str, SwiftConfig]]`) A config or a dict of {adapter_name: SwiftConfig}. If it's a config class, the adapter_name will be `default` extra_state_keys (`List[str]`, `optional`) A list of regex to match the extra state keys to be saved. inference_mode (bool, `optional`): Load model at inference mode, default False. """ def __init__(self, - model: nn.Module, + model: Union[nn.Module, 'SwiftModel'], config: Union[SwiftConfig, Dict[str, SwiftConfig]], extra_state_keys: List[str] = None, inference_mode: bool = False, **kwargs): super().__init__() + self.adapters = {} + if isinstance(model, SwiftModel): + self.adapters = model.adapters + extra_state_keys = extra_state_keys or [] + extra_state_keys.extend(model.extra_state_keys) + model = model.base_model + if (getattr(model, 'hf_device_map', None) is not None) and ( len(set(model.hf_device_map.values()) & {'cpu', 'disk'}) > 0): from accelerate.hooks import remove_hook_from_submodules @@ -50,14 +57,14 @@ def __init__(self, for _, p in model.named_parameters(): p.requires_grad = False - self.adapters = {} if isinstance(config, SwiftConfig): - self.adapters[DEFAULT_ADAPTER] = self._prepare_model(model, config) + self.adapters[DEFAULT_ADAPTER] = self._prepare_model( + model, config, DEFAULT_ADAPTER) elif isinstance(config, dict): assert (all(isinstance(c, SwiftConfig) for c in config.values())) - for adapter_name, config in config.items(): + for adapter_name, _config in config.items(): self.adapters[adapter_name] = self._prepare_model( - model, config) + model, _config, adapter_name) self.model = model self.extra_state_keys = extra_state_keys or [] @@ -151,7 +158,8 @@ def state_dict(self, if kwargs.get('save_adapter', True): for name, output in self.adapters.items(): if adapter_name == name or adapter_name is None: - state_dicts.update(output.state_dict_callback(destination)) + state_dicts.update( + output.state_dict_callback(destination, name)) if kwargs.get('save_extra_states', True): state_dicts.update({ k: v @@ -194,18 +202,20 @@ def load_state_file(path): @classmethod def from_pretrained(cls, - model: nn.Module, + model: Union[nn.Module, 'SwiftModel'], model_id: str = None, - adapter_name: Union[str, List[str]] = 'default', + adapter_name: Union[str, List[str]] = None, inference_mode: bool = False, revision: str = None, **kwargs): """Load a set of tuners and corresponding weights by a model_id. Args: - model (`torch.nn.Module`): The model to be tuned. - model_id (`str`): The model_id or a local model dir to use to tune the model. + model (`Union[torch.nn.Module, 'SwiftModel']`): The model to be tuned, + if the model is already a `SwiftModel` it will be un-wrapped and re-wrapped.. + model_id (`str`): The model_id or a local model dir of tuners to use to tune the model. adapter_name (`Union[str, List[str]]`): The adapter_names saved in the model repo to load. + Default `None`, means load all tuners saved in the model_id inference_mode (`bool`): Use in the inference mode or not. revision (`str`): The model revision to use. **kwargs: @@ -228,11 +238,21 @@ def from_pretrained(cls, ) if not os.path.exists(model_id): model_dir = snapshot_download(model_id, revision=revision) + if adapter_name is None: + adapter_name = [ + sub_dir for sub_dir in os.listdir(model_dir) + if os.path.isdir(os.path.join(model_dir, sub_dir)) and + os.path.isfile(os.path.join(model_dir, sub_dir, CONFIG_NAME)) + ] for _name in adapter_name if isinstance(adapter_name, list) else [adapter_name]: sub_folder = os.path.join(model_dir, _name) config_file = os.path.join(sub_folder, CONFIG_NAME) + if not os.path.isfile(config_file): + logger.warning(f'{_name} is not a valid tuner') + continue + with open(config_file, 'r') as file: json_object = json.load(file) @@ -260,10 +280,12 @@ def _prepare_model( cls, model: nn.Module, config: SwiftConfig, + adapter_name: str, ): assert (hasattr(config, SWIFT_TYPE_KEY)) from .mapping import SWIFT_MAPPING - return SWIFT_MAPPING[config.swift_type][1].prepare_model(model, config) + return SWIFT_MAPPING[config.swift_type][1].prepare_model( + model, config, adapter_name) def create_or_update_model_card(self, output_dir: str): """ @@ -299,7 +321,6 @@ def create_or_update_model_card(self, output_dir: str): lines.append( f'{training_procedure_heading}\n{training_config_text}') - # Adds peft version framework_block_heading = '### Framework versions\n' from swift.version import __version__ if framework_block_heading in lines: @@ -310,6 +331,11 @@ def create_or_update_model_card(self, output_dir: str): lines.append( f'{framework_block_heading}\n\n- SWIFT {__version__}\n') + base_model_heading = '### Base model information\n' + lines.append( + f'{base_model_heading}\n\n- BaseModel Class {self.base_model.__class__.__name__}\n' + ) + # write the lines back to README.md with open(os.path.join(output_dir, 'README.md'), 'w') as f: f.writelines(lines) @@ -317,14 +343,14 @@ def create_or_update_model_card(self, output_dir: str): def save_pretrained(self, save_directory: str, safe_serialization: bool = False, - adapter_name: Union[str, List[str]] = 'default', + adapter_name: Union[str, List[str]] = None, **kwargs): """Save the adapters to a local directory. Args: save_directory (`str`): The directory to use. safe_serialization (`bool`): Use safe tensors to save the weights, default False. - adapter_name(`Union[str, List[str]]`): The adapters to be saved, default is `default`. + adapter_name(`Union[str, List[str]]`): The adapters to be saved, default is `None` to save all. """ if os.path.isfile(save_directory): raise ValueError( @@ -333,10 +359,10 @@ def save_pretrained(self, os.makedirs(save_directory, exist_ok=True) self.create_or_update_model_card(save_directory) - adapter_names = adapter_name if isinstance(adapter_name, - list) else [adapter_name] + adapter_names = adapter_name if isinstance( + adapter_name, list) or adapter_name is None else [adapter_name] for adapter_name, output in self.adapters.items(): - if adapter_name not in adapter_names: + if adapter_names is not None and adapter_name not in adapter_names: continue # save only the trainable weights @@ -381,6 +407,37 @@ def save_pretrained(self, def base_model(self): return self.model + def set_active_adapters(self, adapter_names: List[str]): + if not adapter_names: + return + + adapter_names = set(adapter_names) + for adapter_name in (adapter_names & set(self.adapters.keys())): + self.activate_adapter(adapter_name) + + for adapter_name in (set(self.adapters.keys()) - adapter_names): + self.deactivate_adapter(adapter_name) + + def activate_adapter(self, adapter_name): + if adapter_name not in self.adapters: + logger.warning( + f'{adapter_name} not in adapters: {self.adapters.keys()}') + return + + from .mapping import SWIFT_MAPPING + SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ + .activate_adapter(self.base_model, adapter_name, True) + + def deactivate_adapter(self, adapter_name): + if adapter_name not in self.adapters: + logger.warning( + f'{adapter_name} not in adapters: {self.adapters.keys()}') + return + + from .mapping import SWIFT_MAPPING + SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ + .activate_adapter(self.base_model, adapter_name, False) + def get_trainable_parameters(self): """ Get the content of trainable parameters in the model. @@ -397,20 +454,21 @@ def get_trainable_parameters(self): if param.requires_grad: trainable_params += num_params return f'trainable params: {trainable_params:,d} || all params: {all_param:,d} ' \ - f'|| trainable%: {100 * trainable_params / all_param}' + f'|| trainable%: {100 * trainable_params / all_param}' \ + f'|| cuda memory: {sum([torch.cuda.memory_allocated(i) for i in range(torch.cuda.device_count())])}' class Swift: """The Wrapper to use both Peft and Swift tuners.""" @staticmethod - def prepare_model(model: nn.Module, config: Union[SwiftConfig, PeftConfig, - Dict[str, SwiftConfig]], - **kwargs): + def prepare_model(model: Union[nn.Module, 'SwiftModel'], + config: Union[SwiftConfig, PeftConfig, + Dict[str, SwiftConfig]], **kwargs): """Prepare a model by the input config. Args: - model(`nn.Module`): The model to be tuned. + model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned. config(`Union[SwiftConfig, PeftConfig, Dict[str, SwiftConfig]]`): The config or config dict, can be either SwiftConfigs or PeftConfigs **kwargs: @@ -428,15 +486,15 @@ def prepare_model(model: nn.Module, config: Union[SwiftConfig, PeftConfig, raise ValueError(f'Unsupported swift config type: {config.__class__}') @staticmethod - def from_pretrained(model: nn.Module, + def from_pretrained(model: Union[nn.Module, 'SwiftModel'], model_id: str = None, - adapter_name: Union[str, List[str]] = 'default', + adapter_name: Union[str, List[str]] = None, revision: str = None, **kwargs): """Prepare a model by a model_id in the ModelScope hub or a local dir. Args: - model(`nn.Module`): The model to be tuned. + model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned. model_id(`str`): The model id of the modelhub or a local dir containing the configs/weights. adapter_name(`str`, `optional`): The adapter_name to use. revision(`str`, `optional`): The model revision if the model_id is a model id of the modelhub. @@ -453,8 +511,9 @@ def from_pretrained(model: nn.Module, _json = json.load(f) is_peft_model = PEFT_TYPE_KEY in _json - _name = adapter_name if isinstance(adapter_name, - str) else adapter_name[0] + _name = adapter_name if isinstance( + adapter_name, str) or adapter_name is None else adapter_name[0] + _name = _name or '' if os.path.exists(os.path.join(model_id, _name, CONFIG_NAME)): with open(os.path.join(model_id, _name, CONFIG_NAME), 'r') as f: _json = json.load(f) @@ -464,7 +523,7 @@ def from_pretrained(model: nn.Module, model, model_id, revision=revision, - adapter_name=adapter_name, + adapter_name=adapter_name or 'default', **kwargs) else: return SwiftModel.from_pretrained( diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py index 69719c9df1..15a6594aa1 100644 --- a/swift/tuners/lora.py +++ b/swift/tuners/lora.py @@ -1,19 +1,103 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. -import logging import math import re from dataclasses import dataclass, field -from typing import Dict, List +from types import MethodType +from typing import Dict, List, Union import torch import torch.nn as nn import torch.nn.functional as F +from peft.import_utils import (is_auto_gptq_available, is_bnb_4bit_available, + is_bnb_available) +from peft.utils import get_auto_gptq_quant_linear, get_quantization_config -from .utils import SwiftConfig, SwiftOutput +from swift import get_logger +from ..utils.torch_utils import find_sub_module +from .utils import ActivationMixin, SwiftConfig, SwiftOutput -logger = logging.getLogger(__name__) +if is_bnb_available(): + import bitsandbytes as bnb + + from peft.tuners.lora import Linear8bitLt + + class Linear8bitLtSwift(ActivationMixin, Linear8bitLt): + + def __init__( + self, + adapter_name, + in_features, + out_features, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + super(ActivationMixin, + self).__init__(adapter_name, in_features, out_features, r, + lora_alpha, lora_dropout, **kwargs) + super(Linear8bitLtSwift, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.is_activated(): + return bnb.nn.Linear8bitLt.forward(self, x) + return super().forward(x) + + +if is_bnb_4bit_available(): + from peft.tuners.lora import Linear4bit + + class Linear4bitSwift(ActivationMixin, Linear4bit): + + def __init__( + self, + adapter_name, + in_features, + out_features, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + super(ActivationMixin, + self).__init__(adapter_name, in_features, out_features, r, + lora_alpha, lora_dropout, **kwargs) + super(Linear4bitSwift, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.is_activated(): + return bnb.nn.Linear4bit.forward(self, x) + return super().forward(x) + + +if is_auto_gptq_available(): + from peft.tuners.lora import QuantLinear + + class QuantLinearSwift(ActivationMixin, QuantLinear): + + def __init__( + self, + adapter_name, + quant_linear_module, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + super(ActivationMixin, + self).__init__(adapter_name, quant_linear_module, r, + lora_alpha, lora_dropout, **kwargs) + super(QuantLinearSwift, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.is_activated(): + return self.quant_linear_module(x) + return super().forward(x) + + +logger = get_logger() @dataclass @@ -22,19 +106,20 @@ class LoRAConfig(SwiftConfig): The configuration class for the loRA module. Args: - r: The rank of the LoRA module - target_modules: The modules to be replaced by LoRA, can be the end of the module name or a regex string - lora_alpha: The factor to add the lora weights - lora_dropout: The dropout rate of the lora module - merge_weights: Whether to merge weights when validating - use_merged_linear: Whether to replace with merged linear layer - enable_lora: The modules need to be turned on when using the merged linear layer - fan_in_fan_out: Set this to True if the layer to replace stores weight like (fan_in, fan_out) - bias: Bias type. Values ca be "none", "all" or "lora_only" + r(int): The rank of the LoRA module + target_modules(List[str]): The modules to be replaced by LoRA, + can be the end of the module name or a regex string + lora_alpha(float): The factor to add the lora weights + lora_dropout(float): The dropout rate of the lora module + merge_weights(bool): Whether to merge weights when validating + use_merged_linear(bool): Whether to replace with merged linear layer + enable_lora(List[bool]): The modules need to be turned on when using the merged linear layer + fan_in_fan_out(bool): Set this to True if the layer to replace stores weight like (fan_in, fan_out) + bias(str): Bias type. Values ca be "none", "all" or "lora_only" """ r: int = field(default=6, metadata={'help': 'The rank of the LoRA module'}) - target_modules: List = field( + target_modules: List[str] = field( default=None, metadata={ 'help': @@ -76,12 +161,13 @@ def __post_init__(self): class LoRA: @staticmethod - def prepare_model(model: nn.Module, config: LoRAConfig): + def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str): """Prepare a model with `LoRAConfig`""" LoRA._dynamic_patch_lora( model, replace_modules=config.target_modules, r=config.r, + adapter_name=adapter_name, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, merge_weights=config.merge_weights, @@ -89,34 +175,44 @@ def prepare_model(model: nn.Module, config: LoRAConfig): enable_lora=config.enable_lora, fan_in_fan_out=config.fan_in_fan_out) - def state_dict_callback(state_dict): - return lora_state_dict(state_dict, config.bias) + def state_dict_callback(state_dict, adapter_name): + return lora_state_dict(state_dict, adapter_name, config.bias) def mark_trainable_callback(model): - mark_lora_as_trainable(model, config.bias) + mark_lora_as_trainable(model, adapter_name, config.bias) return SwiftOutput(config, state_dict_callback, mark_trainable_callback) @staticmethod - def _dynamic_patch_lora(model, replace_modules, use_merged_linear, + def activate_adapter(module: torch.nn.Module, adapter_name: str, + activate: bool): + modules: List[torch.nn.Module] = find_sub_module( + module, f'loramodule_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module.set_activation(activate) + + @staticmethod + def _dynamic_patch_lora(model: torch.nn.Module, + replace_modules: Union[str, List[str]], + use_merged_linear: bool, adapter_name: str, **kwargs): """Dynamic patch lora to model Args: - model: The torch.nn.Module containing the target module to be patched. - replace_modules: The module names to be replaced, the replacing strategy is `end with`. - use_merged_linear: Whether to replace with merged linear layer + model(`torch.nn.Module`): The torch.nn.Module containing the target module to be patched. + replace_modules(`Union[str, List[str]]`): The module names to be replaced, + the replacing strategy is `end with`. + use_merged_linear(bool): Whether to replace with merged linear layer. + adapter_name(str): The adapter name. **kwargs: The arguments passed from `tune` which are needed by lora. - - Returns: - The lora modules """ - modules = [] + modules = {} module_keys = [key for key, _ in model.named_modules()] assert isinstance(replace_modules, (str, list)) - if isinstance(replace_modules, str): - replace_modules = [replace_modules] + AutoGPTQQuantLinear = get_auto_gptq_quant_linear( + get_quantization_config(model, method='gptq')) for module_key in module_keys: if isinstance(replace_modules, str): @@ -126,26 +222,82 @@ def _dynamic_patch_lora(model, replace_modules, use_merged_linear, module_key.endswith(target_key) for target_key in replace_modules) if target_module_found: # noqa - parts = module_key.split('.') - module = model.get_submodule('.'.join(parts[:-1])) sub_module = model.get_submodule(module_key) - _key = parts[-1] lora_module = None - if isinstance(sub_module, torch.nn.Linear): + if getattr(model, 'is_loaded_in_8bit', False) and isinstance( + sub_module, bnb.nn.Linear8bitLt): + eight_bit_kwargs = kwargs.copy() + eight_bit_kwargs.update({ + 'has_fp16_weights': + sub_module.state.has_fp16_weights, + 'memory_efficient_backward': + sub_module.state.memory_efficient_backward, + 'threshold': + sub_module.state.threshold, + 'index': + sub_module.index, + }) + lora_module = Linear8bitLtSwift( + 'default', + sub_module.in_features, + sub_module.out_features, + bias=hasattr(sub_module, 'bias') + and sub_module.bias is not None, + **eight_bit_kwargs) + elif getattr(model, 'is_loaded_in_4bit', + False) and is_bnb_4bit_available() and isinstance( + sub_module, bnb.nn.Linear4bit): + four_bit_kwargs = kwargs.copy() + four_bit_kwargs.update({ + 'compute_dtype': + sub_module.compute_dtype, + 'compress_statistics': + sub_module.weight.compress_statistics, + 'quant_type': + sub_module.weight.quant_type, + }) + lora_module = Linear4bitSwift( + 'default', + sub_module.in_features, + sub_module.out_features, + bias=hasattr(sub_module, 'bias') + and sub_module.bias is not None, + **four_bit_kwargs) + elif AutoGPTQQuantLinear is not None and isinstance( + sub_module, AutoGPTQQuantLinear): + lora_module = QuantLinearSwift('default', sub_module, + **kwargs) + sub_module.weight = sub_module.qweight + elif isinstance(sub_module, torch.nn.Linear): if use_merged_linear: lora_module = MergedLinear( sub_module.in_features, sub_module.out_features, - bias=sub_module.bias is not None, + bias=hasattr(sub_module, 'bias') + and sub_module.bias is not None, **kwargs) else: kwargs.pop('enable_lora', None) lora_module = Linear( sub_module.in_features, sub_module.out_features, - bias=sub_module.bias is not None, + bias=hasattr(sub_module, 'bias') + and sub_module.bias is not None, **kwargs) + elif isinstance(sub_module, torch.nn.Embedding): + lora_module = Embedding( + num_embeddings=sub_module.num_embeddings, + embedding_dim=sub_module.embedding_dim, + padding_idx=sub_module.padding_idx, + max_norm=sub_module.max_norm, + norm_type=sub_module.norm_type, + scale_grad_by_freq=sub_module.scale_grad_by_freq, + sparse=sub_module.sparse, + r=kwargs['r'], + lora_alpha=kwargs['lora_alpha'], + merge_weights=kwargs['merge_weights'], + ) elif isinstance(sub_module, torch.nn.Conv2d): kwargs.pop('fan_in_fan_out', None) lora_module = Conv2d( @@ -158,14 +310,27 @@ def _dynamic_patch_lora(model, replace_modules, use_merged_linear, groups=sub_module.groups, **kwargs) + def _forward(self, *args, **kwargs): + for _name, _module in self.named_modules(): + if 'loramodule_' in _name and _module.is_activated(): + return _module.forward(*args, **kwargs) + return self.forward_origin(*args, **kwargs) + if lora_module is not None: lora_module.weight = sub_module.weight - if sub_module.bias is not None: + if getattr(sub_module, 'bias', None) is not None: lora_module.bias = sub_module.bias + if getattr(sub_module, 'state', None) is not None: + lora_module.state = sub_module.state lora_module.to(sub_module.weight.device) - setattr(module, _key, lora_module) - modules.append(lora_module) - return modules + setattr(sub_module, f'loramodule_{adapter_name}', + lora_module) + if not hasattr(sub_module, 'forward_origin'): + sub_module.forward_origin = sub_module.forward + sub_module.forward = MethodType(_forward, sub_module) + modules[module_key] = adapter_name + + logger.info(f'Lora modules(module_key -> adapter_name): {modules}') @staticmethod def unpatch_lora(model, config: LoRAConfig): @@ -178,11 +343,7 @@ def unpatch_lora(model, config: LoRAConfig): Args: model: The model called with `tune` function. config: The `LoRAConfig` to use. - - Returns: - The lora modules. """ - modules = [] module_keys = [key for key, _ in model.named_modules()] assert isinstance(config.replace_modules, (str, list)) replace_modules = config.replace_modules @@ -205,7 +366,19 @@ def unpatch_lora(model, config: LoRAConfig): origin_module = torch.nn.Linear( sub_module.in_features, sub_module.out_features, - bias=sub_module.bias is not None) + bias=hasattr(sub_module, 'bias') + and sub_module.bias is not None, + ) + elif isinstance(sub_module, Embedding): + origin_module = torch.nn.Embedding( + num_embeddings=sub_module.num_embeddings, + embedding_dim=sub_module.embedding_dim, + padding_idx=sub_module.padding_idx, + max_norm=sub_module.max_norm, + norm_type=sub_module.norm_type, + scale_grad_by_freq=sub_module.scale_grad_by_freq, + sparse=sub_module.sparse, + ) elif isinstance(sub_module, Conv2d): origin_module = torch.nn.Conv2d( sub_module.in_channels, @@ -220,22 +393,14 @@ def unpatch_lora(model, config: LoRAConfig): sub_module.merge_weights = True sub_module.eval() origin_module.weight = sub_module.weight - if sub_module.bias is not None: + if getattr(sub_module, 'bias', None) is not None: origin_module.bias = sub_module.bias origin_module.to(sub_module.weight.device).to( sub_module.weight.dtype) setattr(module, _key, origin_module) - modules.append(sub_module) - - model.state_dict_hook_handle.remove() - if hasattr(model, 'load_state_dict_hook_handle'): - model.load_state_dict_hook_handle.remove() - else: - model.load_state_dict = model.load_state_dict_origin - return modules -class LoRALayer: +class LoRALayer(ActivationMixin): def __init__( self, @@ -244,6 +409,7 @@ def __init__( lora_dropout: float, merge_weights: bool, ): + super().__init__() self.r = r self.lora_alpha = lora_alpha # Optional dropout @@ -254,6 +420,8 @@ def __init__( # Mark the weight as unmerged self.merged = False self.merge_weights = merge_weights + if not self._unique_thread: + self.merge_weights = False class Embedding(nn.Embedding, LoRALayer): @@ -318,7 +486,7 @@ def eval(self): self.merged = True def forward(self, x: torch.Tensor): - if self.r > 0 and not self.merged: + if self.r > 0 and not self.merged and self.is_activated(): result = nn.Embedding.forward(self, x) if self.r > 0: after_A = F.embedding(x, self.lora_A.T, self.padding_idx, @@ -409,7 +577,7 @@ def forward(self, x: torch.Tensor): def T(w): return w.T if self.fan_in_fan_out else w - if self.r > 0 and not self.merged: + if self.r > 0 and not self.merged and self.is_activated(): result = F.linear(x, T(self.weight), bias=self.bias) if self.r > 0: x_dtype = x.dtype @@ -529,7 +697,7 @@ def forward(self, x: torch.Tensor): def T(w): return w.T if self.fan_in_fan_out else w - if self.merged: + if self.merged or not self.is_activated(): return F.linear(x, T(self.weight), bias=self.bias) else: result = F.linear(x, T(self.weight), bias=self.bias) @@ -611,7 +779,7 @@ def eval(self): self.merged = True def forward(self, x: torch.Tensor): - if self.r > 0 and not self.merged: + if self.r > 0 and not self.merged and self.is_activated(): return F.conv2d( x, self.weight + # noqa @@ -625,7 +793,9 @@ def forward(self, x: torch.Tensor): return nn.Conv2d.forward(self, x) -def mark_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: +def mark_lora_as_trainable(model: nn.Module, + adapter_name: str, + bias: str = 'none') -> None: if bias == 'none': return elif bias == 'all': @@ -633,8 +803,8 @@ def mark_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: if 'bias' in n: p.requires_grad = True elif bias == 'lora_only': - for m in model.modules(): - if isinstance(m, LoRALayer) and \ + for n, m in model.named_modules(): + if f'loramodule_{adapter_name}' in n and \ hasattr(m, 'bias') and \ m.bias is not None: m.bias.requires_grad = True @@ -642,18 +812,26 @@ def mark_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: raise NotImplementedError -def lora_state_dict(state_dict, bias: str = 'none') -> Dict[str, torch.Tensor]: +def lora_state_dict(state_dict, + adapter_name: str, + bias: str = 'none') -> Dict[str, torch.Tensor]: if bias == 'none': - return {k: state_dict[k] for k in state_dict if 'lora_' in k} + return { + k: state_dict[k] + for k in state_dict + if f'loramodule_{adapter_name}' in k and 'lora_' in k + } elif bias == 'all': return { k: state_dict[k] - for k in state_dict if 'lora_' in k or 'bias' in k + for k in state_dict + if ('lora_' in k and f'loramodule_{adapter_name}' in k) or ( + 'bias' in k and f'loramodule_{adapter_name}' not in k) } elif bias == 'lora_only': to_return = {} for k in state_dict: - if 'lora_' in k: + if f'loramodule_{adapter_name}' in k and 'lora_' in k: to_return[k] = state_dict[k] bias_name = k.split('lora_')[0] + 'bias' if bias_name in state_dict: diff --git a/swift/tuners/mapping.py b/swift/tuners/mapping.py index 1f91c542ef..b958cc1305 100644 --- a/swift/tuners/mapping.py +++ b/swift/tuners/mapping.py @@ -3,16 +3,22 @@ from .adapter import Adapter, AdapterConfig from .lora import LoRA, LoRAConfig from .prompt import Prompt, PromptConfig +from .restuning import ResTuning, ResTuningConfig +from .side import Side, SideConfig class SwiftTuners: ADAPTER = 'ADAPTER' PROMPT = 'PROMPT' LORA = 'LORA' + SIDE = 'SIDE' + RESTUNING = 'RESTUNING' SWIFT_MAPPING = { SwiftTuners.ADAPTER: (AdapterConfig, Adapter), SwiftTuners.PROMPT: (PromptConfig, Prompt), - SwiftTuners.LORA: (LoRAConfig, LoRA) + SwiftTuners.LORA: (LoRAConfig, LoRA), + SwiftTuners.SIDE: (SideConfig, Side), + SwiftTuners.RESTUNING: (ResTuningConfig, ResTuning), } diff --git a/swift/tuners/prompt.py b/swift/tuners/prompt.py index f426a4dd83..141c196fdb 100644 --- a/swift/tuners/prompt.py +++ b/swift/tuners/prompt.py @@ -3,12 +3,16 @@ import re import types from dataclasses import dataclass, field -from typing import Union +from typing import List, Union import torch from torch import nn -from .utils import SwiftConfig, SwiftOutput +from swift import get_logger +from ..utils.torch_utils import find_sub_module +from .utils import ActivationMixin, SwiftConfig, SwiftOutput + +logger = get_logger() @dataclass @@ -24,17 +28,18 @@ class PromptConfig(SwiftConfig): Here we apply the VPT to other fields. Args: - dim: The dimension of the hidden states - target_modules: The layer module to be replaced, in regex format - embedding_pos: The position of the embedding tensor - attention_mask_pos: The position of the attention mask - attention_mask_value: The value to pad to the attention mask - prompt_length: The length of the prompt tokens - attach_front: When set to True, prompt is attached in front of the embedding - extract_embedding: Whether the embedding is extracted at final stage to keep the same dims with inputs + dim(`Union[int, List[int]]`): The dimension of the hidden states, use list if there are up-sample blocks + or down-sample blocks + target_modules(str): The layer module to be replaced, in regex format + embedding_pos(Union[str, int]): The position of the embedding tensor + attention_mask_pos(Union[str, int]): The position of the attention mask + attention_mask_value(Union[float, int, bool]): The value to pad to the attention mask + prompt_length(int): The length of the prompt tokens + attach_front(bool): When set to True, prompt is attached in front of the embedding + extract_embedding(bool): Whether the embedding is extracted at final stage to keep the same dims with inputs """ - dim: int = field( + dim: Union[int, List[int]] = field( default=None, metadata={'help': 'The dimension of the hidden states'}) target_modules: str = field( @@ -77,11 +82,19 @@ def __post_init__(self): class Prompt: @staticmethod - def prepare_model(model: nn.Module, config: PromptConfig): + def prepare_model(model: nn.Module, config: PromptConfig, + adapter_name: str): module_keys = [key for key, _ in model.named_modules()] match_module_keys = [] for module_key in module_keys: - if re.fullmatch(config.target_modules, module_key): # noqa + if isinstance(config.target_modules, str): + target_module_found = re.fullmatch(config.target_modules, + module_key) + else: + target_module_found = any( + module_key.endswith(target_key) + for target_key in config.target_modules) + if target_module_found: # noqa module = model.get_submodule(module_key) def _forward(self, *args, **kwargs): @@ -91,7 +104,8 @@ def _forward(self, *args, **kwargs): input_embedding = kwargs[config.embedding_pos] input_embedding = getattr( - self, 'prompt').forward(input_embedding) + self, + f'prompt_{adapter_name}').forward(input_embedding) if isinstance(config.embedding_pos, int): args = type(args)( args[0:config.embedding_pos] + (input_embedding, ) @@ -109,7 +123,8 @@ def _forward(self, *args, **kwargs): if attention_mask is not None: attention_mask = getattr( self, - 'prompt').patch_attention_mask(attention_mask) + f'prompt_{adapter_name}').patch_attention_mask( + attention_mask) if isinstance(config.attention_mask_pos, int): args = type(args)( args[0:config.attention_mask_pos] @@ -118,14 +133,18 @@ def _forward(self, *args, **kwargs): else: kwargs[config.attention_mask_pos] = attention_mask - forward_output = self.forward_origin(*args, **kwargs) + forward_output = getattr( + self, f'forward_origin_{adapter_name}')(*args, + **kwargs) if config.extract_embedding: forward_output = getattr( - self, 'prompt').extract(forward_output) + self, + f'prompt_{adapter_name}').extract(forward_output) return forward_output - module.forward_origin = module.forward + setattr(module, f'forward_origin_{adapter_name}', + module.forward) module.forward = types.MethodType(_forward, module) if isinstance(config.dim, list): input_dim = config.dim[len(match_module_keys)] @@ -136,13 +155,17 @@ def _forward(self, *args, **kwargs): config.prompt_length, config.attention_mask_value, config.attach_front) - setattr(module, 'prompt', prompt_module) + setattr(module, f'prompt_{adapter_name}', prompt_module) + logger.info( + f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}' + ) match_module_keys.append(module_key) - def state_dict_callback(state_dict): + def state_dict_callback(state_dict, adapter_name): return { key: value - for key, value in state_dict.items() if 'prompt' in key + for key, value in state_dict.items() + if f'prompt_{adapter_name}' in key } def mark_trainable_callback(model): @@ -151,8 +174,17 @@ def mark_trainable_callback(model): return SwiftOutput(config, state_dict_callback, mark_trainable_callback) + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, + activate: bool): + modules: List[torch.nn.Module] = find_sub_module( + module, f'prompt_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module.set_activation(activate) + -class PromptModule(nn.Module): +class PromptModule(nn.Module, ActivationMixin): """The implementation of vision prompt tuning method. Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens @@ -173,17 +205,20 @@ def __init__(self, mask_values=0., attach_front=True): super(PromptModule, self).__init__() + super(nn.Module, self).__init__() self.dim = dim self.layer_num = layer_num self.prompt_length = prompt_length self.mask_values = mask_values self.attach_front = attach_front - self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim)) nn.init.xavier_uniform_(self.prompt_token) def forward(self, x): - prompt_token = self.prompt_token.expand(x.shape[0], -1, -1) + if not self.is_activated(): + return x + prompt_token = self.prompt_token.expand(x.shape[0], -1, + -1).to(x.device, x.dtype) if self.layer_num == 0: if self.attach_front: @@ -200,9 +235,14 @@ def forward(self, x): return x def patch_attention_mask(self, m): + if not self.is_activated(): + return m prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length), self.mask_values).to(m.device) - return torch.cat((prefix_attention_mask, m), dim=-1) + if self.attach_front: + return torch.cat((prefix_attention_mask, m), dim=-1) + else: + return torch.cat((m, prefix_attention_mask), dim=-1) def extract(self, x): if self.attach_front: diff --git a/swift/tuners/restuning.py b/swift/tuners/restuning.py new file mode 100644 index 0000000000..d808551971 --- /dev/null +++ b/swift/tuners/restuning.py @@ -0,0 +1,402 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import re +import types +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from swift import get_logger +from ..utils.torch_utils import find_sub_module +from .restuning_components import (ResTuner, detach_tensors, + probe_input_pre_hook, probe_output_hook) +from .utils import ActivationMixin, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class ResTuningConfig(SwiftConfig): + """ + The configuration class for the ResTuning module. + + ResTuning is a flexible parameter-efficient and memory-efficient tuning paradigm framework. + 'Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone' + by Jiang et al.(2023) + See + + Args: + dims(`Union[List[int], int]`): The dimensions of the hidden states + root_modules(`str`): The root module to be replaced, can a regex string + root_modules_hook(`str`): The hook type of root modules, can be "input" or "output" + stem_modules(`Union[List[str], str]`): The stem modules to be replaced, + can a regex string or name list of full match format + stem_modules_hook(`Union[List[str], str]`): The hook type of stem modules, can be "input" or "output" + target_modules(`str`): The target module to be replaced, can a regex string + target_modules_hook(`str`): The hook type of target modules, can be "input" or "output" + tuner_cfg(`Union[List[Dict], Dict, str]`): The configuration of the tuning module, + can a string or customized config + use_upsample(bool): Whether to use auxiliary upsample module + upsample_out_channels(List[int]): The channels if `use_upsample` + zero_init_last(bool): Use zero to initialize the last Linear in every sub tuner. + + """ + + dims: Optional[Union[List[int], int]] = field( + default=None, metadata={'help': 'The dimensions of the hidden states'}) + + root_modules: str = field( + default=None, + metadata={ + 'help': + 'The root module to be replaced, can a regex string (use the first matching module) or full match format' + }) + + root_modules_hook: str = field( + default='input', + metadata={ + 'help': 'The hook type of root modules, can be "input" or "output"' + }) + + stem_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + 'help': + 'The stem modules to be replaced, can a regex string or name list of full match format' + }) + + stem_modules_hook: str = field( + default='output', + metadata={ + 'help': 'The hook type of stem modules, can be "input" or "output"' + }) + + target_modules: str = field( + default=None, + metadata={ + 'help': + 'The target module to be replaced, can a regex string (use the first matching module) or full match format' + }) + + target_modules_hook: str = field( + default='input', + metadata={ + 'help': + 'The hook type of target modules, can be "input" or "output"' + }) + + target_hidden_pos: Union[int, str] = field( + default=None, + metadata={ + 'help': + 'The position of the hidden state for target modules output' + }) + + tuner_cfg: Optional[Union[List[Dict], Dict, str]] = field( + default=None, + metadata={ + 'help': + 'The configuration of the tuning module, can a string or customized config' + }) + + use_upsample: bool = field( + default=False, + metadata={'help': 'Whether to use auxiliary upsample module'}) + + upsample_out_channels: List[int] = field( + default=None, + metadata={ + 'help': + 'The number of output channels when "use_upsample" is set to "True"' + }) + + zero_init_last: bool = field( + default=False, metadata={'help': 'Zero init last weight'}) + + use_bypass: bool = field( + default=True, metadata={'help': 'Whether to use bypass'}) + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.RESTUNING + self.target_hidden_pos = 0 if self.target_hidden_pos is None else self.target_hidden_pos + + +class ResTuning: + + @staticmethod + def prepare_model(model: nn.Module, config: ResTuningConfig, + adapter_name: str) -> SwiftOutput: + """Prepare a model with `ResTuningConfig`""" + + def _forward_seq(self, input, *args, **kwargs): + for idx, module in enumerate(self): + if idx >= len(self.origin_module_keys): + continue + input = module(input) + return input + + def _forward_target(self, *args, **kwargs): + if self.target_modules_hook == 'input': + if isinstance(self.target_hidden_pos, int): + args = list(args) + _arg = args[self.target_hidden_pos] + else: + _arg = kwargs[self.target_hidden_pos] + args_main = _forward_restuning(self, _arg) + if isinstance(self.target_hidden_pos, int): + args[self.target_hidden_pos] = args_main + else: + kwargs[self.target_hidden_pos] = args_main + args_main = getattr(self, + f'forward_origin_{adapter_name}')(*args, + **kwargs) + else: + _args_main = getattr(self, f'forward_origin_{adapter_name}')( + *args, **kwargs) + _arg = _args_main[self.target_hidden_pos] if isinstance( + _args_main, (tuple, list, dict)) else _args_main + args_main = _forward_restuning(self, _arg) + if type(_args_main) != type(args_main): + _args_main[self.target_hidden_pos] = args_main + args_main = _args_main + return args_main + + def _forward_restuning(self, origin_arg): + probe_results = [] + root_module_ins = self.root_module_ins_list[0] + stem_module_ins_list = self.stem_module_ins_list + top_module = model.get_submodule('') + if root_module_ins: + if root_module_ins.root_modules_hook == 'input': + probe_results.append(root_module_ins.probe_input_data) + else: + probe_results.append(root_module_ins.probe_output_data) + for i, st_mod in enumerate(stem_module_ins_list): + if i == 0 and root_module_ins is None: + probe_results.append(st_mod.probe_input_data) + if st_mod.stem_modules_hook == 'input': + probe_results.append(st_mod.probe_input_data) + else: + probe_results.append(st_mod.probe_output_data) + args_main = getattr(top_module, + f'restuning_{adapter_name}')(probe_results, + origin_arg) + return args_main + + # 1. Matching the root module + module_keys = [key for key, _ in model.named_modules()] + root_module_ins_list = [] + if config.root_modules: + for module_key in module_keys: + if re.fullmatch(config.root_modules, module_key): + root_module = model.get_submodule(module_key) + logger.info( + f'Matching root module [{module_key}] of type {type(root_module)}' + ) + if isinstance(root_module, (nn.ModuleList, nn.ModuleDict)): + logger.warning( + f'Type of {type(root_module)} may not be supported because of its customized forward' + ) + if config.root_modules_hook == 'input': + root_module.register_forward_pre_hook( + probe_input_pre_hook) + else: + root_module.register_forward_hook(probe_output_hook) + root_module.root_modules_hook = config.root_modules_hook + root_module_ins_list.append(root_module) + break + if len(root_module_ins_list) == 0: + logger.error('Cannot match root modules') + + # 2. Matching the stem module + stem_module_ins_list = [] + stem_module_ins_index = [] + for module_key in module_keys: + if (isinstance(config.stem_modules, str) and re.fullmatch(config.stem_modules, module_key)) or \ + (isinstance(config.stem_modules, list) and module_key in config.stem_modules): + stem_module = model.get_submodule(module_key) + if isinstance(config.stem_modules, list): + stem_module_ins_index.append( + config.stem_modules.index(module_key)) + logger.info( + f'Matching stem module [{module_key}] of type {type(stem_module)}' + ) + if isinstance(stem_module, (nn.ModuleList, nn.ModuleDict)): + logger.warning( + f'Type of {type(stem_module)} may not be supported because of its customized forward' + ) + if len(root_module_ins_list) == 0 and len( + stem_module_ins_list) == 0: + stem_module.register_forward_pre_hook(probe_input_pre_hook) + if config.stem_modules_hook == 'input': + stem_module.register_forward_pre_hook(probe_input_pre_hook) + else: + stem_module.register_forward_hook(probe_output_hook) + stem_module.stem_modules_hook = config.stem_modules_hook + stem_module_ins_list.append(stem_module) + if isinstance(config.stem_modules, list): + stem_module_ins_list = [ + stem_module_ins_list[stem_module_ins_index.index(i)] + for i in range(len(stem_module_ins_index)) + ] + depth = len(stem_module_ins_list) + if len(stem_module_ins_list) == 0: + raise Exception('Cannot match source modules') + + # 3. Init restuning module + if len(stem_module_ins_list) != 0: + top_module = model.get_submodule('') + restuning_module = ResTuningBypassModule( + config.dims, depth, config.use_upsample, + config.upsample_out_channels, config.zero_init_last, + config.tuner_cfg) + setattr(top_module, f'restuning_{adapter_name}', restuning_module) + + # 4. Matching the target module + target_module_ins = None + for module_key in module_keys: + if re.fullmatch(config.target_modules, module_key): + tgt_module = model.get_submodule(module_key) + logger.info( + f'Matching target module [{module_key}] of type {type(tgt_module)}' + ) + if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)): + raise Exception( + f'Type of {type(tgt_module)} may not be supported because of its customized forward' + ) + + tgt_module.target_modules_hook = config.target_modules_hook + tgt_module.target_hidden_pos = config.target_hidden_pos + tgt_module.root_module_ins_list = root_module_ins_list + tgt_module.stem_module_ins_list = stem_module_ins_list + target_module_ins = tgt_module + + if isinstance(tgt_module, nn.Sequential) and not hasattr( + tgt_module, 'origin_module_keys'): + tgt_module.origin_module_keys = copy.deepcopy( + list(tgt_module._modules.keys())) + + setattr(tgt_module, f'forward_origin_{adapter_name}', + types.MethodType(_forward_seq, tgt_module)) + else: + setattr(tgt_module, f'forward_origin_{adapter_name}', + tgt_module.forward) + tgt_module.forward = types.MethodType(_forward_target, + tgt_module) + if target_module_ins is None: + raise Exception('Cannot match target modules') + + def state_dict_callback(state_dict, adapter_name): + return { + key: value + for key, value in state_dict.items() + if f'restuning_{adapter_name}' in key + } + + def mark_trainable_callback(model): + return + + return SwiftOutput(config, state_dict_callback, + mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, + activate: bool): + modules: List[torch.nn.Module] = find_sub_module( + module, f'restuning_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module.set_activation(activate) + + +class ResTuningBypassModule(nn.Module, ActivationMixin): + """The implementation of ResTuningBypass method. + """ + + def __init__( + self, + dims, + depth, + use_upsample=False, + upsample_out_channels=None, + zero_init_last=False, + tuner_cfg=None, + ): + super(ResTuningBypassModule, self).__init__() + super(nn.Module, self).__init__() + + self.bypass_blocks = nn.Sequential(*[ + ResTunerBypassBlock( + dim=dims[i] if isinstance(dims, list) else dims, + layer_num=i, + depth=depth, + use_upsample=use_upsample, + upsample_out_channels=upsample_out_channels[i] if isinstance( + upsample_out_channels, list) else upsample_out_channels, + zero_init_last=zero_init_last, + tuner_cfg=tuner_cfg[i] if isinstance(tuner_cfg, list + ) else tuner_cfg) + for i in range(depth) + ]) + + def forward(self, x_list, origin_arg, **kwargs): + if not self.is_activated(): + return origin_arg + x_bypass = detach_tensors(x_list.pop(0)) + x_bypass = x_bypass[0] if isinstance(x_bypass, + (list, tuple)) else x_bypass + x_list = detach_tensors(x_list) + x_list = [ + _x[0] if isinstance(_x, (list, tuple)) else _x for _x in x_list + ] + for i, (bp_blk, x_stem) in enumerate(zip(self.bypass_blocks, x_list)): + target_size = x_list[ + i + 1].shape[2:] if i < len(x_list) - 1 else None + x_bypass = bp_blk(x_stem, x_bypass, target_size, **kwargs) + return x_bypass + + +class ResTunerBypassBlock(nn.Module): + + def __init__(self, + dim, + layer_num=-1, + depth=-1, + use_upsample=False, + zero_init_last=False, + tuner_cfg=None, + **kwargs): + super().__init__() + self.layer_num = layer_num + self.depth = depth + + if isinstance(tuner_cfg, str): + lateral_cfg = tuner_cfg + vertical_cfg = tuner_cfg + aux_cfg = 'upsample' if use_upsample and layer_num != depth - 1 else None + elif isinstance(tuner_cfg, dict): + lateral_cfg = tuner_cfg[ + 'lateral_cfg'] if 'lateral_cfg' in tuner_cfg else None + vertical_cfg = tuner_cfg[ + 'vertical_cfg'] if 'vertical_cfg' in tuner_cfg else None + aux_cfg = tuner_cfg['aux_cfg'] if 'aux_cfg' in tuner_cfg else None + + self.lateral_tuner = ResTuner(dim, layer_num, depth, zero_init_last, + 'lateral', lateral_cfg, **kwargs) + self.vertical_tuner = ResTuner(dim, layer_num, depth, zero_init_last, + 'vertical', vertical_cfg, **kwargs) + if aux_cfg and len(aux_cfg) != 0: + self.aux_tuner = ResTuner(dim, layer_num, depth, zero_init_last, + 'aux', aux_cfg, **kwargs) + + def forward(self, x_stem, x_bypass, target_size=None, **kwargs): + x_lateral = self.lateral_tuner(x_stem) + x_vertical = self.vertical_tuner(x_bypass) + + x_bypass_out = x_lateral + x_vertical + if hasattr(self, 'aux_tuner'): + x_bypass_out = self.aux_tuner(x_bypass_out, target_size) + return x_bypass_out diff --git a/swift/tuners/restuning_components.py b/swift/tuners/restuning_components.py new file mode 100644 index 0000000000..e7f02aa5d8 --- /dev/null +++ b/swift/tuners/restuning_components.py @@ -0,0 +1,398 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from swift.utils.logger import get_logger + +logger = get_logger() + + +class ResTuner(nn.Module): + + def __init__(self, + dim=None, + layer_num=-1, + depth=-1, + zero_init_last=False, + stage='', + tuner_cfg={}, + **kwargs): + super().__init__() + self.dim = dim + self.layer_num = layer_num + self.depth = depth + self.stage = stage + self.tuner_cfg = tuner_cfg + + if (isinstance(tuner_cfg, str) and tuner_cfg == 'res_adapter') or \ + (isinstance(tuner_cfg, dict) and 'res_adapter' in tuner_cfg): + tuner_cfg = tuner_cfg['res_adapter'] if isinstance( + tuner_cfg, dict) else tuner_cfg + self.tuner = ResAdapter( + dim=dim, + layer_num=layer_num, + depth=depth, + zero_init_last=zero_init_last, + stage=stage, + tuner_cfg=tuner_cfg, + **kwargs) + elif (isinstance(tuner_cfg, str) and tuner_cfg == 'res_group_adapter') or \ + (isinstance(tuner_cfg, dict) and 'res_group_adapter' in tuner_cfg): + tuner_cfg = tuner_cfg['res_group_adapter'] if isinstance( + tuner_cfg, dict) else tuner_cfg + self.tuner = ResGroupAdapter( + dim=dim, + layer_num=layer_num, + depth=depth, + zero_init_last=zero_init_last, + stage=stage, + tuner_cfg=tuner_cfg, + **kwargs) + elif (isinstance(tuner_cfg, str) and tuner_cfg == 'upsample') or \ + (isinstance(tuner_cfg, dict) and 'upsample' in tuner_cfg): + tuner_cfg = tuner_cfg['upsample'] if isinstance( + tuner_cfg, dict) else tuner_cfg + if 'upsample_out_channels' in kwargs: + out_channels = kwargs['upsample_out_channels'] + use_conv = True if out_channels else False + else: + out_channels = dim + use_conv = False + self.tuner = Upsample( + channels=dim, + use_conv=use_conv, + out_channels=out_channels, + tuner_cfg=tuner_cfg, + **kwargs) + else: + self.tuner = Identity() + + def forward(self, x, *args, **kwargs): + if self.tuner_cfg == 'zero' or 'zero' in self.tuner_cfg: + x_out = 0.0 + else: + x_out = self.tuner(x, *args, **kwargs) + return x_out + + +class ResAdapter(nn.Module): + + def __init__(self, + dim, + layer_num=-1, + depth=-1, + zero_init_last=False, + stage='', + tuner_cfg=None, + act_layer=nn.GELU, + **kwargs): + super(ResAdapter, self).__init__() + self.dim = dim + self.layer_num = layer_num + self.depth = depth + + self.adapter_length = tuner_cfg[ + 'adapter_length'] if 'adapter_length' in tuner_cfg else 32 + self.adapter_type = tuner_cfg[ + 'adapter_type'] if 'adapter_type' in tuner_cfg else None + self.adapter_weight = tuner_cfg[ + 'adapter_weight'] if 'adapter_weight' in tuner_cfg else None + + self.adapter_length = self.adapter_length[ + self.layer_num] if isinstance(self.adapter_length, + list) else self.adapter_length + assert isinstance(self.adapter_length, + int) or (isinstance(self.adapter_length, tuple) + and len(self.adapter_length) == 3) + if isinstance(self.adapter_length, int): + self.ln1 = nn.Linear(dim, self.adapter_length) + else: + self.ln1 = nn.Linear(self.adapter_length[0], + self.adapter_length[1]) + self.activate = act_layer() + if isinstance(self.adapter_length, int): + self.ln2 = nn.Linear(self.adapter_length, dim) + else: + self.ln2 = nn.Linear(self.adapter_length[1], + self.adapter_length[2]) + dim = self.adapter_length[2] + + self._xavier_init_weights(self.ln1) + if zero_init_last and layer_num == depth - 1: + self._zero_init_weights(self.ln2) + else: + self._xavier_init_weights(self.ln2) + + self.scaling = init_weight_type(dim, self.adapter_weight) + self._prepared = False + + def _zero_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def _kaiming_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + nn.init.normal_(m.bias) + + def _xavier_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + if not self._prepared: + self.ln1.to(x.device) + self.activate.to(x.device) + self.ln2.to(x.device) + self._prepared = True + + x_dtype = x.dtype + x = x.to(self.ln1.weight.dtype) + x_shortcut = x + if len(x_shortcut.size()) == 4: + B, C, N1, N2 = x.size() + x = x.view(x_shortcut.size()[0], + x_shortcut.size()[1], -1).permute(0, 2, 1) + + x_adapter = self.ln2(self.activate(self.ln1(x))) + + if self.adapter_weight: + x_adapter = apply_data_weight(x_adapter, self.scaling, + self.adapter_weight) + + if len(x_shortcut.size()) == 4: + x_adapter = x_adapter.permute(0, 2, + 1).view(x_shortcut.size()[0], + x_adapter.size()[-1], + x_shortcut.size()[2], + x_shortcut.size()[3]) + x_out = x_shortcut + x_adapter + return x_out.to(x_dtype) + + +class ResGroupAdapter(nn.Module): + + def __init__(self, + dim, + layer_num=-1, + depth=-1, + zero_init_last=False, + stage='', + tuner_cfg=None, + act_layer=nn.GELU, + **kwargs): + super(ResGroupAdapter, self).__init__() + self.dim = dim + self.layer_num = layer_num + self.depth = depth + + self.adapter_type = tuner_cfg[ + 'adapter_type'] if 'adapter_type' in tuner_cfg else None + self.adapter_weight = tuner_cfg[ + 'adapter_weight'] if 'adapter_weight' in tuner_cfg else None + + self.adapter_dim = tuner_cfg['dim'] if 'dim' in tuner_cfg else dim + self.adapter_head = tuner_cfg['head'] if 'head' in tuner_cfg else 4 + self.adapter_scale_factor = tuner_cfg[ + 'scale_factor'] if 'scale_factor' in tuner_cfg else 2 + + assert self.adapter_dim % self.adapter_head == 0, 'adapter dim should be divisible by adapter head' + self.dim_mlp = self.adapter_dim // self.adapter_head + + self.ln1 = nn.Linear(self.dim_mlp, + self.dim_mlp * self.adapter_scale_factor) + self.ln2 = nn.Linear(self.dim_mlp * self.adapter_scale_factor, + self.dim_mlp) + self.activate = act_layer() + + self._kaiming_init_weights(self.ln1) + if zero_init_last and layer_num == depth - 1: + self._zero_init_weights(self.ln2) + else: + self._kaiming_init_weights(self.ln2) + self.scaling = init_weight_type(dim, self.adapter_weight) + self._prepared = False + + def _zero_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def _kaiming_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + nn.init.normal_(m.bias) + + def _xavier_init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x): + if not self._prepared: + self.ln1.to(x.device) + self.activate.to(x.device) + self.ln2.to(x.device) + self._prepared = True + + x_dtype = x.dtype + x = x.to(self.ln1.weight.dtype) + x_shortcut = x + + batch, inner_dim, height, width = x.shape + + x_adapter = x.permute(0, 2, 3, 1).reshape(batch, height * width, + inner_dim) + + x_adapter = rearrange( + x_adapter, 'b n (c h) -> (b h) n c', h=self.adapter_head) + x_adapter = self.ln2(self.activate(self.ln1(x_adapter))) + x_adapter = rearrange( + x_adapter, '(b h) n c -> b n (c h)', h=self.adapter_head) + + if self.adapter_weight: + x_adapter = apply_data_weight(x_adapter, self.scaling, + self.adapter_weight) + + x_adapter = x_adapter.reshape(batch, height, width, + -1).permute(0, 3, 1, 2).contiguous() + x_out = x_shortcut + x_adapter + + return x_out.to(x_dtype) + + +class Identity(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, inputs, *args, **kwargs): + return inputs + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv=False, + out_channels=None, + padding=1, + **kwargs): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + if use_conv: + self.conv = nn.Conv2d( + self.channels, self.out_channels, 3, padding=padding) + self.init_weights() + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Conv2d): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + self.apply(_init_weights) + + def forward(self, x, target_size=None, *args, **kwargs): + assert x.shape[1] == self.channels + if target_size is None: + x = F.interpolate( + x.float(), scale_factor=2, mode='nearest').type_as(x) + else: + x = F.interpolate( + x.float(), target_size, mode='nearest').type_as(x) + if self.use_conv: + x = self.conv(x) + return x + + +def init_weight_type(dim, weight_type): + if weight_type is None: + scaling = None + elif weight_type == 'gate': + scaling = nn.Linear(dim, 1) + elif weight_type == 'scale': + scaling = nn.Parameter(torch.Tensor(1)) + scaling.data.fill_(1) + elif weight_type == 'scale_kv': + scaling_k = nn.Parameter(torch.Tensor(1)) + scaling_k.data.fill_(1) + scaling_v = nn.Parameter(torch.Tensor(1)) + scaling_v.data.fill_(1) + scaling = (scaling_k, scaling_v) + elif weight_type == 'scale_channel': + scaling = nn.Parameter(torch.Tensor(dim)) + scaling.data.fill_(1) + elif weight_type == 'scale_kv_channel': + scaling_k = nn.Parameter(torch.Tensor(dim)) + scaling_k.data.fill_(1) + scaling_v = nn.Parameter(torch.Tensor(dim)) + scaling_v.data.fill_(1) + scaling = (scaling_k, scaling_v) + elif weight_type and weight_type.startswith('scalar'): + scaling = float(weight_type.split('_')[-1]) + else: + scaling = None + return scaling + + +def apply_data_weight(data, scaling, weight_type): + if weight_type in ['gate']: + scaling = torch.mean( + torch.sigmoid(scaling(data)), dim=1).view(-1, 1, 1) + elif weight_type in ['scale', 'scale_channel' + ] or weight_type.startswith('scalar'): + scaling = scaling + else: + scaling = None + if scaling is not None: + data = data * scaling + return data + + +def detach_tensors(feats): + if type(feats) in [list, tuple]: + feats = [ + detach_tensors(feat) if feat is not None else None + for feat in feats + ] + elif isinstance(feats, dict): + feats = {key: detach_tensors(val) for key, val in feats.items()} + elif isinstance(feats, torch.Tensor): + feats = feats.detach() + else: + feats = feats.detach() + return feats + + +def probe_tensors(module, feats, name): + feats = detach_tensors(feats) + setattr(module, name, feats) + + +def probe_input_pre_hook(self, args): + input = args[0] + probe_tensors(self, input, 'probe_input_data') + return args + + +def probe_output_hook(self, args, result): + output = result + probe_tensors(self, output, 'probe_output_data') + return output diff --git a/swift/tuners/side.py b/swift/tuners/side.py new file mode 100644 index 0000000000..168cc2bb2c --- /dev/null +++ b/swift/tuners/side.py @@ -0,0 +1,296 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import re +import types +from collections import OrderedDict +from dataclasses import dataclass, field +from functools import partial +from itertools import repeat +from typing import List, Union + +import torch +import torchvision +from torch import nn + +from swift.utils.logger import get_logger +from ..utils.torch_utils import find_sub_module +from .utils import ActivationMixin, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class SideConfig(SwiftConfig): + """ + The configuration class for the side module. + + Side-Tuning only needs to train one side network and + weights the output of pre-trained model and side network. + 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks' + by Zhang et al.(2019) + See https://arxiv.org/abs/1912.13503 + + Args: + target_modules: The feedforward module to be replaced, in regex format + """ + + dim: int = field( + default=None, metadata={'help': 'The dimension of the hidden states'}) + + target_modules: str = field( + default=None, + metadata={ + 'help': 'The target module to be replaced, in full match format' + }) + + side_module_name: str = field( + default='fcn4', + metadata={'help': 'The name of the additive side networks'}) + + source_hidden_pos: Union[str, int] = field( + default=0, + metadata={ + 'help': + 'The position of the hidden state input to the target module, can be int (args) or str (kwargs)' + }) + + target_hidden_pos: Union[str, int] = field( + default=0, + metadata={ + 'help': + 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)' + }) + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.SIDE + + +class Side: + + @staticmethod + def prepare_model(model: nn.Module, config: SideConfig, + adapter_name: str) -> SwiftOutput: + """Prepare a model with `SideConfig`""" + module_keys = [key for key, _ in model.named_modules()] + + for module_key in module_keys: + if re.fullmatch(config.target_modules, module_key): # noqa + tgt_module = model.get_submodule(module_key) + logger.info( + f'Matching target module [{module_key}] of type {type(tgt_module)}' + ) + if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)): + raise Exception( + f'Type of {type(tgt_module)} may not be supported because of its customized forward' + ) + + def _forward(self, *args, **kwargs): + args_main = getattr( + self, f'forward_origin_{adapter_name}')(*args, + **kwargs) + + if isinstance(config.source_hidden_pos, int): + x = args[config.source_hidden_pos] + else: + x = kwargs[config.source_hidden_pos] + + x_main = args_main[config.target_modules] \ + if isinstance(args_main, (tuple, list, dict)) else args_main + out = getattr(self, f'side_{adapter_name}')(x, x_main) + if isinstance(args_main, (tuple, list, dict)): + args_main[config.target_modules] = out + else: + args_main = out + return args_main + + if isinstance(tgt_module, nn.Sequential) and not hasattr( + tgt_module, 'tgt_module_keys'): + tgt_module.tgt_module_keys = copy.deepcopy( + list(tgt_module._modules.keys())) + + def forward_seq(self, input, *args, **kwargs): + for idx, module in enumerate(self): + if idx >= len(tgt_module.tgt_module_keys): + continue + input = module(input) + return input + + setattr(tgt_module, f'forward_origin_{adapter_name}', + types.MethodType(forward_seq, tgt_module)) + else: + setattr(tgt_module, f'forward_origin_{adapter_name}', + tgt_module.forward) + tgt_module.forward = types.MethodType(_forward, tgt_module) + side_module = SideModule(config.dim, config.side_module_name) + setattr(tgt_module, f'side_{adapter_name}', side_module) + logger.info( + f'Side modules(module_key): {module_key}.side_{adapter_name}' + ) + + def state_dict_callback(state_dict, adapter_name): + return { + key: value + for key, value in state_dict.items() + if f'side_{adapter_name}' in key + } + + def mark_trainable_callback(model): + return + + return SwiftOutput(config, state_dict_callback, + mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, + activate: bool): + modules: List[torch.nn.Module] = find_sub_module( + module, f'side_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module.set_activation(activate) + + +class SideModule(nn.Module, ActivationMixin): + """The implementation of vision side-tuning method. + + Side-Tuning only needs to train one side network and + weights the output of pre-trained model and side network. + 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks' + by Zhang et al.(2019) + See https://arxiv.org/abs/1912.13503 + + Attributes: + side_module_name: The name of the additive side networks. + """ + + def __init__(self, dim, side_module_name='fcn4'): + super(SideModule, self).__init__() + super(nn.Module, self).__init__() + + side_module_name = side_module_name.lower() + if side_module_name == 'fcn4': + self.side_net = FCN4(out_dims=dim) + elif side_module_name == 'mlp': + self.side_net = Mlp(dim) + elif side_module_name == 'alexnet': + mm = torchvision.models.alexnet(pretrained=True) + self.side_net = nn.Sequential( + OrderedDict([('features', mm.features), + ('avgpool', mm.avgpool), + ('flatten', nn.Flatten()), + ('fc', nn.Linear(9216, dim, bias=False))])) + else: + raise ValueError( + f'Unsupported side_module_name: {side_module_name}') + self.alpha = nn.Parameter(torch.tensor(0.0)) + + def forward(self, x, x_main): + if not self.is_activated(): + return x_main + alpha_squashed = torch.sigmoid(self.alpha) + x_side = self.side_net(x) + x_out = alpha_squashed * x_main + (1 - alpha_squashed) * x_side + return x_out + + +class FCN4(nn.Module): + """The implementation of simple FCN4 network for side network. + """ + + def __init__(self, out_dims=-1, **kwargs): + super(FCN4, self).__init__(**kwargs) + + self.conv1 = nn.Sequential( + nn.Conv2d( + 3, + 16, + kernel_size=3, + stride=1, + padding=1, + bias=False, + dilation=1), nn.GroupNorm(2, 16), nn.ReLU()) + self.conv2 = nn.Sequential( + nn.Conv2d( + 16, + 16, + kernel_size=3, + stride=2, + padding=0, + bias=False, + dilation=1), nn.GroupNorm(2, 16), nn.ReLU()) + self.conv3 = nn.Sequential( + nn.Conv2d( + 16, + 32, + kernel_size=3, + stride=2, + padding=0, + bias=False, + dilation=1), nn.GroupNorm(2, 32), nn.ReLU()) + self.conv4 = nn.Sequential( + nn.Conv2d( + 32, + 64, + kernel_size=3, + stride=1, + padding=0, + bias=False, + dilation=1), nn.GroupNorm(2, 64), nn.ReLU()) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + if out_dims > 0: + self.fc = nn.Linear(64, out_dims) + else: + self.fc = None + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.pool(x) + x = x.view(x.size(0), -1) + if self.fc is not None: + x = self.fc(x) + return x + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer. + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0., + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = tuple(repeat(bias, 2)) + drop_probs = tuple(repeat(drop, 2)) + linear_layer = partial( + nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer( + hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py index 42faa94e84..7289773532 100644 --- a/swift/tuners/utils.py +++ b/swift/tuners/utils.py @@ -2,8 +2,10 @@ # Copyright 2023-present the HuggingFace Inc. team. import os +import threading from dataclasses import asdict, dataclass, field from types import FunctionType +from typing import Dict import json from peft.utils import CONFIG_NAME @@ -109,10 +111,10 @@ class SwiftOutput: which is used to get the tuner's state dict among the model's state dict. This callback should receive a state dict, and returns a created state dict. Examples: - >>> def state_dict_callback(state_dict): + >>> def state_dict_callback(state_dict, adapter_name): >>> return { >>> key: value - >>> for key, value in state_dict.items() if 'adapter' in key + >>> for key, value in state_dict.items() if adapter_name in key >>> } mark_trainable_callback (`FunctionType`): A callback returned by the tuner which is used to mark the tuner's adapter's parameters to trainable. @@ -125,3 +127,21 @@ class SwiftOutput: config: SwiftConfig = None state_dict_callback: FunctionType = None mark_trainable_callback: FunctionType = None + + +class ActivationMixin: + + USE_UNIQUE_THREAD = 'USE_UNIQUE_THREAD' + + def __init__(self): + self._thread_inf: Dict[int, bool] = {} + self._unique_thread = bool( + int(os.environ.get(ActivationMixin.USE_UNIQUE_THREAD, '0'))) + + def set_activation(self, activate=True): + tid = 0 if self._unique_thread else threading.get_ident() + self._thread_inf[tid] = activate + + def is_activated(self): + tid = 0 if self._unique_thread else threading.get_ident() + return self._thread_inf.get(tid, True) diff --git a/swift/utils/llm_utils.py b/swift/utils/llm_utils.py index 3ae6e3aca7..b61ce173d7 100644 --- a/swift/utils/llm_utils.py +++ b/swift/utils/llm_utils.py @@ -41,6 +41,7 @@ def data_collate_fn(batch: List[Dict[str, Any]], tokenizer) -> Dict[str, Any]: attention_mask = pad_sequence( attention_mask, batch_first=True, padding_value=0) labels = pad_sequence(labels, batch_first=True, padding_value=-100) + return { 'input_ids': input_ids, 'attention_mask': attention_mask, diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py index d993c7318d..b51453df2e 100644 --- a/swift/utils/torch_utils.py +++ b/swift/utils/torch_utils.py @@ -89,13 +89,24 @@ def print_model_info(model: Module, name: Optional[str] = None) -> None: n_params /= 1e6 n_grads /= 1e6 n_buffers /= 1e6 - s = [ - f'{name}: ', - f'{n_params:.4f}M Params ({n_grads:.4f}M Trainable), ', - f'{n_buffers:.4f}M Buffers', - ] - s += '.' - logger.info(''.join(s)) + s = (f'{name}: ' + f'{n_params:.4f}M Params ({n_grads:.4f}M Trainable ' + f'[{100 * n_grads / n_params:.4f}%]), ' + f'{n_buffers:.4f}M Buffers.') + logger.info(s) + + +def find_sub_module(module: torch.nn.Module, + module_name: str) -> List[torch.nn.Module]: + _modules = list() + for name, sub_module in module.named_modules(): + if not name: + continue + if module_name == name: + _modules.append(sub_module) + else: + _modules.extend(find_sub_module(sub_module, module_name)) + return _modules def get_seed(random_state: RandomState) -> int: diff --git a/tests/tuners/test_swift_base.py b/tests/tuners/test_swift_base.py index 715fd0c743..f6deec9f86 100644 --- a/tests/tuners/test_swift_base.py +++ b/tests/tuners/test_swift_base.py @@ -1,16 +1,21 @@ import copy +import math import os import shutil import tempfile import unittest +from concurrent.futures import ThreadPoolExecutor from time import time import torch +from modelscope import Model, Preprocessor from modelscope.models.nlp.structbert import (SbertConfig, SbertForSequenceClassification) from peft.utils import WEIGHTS_NAME +from torch import nn -from swift import AdapterConfig, LoRAConfig, Swift, SwiftModel, push_to_hub +from swift import (AdapterConfig, LoRAConfig, PromptConfig, ResTuningConfig, + SideConfig, Swift, SwiftModel, push_to_hub) class TestSwift(unittest.TestCase): @@ -25,12 +30,140 @@ def tearDown(self): shutil.rmtree(self.tmp_dir) super().tearDown() + def test_swift_lora_forward(self): + + from swift.tuners.lora import Linear + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.ones_(self.lora_B) + + Linear.reset_parameters = reset_parameters + + model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + preprocessor = Preprocessor.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + inputs = preprocessor('how are you') + lora_config = LoRAConfig(target_modules=['query', 'key', 'value']) + outputs = model(**inputs) + model = Swift.prepare_model(model, config=lora_config) + outputs_lora = model(**inputs) + model.deactivate_adapter('default') + outputs_deactivate = model(**inputs) + model.activate_adapter('default') + outputs_reactivate = model(**inputs) + self.assertTrue( + torch.allclose(outputs.logits, outputs_deactivate.logits)) + self.assertTrue( + not torch.allclose(outputs.logits, outputs_lora.logits)) + self.assertTrue( + torch.allclose(outputs_lora.logits, outputs_reactivate.logits)) + + def test_swift_adapter_forward(self): + model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + preprocessor = Preprocessor.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + inputs = preprocessor('how are you') + adapter_config = AdapterConfig( + dim=model.config.hidden_size, + target_modules=r'.*layer\.\d+$', + method_name='feed_forward_chunk', + hidden_pos=0) + outputs = model(**inputs) + model = Swift.prepare_model(model, config=adapter_config) + outputs_lora = model(**inputs) + model.deactivate_adapter('default') + outputs_deactivate = model(**inputs) + model.activate_adapter('default') + outputs_reactivate = model(**inputs) + self.assertTrue( + torch.allclose(outputs.logits, outputs_deactivate.logits)) + self.assertTrue( + not torch.allclose(outputs.logits, outputs_lora.logits)) + self.assertTrue( + torch.allclose(outputs_lora.logits, outputs_reactivate.logits)) + + def test_swift_prompt_forward(self): + model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + preprocessor = Preprocessor.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + inputs = preprocessor('how are you') + prompt_config = PromptConfig( + dim=model.config.hidden_size, + target_modules=r'.*layer\.\d+$', + embedding_pos=0, + attention_mask_pos=1) + outputs = model(**inputs) + model = Swift.prepare_model(model, config=prompt_config) + outputs_lora = model(**inputs) + model.deactivate_adapter('default') + outputs_deactivate = model(**inputs) + model.activate_adapter('default') + outputs_reactivate = model(**inputs) + self.assertTrue( + torch.allclose(outputs.logits, outputs_deactivate.logits)) + self.assertTrue( + not torch.allclose(outputs.logits, outputs_lora.logits)) + self.assertTrue( + torch.allclose(outputs_lora.logits, outputs_reactivate.logits)) + + def test_swift_restuner_forward(self): + model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + preprocessor = Preprocessor.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + inputs = preprocessor('how are you') + restuner_config = ResTuningConfig( + dims=model.config.hidden_size, + root_modules=r'.*layer.0$', + stem_modules=r'.*layer\.\d+$', + target_modules=r'.*pooler', + target_modules_hook='input', + tuner_cfg='res_adapter', + ) + outputs = model(**inputs) + model = Swift.prepare_model(model, config=restuner_config) + outputs_lora = model(**inputs) + model.deactivate_adapter('default') + outputs_deactivate = model(**inputs) + model.activate_adapter('default') + outputs_reactivate = model(**inputs) + self.assertTrue( + torch.allclose(outputs.logits, outputs_deactivate.logits)) + self.assertTrue( + not torch.allclose(outputs.logits, outputs_lora.logits)) + self.assertTrue( + torch.allclose(outputs_lora.logits, outputs_reactivate.logits)) + def test_swift_lora_injection(self): - model = SbertForSequenceClassification(SbertConfig()) + + from swift.tuners.lora import Linear + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.ones_(self.lora_A) + nn.init.ones_(self.lora_B) + + Linear.reset_parameters = reset_parameters + + model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + preprocessor = Preprocessor.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + input = preprocessor('this is a test') model2 = copy.deepcopy(model) lora_config = LoRAConfig(target_modules=['query', 'key', 'value']) model = Swift.prepare_model(model, config=lora_config) self.assertTrue(isinstance(model, SwiftModel)) + output1 = model(**input) model.save_pretrained(self.tmp_dir) self.assertTrue(os.path.exists(os.path.join(self.tmp_dir, 'default'))) self.assertTrue( @@ -38,7 +171,8 @@ def test_swift_lora_injection(self): os.path.join(self.tmp_dir, 'default', WEIGHTS_NAME))) model2 = Swift.from_pretrained(model2, self.tmp_dir) - + output2 = model2(**input) + self.assertTrue(torch.allclose(output1.logits, output2.logits)) state_dict = model.state_dict() state_dict2 = model2.state_dict() for key in state_dict: @@ -92,3 +226,172 @@ def test_swift_multiple_adapters(self): all( torch.isclose(state_dict[key], state_dict2[key]).flatten().detach().cpu())) + + def test_swift_multiple_adapters_switching(self): + from swift.tuners.lora import Linear + from swift.tuners.adapter import AdapterModule + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.ones_(self.lora_A) + nn.init.ones_(self.lora_B) + + Linear.reset_parameters = reset_parameters + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + nn.init.ones_(m.weight) + nn.init.ones_(m.bias) + + self.apply(_init_weights) + + AdapterModule.init_weights = init_weights + + model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + preprocessor = Preprocessor.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + inputs = preprocessor('how are you') + model1 = copy.deepcopy(model) + model2 = copy.deepcopy(model) + model1 = Swift.prepare_model( + model1, + config={ + 'lora1': + LoRAConfig(target_modules=['query', 'key', 'value']), + 'adapter1': + AdapterConfig( + dim=model.config.hidden_size, + target_modules=r'.*layer\.\d+$', + method_name='feed_forward_chunk', + hidden_pos=0) + }) + model2 = Swift.prepare_model( + model2, + config={ + 'lora2': + LoRAConfig(target_modules=['query', 'key', 'value']), + 'adapter2': + AdapterConfig( + dim=model.config.hidden_size, + target_modules=r'.*layer\.\d+$', + method_name='feed_forward_chunk', + hidden_pos=0) + }) + model = Swift.prepare_model( + model, + config={ + 'lora1': LoRAConfig(target_modules=['query', 'key', 'value']), + 'lora2': LoRAConfig(target_modules=['query', 'key', 'value']), + }) + + model = Swift.prepare_model( + model, + config={ + 'adapter1': + AdapterConfig( + dim=model.config.hidden_size, + target_modules=r'.*layer\.\d+$', + method_name='feed_forward_chunk', + hidden_pos=0), + 'adapter2': + AdapterConfig( + dim=model.config.hidden_size, + target_modules=r'.*layer\.\d+$', + method_name='feed_forward_chunk', + hidden_pos=0), + }) + + model.deactivate_adapter('adapter2') + model.deactivate_adapter('lora2') + outputs1 = model(**inputs) + outputs2 = model1(**inputs) + self.assertTrue(torch.allclose(outputs1.logits, outputs2.logits)) + model.activate_adapter('adapter2') + model.activate_adapter('lora2') + model.deactivate_adapter('adapter1') + model.deactivate_adapter('lora1') + outputs1 = model(**inputs) + outputs2 = model2(**inputs) + self.assertTrue(torch.allclose(outputs1.logits, outputs2.logits)) + + def thread_func1(): + model.set_active_adapters(['lora1', 'adapter1']) + outputs_single = model1(**inputs) + outputs_t1 = model(**inputs) + self.assertTrue( + torch.allclose(outputs_single.logits, outputs_t1.logits)) + + def thread_func2(): + model.set_active_adapters(['lora2', 'adapter2']) + outputs_single = model2(**inputs) + outputs_t2 = model(**inputs) + self.assertTrue( + torch.allclose(outputs_single.logits, outputs_t2.logits)) + + with ThreadPoolExecutor(2) as executor: + f1 = executor.submit(thread_func1) + f2 = executor.submit(thread_func2) + e1 = f1.exception() + e2 = f2.exception() + if e1 is not None: + raise e1 + if e2 is not None: + raise e2 + + def test_swift_side_bert(self): + model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + preprocessor = Preprocessor.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + inputs = preprocessor('how are you') + model2 = copy.deepcopy(model) + result_origin = model(**inputs).logits + print( + f'test_swift_side_bert result_origin shape: {result_origin.shape}, ' + f'result_origin sum: {torch.sum(result_origin)}') + + side_config = SideConfig( + dim=model.config.hidden_size, + target_modules=r'.*encoder.encoder', + side_module_name='mlp', + target_hidden_pos='last_hidden_state') + + model = Swift.prepare_model(model, config=side_config) + result_activate = model(**inputs).logits + model.deactivate_adapter('default') + result_deactivate = model(**inputs).logits + model.activate_adapter('default') + result_reactivate = model(**inputs).logits + self.assertTrue(torch.allclose(result_origin, result_deactivate)) + self.assertTrue(not torch.allclose(result_origin, result_activate)) + self.assertTrue(torch.allclose(result_activate, result_reactivate)) + print( + f'test_swift_side_bert result shape: {result_origin.shape}, result sum: {torch.sum(result_origin)}' + ) + + self.assertTrue(isinstance(model, SwiftModel)) + model.save_pretrained(self.tmp_dir) + self.assertTrue(os.path.exists(os.path.join(self.tmp_dir, 'default'))) + self.assertTrue( + os.path.exists( + os.path.join(self.tmp_dir, 'default', WEIGHTS_NAME))) + + model2 = Swift.from_pretrained(model2, self.tmp_dir) + + state_dict = model.state_dict() + state_dict2 = model2.state_dict() + for key in state_dict: + self.assertTrue(key in state_dict2) + self.assertTrue( + all( + torch.isclose(state_dict[key], + state_dict2[key]).flatten().detach().cpu())) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tuners/test_swift_restuning.py b/tests/tuners/test_swift_restuning.py new file mode 100644 index 0000000000..016c8d7361 --- /dev/null +++ b/tests/tuners/test_swift_restuning.py @@ -0,0 +1,152 @@ +import copy +import os +import shutil +import tempfile +import unittest + +import torch + +from swift import ResTuningConfig, Swift, SwiftModel, snapshot_download + + +class TestSwiftResTuning(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def set_random_seed(self, seed=123): + """Set random seed manually to get deterministic results""" + import random + import numpy as np + import torch + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + def model_comparison(self, model, model2): + model_key = list(model.state_dict().keys()) + model2_key = list(model2.state_dict().keys()) + self.assertTrue(model_key == model2_key) + model_val = torch.sum( + torch.stack( + [torch.sum(val) for val in model.state_dict().values()])) + model2_val = torch.sum( + torch.stack( + [torch.sum(val) for val in model2.state_dict().values()])) + self.assertTrue(torch.isclose(model_val, model2_val)) + + def test_swift_restuning_vit(self): + model_dir = snapshot_download('AI-ModelScope/vit-base-patch16-224') + from transformers import AutoModelForImageClassification + model = AutoModelForImageClassification.from_pretrained(model_dir) + model_swift_1 = copy.deepcopy(model) + model_swift_2 = copy.deepcopy(model) + result_origin = model(torch.ones((1, 3, 224, 224))).logits + print( + f'test_swift_restuning_vit result_origin shape: {result_origin.shape}, ' + f'result_origin sum: {torch.sum(result_origin)}') + + # load type - 1 + self.set_random_seed() + restuning_config_1 = ResTuningConfig( + dims=768, + root_modules=r'.*vit.encoder.layer.0$', + stem_modules=r'.*vit.encoder.layer\.\d+$', + target_modules=r'.*vit.layernorm', + target_modules_hook='input', + tuner_cfg='res_adapter', + ) + model_swift_1 = Swift.prepare_model( + model_swift_1, config=restuning_config_1) + self.assertTrue(isinstance(model_swift_1, SwiftModel)) + print(model_swift_1.get_trainable_parameters()) + result_swift_1 = model_swift_1(torch.ones((1, 3, 224, 224))).logits + print( + f'test_swift_restuning_vit result_swift_1 shape: {result_swift_1.shape}, ' + f'result_swift_1 sum: {torch.sum(result_swift_1)}') + + # load type - 2 + self.set_random_seed() + restuning_config_2 = ResTuningConfig( + dims=768, + root_modules=r'.*vit.encoder.layer.0$', + stem_modules=r'.*vit.encoder.layer\.\d+$', + target_modules=r'.*vit.encoder', + target_modules_hook='output', + target_hidden_pos='last_hidden_state', + tuner_cfg='res_adapter', + ) + model_swift_2 = Swift.prepare_model( + model_swift_2, config=restuning_config_2) + self.assertTrue(isinstance(model_swift_2, SwiftModel)) + print(model_swift_2.get_trainable_parameters()) + result_swift_2 = model_swift_2(torch.ones((1, 3, 224, 224))).logits + print( + f'test_swift_restuning_vit result_swift_2 shape: {result_swift_2.shape}, ' + f'result_swift_2 sum: {torch.sum(result_swift_2)}') + + self.assertTrue( + all(torch.isclose(result_swift_1, result_swift_2).flatten())) + + model_swift_1.save_pretrained(self.tmp_dir) + self.assertTrue(os.path.exists(os.path.join(self.tmp_dir, 'default'))) + model_loaded = Swift.from_pretrained(model, self.tmp_dir) + self.model_comparison(model_swift_1, model_loaded) + + def test_swift_restuning_diffusers_sd(self): + model_dir = snapshot_download('AI-ModelScope/stable-diffusion-v1-5') + from diffusers import UNet2DConditionModel + model = UNet2DConditionModel.from_pretrained( + model_dir, subfolder='unet') + model.requires_grad_(False) + model2 = copy.deepcopy(model) + self.set_random_seed() + input_data = { + 'sample': torch.ones((1, 4, 64, 64)), + 'timestep': 10, + 'encoder_hidden_states': torch.ones((1, 77, 768)) + } + result_origin = model(**input_data).sample + print( + f'test_swift_restuning_diffusers_sd result_origin shape: {result_origin.shape}, ' + f'result_origin sum: {torch.sum(result_origin)}') + + self.set_random_seed() + restuning_config = ResTuningConfig( + dims=[1280, 1280, 1280, 640, 320], + root_modules='mid_block', + stem_modules=[ + 'mid_block', 'up_blocks.0', 'up_blocks.1', 'up_blocks.2', + 'up_blocks.3' + ], + target_modules='conv_norm_out', + tuner_cfg='res_group_adapter', + use_upsample=True, + upsample_out_channels=[1280, 1280, 640, 320, None], + zero_init_last=True) + + model = Swift.prepare_model(model, config=restuning_config) + self.assertTrue(isinstance(model, SwiftModel)) + print(model.get_trainable_parameters()) + + result = model(**input_data).sample + print( + f'test_swift_restuning_diffusers_sd result shape: {result.shape}, result sum: {torch.sum(result)}' + ) + model.save_pretrained(self.tmp_dir) + self.assertTrue(os.path.exists(os.path.join(self.tmp_dir, 'default'))) + model2 = Swift.from_pretrained(model2, self.tmp_dir) + self.model_comparison(model, model2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils/test_torch_utils.py b/tests/utils/test_torch_utils.py new file mode 100644 index 0000000000..106f5148eb --- /dev/null +++ b/tests/utils/test_torch_utils.py @@ -0,0 +1,17 @@ +import unittest + +from modelscope import Model + +from swift.utils.torch_utils import find_sub_module + + +class TestTorchUtils(unittest.TestCase): + + def test_find_sub_module(self): + model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + self.assertTrue(find_sub_module(model, 'query') is not None) + + +if __name__ == '__main__': + unittest.main()