Skip to content

bnb==0.41.3: PEFT integration tests failing after most recent release (likely 4-bit serialization PR) #901

@Titus-von-Koeller

Description

@Titus-von-Koeller

We're seeing 2 failing PEFT/bnb integration tests. This is definitely a bnb issue, since downgrading to bitsandbytes==0.41.2.post2 makes the tests pass.

My best guess is that this is some implicit assumption in the interface between bnb and PEFT. Before, the tensor had already been moved to cuda on the bnb side, now, it hasn't yet. I might be wrong, this will need to be investigated.

@poedator could you take a look and see if this might be related to your work in #868?

For now, we've pinned the bnb version that PEFT pulls in, but we should fix this soon.

P.S. We'll need to establish a firm procedure on how to avoid such failures in the future. I think that this could have been avoided. At a minimum, we have to run the HF/bnb integration tests of their libraries as part of our release process and, optimally, integrate them into a CI system on our side.

See the logs below, which I've run w/ Python 3.8, PEFT 86562ee, bnb==0.41.3:

❯ pytest -m single_gpu_tests tests/test_common_gpu.py -k test_4bit_merge
================================================================ test session starts ================================================================
platform linux -- Python 3.8.18, pytest-7.4.3, pluggy-1.3.0
rootdir: /mnt/D/titus/src/peft
configfile: pyproject.toml
plugins: cov-4.1.0
collected 16 items / 14 deselected / 2 selected                                                                                                     

tests/test_common_gpu.py FF                                                                                                                   [100%]

===================================================================== FAILURES ======================================================================
________________________________________________ PeftGPUCommonTests.test_4bit_merge_and_disable_lora ________________________________________________

self = <tests.test_common_gpu.PeftGPUCommonTests testMethod=test_4bit_merge_and_disable_lora>

    @require_torch_gpu
    @pytest.mark.single_gpu_tests
    @require_bitsandbytes
    def test_4bit_merge_and_disable_lora(self):
        torch.manual_seed(3000)
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=False,
            bnb_4bit_compute_type=torch.float32,
        )
        model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            quantization_config=bnb_config,
            torch_dtype=torch.float32,
        )
        random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
        # compare outputs in probability space, because logits can have outliers
        # and token ids are not precise enough
        out_base = F.softmax(model(random_input).logits, dim=-1)
    
        config = LoraConfig(
            r=8,
            init_lora_weights=False,
        )
        model = get_peft_model(model, config)
    
        with torch.inference_mode():
            out_before = F.softmax(model(random_input).logits, dim=-1)
    
        model.merge_adapter()
        with model.disable_adapter():
            with torch.inference_mode():
>               out_after = F.softmax(model(random_input).logits, dim=-1)

tests/test_common_gpu.py:628: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
src/peft/peft_model.py:534: in forward
    return self.get_base_model()(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
    output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:879: in forward
    outputs = self.model.decoder(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
    output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:645: in forward
    layer_outputs = decoder_layer(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
    output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:299: in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
    output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:142: in forward
    query_states = self.q_proj(hidden_states) * self.scaling
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
src/peft/tuners/lora/bnb.py:283: in forward
    self.unmerge()
src/peft/tuners/lora/bnb.py:266: in unmerge
    w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/functional.py:1030: in dequantize_4bit
    device = pre_call(A.device)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/functional.py:417: in pre_call
    torch.cuda.set_device(device)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/cuda/__init__.py:402: in set_device
    device = _get_device_index(device)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

device = device(type='cpu'), optional = False, allow_cpu = False

    def _get_device_index(
        device: Any, optional: bool = False, allow_cpu: bool = False
    ) -> int:
        r"""Gets the device index from :attr:`device`, which can be a torch.device
        object, a Python integer, or ``None``.
    
        If :attr:`device` is a torch.device object, returns the device index if it
        is a CUDA device. Note that for a CUDA device without a specified index,
        i.e., ``torch.device('cuda')``, this will return the current default CUDA
        device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
        CPU devices will be accepted and ``-1`` will be returned in this case.
    
        If :attr:`device` is a Python integer, it is returned as is.
    
        If :attr:`device` is ``None``, this will return the current default CUDA
        device if :attr:`optional` is ``True``.
        """
        if isinstance(device, int):
            return device
        if isinstance(device, str):
            device = torch.device(device)
        if isinstance(device, torch.device):
            if allow_cpu:
                if device.type not in ["cuda", "cpu"]:
                    raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
            elif device.type != "cuda":
>               raise ValueError(f"Expected a cuda device, but got: {device}")
E               ValueError: Expected a cuda device, but got: cpu

../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/cuda/_utils.py:35: ValueError
--------------------------------------------------------------- Captured stderr call ----------------------------------------------------------------
config.json: 100%|██████████| 651/651 [00:00<00:00, 156kB/s]
pytorch_model.bin: 100%|██████████| 251M/251M [00:02<00:00, 110MB/s] 
generation_config.json: 100%|██████████| 137/137 [00:00<00:00, 102kB/s]
______________________________________________________ PeftGPUCommonTests.test_4bit_merge_lora ______________________________________________________

self = <tests.test_common_gpu.PeftGPUCommonTests testMethod=test_4bit_merge_lora>

    @require_torch_gpu
    @pytest.mark.single_gpu_tests
    @require_bitsandbytes
    def test_4bit_merge_lora(self):
        torch.manual_seed(3000)
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=False,
            bnb_4bit_compute_type=torch.float32,
        )
        model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            quantization_config=bnb_config,
            torch_dtype=torch.float32,
        )
        random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
        # compare outputs in probability space, because logits can have outliers
        # and token ids are not precise enough
        out_base = F.softmax(model(random_input).logits, dim=-1)
    
        config = LoraConfig(
            r=8,
            init_lora_weights=False,
        )
        model = get_peft_model(model, config)
    
        with torch.inference_mode():
            out_before_merge = F.softmax(model(random_input).logits, dim=-1)
    
        model.merge_and_unload()
        with torch.inference_mode():
>           out_after_merge = F.softmax(model(random_input).logits, dim=-1)

tests/test_common_gpu.py:585: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
src/peft/peft_model.py:534: in forward
    return self.get_base_model()(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
    output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:879: in forward
    outputs = self.model.decoder(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
    output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:645: in forward
    layer_outputs = decoder_layer(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
    output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:299: in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
    output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:142: in forward
    query_states = self.q_proj(hidden_states) * self.scaling
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
    return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
    output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/nn/modules.py:258: in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/autograd/_functions.py:577: in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/autograd/function.py:539: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/autograd/_functions.py:516: in forward
    output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/functional.py:1030: in dequantize_4bit
    device = pre_call(A.device)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/functional.py:417: in pre_call
    torch.cuda.set_device(device)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/cuda/__init__.py:402: in set_device
    device = _get_device_index(device)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

device = device(type='cpu'), optional = False, allow_cpu = False

    def _get_device_index(
        device: Any, optional: bool = False, allow_cpu: bool = False
    ) -> int:
        r"""Gets the device index from :attr:`device`, which can be a torch.device
        object, a Python integer, or ``None``.
    
        If :attr:`device` is a torch.device object, returns the device index if it
        is a CUDA device. Note that for a CUDA device without a specified index,
        i.e., ``torch.device('cuda')``, this will return the current default CUDA
        device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
        CPU devices will be accepted and ``-1`` will be returned in this case.
    
        If :attr:`device` is a Python integer, it is returned as is.
    
        If :attr:`device` is ``None``, this will return the current default CUDA
        device if :attr:`optional` is ``True``.
        """
        if isinstance(device, int):
            return device
        if isinstance(device, str):
            device = torch.device(device)
        if isinstance(device, torch.device):
            if allow_cpu:
                if device.type not in ["cuda", "cpu"]:
                    raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
            elif device.type != "cuda":
>               raise ValueError(f"Expected a cuda device, but got: {device}")
E               ValueError: Expected a cuda device, but got: cpu

../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/cuda/_utils.py:35: ValueError
================================================================= warnings summary ==================================================================
tests/test_common_gpu.py::PeftGPUCommonTests::test_4bit_merge_and_disable_lora
tests/test_common_gpu.py::PeftGPUCommonTests::test_4bit_merge_lora
  /mnt/D/titus/src/peft/src/peft/tuners/lora/bnb.py:229: UserWarning: Merge lora module to 4-bit linear may get different generations due to rounding errors.
    warnings.warn(

tests/test_common_gpu.py::PeftGPUCommonTests::test_4bit_merge_and_disable_lora
  /mnt/D/titus/src/peft/src/peft/tuners/lora/bnb.py:260: UserWarning: Unmerge lora module to 4-bit linear may get different generations due to rounding errors.
    warnings.warn(

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions