-
-
Notifications
You must be signed in to change notification settings - Fork 797
Description
We're seeing 2 failing PEFT/bnb integration tests. This is definitely a bnb issue, since downgrading to bitsandbytes==0.41.2.post2 makes the tests pass.
My best guess is that this is some implicit assumption in the interface between bnb and PEFT. Before, the tensor had already been moved to cuda on the bnb side, now, it hasn't yet. I might be wrong, this will need to be investigated.
@poedator could you take a look and see if this might be related to your work in #868?
For now, we've pinned the bnb version that PEFT pulls in, but we should fix this soon.
P.S. We'll need to establish a firm procedure on how to avoid such failures in the future. I think that this could have been avoided. At a minimum, we have to run the HF/bnb integration tests of their libraries as part of our release process and, optimally, integrate them into a CI system on our side.
See the logs below, which I've run w/ Python 3.8, PEFT 86562ee, bnb==0.41.3:
❯ pytest -m single_gpu_tests tests/test_common_gpu.py -k test_4bit_merge
================================================================ test session starts ================================================================
platform linux -- Python 3.8.18, pytest-7.4.3, pluggy-1.3.0
rootdir: /mnt/D/titus/src/peft
configfile: pyproject.toml
plugins: cov-4.1.0
collected 16 items / 14 deselected / 2 selected
tests/test_common_gpu.py FF [100%]
===================================================================== FAILURES ======================================================================
________________________________________________ PeftGPUCommonTests.test_4bit_merge_and_disable_lora ________________________________________________
self = <tests.test_common_gpu.PeftGPUCommonTests testMethod=test_4bit_merge_and_disable_lora>
@require_torch_gpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_merge_and_disable_lora(self):
torch.manual_seed(3000)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_type=torch.float32,
)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
quantization_config=bnb_config,
torch_dtype=torch.float32,
)
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
# compare outputs in probability space, because logits can have outliers
# and token ids are not precise enough
out_base = F.softmax(model(random_input).logits, dim=-1)
config = LoraConfig(
r=8,
init_lora_weights=False,
)
model = get_peft_model(model, config)
with torch.inference_mode():
out_before = F.softmax(model(random_input).logits, dim=-1)
model.merge_adapter()
with model.disable_adapter():
with torch.inference_mode():
> out_after = F.softmax(model(random_input).logits, dim=-1)
tests/test_common_gpu.py:628:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
src/peft/peft_model.py:534: in forward
return self.get_base_model()(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:879: in forward
outputs = self.model.decoder(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:645: in forward
layer_outputs = decoder_layer(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:299: in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:142: in forward
query_states = self.q_proj(hidden_states) * self.scaling
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
src/peft/tuners/lora/bnb.py:283: in forward
self.unmerge()
src/peft/tuners/lora/bnb.py:266: in unmerge
w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/functional.py:1030: in dequantize_4bit
device = pre_call(A.device)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/functional.py:417: in pre_call
torch.cuda.set_device(device)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/cuda/__init__.py:402: in set_device
device = _get_device_index(device)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
device = device(type='cpu'), optional = False, allow_cpu = False
def _get_device_index(
device: Any, optional: bool = False, allow_cpu: bool = False
) -> int:
r"""Gets the device index from :attr:`device`, which can be a torch.device
object, a Python integer, or ``None``.
If :attr:`device` is a torch.device object, returns the device index if it
is a CUDA device. Note that for a CUDA device without a specified index,
i.e., ``torch.device('cuda')``, this will return the current default CUDA
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
CPU devices will be accepted and ``-1`` will be returned in this case.
If :attr:`device` is a Python integer, it is returned as is.
If :attr:`device` is ``None``, this will return the current default CUDA
device if :attr:`optional` is ``True``.
"""
if isinstance(device, int):
return device
if isinstance(device, str):
device = torch.device(device)
if isinstance(device, torch.device):
if allow_cpu:
if device.type not in ["cuda", "cpu"]:
raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
elif device.type != "cuda":
> raise ValueError(f"Expected a cuda device, but got: {device}")
E ValueError: Expected a cuda device, but got: cpu
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/cuda/_utils.py:35: ValueError
--------------------------------------------------------------- Captured stderr call ----------------------------------------------------------------
config.json: 100%|██████████| 651/651 [00:00<00:00, 156kB/s]
pytorch_model.bin: 100%|██████████| 251M/251M [00:02<00:00, 110MB/s]
generation_config.json: 100%|██████████| 137/137 [00:00<00:00, 102kB/s]
______________________________________________________ PeftGPUCommonTests.test_4bit_merge_lora ______________________________________________________
self = <tests.test_common_gpu.PeftGPUCommonTests testMethod=test_4bit_merge_lora>
@require_torch_gpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_merge_lora(self):
torch.manual_seed(3000)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_type=torch.float32,
)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
quantization_config=bnb_config,
torch_dtype=torch.float32,
)
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
# compare outputs in probability space, because logits can have outliers
# and token ids are not precise enough
out_base = F.softmax(model(random_input).logits, dim=-1)
config = LoraConfig(
r=8,
init_lora_weights=False,
)
model = get_peft_model(model, config)
with torch.inference_mode():
out_before_merge = F.softmax(model(random_input).logits, dim=-1)
model.merge_and_unload()
with torch.inference_mode():
> out_after_merge = F.softmax(model(random_input).logits, dim=-1)
tests/test_common_gpu.py:585:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
src/peft/peft_model.py:534: in forward
return self.get_base_model()(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:879: in forward
outputs = self.model.decoder(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:645: in forward
layer_outputs = decoder_layer(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:299: in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/transformers/models/opt/modeling_opt.py:142: in forward
query_states = self.q_proj(hidden_states) * self.scaling
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1518: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/nn/modules/module.py:1527: in _call_impl
return forward_call(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/accelerate/hooks.py:165: in new_forward
output = module._old_forward(*args, **kwargs)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/nn/modules.py:258: in forward
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/autograd/_functions.py:577: in matmul_4bit
return MatMul4Bit.apply(A, B, out, bias, quant_state)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/autograd/function.py:539: in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/autograd/_functions.py:516: in forward
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/functional.py:1030: in dequantize_4bit
device = pre_call(A.device)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/bitsandbytes/functional.py:417: in pre_call
torch.cuda.set_device(device)
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/cuda/__init__.py:402: in set_device
device = _get_device_index(device)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
device = device(type='cpu'), optional = False, allow_cpu = False
def _get_device_index(
device: Any, optional: bool = False, allow_cpu: bool = False
) -> int:
r"""Gets the device index from :attr:`device`, which can be a torch.device
object, a Python integer, or ``None``.
If :attr:`device` is a torch.device object, returns the device index if it
is a CUDA device. Note that for a CUDA device without a specified index,
i.e., ``torch.device('cuda')``, this will return the current default CUDA
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
CPU devices will be accepted and ``-1`` will be returned in this case.
If :attr:`device` is a Python integer, it is returned as is.
If :attr:`device` is ``None``, this will return the current default CUDA
device if :attr:`optional` is ``True``.
"""
if isinstance(device, int):
return device
if isinstance(device, str):
device = torch.device(device)
if isinstance(device, torch.device):
if allow_cpu:
if device.type not in ["cuda", "cpu"]:
raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
elif device.type != "cuda":
> raise ValueError(f"Expected a cuda device, but got: {device}")
E ValueError: Expected a cuda device, but got: cpu
../../.condax/mamba/envs/peft/lib/python3.8/site-packages/torch/cuda/_utils.py:35: ValueError
================================================================= warnings summary ==================================================================
tests/test_common_gpu.py::PeftGPUCommonTests::test_4bit_merge_and_disable_lora
tests/test_common_gpu.py::PeftGPUCommonTests::test_4bit_merge_lora
/mnt/D/titus/src/peft/src/peft/tuners/lora/bnb.py:229: UserWarning: Merge lora module to 4-bit linear may get different generations due to rounding errors.
warnings.warn(
tests/test_common_gpu.py::PeftGPUCommonTests::test_4bit_merge_and_disable_lora
/mnt/D/titus/src/peft/src/peft/tuners/lora/bnb.py:260: UserWarning: Unmerge lora module to 4-bit linear may get different generations due to rounding errors.
warnings.warn(