From 46f833cd91ba8b07f88cc0ceac4af8f48d8babcb Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Thu, 4 Jan 2024 20:15:27 +0800 Subject: [PATCH 1/7] no message --- swift/tuners/adapter.py | 15 ++-- swift/tuners/base.py | 4 +- swift/tuners/lora.py | 6 +- swift/tuners/lora_layers.py | 95 ++++++++++++++++++------ swift/tuners/prompt.py | 15 ++-- swift/tuners/restuning.py | 15 ++-- swift/tuners/scetuning/scetuning.py | 5 ++ swift/tuners/side.py | 14 ++-- swift/tuners/utils.py | 111 ++++++++++++++++++++++++++-- swift/utils/torch_utils.py | 14 +++- 10 files changed, 239 insertions(+), 55 deletions(-) diff --git a/swift/tuners/adapter.py b/swift/tuners/adapter.py index a89f52b95b..f98df01492 100644 --- a/swift/tuners/adapter.py +++ b/swift/tuners/adapter.py @@ -152,13 +152,18 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( - module, f'adapter_{adapter_name}') - for _module in modules: + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules, module_keys = find_sub_module(module, + f'adapter_{adapter_name}') + for _module_key, _module in zip(module_keys, modules): _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module_key, + activate, offload) class AdapterModule(nn.Module, ActivationMixin): diff --git a/swift/tuners/base.py b/swift/tuners/base.py index 658907712f..0577b61af7 100644 --- a/swift/tuners/base.py +++ b/swift/tuners/base.py @@ -456,7 +456,7 @@ def activate_adapter(self, adapter_name): SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ .activate_adapter(self.base_model, adapter_name, True) - def deactivate_adapter(self, adapter_name): + def deactivate_adapter(self, adapter_name, offload='cpu'): if adapter_name not in self.adapters: logger.warning( f'{adapter_name} not in adapters: {self.adapters.keys()}') @@ -464,7 +464,7 @@ def deactivate_adapter(self, adapter_name): from .mapping import SWIFT_MAPPING SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ - .activate_adapter(self.base_model, adapter_name, False) + .activate_adapter(self.base_model, adapter_name, False, offload=offload) def get_trainable_parameters(self): """ diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py index 19a274f26e..4242a97dd6 100644 --- a/swift/tuners/lora.py +++ b/swift/tuners/lora.py @@ -69,8 +69,10 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): set_adapter(module, adapter_name, activate) for sub_module in module.modules(): if isinstance(sub_module, (LoraLayer, LoRALayer)): diff --git a/swift/tuners/lora_layers.py b/swift/tuners/lora_layers.py index 2cdf229f36..9411ca069c 100644 --- a/swift/tuners/lora_layers.py +++ b/swift/tuners/lora_layers.py @@ -16,7 +16,6 @@ from peft.tuners.lora import Conv2d as _Conv2d from peft.tuners.lora import Embedding as _Embedding from peft.tuners.lora import Linear as _Linear -from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraModel as _LoraModel from peft.tuners.lora.tp_layer import LoraParallelLinear as _LoraParallelLinear from peft.tuners.tuners_utils import BaseTunerLayer @@ -25,7 +24,7 @@ from transformers import Conv1D from swift import get_logger -from .utils import ActivationMixin, ModulesToSaveWrapper +from .utils import ActivationMixin, ModulesToSaveWrapper, SwiftAdapter logger = get_logger() @@ -44,6 +43,14 @@ def is_auto_gptq_available(): class LoRAActivationMixin(ActivationMixin): + def __init__(self, module_key): + self.module_key = module_key + self.offloads = {} + super().__init__() + + def add_offload(self, adapter_name: str, offload=None): + self.offloads[adapter_name] = offload + @property def active_adapters(self): return self.get_activated_adapters() @@ -63,9 +70,21 @@ def set_adapter(self, adapter_names): if key in adapter_names: self.set_activation(key, True) layer.requires_grad_(True) + SwiftAdapter.save_memory( + layer, + key, + self.module_key, + True, + offload=self.offloads.get(key)) else: self.set_activation(key, False) layer.requires_grad_(False) + SwiftAdapter.save_memory( + layer, + key, + self.module_key, + False, + offload=self.offloads.get(key)) def merge(self, *args, **kwargs): if not self.unique_thread: @@ -85,9 +104,10 @@ class Linear8bitLt(LoRAActivationMixin, _Linear8bitLt): def __init__( self, *args, + module_key: str, **kwargs, ): - super(Linear8bitLt, self).__init__() + super(Linear8bitLt, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) @@ -100,9 +120,10 @@ class Linear4bit(LoRAActivationMixin, _Linear4bit): def __init__( self, *args, + module_key: str, **kwargs, ): - super(Linear4bit, self).__init__() + super(Linear4bit, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) @@ -117,9 +138,10 @@ def __init__( *args, use_qa_lora=False, group_size=None, + module_key: str, **kwargs, ): - super(QuantLinear, self).__init__() + super(QuantLinear, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) self.group_size = group_size @@ -166,33 +188,34 @@ class Embedding(LoRAActivationMixin, _Embedding): def __init__( self, *args, + module_key: str, **kwargs, ) -> None: - super(Embedding, self).__init__() + super(Embedding, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) class Linear(LoRAActivationMixin, _Linear): - def __init__(self, *args, **kwargs): - super(Linear, self).__init__() + def __init__(self, *args, module_key: str, **kwargs): + super(Linear, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) class Conv2d(LoRAActivationMixin, _Conv2d): - def __init__(self, *args, **kwargs): - super(Conv2d, self).__init__() + def __init__(self, *args, module_key: str, **kwargs): + super(Conv2d, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) class LoraParallelLinear(LoRAActivationMixin, _LoraParallelLinear): - def __init__(self, *args, **kwargs): - super(LoraParallelLinear, self).__init__() + def __init__(self, *args, module_key: str, **kwargs): + super(LoraParallelLinear, self).__init__(module_key) self.set_activation(args[1], True) super(ActivationMixin, self).__init__(*args, **kwargs) @@ -249,10 +272,12 @@ def inject_adapter(self, model: nn.Module, adapter_name: str): parent, target, target_name = _get_submodules(model, key) if not isinstance(target, ModulesToSaveWrapper): - new_module = ModulesToSaveWrapper(target, adapter_name) + new_module = ModulesToSaveWrapper( + target, adapter_name, module_key=key) setattr(parent, target_name, new_module) else: target.update(adapter_name) + target.add_offload(adapter_name, peft_config.offload) _has_modules_to_save = True continue @@ -364,6 +389,7 @@ def _create_and_replace( lora_config.lora_dropout, lora_config.init_lora_weights, ) + target.add_offload(adapter_name, lora_config.offload) self._convert_dtype(target, lora_config.lora_dtype) elif isinstance(target, Embedding): target.update_layer_embedding( @@ -373,6 +399,7 @@ def _create_and_replace( lora_config.lora_dropout, lora_config.init_lora_weights, ) + target.add_offload(adapter_name, lora_config.offload) self._convert_dtype(target, lora_config.lora_dtype) elif isinstance(target, linear_types): target.update_layer( @@ -382,19 +409,26 @@ def _create_and_replace( lora_config.lora_dropout, lora_config.init_lora_weights, ) + target.add_offload(adapter_name, lora_config.offload) self._convert_dtype(target, lora_config.lora_dtype) else: - new_module = self._create_new_module(lora_config, adapter_name, - target, **kwargs) + new_module = self._create_new_module( + lora_config, + adapter_name, + target, + current_key=current_key, + **kwargs) if new_module is not None: if adapter_name != self.active_adapter: # adding an additional adapter: it is not automatically trainable new_module.requires_grad_(False) self._replace_module(parent, target_name, new_module, target) self._convert_dtype(new_module, lora_config.lora_dtype) + new_module.add_offload(adapter_name, lora_config.offload) @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): + current_key = kwargs.pop('current_key') gptq_quantization_config = kwargs.get('gptq_quantization_config', None) AutoGPTQQuantLinear = get_auto_gptq_quant_linear( gptq_quantization_config) @@ -422,7 +456,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): 'threshold': target.state.threshold, 'index': target.index, }) - new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs) + new_module = Linear8bitLt( + target, + adapter_name, + module_key=current_key, + **eightbit_kwargs) elif loaded_in_4bit and is_bnb_4bit_available() and isinstance( target_base_layer, bnb.nn.Linear4bit): fourbit_kwargs = kwargs.copy() @@ -434,19 +472,26 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): 'quant_type': target_base_layer.weight.quant_type, }) - new_module = Linear4bit(target, adapter_name, **fourbit_kwargs) + new_module = Linear4bit( + target, adapter_name, module_key=current_key, **fourbit_kwargs) elif AutoGPTQQuantLinear is not None and isinstance( target_base_layer, AutoGPTQQuantLinear): - new_module = QuantLinear(target, adapter_name, **kwargs) + new_module = QuantLinear( + target, adapter_name, module_key=current_key, **kwargs) target.qweight = target_base_layer.qweight elif isinstance(target_base_layer, torch.nn.Embedding): embedding_kwargs = kwargs.copy() embedding_kwargs.pop('fan_in_fan_out', None) embedding_kwargs.update(lora_config.loftq_config) - new_module = Embedding(target, adapter_name, **embedding_kwargs) + new_module = Embedding( + target, + adapter_name, + module_key=current_key, + **embedding_kwargs) elif isinstance(target_base_layer, torch.nn.Conv2d): kwargs.update(lora_config.loftq_config) - new_module = Conv2d(target, adapter_name, **kwargs) + new_module = Conv2d( + target, adapter_name, module_key=current_key, **kwargs) elif lora_config.use_merged_linear: new_module = MergedLinear( adapter_name, @@ -461,7 +506,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): 'Setting fan_in_fan_out to False.') kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = False kwargs.update(lora_config.loftq_config) - new_module = Linear(target, adapter_name, **kwargs) + new_module = Linear( + target, adapter_name, module_key=current_key, **kwargs) elif megatron_core and isinstance( target_base_layer, # noqa ( # noqa @@ -486,6 +532,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): new_module = LoraParallelLinear( base_layer=target, adapter_name=adapter_name, + module_key=current_key, backend=megatron_core.tensor_parallel, **megatron_kwargs) elif isinstance(target_base_layer, Conv1D): @@ -496,7 +543,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = True kwargs.update(lora_config.loftq_config) new_module = Linear( - target, adapter_name, is_target_conv_1d_layer=True, **kwargs) + target, + adapter_name, + module_key=current_key, + is_target_conv_1d_layer=True, + **kwargs) else: logger.debug( f'Target module {target} is not supported. Currently, only the following modules are supported: ' diff --git a/swift/tuners/prompt.py b/swift/tuners/prompt.py index 1f6b4ae66c..6c90bcd115 100644 --- a/swift/tuners/prompt.py +++ b/swift/tuners/prompt.py @@ -176,13 +176,18 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( - module, f'prompt_{adapter_name}') - for _module in modules: + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules, module_keys = find_sub_module(module, + f'prompt_{adapter_name}') + for _module_key, _module in zip(module_keys, modules): _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module_key, + activate, offload) class PromptModule(nn.Module, ActivationMixin): diff --git a/swift/tuners/restuning.py b/swift/tuners/restuning.py index 65c213c8d0..3751b819c8 100644 --- a/swift/tuners/restuning.py +++ b/swift/tuners/restuning.py @@ -303,13 +303,18 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( - module, f'restuning_{adapter_name}') - for _module in modules: + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules, module_keys = find_sub_module(module, + f'restuning_{adapter_name}') + for _module_key, _module in zip(module_keys, modules): _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module_key, + activate, offload) class ResTuningBypassModule(nn.Module, ActivationMixin): diff --git a/swift/tuners/scetuning/scetuning.py b/swift/tuners/scetuning/scetuning.py index d007d0ef91..313fbd58de 100644 --- a/swift/tuners/scetuning/scetuning.py +++ b/swift/tuners/scetuning/scetuning.py @@ -200,6 +200,7 @@ def _forward_decoder_mode(self, *args, **kwargs): setattr(t_module, 'forward', types.MethodType(_forward, t_module)) tuner_op = SCETunerModule( name=config.tuner_op, + adapter_name=adapter_name, dim=dims[tuner_id], tuner_length=int(dims[tuner_id] * config.down_ratio)) setattr(t_module, f'scetuner_{adapter_name}', tuner_op) @@ -234,6 +235,7 @@ class SCETunerModule(nn.Module, ActivationMixin): def __init__(self, name, + adapter_name, dim, tuner_length, tuner_type=None, @@ -244,6 +246,7 @@ def __init__(self, super(SCETunerModule, self).__init__() super(nn.Module, self).__init__() self.name = name + self.adapter_name = adapter_name self.dim = dim if name == 'SCEAdapter': from .scetuning_components import SCEAdapter @@ -257,6 +260,8 @@ def __init__(self, raise Exception(f'Error tuner op {name}') def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs): + if not self.is_activated(self.adapter_name): + return x if self.name == 'SCEAdapter': out = self.tuner_op(x) else: diff --git a/swift/tuners/side.py b/swift/tuners/side.py index febf92cacf..1bb7afe39b 100644 --- a/swift/tuners/side.py +++ b/swift/tuners/side.py @@ -142,13 +142,17 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( - module, f'side_{adapter_name}') - for _module in modules: + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules, module_keys = find_sub_module(module, f'side_{adapter_name}') + for _module_key, _module in zip(module_keys, modules): _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module_key, + activate, offload) class SideModule(nn.Module, ActivationMixin): diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py index 25dd812753..796cfef5d5 100644 --- a/swift/tuners/utils.py +++ b/swift/tuners/utils.py @@ -1,20 +1,22 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Copyright 2023-present the HuggingFace Inc. team. +import hashlib import os +import shutil import threading from dataclasses import asdict, dataclass, field from types import FunctionType -from typing import Dict, List, Optional +from typing import Dict import json -import peft.utils import torch from peft.utils import CONFIG_NAME from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper from peft.utils import _get_submodules from swift.hub.snapshot_download import snapshot_download +from swift.hub.utils.utils import get_cache_dir from swift.utils.constants import BIN_EXTENSIONS from swift.utils.logger import get_logger @@ -175,6 +177,43 @@ def get_activated_adapters(self): ] +class OffloadHelper: + + sub_dir = 'offload_cache' + cache_dir = os.path.join(get_cache_dir(), sub_dir) + shutil.rmtree(cache_dir) + os.makedirs(cache_dir, exist_ok=True) + + @staticmethod + def read_safe_tensors(safe_tensor_file): + if os.path.exists(safe_tensor_file): + from safetensors.torch import load_file as safe_load_file + return safe_load_file( + safe_tensor_file, + device='cuda' if torch.cuda.is_available() else 'cpu') + + @staticmethod + def write_safe_tensors(state_dict, safe_tensor_file): + from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, safe_tensor_file, metadata={'format': 'pt'}) + + @staticmethod + def offload_disk(module: torch.nn.Module, adapter_name, module_key): + key = adapter_name + ':' + module_key + md5 = hashlib.md5(key).hexdigest() + file = os.path.join(OffloadHelper.cache_dir, md5 + '.safetensors') + OffloadHelper.write_safe_tensors(module.state_dict(), file) + + @staticmethod + def load_disk(module: torch.nn.Module, adapter_name, module_key): + key = adapter_name + ':' + module_key + md5 = hashlib.md5(key).hexdigest() + file = os.path.join(OffloadHelper.cache_dir, md5 + '.safetensors') + state_dict = OffloadHelper.read_safe_tensors(file) + module.load_state_dict(state_dict, assign=True) + shutil.rmtree(file) + + class SwiftAdapter: @staticmethod @@ -183,10 +222,55 @@ def prepare_model(model: torch.nn.Module, config: SwiftConfig, raise NotImplementedError @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): raise NotImplementedError + @staticmethod + def save_memory(module: torch.nn.Module, + adapter_name: str, + module_key: str, + activate: bool, + offload: str = None): + if offload is not None: + if activate: + SwiftAdapter.load( + module, adapter_name, module_key, offload=offload) + else: + SwiftAdapter.offload( + module, adapter_name, module_key, offload=offload) + + @staticmethod + def offload(module: torch.nn.Module, adapter_name, module_key, + offload: str): + module.origin_device = str(module.device) + if offload == 'cpu' and str(module.device) != 'cpu': + module.to('cpu') + if offload == 'meta' and str(module.device) != 'meta': + OffloadHelper.offload_disk( + module, adapter_name=adapter_name, module_key=module_key) + module.to('meta') + else: + raise NotImplementedError + + @staticmethod + def load(module: torch.nn.Module, adapter_name, module_key, offload: str): + if not hasattr(module, 'origin_device') or module.origin_device == str( + module.device): + return + if offload == 'cpu': + module.to(module.origin_device) + delattr(module, 'origin_device') + elif offload == 'meta': + OffloadHelper.load_disk( + module, adapter_name=adapter_name, module_key=module_key) + module.to(module.origin_device) + delattr(module, 'origin_device') + else: + raise NotImplementedError + @staticmethod def freeze_model(): return True @@ -194,7 +278,9 @@ def freeze_model(): class ModulesToSaveWrapper(ActivationMixin, _ModulesToSaveWrapper): - def __init__(self, *args, **kwargs): + def __init__(self, *args, module_key, **kwargs): + self.module_key = module_key + self.offloads = {} super(ModulesToSaveWrapper, self).__init__() super(ActivationMixin, self).__init__(*args, **kwargs) @@ -209,6 +295,9 @@ def active_adapter(self): ) return active_adapters[0] + def add_offload(self, adapter_name: str, offload=None): + self.offloads[adapter_name] = offload + def set_adapter(self, adapter_name: str): if adapter_name not in self.modules_to_save: raise ValueError( @@ -216,11 +305,23 @@ def set_adapter(self, adapter_name: str): ) self.modules_to_save[adapter_name].requires_grad_(True) self.set_activation(adapter_name, True) + SwiftAdapter.save_memory( + self.modules_to_save[adapter_name], + adapter_name, + self.module_key, + True, + offload=self.offloads.get(adapter_name)) def deactivate_adapter(self, adapter_name: str): if adapter_name in self.modules_to_save and self.unique_thread: self.modules_to_save[adapter_name].requires_grad_(False) self.set_activation(adapter_name, False) + SwiftAdapter.save_memory( + self.modules_to_save[adapter_name], + adapter_name, + self.module_key, + False, + offload=self.offloads.get(adapter_name)) def set_adapter(model, adapter_name, activate): diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py index 5524e74f5b..510dd14d95 100644 --- a/swift/utils/torch_utils.py +++ b/swift/utils/torch_utils.py @@ -69,16 +69,22 @@ def get_model_info(model: Module, name: Optional[str] = None) -> str: def find_sub_module(module: torch.nn.Module, - module_name: str) -> List[torch.nn.Module]: + module_name: str, + prefix='') -> Tuple[List[torch.nn.Module], List[str]]: _modules = list() - for name, sub_module in module.named_modules(): + _module_keys = list() + for name, sub_module in module.named_modules(prefix=prefix): if not name: continue if module_name == name: _modules.append(sub_module) + _module_keys.append(name) else: - _modules.extend(find_sub_module(sub_module, module_name)) - return _modules + _sub_modules, _sub_module_keys = find_sub_module( + sub_module, module_name, prefix=name) + _modules.extend(_sub_modules) + _module_keys.extend(_sub_module_keys) + return _modules, _module_keys def get_dist_setting() -> Tuple[int, int, int, int]: From 3be79ca3bd433068a1f5bdf8f88a99d1105b9f15 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Thu, 4 Jan 2024 20:19:58 +0800 Subject: [PATCH 2/7] add offload param --- swift/tuners/lora.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py index 4242a97dd6..5f0d1aea6f 100644 --- a/swift/tuners/lora.py +++ b/swift/tuners/lora.py @@ -48,6 +48,13 @@ class LoRAConfig(LoraConfig, SwiftConfig): 'The lora dtype, default None means following the original layer\'s dtype' }) + offload: str = field( + default=None, + metadata={ + 'help': + 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' + }) + def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.LORA From 51a8c348d9f1650aebc0d3d50c20a43693fd272c Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Thu, 4 Jan 2024 20:29:38 +0800 Subject: [PATCH 3/7] fix --- tests/tuners/test_swift_base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/tuners/test_swift_base.py b/tests/tuners/test_swift_base.py index 72ab10fc03..81160a5652 100644 --- a/tests/tuners/test_swift_base.py +++ b/tests/tuners/test_swift_base.py @@ -312,31 +312,33 @@ def _init_weights(m): model1, config={ 'lora1': - LoRAConfig(target_modules=['query', 'key', 'value']), + LoRAConfig(target_modules=['query', 'key', 'value'], offload='meta'), 'adapter1': AdapterConfig( dim=model.config.hidden_size, target_modules=r'.*layer\.\d+$', method_name='feed_forward_chunk', + offload='meta', hidden_pos=0) }) model2 = Swift.prepare_model( model2, config={ 'lora2': - LoRAConfig(target_modules=['query', 'key', 'value']), + LoRAConfig(target_modules=['query', 'key', 'value'], offload='meta'), 'adapter2': AdapterConfig( dim=model.config.hidden_size, target_modules=r'.*layer\.\d+$', method_name='feed_forward_chunk', + offload='meta', hidden_pos=0) }) model = Swift.prepare_model( model, config={ - 'lora1': LoRAConfig(target_modules=['query', 'key', 'value']), - 'lora2': LoRAConfig(target_modules=['query', 'key', 'value']), + 'lora1': LoRAConfig(target_modules=['query', 'key', 'value'], offload='meta'), + 'lora2': LoRAConfig(target_modules=['query', 'key', 'value'], offload='meta'), }) model = Swift.prepare_model( @@ -347,12 +349,14 @@ def _init_weights(m): dim=model.config.hidden_size, target_modules=r'.*layer\.\d+$', method_name='feed_forward_chunk', + offload='meta', hidden_pos=0), 'adapter2': AdapterConfig( dim=model.config.hidden_size, target_modules=r'.*layer\.\d+$', method_name='feed_forward_chunk', + offload='meta', hidden_pos=0), }) From c2ea1bad2ec84e2bca9d0eb5409e6433a030342a Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Thu, 4 Jan 2024 23:49:25 +0800 Subject: [PATCH 4/7] unfinished --- swift/tuners/adapter.py | 27 ++++++--- swift/tuners/lora.py | 1 + swift/tuners/lora_layers.py | 13 ++-- swift/tuners/prompt.py | 21 +++++-- swift/tuners/restuning.py | 19 ++++-- swift/tuners/scetuning/scetuning.py | 17 +++++- swift/tuners/side.py | 22 ++++--- swift/tuners/utils.py | 94 ++++++++++++++++++++++++----- swift/utils/torch_utils.py | 16 ++--- 9 files changed, 163 insertions(+), 67 deletions(-) diff --git a/swift/tuners/adapter.py b/swift/tuners/adapter.py index f98df01492..05db6a56e6 100644 --- a/swift/tuners/adapter.py +++ b/swift/tuners/adapter.py @@ -68,6 +68,13 @@ class AdapterConfig(SwiftConfig): default='gelu', metadata={'help': 'The activation layer of the adapter'}) + offload: str = field( + default=None, + metadata={ + 'help': + 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' + }) + def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.ADAPTER @@ -130,9 +137,10 @@ def _feed_forward_chunk(self, attention_output): else: setattr(module, config.method_name, types.MethodType(_forward, module)) - adapter_module = AdapterModule(config.dim, adapter_name, + adapter_module = AdapterModule(config.dim, adapter_name, module_key, config.adapter_length, ACT2CLS[config.act_layer]) + adapter_module.add_offload(adapter_name, config.offload) setattr(module, f'adapter_{adapter_name}', adapter_module) logger.info( f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}' @@ -156,14 +164,13 @@ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): - modules, module_keys = find_sub_module(module, - f'adapter_{adapter_name}') - for _module_key, _module in zip(module_keys, modules): + modules = find_sub_module(module, f'adapter_{adapter_name}') + for _module in modules: _module: ActivationMixin _module: nn.Module _module.set_activation(adapter_name, activate) - SwiftAdapter.save_memory(_module, adapter_name, _module_key, - activate, offload) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, _module.offloads.get(adapter_name)) class AdapterModule(nn.Module, ActivationMixin): @@ -182,11 +189,12 @@ def __init__( self, dim, adapter_name, + module_key, adapter_length=None, act_layer=nn.GELU, ): super(AdapterModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__(module_key) self.dim = dim self.adapter_name = adapter_name self.adapter_length = adapter_length @@ -216,7 +224,10 @@ def forward(self, x, identity=None): x_dtype = x.dtype x = x.to(self.linear1.weight.dtype) - out = self.linear2(self.act(self.linear1(x))) + try: + out = self.linear2(self.act(self.linear1(x))) + except: + print() if identity is None: identity = x identity = identity.to(out.dtype) diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py index 5f0d1aea6f..e6d1884b39 100644 --- a/swift/tuners/lora.py +++ b/swift/tuners/lora.py @@ -5,6 +5,7 @@ import torch from packaging import version +from peft.tuners.lora import LoraLayer from swift import LoraConfig from .lora_layers import * # noqa diff --git a/swift/tuners/lora_layers.py b/swift/tuners/lora_layers.py index 9411ca069c..8c4d2603b5 100644 --- a/swift/tuners/lora_layers.py +++ b/swift/tuners/lora_layers.py @@ -43,14 +43,6 @@ def is_auto_gptq_available(): class LoRAActivationMixin(ActivationMixin): - def __init__(self, module_key): - self.module_key = module_key - self.offloads = {} - super().__init__() - - def add_offload(self, adapter_name: str, offload=None): - self.offloads[adapter_name] = offload - @property def active_adapters(self): return self.get_activated_adapters() @@ -563,12 +555,13 @@ class LoRALayer(ActivationMixin): def __init__( self, adapter_name: str, + module_key: str, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): - super().__init__() + super().__init__(module_key) self.adapter_name = adapter_name self.r = r self.lora_alpha = lora_alpha @@ -588,6 +581,7 @@ class MergedLinear(nn.Linear, LoRALayer): # LoRA implemented in a dense layer def __init__(self, adapter_name: str, + module_key: str, base_layer: nn.Linear, r: int = 0, lora_alpha: int = 1, @@ -609,6 +603,7 @@ def __init__(self, LoRALayer.__init__( self, adapter_name, + module_key, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, diff --git a/swift/tuners/prompt.py b/swift/tuners/prompt.py index 6c90bcd115..8721e4c335 100644 --- a/swift/tuners/prompt.py +++ b/swift/tuners/prompt.py @@ -74,6 +74,13 @@ class PromptConfig(SwiftConfig): 'Whether the embedding is extracted at final stage to keep the same dims with inputs' }) + offload: str = field( + default=None, + metadata={ + 'help': + 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' + }) + def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.PROMPT @@ -153,9 +160,11 @@ def _forward(self, *args, **kwargs): prompt_module = PromptModule(input_dim, int(module_key.rsplit('.')[-1]), adapter_name, + module_key, config.prompt_length, config.attention_mask_value, config.attach_front) + prompt_module.add_offload(adapter_name, config.offload) setattr(module, f'prompt_{adapter_name}', prompt_module) logger.info( f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}' @@ -180,14 +189,13 @@ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): - modules, module_keys = find_sub_module(module, - f'prompt_{adapter_name}') - for _module_key, _module in zip(module_keys, modules): + modules = find_sub_module(module, f'prompt_{adapter_name}') + for _module in modules: _module: ActivationMixin _module: nn.Module _module.set_activation(adapter_name, activate) - SwiftAdapter.save_memory(_module, adapter_name, _module_key, - activate, offload) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, _module.offloads.get(adapter_name)) class PromptModule(nn.Module, ActivationMixin): @@ -208,11 +216,12 @@ def __init__(self, dim, layer_num, adapter_name, + module_key, prompt_length=None, mask_values=0., attach_front=True): super(PromptModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__(module_key) self.dim = dim self.layer_num = layer_num self.adapter_name = adapter_name diff --git a/swift/tuners/restuning.py b/swift/tuners/restuning.py index 3751b819c8..05712c35b8 100644 --- a/swift/tuners/restuning.py +++ b/swift/tuners/restuning.py @@ -118,6 +118,13 @@ class ResTuningConfig(SwiftConfig): use_bypass: bool = field( default=True, metadata={'help': 'Whether to use bypass'}) + offload: str = field( + default=None, + metadata={ + 'help': + 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' + }) + def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.RESTUNING @@ -253,6 +260,7 @@ def _forward_restuning(self, origin_arg): config.dims, depth, adapter_name, config.use_upsample, config.upsample_out_channels, config.zero_init_last, config.tuner_cfg) + restuning_module.add_offload(adapter_name, config.offload) setattr(top_module, f'restuning_{adapter_name}', restuning_module) # 4. Matching the target module @@ -307,14 +315,13 @@ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): - modules, module_keys = find_sub_module(module, - f'restuning_{adapter_name}') - for _module_key, _module in zip(module_keys, modules): + modules = find_sub_module(module, f'restuning_{adapter_name}') + for _module in modules: _module: ActivationMixin _module: nn.Module _module.set_activation(adapter_name, activate) - SwiftAdapter.save_memory(_module, adapter_name, _module_key, - activate, offload) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, _module.offloads.get(adapter_name)) class ResTuningBypassModule(nn.Module, ActivationMixin): @@ -332,7 +339,7 @@ def __init__( tuner_cfg=None, ): super(ResTuningBypassModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__('') self.adapter_name = adapter_name self.bypass_blocks = nn.Sequential(*[ diff --git a/swift/tuners/scetuning/scetuning.py b/swift/tuners/scetuning/scetuning.py index 313fbd58de..2d7a533ee4 100644 --- a/swift/tuners/scetuning/scetuning.py +++ b/swift/tuners/scetuning/scetuning.py @@ -67,6 +67,13 @@ class SCETuningConfig(SwiftConfig): default=1.0, metadata={'help': 'The dim down ratio of tuner hidden state'}) + offload: str = field( + default=None, + metadata={ + 'help': + 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' + }) + def __post_init__(self): from swift.tuners.mapping import SwiftTuners self.swift_type = SwiftTuners.SCETUNING @@ -203,6 +210,7 @@ def _forward_decoder_mode(self, *args, **kwargs): adapter_name=adapter_name, dim=dims[tuner_id], tuner_length=int(dims[tuner_id] * config.down_ratio)) + tuner_op.add_offload(adapter_name, offload=config.offload) setattr(t_module, f'scetuner_{adapter_name}', tuner_op) if len(hint_module_ins_list) > 0: setattr(t_module, 'hint', hint_module_ins_list[tuner_id]) @@ -223,12 +231,15 @@ def mark_trainable_callback(model): @staticmethod def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool): - modules: List[torch.nn.Module] = find_sub_module( + activate: bool, offload: str = None): + modules = find_sub_module( module, f'scetuner_{adapter_name}') for _module in modules: _module: ActivationMixin + _module: nn.Module _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, _module.offloads.get(adapter_name)) class SCETunerModule(nn.Module, ActivationMixin): @@ -244,7 +255,7 @@ def __init__(self, zero_init_last=True, use_bias=True): super(SCETunerModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__('') self.name = name self.adapter_name = adapter_name self.dim = dim diff --git a/swift/tuners/side.py b/swift/tuners/side.py index 1bb7afe39b..55d4c2e910 100644 --- a/swift/tuners/side.py +++ b/swift/tuners/side.py @@ -60,6 +60,13 @@ class SideConfig(SwiftConfig): 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)' }) + offload: str = field( + default=None, + metadata={ + 'help': + 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' + }) + def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.SIDE @@ -121,8 +128,9 @@ def forward_seq(self, input, *args, **kwargs): setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward) tgt_module.forward = types.MethodType(_forward, tgt_module) - side_module = SideModule(config.dim, adapter_name, + side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name) + tgt_module.add_offload(adapter_name, config.offload) setattr(tgt_module, f'side_{adapter_name}', side_module) logger.info( f'Side modules(module_key): {module_key}.side_{adapter_name}' @@ -146,13 +154,13 @@ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): - modules, module_keys = find_sub_module(module, f'side_{adapter_name}') - for _module_key, _module in zip(module_keys, modules): + modules = find_sub_module(module, f'side_{adapter_name}') + for _module in modules: _module: ActivationMixin _module: nn.Module _module.set_activation(adapter_name, activate) - SwiftAdapter.save_memory(_module, adapter_name, _module_key, - activate, offload) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, + activate, _module.offloads.get(adapter_name)) class SideModule(nn.Module, ActivationMixin): @@ -168,9 +176,9 @@ class SideModule(nn.Module, ActivationMixin): side_module_name: The name of the additive side networks. """ - def __init__(self, dim, adapter_name, side_module_name='fcn4'): + def __init__(self, dim, adapter_name, module_key, side_module_name='fcn4'): super(SideModule, self).__init__() - super(nn.Module, self).__init__() + super(nn.Module, self).__init__(module_key) self.adapter_name = adapter_name side_module_name = side_module_name.lower() diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py index 796cfef5d5..b89d333fec 100644 --- a/swift/tuners/utils.py +++ b/swift/tuners/utils.py @@ -142,7 +142,9 @@ class ActivationMixin: USE_UNIQUE_THREAD = 'USE_UNIQUE_THREAD' - def __init__(self): + def __init__(self, module_key): + self.module_key = module_key + self.offloads = {} self._thread_inf: Dict[int, Dict[str, bool]] = {} self._unique_thread = bool( int(os.environ.get(ActivationMixin.USE_UNIQUE_THREAD, '1'))) @@ -151,6 +153,9 @@ def __init__(self): 'Using multiple thread mode, gradient checkpointing is not supported.' ) + def add_offload(self, adapter_name: str, offload=None): + self.offloads[adapter_name] = offload + @property def indent(self): return 0 if self.unique_thread else threading.get_ident() @@ -181,7 +186,7 @@ class OffloadHelper: sub_dir = 'offload_cache' cache_dir = os.path.join(get_cache_dir(), sub_dir) - shutil.rmtree(cache_dir) + shutil.rmtree(cache_dir, ignore_errors=True) os.makedirs(cache_dir, exist_ok=True) @staticmethod @@ -197,21 +202,70 @@ def write_safe_tensors(state_dict, safe_tensor_file): from safetensors.torch import save_file as safe_save_file safe_save_file(state_dict, safe_tensor_file, metadata={'format': 'pt'}) + @staticmethod + def offload_weight(weight, weight_name, offload_folder, index=None): + dtype = None + # Check the string instead of the dtype to be compatible with versions of PyTorch that don't have bfloat16. + if str(weight.dtype) == "torch.bfloat16": + # Need to reinterpret the underlined data as int16 since NumPy does not handle bfloat16s. + weight = weight.view(torch.int16) + dtype = "bfloat16" + array = weight.cpu().numpy() + tensor_file = os.path.join(offload_folder, f"{weight_name}.dat") + if index is not None: + if dtype is None: + dtype = str(array.dtype) + index[weight_name] = {"dtype": dtype, "shape": list(array.shape)} + if array.ndim == 0: + array = array[None] + file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape) + file_array[:] = array[:] + file_array.flush() + return index + + @staticmethod + def load_offloaded_weight(weight_file, weight_info): + shape = tuple(weight_info["shape"]) + if shape == (): + # NumPy memory-mapped arrays can't have 0 dims so it was saved as 1d tensor + shape = (1,) + + dtype = weight_info["dtype"] + if dtype == "bfloat16": + # NumPy does not support bfloat16 so this was saved as a int16 + dtype = "int16" + + weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode="r") + + if len(weight_info["shape"]) == 0: + weight = weight[0] + weight = torch.tensor(weight) + if weight_info["dtype"] == "bfloat16": + weight = weight.view(torch.bfloat16) + + return weight + @staticmethod def offload_disk(module: torch.nn.Module, adapter_name, module_key): key = adapter_name + ':' + module_key - md5 = hashlib.md5(key).hexdigest() + md5 = hashlib.md5(key.encode('utf-8')).hexdigest() file = os.path.join(OffloadHelper.cache_dir, md5 + '.safetensors') OffloadHelper.write_safe_tensors(module.state_dict(), file) @staticmethod def load_disk(module: torch.nn.Module, adapter_name, module_key): key = adapter_name + ':' + module_key - md5 = hashlib.md5(key).hexdigest() + md5 = hashlib.md5(key.encode('utf-8')).hexdigest() file = os.path.join(OffloadHelper.cache_dir, md5 + '.safetensors') state_dict = OffloadHelper.read_safe_tensors(file) - module.load_state_dict(state_dict, assign=True) - shutil.rmtree(file) + print(module.load_state_dict(state_dict, assign=True)) + shutil.rmtree(file, ignore_errors=True) + try: + print('here1!!!') + module.to(module.origin_device) + print('here2!!!') + except: + print() class SwiftAdapter: @@ -245,20 +299,25 @@ def save_memory(module: torch.nn.Module, @staticmethod def offload(module: torch.nn.Module, adapter_name, module_key, offload: str): - module.origin_device = str(module.device) - if offload == 'cpu' and str(module.device) != 'cpu': - module.to('cpu') - if offload == 'meta' and str(module.device) != 'meta': - OffloadHelper.offload_disk( - module, adapter_name=adapter_name, module_key=module_key) - module.to('meta') + device = next(iter(module.parameters())).device + if hasattr(module, 'origin_device') and module.origin_device != str(device): + return + module.origin_device = str(device) + if offload == 'cpu': + if str(device) != 'cpu': + module.to('cpu') + if offload == 'meta': + if str(device) != 'meta': + OffloadHelper.offload_disk( + module, adapter_name=adapter_name, module_key=module_key) + module.to('meta') else: raise NotImplementedError @staticmethod def load(module: torch.nn.Module, adapter_name, module_key, offload: str): - if not hasattr(module, 'origin_device') or module.origin_device == str( - module.device): + device = next(iter(module.parameters())).device + if not hasattr(module, 'origin_device') or module.origin_device == str(device): return if offload == 'cpu': module.to(module.origin_device) @@ -266,7 +325,10 @@ def load(module: torch.nn.Module, adapter_name, module_key, offload: str): elif offload == 'meta': OffloadHelper.load_disk( module, adapter_name=adapter_name, module_key=module_key) - module.to(module.origin_device) + try: + module.to(module.origin_device) + except: + print() delattr(module, 'origin_device') else: raise NotImplementedError diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py index 510dd14d95..1e05b25a34 100644 --- a/swift/utils/torch_utils.py +++ b/swift/utils/torch_utils.py @@ -69,22 +69,14 @@ def get_model_info(model: Module, name: Optional[str] = None) -> str: def find_sub_module(module: torch.nn.Module, - module_name: str, - prefix='') -> Tuple[List[torch.nn.Module], List[str]]: + module_name: str) -> List[torch.nn.Module]: _modules = list() - _module_keys = list() - for name, sub_module in module.named_modules(prefix=prefix): + for name, sub_module in module.named_modules(): if not name: continue - if module_name == name: + if name.endswith(module_name): _modules.append(sub_module) - _module_keys.append(name) - else: - _sub_modules, _sub_module_keys = find_sub_module( - sub_module, module_name, prefix=name) - _modules.extend(_sub_modules) - _module_keys.extend(_sub_module_keys) - return _modules, _module_keys + return _modules def get_dist_setting() -> Tuple[int, int, int, int]: From 105cc731e0b02849c803bde46c4a402e89cc8349 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Fri, 5 Jan 2024 15:07:52 +0800 Subject: [PATCH 5/7] fix bug --- swift/tuners/adapter.py | 18 +--- swift/tuners/base.py | 8 +- swift/tuners/lora.py | 10 +-- swift/tuners/lora_layers.py | 36 ++++---- swift/tuners/prompt.py | 13 +-- swift/tuners/restuning.py | 10 +-- swift/tuners/scetuning/scetuning.py | 19 ++-- swift/tuners/side.py | 10 +-- swift/tuners/utils.py | 135 ++++++++++++---------------- tests/tuners/test_swift_base.py | 28 +++--- 10 files changed, 111 insertions(+), 176 deletions(-) diff --git a/swift/tuners/adapter.py b/swift/tuners/adapter.py index 05db6a56e6..c2e04ac242 100644 --- a/swift/tuners/adapter.py +++ b/swift/tuners/adapter.py @@ -68,13 +68,6 @@ class AdapterConfig(SwiftConfig): default='gelu', metadata={'help': 'The activation layer of the adapter'}) - offload: str = field( - default=None, - metadata={ - 'help': - 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' - }) - def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.ADAPTER @@ -137,10 +130,10 @@ def _feed_forward_chunk(self, attention_output): else: setattr(module, config.method_name, types.MethodType(_forward, module)) - adapter_module = AdapterModule(config.dim, adapter_name, module_key, + adapter_module = AdapterModule(config.dim, adapter_name, + module_key, config.adapter_length, ACT2CLS[config.act_layer]) - adapter_module.add_offload(adapter_name, config.offload) setattr(module, f'adapter_{adapter_name}', adapter_module) logger.info( f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}' @@ -170,7 +163,7 @@ def activate_adapter(module: torch.nn.Module, _module: nn.Module _module.set_activation(adapter_name, activate) SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, - activate, _module.offloads.get(adapter_name)) + activate, offload) class AdapterModule(nn.Module, ActivationMixin): @@ -224,10 +217,7 @@ def forward(self, x, identity=None): x_dtype = x.dtype x = x.to(self.linear1.weight.dtype) - try: - out = self.linear2(self.act(self.linear1(x))) - except: - print() + out = self.linear2(self.act(self.linear1(x))) if identity is None: identity = x identity = identity.to(out.dtype) diff --git a/swift/tuners/base.py b/swift/tuners/base.py index 0577b61af7..8fc7bd0dad 100644 --- a/swift/tuners/base.py +++ b/swift/tuners/base.py @@ -432,7 +432,9 @@ def save_pretrained(self, def base_model(self): return self.model - def set_active_adapters(self, adapter_names: Union[List[str], str]): + def set_active_adapters(self, + adapter_names: Union[List[str], str], + offload=None): if not adapter_names: return @@ -444,7 +446,7 @@ def set_active_adapters(self, adapter_names: Union[List[str], str]): self.activate_adapter(adapter_name) for adapter_name in (set(self.adapters.keys()) - adapter_names): - self.deactivate_adapter(adapter_name) + self.deactivate_adapter(adapter_name, offload) def activate_adapter(self, adapter_name): if adapter_name not in self.adapters: @@ -456,7 +458,7 @@ def activate_adapter(self, adapter_name): SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ .activate_adapter(self.base_model, adapter_name, True) - def deactivate_adapter(self, adapter_name, offload='cpu'): + def deactivate_adapter(self, adapter_name, offload=None): if adapter_name not in self.adapters: logger.warning( f'{adapter_name} not in adapters: {self.adapters.keys()}') diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py index e6d1884b39..b62931904d 100644 --- a/swift/tuners/lora.py +++ b/swift/tuners/lora.py @@ -49,13 +49,6 @@ class LoRAConfig(LoraConfig, SwiftConfig): 'The lora dtype, default None means following the original layer\'s dtype' }) - offload: str = field( - default=None, - metadata={ - 'help': - 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' - }) - def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.LORA @@ -81,10 +74,11 @@ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): - set_adapter(module, adapter_name, activate) + set_adapter(module, adapter_name, activate, offload) for sub_module in module.modules(): if isinstance(sub_module, (LoraLayer, LoRALayer)): sub_module.set_activation(adapter_name, activate) + sub_module.save_memory(adapter_name, activate, offload) @staticmethod def unpatch_lora(model, config: LoRAConfig, adapter_name: str): diff --git a/swift/tuners/lora_layers.py b/swift/tuners/lora_layers.py index 8c4d2603b5..bfbfa66b73 100644 --- a/swift/tuners/lora_layers.py +++ b/swift/tuners/lora_layers.py @@ -51,7 +51,7 @@ def active_adapters(self): def active_adapter(self) -> str: return self.get_activated_adapters() - def set_adapter(self, adapter_names): + def set_adapter(self, adapter_names, offload=None): if isinstance(adapter_names, str): adapter_names = [adapter_names] @@ -62,21 +62,28 @@ def set_adapter(self, adapter_names): if key in adapter_names: self.set_activation(key, True) layer.requires_grad_(True) - SwiftAdapter.save_memory( - layer, - key, - self.module_key, - True, - offload=self.offloads.get(key)) + SwiftAdapter.save_memory(layer, key, self.module_key, True) else: self.set_activation(key, False) layer.requires_grad_(False) SwiftAdapter.save_memory( - layer, - key, - self.module_key, - False, - offload=self.offloads.get(key)) + layer, key, self.module_key, False, offload=offload) + + def save_memory(self, adapter_name, activate, offload=None): + for layer_name in self.adapter_layer_names: + module_dict = getattr(self, layer_name) + for key, layer in module_dict.items(): + if key == adapter_name: + if activate: + SwiftAdapter.save_memory(layer, layer_name + '.' + key, + self.module_key, True) + else: + SwiftAdapter.save_memory( + layer, + layer_name + '.' + key, + self.module_key, + False, + offload=offload) def merge(self, *args, **kwargs): if not self.unique_thread: @@ -269,7 +276,6 @@ def inject_adapter(self, model: nn.Module, adapter_name: str): setattr(parent, target_name, new_module) else: target.update(adapter_name) - target.add_offload(adapter_name, peft_config.offload) _has_modules_to_save = True continue @@ -381,7 +387,6 @@ def _create_and_replace( lora_config.lora_dropout, lora_config.init_lora_weights, ) - target.add_offload(adapter_name, lora_config.offload) self._convert_dtype(target, lora_config.lora_dtype) elif isinstance(target, Embedding): target.update_layer_embedding( @@ -391,7 +396,6 @@ def _create_and_replace( lora_config.lora_dropout, lora_config.init_lora_weights, ) - target.add_offload(adapter_name, lora_config.offload) self._convert_dtype(target, lora_config.lora_dtype) elif isinstance(target, linear_types): target.update_layer( @@ -401,7 +405,6 @@ def _create_and_replace( lora_config.lora_dropout, lora_config.init_lora_weights, ) - target.add_offload(adapter_name, lora_config.offload) self._convert_dtype(target, lora_config.lora_dtype) else: new_module = self._create_new_module( @@ -416,7 +419,6 @@ def _create_and_replace( new_module.requires_grad_(False) self._replace_module(parent, target_name, new_module, target) self._convert_dtype(new_module, lora_config.lora_dtype) - new_module.add_offload(adapter_name, lora_config.offload) @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): diff --git a/swift/tuners/prompt.py b/swift/tuners/prompt.py index 8721e4c335..23eeb34e72 100644 --- a/swift/tuners/prompt.py +++ b/swift/tuners/prompt.py @@ -74,13 +74,6 @@ class PromptConfig(SwiftConfig): 'Whether the embedding is extracted at final stage to keep the same dims with inputs' }) - offload: str = field( - default=None, - metadata={ - 'help': - 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' - }) - def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.PROMPT @@ -159,12 +152,10 @@ def _forward(self, *args, **kwargs): input_dim = config.dim prompt_module = PromptModule(input_dim, int(module_key.rsplit('.')[-1]), - adapter_name, - module_key, + adapter_name, module_key, config.prompt_length, config.attention_mask_value, config.attach_front) - prompt_module.add_offload(adapter_name, config.offload) setattr(module, f'prompt_{adapter_name}', prompt_module) logger.info( f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}' @@ -195,7 +186,7 @@ def activate_adapter(module: torch.nn.Module, _module: nn.Module _module.set_activation(adapter_name, activate) SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, - activate, _module.offloads.get(adapter_name)) + activate, offload) class PromptModule(nn.Module, ActivationMixin): diff --git a/swift/tuners/restuning.py b/swift/tuners/restuning.py index 05712c35b8..e5d5103e45 100644 --- a/swift/tuners/restuning.py +++ b/swift/tuners/restuning.py @@ -118,13 +118,6 @@ class ResTuningConfig(SwiftConfig): use_bypass: bool = field( default=True, metadata={'help': 'Whether to use bypass'}) - offload: str = field( - default=None, - metadata={ - 'help': - 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' - }) - def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.RESTUNING @@ -260,7 +253,6 @@ def _forward_restuning(self, origin_arg): config.dims, depth, adapter_name, config.use_upsample, config.upsample_out_channels, config.zero_init_last, config.tuner_cfg) - restuning_module.add_offload(adapter_name, config.offload) setattr(top_module, f'restuning_{adapter_name}', restuning_module) # 4. Matching the target module @@ -321,7 +313,7 @@ def activate_adapter(module: torch.nn.Module, _module: nn.Module _module.set_activation(adapter_name, activate) SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, - activate, _module.offloads.get(adapter_name)) + activate, offload) class ResTuningBypassModule(nn.Module, ActivationMixin): diff --git a/swift/tuners/scetuning/scetuning.py b/swift/tuners/scetuning/scetuning.py index 2d7a533ee4..dcb5615eb3 100644 --- a/swift/tuners/scetuning/scetuning.py +++ b/swift/tuners/scetuning/scetuning.py @@ -67,13 +67,6 @@ class SCETuningConfig(SwiftConfig): default=1.0, metadata={'help': 'The dim down ratio of tuner hidden state'}) - offload: str = field( - default=None, - metadata={ - 'help': - 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' - }) - def __post_init__(self): from swift.tuners.mapping import SwiftTuners self.swift_type = SwiftTuners.SCETUNING @@ -210,7 +203,6 @@ def _forward_decoder_mode(self, *args, **kwargs): adapter_name=adapter_name, dim=dims[tuner_id], tuner_length=int(dims[tuner_id] * config.down_ratio)) - tuner_op.add_offload(adapter_name, offload=config.offload) setattr(t_module, f'scetuner_{adapter_name}', tuner_op) if len(hint_module_ins_list) > 0: setattr(t_module, 'hint', hint_module_ins_list[tuner_id]) @@ -230,16 +222,17 @@ def mark_trainable_callback(model): mark_trainable_callback) @staticmethod - def activate_adapter(module: torch.nn.Module, adapter_name: str, - activate: bool, offload: str = None): - modules = find_sub_module( - module, f'scetuner_{adapter_name}') + def activate_adapter(module: torch.nn.Module, + adapter_name: str, + activate: bool, + offload: str = None): + modules = find_sub_module(module, f'scetuner_{adapter_name}') for _module in modules: _module: ActivationMixin _module: nn.Module _module.set_activation(adapter_name, activate) SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, - activate, _module.offloads.get(adapter_name)) + activate, offload) class SCETunerModule(nn.Module, ActivationMixin): diff --git a/swift/tuners/side.py b/swift/tuners/side.py index 55d4c2e910..cdd985b34f 100644 --- a/swift/tuners/side.py +++ b/swift/tuners/side.py @@ -60,13 +60,6 @@ class SideConfig(SwiftConfig): 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)' }) - offload: str = field( - default=None, - metadata={ - 'help': - 'Offload deactivated adapters. Support None(no offloading), `cpu` or `meta`(meta device)' - }) - def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.SIDE @@ -130,7 +123,6 @@ def forward_seq(self, input, *args, **kwargs): tgt_module.forward = types.MethodType(_forward, tgt_module) side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name) - tgt_module.add_offload(adapter_name, config.offload) setattr(tgt_module, f'side_{adapter_name}', side_module) logger.info( f'Side modules(module_key): {module_key}.side_{adapter_name}' @@ -160,7 +152,7 @@ def activate_adapter(module: torch.nn.Module, _module: nn.Module _module.set_activation(adapter_name, activate) SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, - activate, _module.offloads.get(adapter_name)) + activate, offload) class SideModule(nn.Module, ActivationMixin): diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py index b89d333fec..61169a8b3a 100644 --- a/swift/tuners/utils.py +++ b/swift/tuners/utils.py @@ -7,9 +7,10 @@ import threading from dataclasses import asdict, dataclass, field from types import FunctionType -from typing import Dict +from typing import Dict, OrderedDict import json +import numpy as np import torch from peft.utils import CONFIG_NAME from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper @@ -142,20 +143,19 @@ class ActivationMixin: USE_UNIQUE_THREAD = 'USE_UNIQUE_THREAD' + REMINEDED = False + def __init__(self, module_key): self.module_key = module_key - self.offloads = {} self._thread_inf: Dict[int, Dict[str, bool]] = {} self._unique_thread = bool( int(os.environ.get(ActivationMixin.USE_UNIQUE_THREAD, '1'))) - if not self._unique_thread: - logger.info( + if not self._unique_thread and not ActivationMixin.REMINEDED: + ActivationMixin.REMINEDED = True + logger.warn( 'Using multiple thread mode, gradient checkpointing is not supported.' ) - def add_offload(self, adapter_name: str, offload=None): - self.offloads[adapter_name] = offload - @property def indent(self): return 0 if self.unique_thread else threading.get_ident() @@ -188,59 +188,44 @@ class OffloadHelper: cache_dir = os.path.join(get_cache_dir(), sub_dir) shutil.rmtree(cache_dir, ignore_errors=True) os.makedirs(cache_dir, exist_ok=True) - - @staticmethod - def read_safe_tensors(safe_tensor_file): - if os.path.exists(safe_tensor_file): - from safetensors.torch import load_file as safe_load_file - return safe_load_file( - safe_tensor_file, - device='cuda' if torch.cuda.is_available() else 'cpu') - - @staticmethod - def write_safe_tensors(state_dict, safe_tensor_file): - from safetensors.torch import save_file as safe_save_file - safe_save_file(state_dict, safe_tensor_file, metadata={'format': 'pt'}) + index = {} @staticmethod def offload_weight(weight, weight_name, offload_folder, index=None): dtype = None - # Check the string instead of the dtype to be compatible with versions of PyTorch that don't have bfloat16. - if str(weight.dtype) == "torch.bfloat16": - # Need to reinterpret the underlined data as int16 since NumPy does not handle bfloat16s. + if str(weight.dtype) == 'torch.bfloat16': weight = weight.view(torch.int16) - dtype = "bfloat16" + dtype = 'bfloat16' array = weight.cpu().numpy() - tensor_file = os.path.join(offload_folder, f"{weight_name}.dat") + tensor_file = os.path.join(offload_folder, f'{weight_name}.dat') if index is not None: if dtype is None: dtype = str(array.dtype) - index[weight_name] = {"dtype": dtype, "shape": list(array.shape)} + index[weight_name] = {'dtype': dtype, 'shape': list(array.shape)} if array.ndim == 0: array = array[None] - file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape) + file_array = np.memmap( + tensor_file, dtype=array.dtype, mode='w+', shape=array.shape) file_array[:] = array[:] file_array.flush() return index @staticmethod def load_offloaded_weight(weight_file, weight_info): - shape = tuple(weight_info["shape"]) + shape = tuple(weight_info['shape']) if shape == (): - # NumPy memory-mapped arrays can't have 0 dims so it was saved as 1d tensor - shape = (1,) + shape = (1, ) - dtype = weight_info["dtype"] - if dtype == "bfloat16": - # NumPy does not support bfloat16 so this was saved as a int16 - dtype = "int16" + dtype = weight_info['dtype'] + if dtype == 'bfloat16': + dtype = 'int16' - weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode="r") + weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode='r') - if len(weight_info["shape"]) == 0: + if len(weight_info['shape']) == 0: weight = weight[0] weight = torch.tensor(weight) - if weight_info["dtype"] == "bfloat16": + if weight_info['dtype'] == 'bfloat16': weight = weight.view(torch.bfloat16) return weight @@ -249,23 +234,26 @@ def load_offloaded_weight(weight_file, weight_info): def offload_disk(module: torch.nn.Module, adapter_name, module_key): key = adapter_name + ':' + module_key md5 = hashlib.md5(key.encode('utf-8')).hexdigest() - file = os.path.join(OffloadHelper.cache_dir, md5 + '.safetensors') - OffloadHelper.write_safe_tensors(module.state_dict(), file) + sub_folder = os.path.join(OffloadHelper.cache_dir, md5) + os.makedirs(sub_folder, exist_ok=True) + state_dict = module.state_dict() + OffloadHelper.index[md5] = {} + for key, tensor in state_dict.items(): + OffloadHelper.offload_weight(tensor, key, sub_folder, + OffloadHelper.index[md5]) @staticmethod def load_disk(module: torch.nn.Module, adapter_name, module_key): key = adapter_name + ':' + module_key md5 = hashlib.md5(key.encode('utf-8')).hexdigest() - file = os.path.join(OffloadHelper.cache_dir, md5 + '.safetensors') - state_dict = OffloadHelper.read_safe_tensors(file) - print(module.load_state_dict(state_dict, assign=True)) - shutil.rmtree(file, ignore_errors=True) - try: - print('here1!!!') - module.to(module.origin_device) - print('here2!!!') - except: - print() + sub_folder = os.path.join(OffloadHelper.cache_dir, md5) + state_dict = {} + for key, value in OffloadHelper.index[md5].items(): + file = os.path.join(sub_folder, f'{key}.dat') + state_dict[key] = OffloadHelper.load_offloaded_weight( + file, OffloadHelper.index[md5][key]) + module.load_state_dict(state_dict, assign=True) + shutil.rmtree(sub_folder, ignore_errors=True) class SwiftAdapter: @@ -288,19 +276,20 @@ def save_memory(module: torch.nn.Module, module_key: str, activate: bool, offload: str = None): - if offload is not None: - if activate: - SwiftAdapter.load( - module, adapter_name, module_key, offload=offload) - else: - SwiftAdapter.offload( - module, adapter_name, module_key, offload=offload) + if activate: + SwiftAdapter.load(module, adapter_name, module_key) + else: + SwiftAdapter.offload( + module, adapter_name, module_key, offload=offload) @staticmethod def offload(module: torch.nn.Module, adapter_name, module_key, offload: str): + if not offload: + return device = next(iter(module.parameters())).device - if hasattr(module, 'origin_device') and module.origin_device != str(device): + if hasattr(module, + 'origin_device') and module.origin_device != str(device): return module.origin_device = str(device) if offload == 'cpu': @@ -315,20 +304,18 @@ def offload(module: torch.nn.Module, adapter_name, module_key, raise NotImplementedError @staticmethod - def load(module: torch.nn.Module, adapter_name, module_key, offload: str): + def load(module: torch.nn.Module, adapter_name, module_key): device = next(iter(module.parameters())).device - if not hasattr(module, 'origin_device') or module.origin_device == str(device): + if not hasattr(module, + 'origin_device') or module.origin_device == str(device): return - if offload == 'cpu': + if str(device) == 'cpu': module.to(module.origin_device) delattr(module, 'origin_device') - elif offload == 'meta': + elif str(device) == 'meta': OffloadHelper.load_disk( module, adapter_name=adapter_name, module_key=module_key) - try: - module.to(module.origin_device) - except: - print() + module.to(module.origin_device) delattr(module, 'origin_device') else: raise NotImplementedError @@ -342,7 +329,6 @@ class ModulesToSaveWrapper(ActivationMixin, _ModulesToSaveWrapper): def __init__(self, *args, module_key, **kwargs): self.module_key = module_key - self.offloads = {} super(ModulesToSaveWrapper, self).__init__() super(ActivationMixin, self).__init__(*args, **kwargs) @@ -357,10 +343,7 @@ def active_adapter(self): ) return active_adapters[0] - def add_offload(self, adapter_name: str, offload=None): - self.offloads[adapter_name] = offload - - def set_adapter(self, adapter_name: str): + def set_adapter(self, adapter_name: str, offload: str): if adapter_name not in self.modules_to_save: raise ValueError( f'Adapter {adapter_name} not found in {self.modules_to_save.keys()}' @@ -372,9 +355,9 @@ def set_adapter(self, adapter_name: str): adapter_name, self.module_key, True, - offload=self.offloads.get(adapter_name)) + offload=offload) - def deactivate_adapter(self, adapter_name: str): + def deactivate_adapter(self, adapter_name: str, offload: str): if adapter_name in self.modules_to_save and self.unique_thread: self.modules_to_save[adapter_name].requires_grad_(False) self.set_activation(adapter_name, False) @@ -383,16 +366,16 @@ def deactivate_adapter(self, adapter_name: str): adapter_name, self.module_key, False, - offload=self.offloads.get(adapter_name)) + offload=offload) -def set_adapter(model, adapter_name, activate): +def set_adapter(model, adapter_name, activate, offload): for module in model.modules(): if isinstance(module, ModulesToSaveWrapper): if activate: - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, offload) else: - module.deactivate_adapter(adapter_name) + module.deactivate_adapter(adapter_name, offload) def set_trainable(model, adapter_name): diff --git a/tests/tuners/test_swift_base.py b/tests/tuners/test_swift_base.py index 81160a5652..36d2ac1fbe 100644 --- a/tests/tuners/test_swift_base.py +++ b/tests/tuners/test_swift_base.py @@ -312,33 +312,31 @@ def _init_weights(m): model1, config={ 'lora1': - LoRAConfig(target_modules=['query', 'key', 'value'], offload='meta'), + LoRAConfig(target_modules=['query', 'key', 'value']), 'adapter1': AdapterConfig( dim=model.config.hidden_size, target_modules=r'.*layer\.\d+$', method_name='feed_forward_chunk', - offload='meta', hidden_pos=0) }) model2 = Swift.prepare_model( model2, config={ 'lora2': - LoRAConfig(target_modules=['query', 'key', 'value'], offload='meta'), + LoRAConfig(target_modules=['query', 'key', 'value']), 'adapter2': AdapterConfig( dim=model.config.hidden_size, target_modules=r'.*layer\.\d+$', method_name='feed_forward_chunk', - offload='meta', hidden_pos=0) }) model = Swift.prepare_model( model, config={ - 'lora1': LoRAConfig(target_modules=['query', 'key', 'value'], offload='meta'), - 'lora2': LoRAConfig(target_modules=['query', 'key', 'value'], offload='meta'), + 'lora1': LoRAConfig(target_modules=['query', 'key', 'value']), + 'lora2': LoRAConfig(target_modules=['query', 'key', 'value']), }) model = Swift.prepare_model( @@ -349,26 +347,24 @@ def _init_weights(m): dim=model.config.hidden_size, target_modules=r'.*layer\.\d+$', method_name='feed_forward_chunk', - offload='meta', hidden_pos=0), 'adapter2': AdapterConfig( dim=model.config.hidden_size, target_modules=r'.*layer\.\d+$', method_name='feed_forward_chunk', - offload='meta', hidden_pos=0), }) - model.deactivate_adapter('adapter2') - model.deactivate_adapter('lora2') + model.deactivate_adapter('adapter2', offload='meta') + model.deactivate_adapter('lora2', offload='meta') outputs1 = model(**inputs) outputs2 = model1(**inputs) self.assertTrue(torch.allclose(outputs1.logits, outputs2.logits)) model.activate_adapter('adapter2') model.activate_adapter('lora2') - model.deactivate_adapter('adapter1') - model.deactivate_adapter('lora1') + model.deactivate_adapter('adapter1', offload='meta') + model.deactivate_adapter('lora1', offload='meta') outputs1 = model(**inputs) outputs2 = model2(**inputs) self.assertTrue(torch.allclose(outputs1.logits, outputs2.logits)) @@ -376,16 +372,16 @@ def _init_weights(m): if os.environ.get('USE_UNIQUE_THREAD') == '0': def thread_func1(): - model1.set_active_adapters(['lora1', 'adapter1']) - model.set_active_adapters(['lora1', 'adapter1']) + model1.set_active_adapters(['lora1', 'adapter1'], offload=None) + model.set_active_adapters(['lora1', 'adapter1'], offload=None) outputs_single = model1(**inputs) outputs_t1 = model(**inputs) self.assertTrue( torch.allclose(outputs_single.logits, outputs_t1.logits)) def thread_func2(): - model2.set_active_adapters(['lora2', 'adapter2']) - model.set_active_adapters(['lora2', 'adapter2']) + model2.set_active_adapters(['lora2', 'adapter2'], offload=None) + model.set_active_adapters(['lora2', 'adapter2'], offload=None) outputs_single = model2(**inputs) outputs_t2 = model(**inputs) self.assertTrue( From 4848c0b718742fd72920327c0db9c426dd97bcb9 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Fri, 5 Jan 2024 15:11:55 +0800 Subject: [PATCH 6/7] fix --- swift/tuners/utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py index 61169a8b3a..70122ee371 100644 --- a/swift/tuners/utils.py +++ b/swift/tuners/utils.py @@ -350,12 +350,8 @@ def set_adapter(self, adapter_name: str, offload: str): ) self.modules_to_save[adapter_name].requires_grad_(True) self.set_activation(adapter_name, True) - SwiftAdapter.save_memory( - self.modules_to_save[adapter_name], - adapter_name, - self.module_key, - True, - offload=offload) + SwiftAdapter.save_memory(self.modules_to_save[adapter_name], + adapter_name, self.module_key, True) def deactivate_adapter(self, adapter_name: str, offload: str): if adapter_name in self.modules_to_save and self.unique_thread: From ecf180b085c69c9e036060997ee4fec6d698b27b Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Fri, 5 Jan 2024 15:15:35 +0800 Subject: [PATCH 7/7] fix --- swift/tuners/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py index 70122ee371..73477936eb 100644 --- a/swift/tuners/utils.py +++ b/swift/tuners/utils.py @@ -302,6 +302,7 @@ def offload(module: torch.nn.Module, adapter_name, module_key, module.to('meta') else: raise NotImplementedError + torch.cuda.empty_cache() @staticmethod def load(module: torch.nn.Module, adapter_name, module_key):