Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import bitsandbytes only if it's going to be used #180

Merged
merged 1 commit into from
Jan 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/petals/utils/convert_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re
from typing import Sequence

import bitsandbytes as bnb
import tensor_parallel as tp
import torch
import torch.nn as nn
Expand All @@ -14,7 +13,6 @@
from transformers.models.bloom.modeling_bloom import BloomAttention

from petals.bloom.block import WrappedBloomBlock
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
Expand Down Expand Up @@ -75,6 +73,12 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
"""

# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
import bitsandbytes as bnb

from petals.utils.linear8bitlt_patch import CustomLinear8bitLt

for n, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)
Expand All @@ -98,7 +102,6 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
def make_tensor_parallel(
block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
):
assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt))
tp_config = get_bloom_config(model_config, devices)
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
Expand Down