Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google/siglip-so400m-patch14-384 inference output mismatch with pipeline output #30951

Closed
2 of 4 tasks
aliencaocao opened this issue May 22, 2024 · 5 comments · Fixed by #31343
Closed
2 of 4 tasks

google/siglip-so400m-patch14-384 inference output mismatch with pipeline output #30951

aliencaocao opened this issue May 22, 2024 · 5 comments · Fixed by #31343
Labels
Examples Which is related to examples in general Multimodal

Comments

@aliencaocao
Copy link
Contributor

System Info

  • transformers version: 4.41.0
  • Platform: Windows-10-10.0.19041-SP0
  • Python version: 3.9.13
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: NO
    - mixed_precision: fp16
    - use_cpu: False
    - debug: False
    - num_processes: 1
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - dynamo_config: {'dynamo_backend': 'INDUCTOR'}
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): 2.16.1 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@amyeroberts

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Using sample code from https://huggingface.co/google/siglip-so400m-patch14-384:

from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch

model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384")
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

texts = ["a photo of 2 cats", "a photo of 2 dogs"]
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image) # these are the probabilities
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")

The output mismatches with the pipeline approach:

from transformers import pipeline
from PIL import Image
import requests

# load pipe
image_classifier = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384")

# load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# inference
outputs = image_classifier(image, candidate_labels=["2 cats", "a plane", "a remote"])
outputs = [{"score": round(output["score"], 4), "label": output["label"] } for output in outputs]
print(outputs)

and the difference is massive.
Pipeline approach seem to give the right results, which also align with inference API.

Expected behavior

Correct result for the manual approach

@jla524
Copy link
Contributor

jla524 commented May 22, 2024

I found that the pipeline uses a prompt, which adds "This is a photo of " before every candidate label.

For example, when we pass in the label "2 cats", the pipeline converts it to "This is is a photo of 2 cats".

We can modify the code to match this behaviour with the manual approach, and we'll get the same results:

from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch

model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384")
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

labels = ["2 cats", "this is a photo of 2 dogs"]
texts = [f"this is a photo of {label}" for label in labels]
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image) # these are the probabilities
print(f"{probs[0][0]:.2%} that '{texts[0]}'")  # 50.89% that 'this is a photo of 2 cats

@aliencaocao
Copy link
Contributor Author

Wow that requires some docs...

@amyeroberts
Copy link
Collaborator

Hi @aliencaocao, thanks for raising this issue!

This is in the docs how it isn't super obvious. If you'd like to open a PR to update the example for the pipeline to highlight this I'd be very happy to review.

@NielsRogge Could you fix the example on the siglip page, as you'll have permissions to open and merge the PR there?

@amyeroberts amyeroberts added Examples Which is related to examples in general Multimodal labels May 22, 2024
@aliencaocao
Copy link
Contributor Author

If you'd like to open a PR to update the example for the pipeline to highlight this I'd be very happy to review.

Sure i will do it soon, thanks for pointing that out

@aliencaocao
Copy link
Contributor Author

PR made

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Examples Which is related to examples in general Multimodal
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants