diff --git a/swift/tuners/adapter.py b/swift/tuners/adapter.py index a89f52b95b..c2e04ac242 100644 --- a/swift/tuners/adapter.py +++ b/swift/tuners/adapter.py @@ -131,6 +131,7 @@ def _feed_forward_chunk(self, attention_output): setattr(module, config.method_name, types.MethodType(_forward, module)) adapter_module = AdapterModule(config.dim, adapter_name, + module_key, config.adapter_length, ACT2CLS[config.act_layer]) setattr(module, f'adapter_{adapter_name}', adapter_module) @@ -152,13 +153,17 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( - module, f'adapter_{adapter_name}') + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules = find_sub_module(module, f'adapter_{adapter_name}') for _module in modules: _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, offload) class AdapterModule(nn.Module, ActivationMixin): @@ -177,11 +182,12 @@ def __init__( self, dim, adapter_name, + module_key, adapter_length=None, act_layer=nn.GELU, ): super(AdapterModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__(module_key) self.dim = dim self.adapter_name = adapter_name self.adapter_length = adapter_length diff --git a/swift/tuners/base.py b/swift/tuners/base.py index 658907712f..8fc7bd0dad 100644 --- a/swift/tuners/base.py +++ b/swift/tuners/base.py @@ -432,7 +432,9 @@ def save_pretrained(self, def base_model(self): return self.model - def set_active_adapters(self, adapter_names: Union[List[str], str]): + def set_active_adapters(self, + adapter_names: Union[List[str], str], + offload=None): if not adapter_names: return @@ -444,7 +446,7 @@ def set_active_adapters(self, adapter_names: Union[List[str], str]): self.activate_adapter(adapter_name) for adapter_name in (set(self.adapters.keys()) - adapter_names): - self.deactivate_adapter(adapter_name) + self.deactivate_adapter(adapter_name, offload) def activate_adapter(self, adapter_name): if adapter_name not in self.adapters: @@ -456,7 +458,7 @@ def activate_adapter(self, adapter_name): SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ .activate_adapter(self.base_model, adapter_name, True) - def deactivate_adapter(self, adapter_name): + def deactivate_adapter(self, adapter_name, offload=None): if adapter_name not in self.adapters: logger.warning( f'{adapter_name} not in adapters: {self.adapters.keys()}') @@ -464,7 +466,7 @@ def deactivate_adapter(self, adapter_name): from .mapping import SWIFT_MAPPING SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ - .activate_adapter(self.base_model, adapter_name, False) + .activate_adapter(self.base_model, adapter_name, False, offload=offload) def get_trainable_parameters(self): """ diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py index 19a274f26e..b62931904d 100644 --- a/swift/tuners/lora.py +++ b/swift/tuners/lora.py @@ -5,6 +5,7 @@ import torch from packaging import version +from peft.tuners.lora import LoraLayer from swift import LoraConfig from .lora_layers import * # noqa @@ -69,12 +70,15 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - set_adapter(module, adapter_name, activate) + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + set_adapter(module, adapter_name, activate, offload) for sub_module in module.modules(): if isinstance(sub_module, (LoraLayer, LoRALayer)): sub_module.set_activation(adapter_name, activate) + sub_module.save_memory(adapter_name, activate, offload) @staticmethod def unpatch_lora(model, config: LoRAConfig, adapter_name: str): diff --git a/swift/tuners/lora_layers.py b/swift/tuners/lora_layers.py index 2cdf229f36..bfbfa66b73 100644 --- a/swift/tuners/lora_layers.py +++ b/swift/tuners/lora_layers.py @@ -16,7 +16,6 @@ from peft.tuners.lora import Conv2d as _Conv2d from peft.tuners.lora import Embedding as _Embedding from peft.tuners.lora import Linear as _Linear -from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraModel as _LoraModel from peft.tuners.lora.tp_layer import LoraParallelLinear as _LoraParallelLinear from peft.tuners.tuners_utils import BaseTunerLayer @@ -25,7 +24,7 @@ from transformers import Conv1D from swift import get_logger -from .utils import ActivationMixin, ModulesToSaveWrapper +from .utils import ActivationMixin, ModulesToSaveWrapper, SwiftAdapter logger = get_logger() @@ -52,7 +51,7 @@ def active_adapters(self): def active_adapter(self) -> str: return self.get_activated_adapters() - def set_adapter(self, adapter_names): + def set_adapter(self, adapter_names, offload=None): if isinstance(adapter_names, str): adapter_names = [adapter_names] @@ -63,9 +62,28 @@ def set_adapter(self, adapter_names): if key in adapter_names: self.set_activation(key, True) layer.requires_grad_(True) + SwiftAdapter.save_memory(layer, key, self.module_key, True) else: self.set_activation(key, False) layer.requires_grad_(False) + SwiftAdapter.save_memory( + layer, key, self.module_key, False, offload=offload) + + def save_memory(self, adapter_name, activate, offload=None): + for layer_name in self.adapter_layer_names: + module_dict = getattr(self, layer_name) + for key, layer in module_dict.items(): + if key == adapter_name: + if activate: + SwiftAdapter.save_memory(layer, layer_name + '.' + key, + self.module_key, True) + else: + SwiftAdapter.save_memory( + layer, + layer_name + '.' + key, + self.module_key, + False, + offload=offload) def merge(self, *args, **kwargs): if not self.unique_thread: @@ -85,9 +103,10 @@ class Linear8bitLt(LoRAActivationMixin, _Linear8bitLt): def __init__( self, *args, + module_key: str, **kwargs, ): - super(Linear8bitLt, self).__init__() + super(Linear8bitLt, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) @@ -100,9 +119,10 @@ class Linear4bit(LoRAActivationMixin, _Linear4bit): def __init__( self, *args, + module_key: str, **kwargs, ): - super(Linear4bit, self).__init__() + super(Linear4bit, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) @@ -117,9 +137,10 @@ def __init__( *args, use_qa_lora=False, group_size=None, + module_key: str, **kwargs, ): - super(QuantLinear, self).__init__() + super(QuantLinear, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) self.group_size = group_size @@ -166,33 +187,34 @@ class Embedding(LoRAActivationMixin, _Embedding): def __init__( self, *args, + module_key: str, **kwargs, ) -> None: - super(Embedding, self).__init__() + super(Embedding, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) class Linear(LoRAActivationMixin, _Linear): - def __init__(self, *args, **kwargs): - super(Linear, self).__init__() + def __init__(self, *args, module_key: str, **kwargs): + super(Linear, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) class Conv2d(LoRAActivationMixin, _Conv2d): - def __init__(self, *args, **kwargs): - super(Conv2d, self).__init__() + def __init__(self, *args, module_key: str, **kwargs): + super(Conv2d, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) class LoraParallelLinear(LoRAActivationMixin, _LoraParallelLinear): - def __init__(self, *args, **kwargs): - super(LoraParallelLinear, self).__init__() + def __init__(self, *args, module_key: str, **kwargs): + super(LoraParallelLinear, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) @@ -249,7 +271,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str): parent, target, target_name = _get_submodules(model, key) if not isinstance(target, ModulesToSaveWrapper): - new_module = ModulesToSaveWrapper(target, adapter_name) + new_module = ModulesToSaveWrapper( + target, adapter_name, module_key=key) setattr(parent, target_name, new_module) else: target.update(adapter_name) @@ -384,8 +407,12 @@ def _create_and_replace( ) self._convert_dtype(target, lora_config.lora_dtype) else: - new_module = self._create_new_module(lora_config, adapter_name, - target, **kwargs) + new_module = self._create_new_module( + lora_config, + adapter_name, + target, + current_key=current_key, + **kwargs) if new_module is not None: if adapter_name != self.active_adapter: # adding an additional adapter: it is not automatically trainable @@ -395,6 +422,7 @@ def _create_and_replace( @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): + current_key = kwargs.pop('current_key') gptq_quantization_config = kwargs.get('gptq_quantization_config', None) AutoGPTQQuantLinear = get_auto_gptq_quant_linear( gptq_quantization_config) @@ -422,7 +450,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): 'threshold': target.state.threshold, 'index': target.index, }) - new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs) + new_module = Linear8bitLt( + target, + adapter_name, + module_key=current_key, + **eightbit_kwargs) elif loaded_in_4bit and is_bnb_4bit_available() and isinstance( target_base_layer, bnb.nn.Linear4bit): fourbit_kwargs = kwargs.copy() @@ -434,19 +466,26 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): 'quant_type': target_base_layer.weight.quant_type, }) - new_module = Linear4bit(target, adapter_name, **fourbit_kwargs) + new_module = Linear4bit( + target, adapter_name, module_key=current_key, **fourbit_kwargs) elif AutoGPTQQuantLinear is not None and isinstance( target_base_layer, AutoGPTQQuantLinear): - new_module = QuantLinear(target, adapter_name, **kwargs) + new_module = QuantLinear( + target, adapter_name, module_key=current_key, **kwargs) target.qweight = target_base_layer.qweight elif isinstance(target_base_layer, torch.nn.Embedding): embedding_kwargs = kwargs.copy() embedding_kwargs.pop('fan_in_fan_out', None) embedding_kwargs.update(lora_config.loftq_config) - new_module = Embedding(target, adapter_name, **embedding_kwargs) + new_module = Embedding( + target, + adapter_name, + module_key=current_key, + **embedding_kwargs) elif isinstance(target_base_layer, torch.nn.Conv2d): kwargs.update(lora_config.loftq_config) - new_module = Conv2d(target, adapter_name, **kwargs) + new_module = Conv2d( + target, adapter_name, module_key=current_key, **kwargs) elif lora_config.use_merged_linear: new_module = MergedLinear( adapter_name, @@ -461,7 +500,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): 'Setting fan_in_fan_out to False.') kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = False kwargs.update(lora_config.loftq_config) - new_module = Linear(target, adapter_name, **kwargs) + new_module = Linear( + target, adapter_name, module_key=current_key, **kwargs) elif megatron_core and isinstance( target_base_layer, # noqa ( # noqa @@ -486,6 +526,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): new_module = LoraParallelLinear( base_layer=target, adapter_name=adapter_name, + module_key=current_key, backend=megatron_core.tensor_parallel, **megatron_kwargs) elif isinstance(target_base_layer, Conv1D): @@ -496,7 +537,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = True kwargs.update(lora_config.loftq_config) new_module = Linear( - target, adapter_name, is_target_conv_1d_layer=True, **kwargs) + target, + adapter_name, + module_key=current_key, + is_target_conv_1d_layer=True, + **kwargs) else: logger.debug( f'Target module {target} is not supported. Currently, only the following modules are supported: ' @@ -512,12 +557,13 @@ class LoRALayer(ActivationMixin): def __init__( self, adapter_name: str, + module_key: str, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): - super().__init__() + super().__init__(module_key) self.adapter_name = adapter_name self.r = r self.lora_alpha = lora_alpha @@ -537,6 +583,7 @@ class MergedLinear(nn.Linear, LoRALayer): # LoRA implemented in a dense layer def __init__(self, adapter_name: str, + module_key: str, base_layer: nn.Linear, r: int = 0, lora_alpha: int = 1, @@ -558,6 +605,7 @@ def __init__(self, LoRALayer.__init__( self, adapter_name, + module_key, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, diff --git a/swift/tuners/prompt.py b/swift/tuners/prompt.py index 1f6b4ae66c..23eeb34e72 100644 --- a/swift/tuners/prompt.py +++ b/swift/tuners/prompt.py @@ -152,7 +152,7 @@ def _forward(self, *args, **kwargs): input_dim = config.dim prompt_module = PromptModule(input_dim, int(module_key.rsplit('.')[-1]), - adapter_name, + adapter_name, module_key, config.prompt_length, config.attention_mask_value, config.attach_front) @@ -176,13 +176,17 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( - module, f'prompt_{adapter_name}') + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules = find_sub_module(module, f'prompt_{adapter_name}') for _module in modules: _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, offload) class PromptModule(nn.Module, ActivationMixin): @@ -203,11 +207,12 @@ def __init__(self, dim, layer_num, adapter_name, + module_key, prompt_length=None, mask_values=0., attach_front=True): super(PromptModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__(module_key) self.dim = dim self.layer_num = layer_num self.adapter_name = adapter_name diff --git a/swift/tuners/restuning.py b/swift/tuners/restuning.py index 65c213c8d0..e5d5103e45 100644 --- a/swift/tuners/restuning.py +++ b/swift/tuners/restuning.py @@ -303,13 +303,17 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( - module, f'restuning_{adapter_name}') + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules = find_sub_module(module, f'restuning_{adapter_name}') for _module in modules: _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, offload) class ResTuningBypassModule(nn.Module, ActivationMixin): @@ -327,7 +331,7 @@ def __init__( tuner_cfg=None, ): super(ResTuningBypassModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__('') self.adapter_name = adapter_name self.bypass_blocks = nn.Sequential(*[ diff --git a/swift/tuners/scetuning/scetuning.py b/swift/tuners/scetuning/scetuning.py index d007d0ef91..dcb5615eb3 100644 --- a/swift/tuners/scetuning/scetuning.py +++ b/swift/tuners/scetuning/scetuning.py @@ -200,6 +200,7 @@ def _forward_decoder_mode(self, *args, **kwargs): setattr(t_module, 'forward', types.MethodType(_forward, t_module)) tuner_op = SCETunerModule( name=config.tuner_op, + adapter_name=adapter_name, dim=dims[tuner_id], tuner_length=int(dims[tuner_id] * config.down_ratio)) setattr(t_module, f'scetuner_{adapter_name}', tuner_op) @@ -221,19 +222,24 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( - module, f'scetuner_{adapter_name}') + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules = find_sub_module(module, f'scetuner_{adapter_name}') for _module in modules: _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, offload) class SCETunerModule(nn.Module, ActivationMixin): def __init__(self, name, + adapter_name, dim, tuner_length, tuner_type=None, @@ -242,8 +248,9 @@ def __init__(self, zero_init_last=True, use_bias=True): super(SCETunerModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__('') self.name = name + self.adapter_name = adapter_name self.dim = dim if name == 'SCEAdapter': from .scetuning_components import SCEAdapter @@ -257,6 +264,8 @@ def __init__(self, raise Exception(f'Error tuner op {name}') def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs): + if not self.is_activated(self.adapter_name): + return x if self.name == 'SCEAdapter': out = self.tuner_op(x) else: diff --git a/swift/tuners/side.py b/swift/tuners/side.py index febf92cacf..cdd985b34f 100644 --- a/swift/tuners/side.py +++ b/swift/tuners/side.py @@ -121,7 +121,7 @@ def forward_seq(self, input, *args, **kwargs): setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward) tgt_module.forward = types.MethodType(_forward, tgt_module) - side_module = SideModule(config.dim, adapter_name, + side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name) setattr(tgt_module, f'side_{adapter_name}', side_module) logger.info( @@ -142,13 +142,17 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( - module, f'side_{adapter_name}') + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules = find_sub_module(module, f'side_{adapter_name}') for _module in modules: _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, offload) class SideModule(nn.Module, ActivationMixin): @@ -164,9 +168,9 @@ class SideModule(nn.Module, ActivationMixin): side_module_name: The name of the additive side networks. """ - def __init__(self, dim, adapter_name, side_module_name='fcn4'): + def __init__(self, dim, adapter_name, module_key, side_module_name='fcn4'): super(SideModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__(module_key) self.adapter_name = adapter_name side_module_name = side_module_name.lower() diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py index 25dd812753..73477936eb 100644 --- a/swift/tuners/utils.py +++ b/swift/tuners/utils.py @@ -1,20 +1,23 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Copyright 2023-present the HuggingFace Inc. team. +import hashlib import os +import shutil import threading from dataclasses import asdict, dataclass, field from types import FunctionType -from typing import Dict, List, Optional +from typing import Dict, OrderedDict import json -import peft.utils +import numpy as np import torch from peft.utils import CONFIG_NAME from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper from peft.utils import _get_submodules from swift.hub.snapshot_download import snapshot_download +from swift.hub.utils.utils import get_cache_dir from swift.utils.constants import BIN_EXTENSIONS from swift.utils.logger import get_logger @@ -140,12 +143,16 @@ class ActivationMixin: USE_UNIQUE_THREAD = 'USE_UNIQUE_THREAD' - def __init__(self): + REMINEDED = False + + def __init__(self, module_key): + self.module_key = module_key self._thread_inf: Dict[int, Dict[str, bool]] = {} self._unique_thread = bool( int(os.environ.get(ActivationMixin.USE_UNIQUE_THREAD, '1'))) - if not self._unique_thread: - logger.info( + if not self._unique_thread and not ActivationMixin.REMINEDED: + ActivationMixin.REMINEDED = True + logger.warn( 'Using multiple thread mode, gradient checkpointing is not supported.' ) @@ -175,6 +182,80 @@ def get_activated_adapters(self): ] +class OffloadHelper: + + sub_dir = 'offload_cache' + cache_dir = os.path.join(get_cache_dir(), sub_dir) + shutil.rmtree(cache_dir, ignore_errors=True) + os.makedirs(cache_dir, exist_ok=True) + index = {} + + @staticmethod + def offload_weight(weight, weight_name, offload_folder, index=None): + dtype = None + if str(weight.dtype) == 'torch.bfloat16': + weight = weight.view(torch.int16) + dtype = 'bfloat16' + array = weight.cpu().numpy() + tensor_file = os.path.join(offload_folder, f'{weight_name}.dat') + if index is not None: + if dtype is None: + dtype = str(array.dtype) + index[weight_name] = {'dtype': dtype, 'shape': list(array.shape)} + if array.ndim == 0: + array = array[None] + file_array = np.memmap( + tensor_file, dtype=array.dtype, mode='w+', shape=array.shape) + file_array[:] = array[:] + file_array.flush() + return index + + @staticmethod + def load_offloaded_weight(weight_file, weight_info): + shape = tuple(weight_info['shape']) + if shape == (): + shape = (1, ) + + dtype = weight_info['dtype'] + if dtype == 'bfloat16': + dtype = 'int16' + + weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode='r') + + if len(weight_info['shape']) == 0: + weight = weight[0] + weight = torch.tensor(weight) + if weight_info['dtype'] == 'bfloat16': + weight = weight.view(torch.bfloat16) + + return weight + + @staticmethod + def offload_disk(module: torch.nn.Module, adapter_name, module_key): + key = adapter_name + ':' + module_key + md5 = hashlib.md5(key.encode('utf-8')).hexdigest() + sub_folder = os.path.join(OffloadHelper.cache_dir, md5) + os.makedirs(sub_folder, exist_ok=True) + state_dict = module.state_dict() + OffloadHelper.index[md5] = {} + for key, tensor in state_dict.items(): + OffloadHelper.offload_weight(tensor, key, sub_folder, + OffloadHelper.index[md5]) + + @staticmethod + def load_disk(module: torch.nn.Module, adapter_name, module_key): + key = adapter_name + ':' + module_key + md5 = hashlib.md5(key.encode('utf-8')).hexdigest() + sub_folder = os.path.join(OffloadHelper.cache_dir, md5) + state_dict = {} + for key, value in OffloadHelper.index[md5].items(): + file = os.path.join(sub_folder, f'{key}.dat') + state_dict[key] = OffloadHelper.load_offloaded_weight( + file, OffloadHelper.index[md5][key]) + module.load_state_dict(state_dict, assign=True) + shutil.rmtree(sub_folder, ignore_errors=True) + + class SwiftAdapter: @staticmethod @@ -183,10 +264,63 @@ def prepare_model(model: torch.nn.Module, config: SwiftConfig, raise NotImplementedError @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): raise NotImplementedError + @staticmethod + def save_memory(module: torch.nn.Module, + adapter_name: str, + module_key: str, + activate: bool, + offload: str = None): + if activate: + SwiftAdapter.load(module, adapter_name, module_key) + else: + SwiftAdapter.offload( + module, adapter_name, module_key, offload=offload) + + @staticmethod + def offload(module: torch.nn.Module, adapter_name, module_key, + offload: str): + if not offload: + return + device = next(iter(module.parameters())).device + if hasattr(module, + 'origin_device') and module.origin_device != str(device): + return + module.origin_device = str(device) + if offload == 'cpu': + if str(device) != 'cpu': + module.to('cpu') + if offload == 'meta': + if str(device) != 'meta': + OffloadHelper.offload_disk( + module, adapter_name=adapter_name, module_key=module_key) + module.to('meta') + else: + raise NotImplementedError + torch.cuda.empty_cache() + + @staticmethod + def load(module: torch.nn.Module, adapter_name, module_key): + device = next(iter(module.parameters())).device + if not hasattr(module, + 'origin_device') or module.origin_device == str(device): + return + if str(device) == 'cpu': + module.to(module.origin_device) + delattr(module, 'origin_device') + elif str(device) == 'meta': + OffloadHelper.load_disk( + module, adapter_name=adapter_name, module_key=module_key) + module.to(module.origin_device) + delattr(module, 'origin_device') + else: + raise NotImplementedError + @staticmethod def freeze_model(): return True @@ -194,7 +328,8 @@ def freeze_model(): class ModulesToSaveWrapper(ActivationMixin, _ModulesToSaveWrapper): - def __init__(self, *args, **kwargs): + def __init__(self, *args, module_key, **kwargs): + self.module_key = module_key super(ModulesToSaveWrapper, self).__init__() super(ActivationMixin, self).__init__(*args, **kwargs) @@ -209,27 +344,35 @@ def active_adapter(self): ) return active_adapters[0] - def set_adapter(self, adapter_name: str): + def set_adapter(self, adapter_name: str, offload: str): if adapter_name not in self.modules_to_save: raise ValueError( f'Adapter {adapter_name} not found in {self.modules_to_save.keys()}' ) self.modules_to_save[adapter_name].requires_grad_(True) self.set_activation(adapter_name, True) + SwiftAdapter.save_memory(self.modules_to_save[adapter_name], + adapter_name, self.module_key, True) - def deactivate_adapter(self, adapter_name: str): + def deactivate_adapter(self, adapter_name: str, offload: str): if adapter_name in self.modules_to_save and self.unique_thread: self.modules_to_save[adapter_name].requires_grad_(False) self.set_activation(adapter_name, False) + SwiftAdapter.save_memory( + self.modules_to_save[adapter_name], + adapter_name, + self.module_key, + False, + offload=offload) -def set_adapter(model, adapter_name, activate): +def set_adapter(model, adapter_name, activate, offload): for module in model.modules(): if isinstance(module, ModulesToSaveWrapper): if activate: - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, offload) else: - module.deactivate_adapter(adapter_name) + module.deactivate_adapter(adapter_name, offload) def set_trainable(model, adapter_name): diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py index 5524e74f5b..1e05b25a34 100644 --- a/swift/utils/torch_utils.py +++ b/swift/utils/torch_utils.py @@ -74,10 +74,8 @@ def find_sub_module(module: torch.nn.Module, for name, sub_module in module.named_modules(): if not name: continue - if module_name == name: + if name.endswith(module_name): _modules.append(sub_module) - else: - _modules.extend(find_sub_module(sub_module, module_name)) return _modules diff --git a/tests/tuners/test_swift_base.py b/tests/tuners/test_swift_base.py index 72ab10fc03..36d2ac1fbe 100644 --- a/tests/tuners/test_swift_base.py +++ b/tests/tuners/test_swift_base.py @@ -356,15 +356,15 @@ def _init_weights(m): hidden_pos=0), }) - model.deactivate_adapter('adapter2') - model.deactivate_adapter('lora2') + model.deactivate_adapter('adapter2', offload='meta') + model.deactivate_adapter('lora2', offload='meta') outputs1 = model(**inputs) outputs2 = model1(**inputs) self.assertTrue(torch.allclose(outputs1.logits, outputs2.logits)) model.activate_adapter('adapter2') model.activate_adapter('lora2') - model.deactivate_adapter('adapter1') - model.deactivate_adapter('lora1') + model.deactivate_adapter('adapter1', offload='meta') + model.deactivate_adapter('lora1', offload='meta') outputs1 = model(**inputs) outputs2 = model2(**inputs) self.assertTrue(torch.allclose(outputs1.logits, outputs2.logits)) @@ -372,16 +372,16 @@ def _init_weights(m): if os.environ.get('USE_UNIQUE_THREAD') == '0': def thread_func1(): - model1.set_active_adapters(['lora1', 'adapter1']) - model.set_active_adapters(['lora1', 'adapter1']) + model1.set_active_adapters(['lora1', 'adapter1'], offload=None) + model.set_active_adapters(['lora1', 'adapter1'], offload=None) outputs_single = model1(**inputs) outputs_t1 = model(**inputs) self.assertTrue( torch.allclose(outputs_single.logits, outputs_t1.logits)) def thread_func2(): - model2.set_active_adapters(['lora2', 'adapter2']) - model.set_active_adapters(['lora2', 'adapter2']) + model2.set_active_adapters(['lora2', 'adapter2'], offload=None) + model.set_active_adapters(['lora2', 'adapter2'], offload=None) outputs_single = model2(**inputs) outputs_t2 = model(**inputs) self.assertTrue(