Skip to content

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Sep 17, 2025

What does this PR do?

This PR refactors how TypedDicts are used across processing classes to cut down duplication and avoid mismatches. Key updates:

  • We previously had two separate “base TypedDicts” for images (one in processing, one in fast image processing). They were identical, both defining the same kwargs. Now we keep a single copy and just import it where needed.
  • For models with non-standard kwargs, we often forget to define new ModelVideosKwargs / ModelImagesKwargs. They should exist to properly merge kwargs with ModelProcessingKwargs. This PR removes the manual step, we now dynamically obtain them at runtime from the preprocessor class attrbiutes.
  • The base slow image processor now also exposes a valid_kwargs attribute as a typed dict. With this, both slow and fast image processors share a consistent view of available kwargs
  • Redundant code where the overwritten part is same as defaults from Mixin are deleted
  • Call method on slow image processors also shows hints for kwargs now
image

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

"""

do_reduce_labels: Optional[bool]
BeitFastImageProcessorKwargs = BeitImageProcessorKwargs
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kept for BC as a reference, ideally should be deleted. Maybe in v5?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can do it in v5, especially as I don't expect anyone importing this class! Let's add a clear comment to say it's scheduled to be deleted and is only there for BC in the meantime!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed!

Comment on lines +1287 to +1299
# Some preprocessors define a set of accepted "valid_kwargs" (currently only vision).
# In those cases, we don’t declare a `ModalityKwargs` attribute in the TypedDict.
# Instead, we dynamically obtain the kwargs from the preprocessor and merge them
# with the general kwargs set. This ensures consistency between preprocessor and
# processor classes, and helps prevent accidental mismatches.
modality_valid_kwargs = set(ModelProcessorKwargs.__annotations__[modality].__annotations__)
if modality in map_preprocessor_kwargs:
preprocessor = getattr(self, map_preprocessor_kwargs[modality], None)
preprocessor_valid_kwargs = (
getattr(preprocessor, "valid_kwargs", None) if preprocessor is not None else None
)
modality_valid_kwargs.update(
set(preprocessor_valid_kwargs.__annotations__ if preprocessor_valid_kwargs is not None else [])
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this allows us to not define images_kwargs: SiglipImagesKwargs in processing file. Instead we check if the image processor has valid_kwargs and obtain kwarg names from it

Comment on lines +1364 to +1368
# For `common_kwargs` just update all modality-specific kwargs with same key/values
common_kwargs = kwargs.get("common_kwargs", {})
common_kwargs.update(ModelProcessorKwargs._defaults.get("common_kwargs", {}))
if common_kwargs:
for kwarg in output_kwargs.values():
kwarg.update(common_kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't return out["common_kwarg"] anymore and instead just merge common kwargs for each modality. In all models the only common kwarg is return_tensors so we don't really need a separate field for it

do_resize=True,
size=None,
do_normalize=True,
do_pad=False,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the processor does not use do_pad key

Comment on lines 1059 to 1063
"ImageProcessorFast": "image_processing*_fast", # "*" indicates where to insert the model name before the "_fast" suffix
"ImageProcessorFast": "image_processing.*_fast", # "*" indicates where to insert the model name before the "_fast" suffix
"VideoProcessor": "video_processing",
"VideoProcessorInitKwargs": "video_processing",
"FastImageProcessorKwargs": "image_processing*_fast",
"ImageProcessorKwargs": "image_processing",
"FastImageProcessorKwargs": "image_processing.*_fast",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was not working properly for fast image processors because match_patterns below does not match the file name. This fixes it

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! It's very nice that we try to unify those, but I'm overall still having a hard time understanding the logic.
Why do we need valid_kwargs? If I check for example cohere2 fast image processor, all the Kwargs class is completely redundant with all the class attributes... This makes it hard to read and understand what are the class used for. IMO, we should either drop the Kwargs class completely and use only the class attributes, and drop the class attributes and only use the Kwargs class. But having both is a struggle IMO

Comment on lines -556 to -546
annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None,
masks_path: Optional[Union[str, pathlib.Path]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we lose this?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because they are in kwargs already (due to TypedDict), and otherwise it is duplicated from args and from kwargs

@zucchini-nlp
Copy link
Member Author

zucchini-nlp commented Sep 25, 2025

Why do we need valid_kwargs? If I check for example cohere2 fast image processor, all the Kwargs class is completely redundant with all the class attributes

The "valid_kwargs" field is used by processing class. Previously we had to define XXXImagesKwargs(typedDict) in the processor file and define same typed dict in image processing file (for ex see "Emu3Processor" diffs). Now we will not need to explicitly indicate that processor's attributes accept certain kwargs. We will infer it on-the-fly looking at valid_kwargs field. This is better seen in processing_utils.py where we try to merge kwargs and return them as dict per modality

The class attributes define defaults for the model, while the TypedDict is for typing hints and to show users (and us) what can be passed as a kwarg

You can look at it as having a config class, but in our case "TypedDict" is not meant to hold defaults. It is not a dataclass. For why not use a dataclass - see #40793

@zucchini-nlp zucchini-nlp requested review from Cyrilvallez and removed request for qubvel September 25, 2025 12:25
@zucchini-nlp zucchini-nlp force-pushed the video-kwargs-typed-dict branch from d6418f9 to c0bfac7 Compare September 30, 2025 17:39
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, let's merge then! I did not follow all the conversation on why a dataclass was not good, but we may think about a way to unify both in the future! Would simplify quite a bit!
This PR is still very nice though 🤗

@zucchini-nlp zucchini-nlp changed the title [unbloating] unify TypedDict usage in processing 🚨 [unbloating] unify TypedDict usage in processing Oct 3, 2025
Copy link
Contributor

github-actions bot commented Oct 3, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: aria, aya_vision, beit, blip, blip_2, bridgetower, chameleon, cohere2_vision, colpali, colqwen2, conditional_detr, convnext, csm, deepseek_vl, deepseek_vl_hybrid

@zucchini-nlp zucchini-nlp merged commit 5339f72 into huggingface:main Oct 3, 2025
25 checks passed
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request Oct 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants