Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a tuner to train a part of the model #1200

Merged
merged 2 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions swift/tuners/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .longlora.longlora import LongLoRA, LongLoRAConfig
from .lora import LoRA, LoRAConfig
from .neftune import NEFTune, NEFTuneConfig
from .part import Part, PartConfig
from .prompt import Prompt, PromptConfig
from .restuning import ResTuning, ResTuningConfig
from .rome import Rome, RomeConfig
Expand All @@ -23,6 +24,7 @@ class SwiftTuners:
NEFTUNE = 'neftune'
LLAMAPRO = 'LLAMAPRO'
SCETUNING = 'SCETuning'
PART = 'part'


SWIFT_MAPPING = {
Expand All @@ -36,4 +38,5 @@ class SwiftTuners:
SwiftTuners.NEFTUNE: (NEFTuneConfig, NEFTune),
SwiftTuners.SCETUNING: (SCETuningConfig, SCETuning),
SwiftTuners.LLAMAPRO: (LLaMAProConfig, LLaMAPro),
SwiftTuners.PART: (PartConfig, Part),
}
56 changes: 56 additions & 0 deletions swift/tuners/part.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import re
import types
from dataclasses import dataclass, field
from typing import List, Optional, Union

import torch
from torch import nn

from swift import get_logger
from swift.utils.torch_utils import find_sub_module
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()


@dataclass
class PartConfig(SwiftConfig):
"""
Freeze the model and train a part of it.

Args:
target_modules(`Optional[str]`): The target modules to be trained in regex format
"""

target_modules: Optional[str] = None

def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.PART


class Part(SwiftAdapter):

@staticmethod
def target_module_matched(module_key: str, config: PartConfig):
return re.fullmatch(config.target_modules, module_key)

@staticmethod
def prepare_model(model: nn.Module, config: PartConfig, adapter_name: str):

def state_dict_callback(state_dict, adapter_name):
return {key: value for key, value in state_dict.items() if Part.target_module_matched(key, config)}

def mark_trainable_callback(model: nn.Module):
for name, module in model.named_modules():
module: nn.Module
if Part.target_module_matched(name, config):
module.requires_grad_(True)

return SwiftOutput(config, state_dict_callback, mark_trainable_callback)

@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
pass
30 changes: 29 additions & 1 deletion tests/tuners/test_swift_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import math
import os
import re
import shutil
import tempfile
import unittest
Expand All @@ -10,12 +11,12 @@
import torch
from modelscope import Model, Preprocessor
from modelscope.models.nlp.structbert import SbertConfig, SbertForSequenceClassification
from packaging import version
from peft import PeftModel
from peft.utils import WEIGHTS_NAME
from torch import nn

from swift import AdapterConfig, LoRAConfig, PromptConfig, ResTuningConfig, SideConfig, Swift, SwiftModel
from swift.tuners.part import PartConfig


class TestSwift(unittest.TestCase):
Expand Down Expand Up @@ -280,6 +281,33 @@ def test_swift_multiple_adapters(self):
self.assertTrue(key in state_dict2)
self.assertTrue(all(torch.isclose(state_dict[key], state_dict2[key]).flatten().detach().cpu()))

def test_part(self):
model = SbertForSequenceClassification(SbertConfig())
model2 = copy.deepcopy(model)
targets = r'.*(query|key|value).*'
part_config = PartConfig(target_modules=targets)
model = Swift.prepare_model(model, config={'part': part_config})
self.assertTrue(isinstance(model, SwiftModel))
trainable = [name for name, p in model.named_parameters() if p.requires_grad]
not_trainable = [name for name, p in model.named_parameters() if not p.requires_grad]

def target_in(t: str):
return re.fullmatch(targets, t)

self.assertTrue(all([target_in(t) for t in trainable]))
self.assertTrue(not any([target_in(t) for t in not_trainable]))
model.save_pretrained(self.tmp_dir, adapter_name=['part'])
with open(os.path.join(self.tmp_dir, 'configuration.json'), 'w') as f:
f.write('{}')
self.assertTrue(os.path.exists(os.path.join(self.tmp_dir, 'part')))
self.assertTrue(os.path.exists(os.path.join(self.tmp_dir, 'part', WEIGHTS_NAME)))
model2 = Swift.from_pretrained(model2, self.tmp_dir, adapter_name=['part'])
state_dict = model.state_dict()
state_dict2 = model2.state_dict()
for key in state_dict:
self.assertTrue(key in state_dict2)
self.assertTrue(all(torch.isclose(state_dict[key], state_dict2[key]).flatten().detach().cpu()))

def test_swift_multiple_adapters_switching(self):
from swift.tuners.lora import Linear
from swift.tuners.adapter import AdapterModule
Expand Down
Loading