Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/auto_docstring.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ The `@auto_docstring` decorator automatically generates docstrings by:

8. Unrolling kwargs typed with the unpack operator. For specific methods (defined in `UNROLL_KWARGS_METHODS`) or classes (defined in `UNROLL_KWARGS_CLASSES`), the decorator processes `**kwargs` parameters that are typed with `Unpack[KwargsTypedDict]`. It extracts the documentations from the `TypedDict` and adds each parameter to the function's docstring.

Currently only supported for [`FastImageProcessorKwargs`].
Currently only supported for [`ImagesKwargs`].

## Best practices

Expand Down
9 changes: 6 additions & 3 deletions src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

from .image_processing_base import BatchFeature, ImageProcessingMixin
from .image_transforms import center_crop, normalize, rescale
from .image_utils import ChannelDimension, get_image_size
from .image_utils import ChannelDimension, ImageInput, get_image_size
from .processing_utils import ImagesKwargs, Unpack
from .utils import logging
from .utils.import_utils import requires

Expand All @@ -36,6 +37,8 @@

@requires(backends=("vision",))
class BaseImageProcessor(ImageProcessingMixin):
valid_kwargs = ImagesKwargs

def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -46,9 +49,9 @@ def is_fast(self) -> bool:
"""
return False

def __call__(self, images, **kwargs) -> BatchFeature:
def __call__(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
"""Preprocess an image or a batch of images."""
return self.preprocess(images, **kwargs)
return self.preprocess(images, *args, **kwargs)

def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")
Expand Down
37 changes: 6 additions & 31 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Iterable
from copy import deepcopy
from functools import lru_cache, partial
from typing import Any, Optional, TypedDict, Union
from typing import Any, Optional, Union

import numpy as np

Expand All @@ -40,7 +40,7 @@
validate_kwargs,
validate_preprocess_arguments,
)
from .processing_utils import Unpack
from .processing_utils import ImagesKwargs, Unpack
from .utils import (
TensorType,
auto_docstring,
Expand Down Expand Up @@ -163,28 +163,6 @@ def divide_to_patches(
return patches


class DefaultFastImageProcessorKwargs(TypedDict, total=False):
do_resize: Optional[bool]
size: Optional[dict[str, int]]
default_to_square: Optional[bool]
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]]
do_center_crop: Optional[bool]
crop_size: Optional[dict[str, int]]
do_rescale: Optional[bool]
rescale_factor: Optional[Union[int, float]]
do_normalize: Optional[bool]
image_mean: Optional[Union[float, list[float]]]
image_std: Optional[Union[float, list[float]]]
do_pad: Optional[bool]
pad_size: Optional[dict[str, int]]
do_convert_rgb: Optional[bool]
return_tensors: Optional[Union[str, TensorType]]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
device: Optional["torch.device"]
disable_grouping: Optional[bool]


@auto_docstring
class BaseImageProcessorFast(BaseImageProcessor):
resample = None
Expand All @@ -206,10 +184,10 @@ class BaseImageProcessorFast(BaseImageProcessor):
input_data_format = None
device = None
model_input_names = ["pixel_values"]
valid_kwargs = DefaultFastImageProcessorKwargs
valid_kwargs = ImagesKwargs
unused_kwargs = None

def __init__(self, **kwargs: Unpack[DefaultFastImageProcessorKwargs]):
def __init__(self, **kwargs: Unpack[ImagesKwargs]):
super().__init__(**kwargs)
kwargs = self.filter_out_unused_kwargs(kwargs)
size = kwargs.pop("size", self.size)
Expand Down Expand Up @@ -728,11 +706,8 @@ def _validate_preprocess_kwargs(
data_format=data_format,
)

def __call__(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
return self.preprocess(images, *args, **kwargs)

@auto_docstring
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImageProcessorKwargs]) -> BatchFeature:
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
# args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names)
# Set default kwargs from self. This ensures that if a kwarg is not provided
Expand Down Expand Up @@ -765,7 +740,7 @@ def _preprocess_image_like_inputs(
do_convert_rgb: bool,
input_data_format: ChannelDimension,
device: Optional[Union[str, "torch.device"]] = None,
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
**kwargs: Unpack[ImagesKwargs],
) -> BatchFeature:
"""
Preprocess image-like inputs.
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,8 +959,6 @@ def __call__(
self,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
images: Optional[ImageInput] = None,
audio=None,
videos=None,
**kwargs: Unpack[AriaProcessorKwargs],
) -> BatchFeature:
"""
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/aria/processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def __call__(
self,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
images: Optional[ImageInput] = None,
audio=None,
videos=None,
**kwargs: Unpack[AriaProcessorKwargs],
) -> BatchFeature:
"""
Expand Down
11 changes: 1 addition & 10 deletions src/transformers/models/aya_vision/processing_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,11 @@

from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput, make_flat_list_of_images
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput


class AyaVisionImagesKwargs(ImagesKwargs, total=False):
crop_to_patches: Optional[bool]
min_patches: Optional[int]
max_patches: Optional[int]


class AyaVisionProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: AyaVisionImagesKwargs
_defaults = {
"text_kwargs": {
"padding_side": "left",
Expand Down Expand Up @@ -140,8 +133,6 @@ def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[AyaVisionProcessorKwargs],
) -> BatchFeature:
"""
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/models/beit/image_processing_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
valid_images,
validate_preprocess_arguments,
)
from ...processing_utils import ImagesKwargs
from ...utils import (
TensorType,
filter_out_non_signature_kwargs,
Expand All @@ -54,6 +55,17 @@
logger = logging.get_logger(__name__)


class BeitImageProcessorKwargs(ImagesKwargs):
r"""
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
"""

do_reduce_labels: Optional[bool]


@requires(backends=("vision",))
class BeitImageProcessor(BaseImageProcessor):
r"""
Expand Down Expand Up @@ -99,6 +111,7 @@ class BeitImageProcessor(BaseImageProcessor):
"""

model_input_names = ["pixel_values"]
valid_kwargs = BeitImageProcessorKwargs

@filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS)
def __init__(
Expand Down
21 changes: 5 additions & 16 deletions src/transformers/models/beit/image_processing_beit_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
Expand All @@ -40,17 +39,7 @@
TensorType,
auto_docstring,
)


class BeitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
r"""
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
"""

do_reduce_labels: Optional[bool]
from .image_processing_beit import BeitImageProcessorKwargs


@auto_docstring
Expand All @@ -66,9 +55,9 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
do_rescale = True
do_normalize = True
do_reduce_labels = False
valid_kwargs = BeitFastImageProcessorKwargs
valid_kwargs = BeitImageProcessorKwargs

def __init__(self, **kwargs: Unpack[BeitFastImageProcessorKwargs]):
def __init__(self, **kwargs: Unpack[BeitImageProcessorKwargs]):
super().__init__(**kwargs)

def reduce_label(self, labels: list["torch.Tensor"]):
Expand All @@ -86,7 +75,7 @@ def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
**kwargs: Unpack[BeitFastImageProcessorKwargs],
**kwargs: Unpack[BeitImageProcessorKwargs],
) -> BatchFeature:
r"""
segmentation_maps (`ImageInput`, *optional*):
Expand All @@ -101,7 +90,7 @@ def _preprocess_image_like_inputs(
do_convert_rgb: bool,
input_data_format: ChannelDimension,
device: Optional[Union[str, "torch.device"]] = None,
**kwargs: Unpack[BeitFastImageProcessorKwargs],
**kwargs: Unpack[BeitImageProcessorKwargs],
) -> BatchFeature:
"""
Preprocess image-like inputs.
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/blip/processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class BlipProcessorKwargs(ProcessingKwargs, total=False):
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}


Expand Down Expand Up @@ -67,8 +66,6 @@ def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[str, list[str], TextInput, PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[BlipProcessorKwargs],
) -> BatchEncoding:
"""
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class Blip2ProcessorKwargs(ProcessingKwargs, total=False):
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}


Expand Down Expand Up @@ -81,8 +80,6 @@ def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[str, list[str], TextInput, PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[Blip2ProcessorKwargs],
) -> BatchEncoding:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
valid_images,
validate_preprocess_arguments,
)
from ...processing_utils import ImagesKwargs
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging


Expand Down Expand Up @@ -122,6 +123,10 @@ def get_resize_output_image_size(
return new_height, new_width


class BridgeTowerImageProcessorKwargs(ImagesKwargs):
size_divisor: Optional[int]


class BridgeTowerImageProcessor(BaseImageProcessor):
r"""
Constructs a BridgeTower image processor.
Expand Down Expand Up @@ -169,6 +174,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
"""

model_input_names = ["pixel_values", "pixel_mask"]
valid_kwargs = BridgeTowerImageProcessorKwargs

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
ImageInput,
SizeDict,
TensorType,
Expand All @@ -33,6 +32,7 @@
)
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
from ...utils import auto_docstring
from .image_processing_bridgetower import BridgeTowerImageProcessorKwargs


def make_pixel_mask(
Expand Down Expand Up @@ -85,17 +85,6 @@ def get_resize_output_image_size(
return new_height, new_width


class BridgeTowerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
Args:
size_divisor (`int`, *optional*, defaults to 32):
The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
"""

size_divisor: Optional[int]


@auto_docstring
class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
Expand All @@ -110,14 +99,14 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
do_normalize = True
do_pad = True
size_divisor = 32
valid_kwargs = BridgeTowerFastImageProcessorKwargs
valid_kwargs = BridgeTowerImageProcessorKwargs
model_input_names = ["pixel_values", "pixel_mask"]

def __init__(self, **kwargs: Unpack[BridgeTowerFastImageProcessorKwargs]):
def __init__(self, **kwargs: Unpack[BridgeTowerImageProcessorKwargs]):
super().__init__(**kwargs)

@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[BridgeTowerFastImageProcessorKwargs]) -> BatchFeature:
def preprocess(self, images: ImageInput, **kwargs: Unpack[BridgeTowerImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)

def resize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,10 @@
Processor class for BridgeTower.
"""

from typing import Optional

from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin


class BridgeTowerImagesKwargs(ImagesKwargs):
size_divisor: Optional[int]
from ...processing_utils import ProcessingKwargs, ProcessorMixin


class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: BridgeTowerImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
Expand Down
Loading