Skip to content

Commit

Permalink
Import bitsandbytes only if it's going to be used (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jan 5, 2023
1 parent e277063 commit 6dd9a93
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/petals/utils/convert_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re
from typing import Sequence

import bitsandbytes as bnb
import tensor_parallel as tp
import torch
import torch.nn as nn
Expand All @@ -14,7 +13,6 @@
from transformers.models.bloom.modeling_bloom import BloomAttention

from petals.bloom.block import WrappedBloomBlock
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
Expand Down Expand Up @@ -75,6 +73,12 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
"""

# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
import bitsandbytes as bnb

from petals.utils.linear8bitlt_patch import CustomLinear8bitLt

for n, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)
Expand All @@ -98,7 +102,6 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
def make_tensor_parallel(
block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
):
assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt))
tp_config = get_bloom_config(model_config, devices)
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
Expand Down

0 comments on commit 6dd9a93

Please sign in to comment.