Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bnb] Fix blip2 4bit #23895

Closed
wants to merge 2 commits into from

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

Fixes #23839

Indeed, for models such as Blip2 that have lm head inside submodules (and not directly on the top level of the model), the lm head does get converted to 4bit / 8bit models, leading to unexpected behavior for 4bit models. The PR fixes this by making sure to consider the last term after . when creating modules_not_to_convert.

cc @sgugger

@younesbelkada younesbelkada mentioned this pull request May 31, 2023
4 tasks
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a tiny bit brittle. Are you sure it doesn't break any other model quantization?

@younesbelkada
Copy link
Contributor Author

It should be all good, I have verified the slow tests pass for 8bit and 4bit. Let me know if there is anything particular I should have a look. Per my understanding this only affects Blip2 as it is the only model (from what I know) that have an lm head as part of a submodule.

@huggingface huggingface deleted a comment from github-actions bot Jul 6, 2023
@younesbelkada
Copy link
Contributor Author

Hmm getting some gibberish output with the fix, need to investigate more

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Aug 7, 2023
@kevinknights29
Copy link

I know this was closed, but I'm getting the following error:
FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.

Runner: Docker container using python:3.9 image.

Usage:

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to("cuda:0" if torch.cuda.is_available() else "cpu")
outputs = model.generate(input_ids, max_length=500)

Model class:

class ModelLoader:
    def __init__(self, model_path: str):
        self.model_path = model_path
        self.config = AutoConfig.from_pretrained(
            self.model_path,
            trust_remote_code=True,
            use_auth_token=os.getenv("HUGGINGFACE_TOKEN"),
        )
        self.model = self._load_model()
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_path,
            use_auth_token=os.getenv("HUGGINGFACE_TOKEN"),
        )

    def _load_model(self):
        if torch.cuda.is_available():
            return AutoModelForCausalLM.from_pretrained(
                self.model_path,
                config=self.config,
                trust_remote_code=True,
                device_map="cuda:0", # or "auto"
                use_auth_token=os.getenv("HUGGINGFACE_TOKEN"),
            )
        return AutoModelForCausalLM.from_pretrained(
            self.model_path,
            config=self.config,
            trust_remote_code=True,
            load_in_8bit=True,
            device_map="cpu",
            torch_dtype=torch.float16,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True, 
                bnb_4bit_compute_dtype=torch.bfloat16,
                llm_int8_enable_fp32_cpu_offload=True,
            ),
            use_auth_token=os.getenv("HUGGINGFACE_TOKEN"),
        )

Complete traceback:

2023-09-01 18:34:47 Traceback (most recent call last):
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/celery/app/trace.py", line 477, in trace_task
2023-09-01 18:34:47     R = retval = fun(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/celery/app/trace.py", line 760, in __protected_call__
2023-09-01 18:34:47     return self.run(*args, **kwargs)
2023-09-01 18:34:47   File "/app/src/celery/celery.py", line 29, in generate_text_task
2023-09-01 18:34:47     time, memory, outputs = generate_output(
2023-09-01 18:34:47   File "/app/src/utils/utils.py", line 36, in wrapper
2023-09-01 18:34:47     result, exec_time = func(*args, **kwargs)
2023-09-01 18:34:47   File "/app/src/utils/utils.py", line 16, in wrapper
2023-09-01 18:34:47     result = func(*args, **kwargs)
2023-09-01 18:34:47   File "/app/src/utils/utils.py", line 53, in generate_output
2023-09-01 18:34:47     outputs = model.generate(input_ids, max_length=500)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
2023-09-01 18:34:47     return func(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/transformers/generation/utils.py", line 1572, in generate
2023-09-01 18:34:47     return self.sample(
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/transformers/generation/utils.py", line 2619, in sample
2023-09-01 18:34:47     outputs = self(
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
2023-09-01 18:34:47     return forward_call(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
2023-09-01 18:34:47     output = old_forward(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 688, in forward
2023-09-01 18:34:47     outputs = self.model(
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
2023-09-01 18:34:47     return forward_call(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
2023-09-01 18:34:47     output = old_forward(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 578, in forward
2023-09-01 18:34:47     layer_outputs = decoder_layer(
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
2023-09-01 18:34:47     return forward_call(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
2023-09-01 18:34:47     output = old_forward(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward
2023-09-01 18:34:47     hidden_states, self_attn_weights, present_key_value = self.self_attn(
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
2023-09-01 18:34:47     return forward_call(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
2023-09-01 18:34:47     output = old_forward(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 194, in forward
2023-09-01 18:34:47     query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
2023-09-01 18:34:47     return forward_call(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
2023-09-01 18:34:47     output = old_forward(*args, **kwargs)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/bitsandbytes/nn/modules.py", line 248, in forward
2023-09-01 18:34:47     out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
2023-09-01 18:34:47   File "/usr/local/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py", line 567, in matmul_4bit
2023-09-01 18:34:47     assert quant_state is not None
2023-09-01 18:34:47 AssertionError

@SunMarc
Copy link
Member

SunMarc commented Sep 22, 2023

Hi @kevinknights29 , I see that in your script, you are trying to load in 8-bit and in 4-bit at the same time. Please select only one option.

return AutoModelForCausalLM.from_pretrained(
            self.model_path,
            config=self.config,
            trust_remote_code=True,
# either remove load_in_8bit arg
            load_in_8bit=True,
            device_map="cpu",
            torch_dtype=torch.float16,
# or remove quantization_config 
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True, 
                bnb_4bit_compute_dtype=torch.bfloat16,
                llm_int8_enable_fp32_cpu_offload=True,
            )

On my side I was able to execute the following script with :

  • transformers version: 4.34.0.dev0 (main branch)
  • accelerate version: 0.23
  • bitsandbytes version: 0.41.1
import torch
from transformers import Blip2ForConditionalGeneration, Blip2Processor, BitsAndBytesConfig
from PIL import Image
import requests

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b-coco")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b-coco", device_map='auto', quantization_config=nf4_config)

def prepare_img():
    url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    return image

image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(dtype=torch.float16)

predictions = model.generate(**inputs, num_beams=2)
print(processor.batch_decode(predictions, skip_special_tokens=True)[0].strip())
# print -> a woman sitting on the beach with her dog

@robinsonmhj
Copy link

I am getting similar error while using llma2 7b model, and I am using the latest version of transformers

here is the code

from transformers import AutoTokenizer, set_seed, BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
model_name = 'llm-models/Llama-2-7b-hf'
model = AutoModelForCausalLM.from_pretrained(
    model_name,  
    quantization_config=bnb_config,
    device_map="cuda",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map='cuda')
def generate_text(prompt):
    # Tokenize the prompt
    inputs = tokenizer.encode(prompt, return_tensors='pt')
    
    print(f'inputs is {inputs} on {inputs.device}')
    
    inputs = inputs.to('cuda:0')
    
    print(f'inputs is {inputs} on {inputs.device}')
    
    # Generate a response
    outputs = model.generate(inputs)
    
    # Decode the response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return response

prompt = 'User1: Hey, I need a new laptop. Which one should I buy?'
response = generate_text(prompt)
print(response)
package info
transformers==4.38.1
accelerate==0.21.0
bitsandbytes==0.42.0

I also tried 4.34, it doesn't work either. Besides that, I check this #23895, it doesn't look like that it is in any of the release branch nor the maser branch

here is the error I get

FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.
AssertionError Traceback (most recent call last)
Cell In[18], line 2
1 prompt = 'User1: Hey, I need a new laptop. Which one should I buy?'
----> 2 response = generate_text(prompt)
3 print(response)

Cell In[17], line 13, in generate_text(prompt)
10 print(f'inputs is {inputs} on {inputs.device}')
12 # Generate a response
---> 13 outputs = model.generate(inputs)
15 # Decode the response
16 response = tokenizer.decode(outputs[0], skip_special_tokens=True)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/transformers/generation/utils.py:1345, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
1337 logger.warning(
1338 "A decoder-only architecture is being used, but right-padding was detected! For correct "
1339 "generation results, please set padding_side='left' when initializing the tokenizer."
1340 )
1342 if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
1343 # if model is encoder decoder encoder_outputs are created
1344 # and added to model_kwargs
-> 1345 model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
1346 inputs_tensor, model_kwargs, model_input_name
1347 )
1349 # 5. Prepare input_ids which will be used for auto-regressive generation
1350 if self.config.is_encoder_decoder:

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/transformers/generation/utils.py:644, in GenerationMixin._prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor, model_kwargs, model_input_name)
642 encoder_kwargs["return_dict"] = True
643 encoder_kwargs[model_input_name] = inputs_tensor
--> 644 model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
646 return model_kwargs

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1094, in T5Stack.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
1081 layer_outputs = checkpoint(
1082 create_custom_forward(layer_module),
1083 hidden_states,
(...)
1091 None, # past_key_value is always None with gradient checkpointing
1092 )
1093 else:
-> 1094 layer_outputs = layer_module(
1095 hidden_states,
1096 attention_mask=extended_attention_mask,
1097 position_bias=position_bias,
1098 encoder_hidden_states=encoder_hidden_states,
1099 encoder_attention_mask=encoder_extended_attention_mask,
1100 encoder_decoder_position_bias=encoder_decoder_position_bias,
1101 layer_head_mask=layer_head_mask,
1102 cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1103 past_key_value=past_key_value,
1104 use_cache=use_cache,
1105 output_attentions=output_attentions,
1106 )
1108 # layer_outputs is a tuple with:
1109 # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1110 if use_cache is False:

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:694, in T5Block.forward(self, hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, layer_head_mask, cross_attn_layer_head_mask, past_key_value, use_cache, output_attentions, return_dict)
691 else:
692 self_attn_past_key_value, cross_attn_past_key_value = None, None
--> 694 self_attention_outputs = self.layer[0](
695 hidden_states,
696 attention_mask=attention_mask,
697 position_bias=position_bias,
698 layer_head_mask=layer_head_mask,
699 past_key_value=self_attn_past_key_value,
700 use_cache=use_cache,
701 output_attentions=output_attentions,
702 )
703 hidden_states, present_key_value_state = self_attention_outputs[:2]
704 attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:601, in T5LayerSelfAttention.forward(self, hidden_states, attention_mask, position_bias, layer_head_mask, past_key_value, use_cache, output_attentions)
590 def forward(
591 self,
592 hidden_states,
(...)
598 output_attentions=False,
599 ):
600 normed_hidden_states = self.layer_norm(hidden_states)
--> 601 attention_output = self.SelfAttention(
602 normed_hidden_states,
603 mask=attention_mask,
604 position_bias=position_bias,
605 layer_head_mask=layer_head_mask,
606 past_key_value=past_key_value,
607 use_cache=use_cache,
608 output_attentions=output_attentions,
609 )
610 hidden_states = hidden_states + self.dropout(attention_output[0])
611 outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:520, in T5Attention.forward(self, hidden_states, mask, key_value_states, position_bias, past_key_value, layer_head_mask, query_length, use_cache, output_attentions)
517 return hidden_states
519 # get query states
--> 520 query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
522 # get key/value states
523 key_states = project(
524 hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
525 )

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:256, in Linear4bit.forward(self, x)
253 x = x.to(self.compute_dtype)
255 bias = None if self.bias is None else self.bias.to(self.compute_dtype)
--> 256 out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
258 out = out.to(inp_dtype)
260 return out

File /opt/conda/envs/domino-ray/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:566, in matmul_4bit(A, B, quant_state, out, bias)
565 def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
--> 566 assert quant_state is not None
567 if A.numel() == A.shape[-1] and A.requires_grad == False:
568 if A.shape[-1] % quant_state.blocksize != 0:

AssertionError:

@robinsonmhj
Copy link

robinsonmhj commented Feb 22, 2024

Found the issue, changing device_map to 'auto' fix the issue. Anyone can explain why?
Instead of

model = AutoModelForCausalLM.from_pretrained(
    model_name,  
    quantization_config=bnb_config,
    device_map="cuda",
    trust_remote_code=True,
)

it should be

model = AutoModelForCausalLM.from_pretrained(
    model_name,  
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

transformers and accelerate versions are as belpow
transformers==4.31.0
accelerate==0.21.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4bit Blip2 compatibility
6 participants