Skip to content

Commit

Permalink
Import petals.utils.peft only when needed to avoid unnecessary import…
Browse files Browse the repository at this point in the history
… of bitsandbytes (#345)

The motivation is the same as in #180.
  • Loading branch information
borzunov committed Jul 12, 2023
1 parent 294970f commit 43acfe5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/petals/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from itertools import chain
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import peft
import torch
from hivemind import BatchTensorDescriptor, TensorDescriptor
from hivemind.moe.expert_uid import ExpertUID
Expand Down Expand Up @@ -156,9 +155,13 @@ def shutdown(self):

def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
"""Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""

# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt

adapter_was_loaded = False
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)):
layer.active_adapter = active_adapter # empty string for no adapter
if active_adapter in layer.lora_A.keys():
adapter_was_loaded = True
Expand Down
5 changes: 4 additions & 1 deletion src/petals/utils/convert_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from transformers import PretrainedConfig

from petals.utils.misc import QuantType
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
Expand Down Expand Up @@ -56,6 +55,10 @@ def convert_block(
shard.to(device)

if adapters:
# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft

create_lora_adapter(block, quant_type=quant_type)
for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft(
Expand Down

0 comments on commit 43acfe5

Please sign in to comment.