Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Nougat] Fix pipeline #28242

Merged
merged 7 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@
# any tokenizer/feature_extractor might be use for a given model so we cannot
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
# see if the model defines such objects or not.
MULTI_MODEL_CONFIGS = {"SpeechEncoderDecoderConfig", "VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig"}
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
MULTI_MODEL_SPEECH_CONFIGS = {"SpeechEncoderDecoderConfig"}
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
MULTI_MODEL_VISION_CONFIGS = {"VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig"}
for task, values in SUPPORTED_TASKS.items():
if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task)
Expand Down Expand Up @@ -896,7 +897,10 @@ def pipeline(
and not load_tokenizer
and normalized_task not in NO_TOKENIZER_TASKS
# Using class name to avoid importing the real class.
and model_config.__class__.__name__ in MULTI_MODEL_CONFIGS
and (
model_config.__class__.__name__ in MULTI_MODEL_SPEECH_CONFIGS
or model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
)
):
# This is a special category of models, that are fusions of multiple models
# so the model_config might not define a tokenizer, but it seems to be
Expand All @@ -907,8 +911,7 @@ def pipeline(
and not load_image_processor
and normalized_task not in NO_IMAGE_PROCESSOR_TASKS
# Using class name to avoid importing the real class.
and model_config.__class__.__name__ in MULTI_MODEL_CONFIGS
and normalized_task != "automatic-speech-recognition"
and model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS
):
# This is a special category of models, that are fusions of multiple models
# so the model_config might not define a tokenizer, but it seems to be
Expand All @@ -919,7 +922,7 @@ def pipeline(
and not load_feature_extractor
and normalized_task not in NO_FEATURE_EXTRACTOR_TASKS
# Using class name to avoid importing the real class.
and model_config.__class__.__name__ in MULTI_MODEL_CONFIGS
and model_config.__class__.__name__ in MULTI_MODEL_SPEECH_CONFIGS
):
# This is a special category of models, that are fusions of multiple models
# so the model_config might not define a tokenizer, but it seems to be
Expand Down
20 changes: 17 additions & 3 deletions tests/pipelines/test_pipelines_image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,16 @@ def test_large_model_tf(self):
@require_torch
def test_conditional_generation_llava(self):
pipe = pipeline("image-to-text", model="llava-hf/bakLlava-v1-hf")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)

prompt = (
"<image>\nUSER: What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud?\nASSISTANT:"
)

outputs = pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
outputs = pipe(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
prompt=prompt,
generate_kwargs={"max_new_tokens": 200},
)
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(
outputs,
[
Expand All @@ -273,3 +275,15 @@ def test_conditional_generation_llava(self):
}
],
)

@slow
@require_torch
def test_nougat(self):
pipe = pipeline("image-to-text", "facebook/nougat-base")

outputs = pipe("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png")

self.assertEqual(
outputs,
[{"generated_text": "# Nougat: Neural Optical Understanding for Academic Documents\n\n Lukas Blec"}],
)