Skip to content

ColPaliForRetrieval errors out when loaded in half precision dtypes #40875

@merveenoyan

Description

@merveenoyan

System Info

transformers version: transformers==4.56.1

Here's the error, this can be fixed by setting dtype to float32. float16 and bfloat16 won't work.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[/tmp/ipython-input-2540780684.py](https://localhost:8080/#) in <cell line: 0>()
      2     image_inputs = processor(images=images)
      3     image_inputs = image_inputs.to(model.device, model.dtype)
----> 4     image_outputs = model(**image_inputs)
      5     image_embeddings_torch = image_outputs.embeddings

23 frames
[/usr/local/lib/python3.12/dist-packages/transformers/integrations/sdpa_attention.py](https://localhost:8080/#) in sdpa_attention_forward(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs)
     81         is_causal = is_causal.item()
     82 
---> 83     attn_output = torch.nn.functional.scaled_dot_product_attention(
     84         query,
     85         key,

RuntimeError: Expected attn_mask dtype to be bool or float or to match query dtype, but got attn_mask.dtype: c10::Half and  query.dtype: float instead.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import ColPaliForRetrieval, ColPaliProcessor, infer_device
import torch

device = infer_device()
model = ColPaliForRetrieval.from_pretrained(
"vidore/colpali-v1.3-hf",
dtype=torch.float16, # can also be bfloat16
).to(device)

processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.3-hf")

with torch.no_grad():
image_inputs = processor(images=images)
image_inputs = image_inputs.to(model.device, model.dtype)
image_outputs = model(**image_inputs)
image_embeddings_torch = image_outputs.embeddings

Expected behavior

I'd expect float16 and bfloat16 to work.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions