From ce75466a41e0558acedf13c9de40ea5225e34694 Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Mon, 10 Feb 2025 13:56:04 +0000 Subject: [PATCH 1/9] More robust from_pretrained init_kwargs type checking --- src/diffusers/pipelines/pipeline_utils.py | 99 +++++++++++++++++------ 1 file changed, 75 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2fde0bb9f861..6851f780d60a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import fnmatch import importlib import inspect @@ -22,7 +21,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin import numpy as np import PIL.Image @@ -864,26 +863,6 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} - for key in init_dict.keys(): - if key not in passed_class_obj: - continue - if "scheduler" in key: - continue - - class_obj = passed_class_obj[key] - _expected_class_types = [] - for expected_type in expected_types[key]: - if isinstance(expected_type, enum.EnumMeta): - _expected_class_types.extend(expected_type.__members__.keys()) - else: - _expected_class_types.append(expected_type.__name__) - - _is_valid_type = class_obj.__class__.__name__ in _expected_class_types - if not _is_valid_type: - logger.warning( - f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." - ) - # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( @@ -1003,10 +982,82 @@ def load_module(name, value): f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) - # 10. Instantiate the pipeline + # 10. Type checking init arguments + def is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: + if not isinstance(class_or_tuple, tuple): + class_or_tuple = (class_or_tuple,) + + # Unpack unions + unpacked_class_or_tuple = [] + for t in class_or_tuple: + if get_origin(t) is Union: + unpacked_class_or_tuple.extend(get_args(t)) + else: + unpacked_class_or_tuple.append(t) + class_or_tuple = tuple(unpacked_class_or_tuple) + + if Any in class_or_tuple: + return True + + obj_type = type(obj) + # Classes with obj's type + class_or_tuple = {t for t in class_or_tuple if (get_origin(t) or t) is obj_type} + + # Singular types (e.g. int, ControlNet, ...) + # Untyped collections (e.g. List, but not List[int]) + elem_class_or_tuple = {get_args(t) for t in class_or_tuple} + if () in elem_class_or_tuple: + return True + # Typed lists or sets + elif obj_type in (list, set): + return any(all(is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) + # Typed tuples + elif obj_type is tuple: + return any( + # Tuples with any length and single type (e.g. Tuple[int, ...]) + (len(t) == 2 and t[-1] is Ellipsis and all(is_valid_type(x, t[0]) for x in obj)) + or + # Tuples with fixed length and any types (e.g. Tuple[int, str]) + (len(obj) == len(t) and all(is_valid_type(x, tt) for x, tt in zip(obj, t))) + for t in elem_class_or_tuple + ) + # Typed dicts + elif obj_type is dict: + return any( + all(is_valid_type(k, kt) and is_valid_type(v, vt) for k, v in obj.items()) + for kt, vt in elem_class_or_tuple + ) + + else: + return False + + def get_detailed_type(obj: Any) -> Type: + obj_type = type(obj) + + if obj_type in (list, set): + obj_origin_type = List if obj_type is list else Set + elems_type = Union[*{get_detailed_type(x) for x in obj}] + return obj_origin_type[elems_type] + elif obj_type is tuple: + return Tuple[tuple(get_detailed_type(x) for x in obj)] + elif obj_type is dict: + keys_type = Union[*{get_detailed_type(k) for k in obj.keys()}] + values_type = Union[*{get_detailed_type(k) for k in obj.values()}] + return Dict[keys_type, values_type] + else: + return obj_type + + for key, class_obj in init_kwargs.items(): + if "scheduler" in key: + continue + + if class_obj is not None and not is_valid_type(class_obj, expected_types[key]): + logger.warning(f"Expected types for {key}: {expected_types[key]}, got {get_detailed_type(class_obj)}.") + + # 11. Instantiate the pipeline model = pipeline_class(**init_kwargs) - # 11. Save where the model was instantiated from + # 12. Save where the model was instantiated from model.register_to_config(_name_or_path=pretrained_model_name_or_path) if device_map is not None: setattr(model, "hf_device_map", final_device_map) From b1f26c53477b16a33190e7171d8729ca99817163 Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Mon, 10 Feb 2025 14:38:52 +0000 Subject: [PATCH 2/9] Corrected for Python 3.10 --- src/diffusers/pipelines/pipeline_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6851f780d60a..0eae1ac75ba4 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1036,13 +1036,13 @@ def get_detailed_type(obj: Any) -> Type: if obj_type in (list, set): obj_origin_type = List if obj_type is list else Set - elems_type = Union[*{get_detailed_type(x) for x in obj}] + elems_type = Union[tuple({get_detailed_type(x) for x in obj})] return obj_origin_type[elems_type] elif obj_type is tuple: return Tuple[tuple(get_detailed_type(x) for x in obj)] elif obj_type is dict: - keys_type = Union[*{get_detailed_type(k) for k in obj.keys()}] - values_type = Union[*{get_detailed_type(k) for k in obj.values()}] + keys_type = Union[tuple({get_detailed_type(k) for k in obj.keys()})] + values_type = Union[tuple({get_detailed_type(k) for k in obj.values()})] return Dict[keys_type, values_type] else: return obj_type From 5ca27aaf09f0ad9771692d81125869cbdd6271f6 Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Wed, 12 Feb 2025 11:13:06 +0000 Subject: [PATCH 3/9] Type checks subclasses and fixed type warnings --- .../pipeline_animatediff_video2video.py | 2 +- ...line_animatediff_video2video_controlnet.py | 2 +- .../pipeline_hunyuandit_controlnet.py | 4 ++-- .../pipeline_dance_diffusion.py | 4 +++- src/diffusers/pipelines/ddim/pipeline_ddim.py | 3 ++- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 4 +++- .../deprecated/repaint/pipeline_repaint.py | 2 +- .../hunyuandit/pipeline_hunyuandit.py | 4 ++-- .../pipelines/lumina/pipeline_lumina.py | 12 +++++------ .../pipelines/pag/pipeline_pag_sana.py | 6 +++--- src/diffusers/pipelines/pipeline_utils.py | 2 +- src/diffusers/pipelines/sana/pipeline_sana.py | 6 +++--- .../stable_cascade/pipeline_stable_cascade.py | 6 +++--- .../pipeline_stable_cascade_combined.py | 20 +++++++++++-------- .../pipeline_stable_unclip.py | 2 +- tests/fixtures/custom_pipeline/pipeline.py | 4 ++-- tests/fixtures/custom_pipeline/what_ever.py | 3 ++- 17 files changed, 48 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index edac6bfd9e4e..59a473e32ae1 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -224,7 +224,7 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, scheduler: Union[ DDIMScheduler, diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py index 1a75d658b3ad..fd4d5346f7c1 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py @@ -246,7 +246,7 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: Union[ diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py index f01c8cc4674d..5ee712b5f116 100644 --- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -232,8 +232,8 @@ def __init__( Tuple[HunyuanDiT2DControlNetModel], HunyuanDiT2DMultiControlNetModel, ], - text_encoder_2=T5EncoderModel, - tokenizer_2=MT5Tokenizer, + text_encoder_2: Optional[T5EncoderModel] = None, + tokenizer_2: Optional[MT5Tokenizer] = None, requires_safety_checker: bool = True, ): super().__init__() diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py index ed342f66804a..34b2a3945572 100644 --- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -17,6 +17,8 @@ import torch +from ...models import UNet1DModel +from ...schedulers import SchedulerMixin from ...utils import is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline @@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline): model_cpu_offload_seq = "unet" - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 1b424f5742f2..1fd8ce4e6570 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -16,6 +16,7 @@ import torch +from ...models import UNet2DModel from ...schedulers import DDIMScheduler from ...utils import is_torch_xla_available from ...utils.torch_utils import randn_tensor @@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline): model_cpu_offload_seq = "unet" - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler): super().__init__() # make sure scheduler can always be converted to DDIM diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index e58a53b5b7e8..1c5ac4baeae0 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -17,6 +17,8 @@ import torch +from ...models import UNet2DModel +from ...schedulers import DDPMScheduler from ...utils import is_torch_xla_available from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline): model_cpu_offload_seq = "unet" - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py b/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py index 101d315dfe59..843528a532f1 100644 --- a/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py @@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline): scheduler: RePaintScheduler model_cpu_offload_seq = "unet" - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index 6a5cf298d2d4..febf2b0392cc 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -207,8 +207,8 @@ def __init__( safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, - text_encoder_2=T5EncoderModel, - tokenizer_2=MT5Tokenizer, + text_encoder_2: Optional[T5EncoderModel] = None, + tokenizer_2: Optional[MT5Tokenizer] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 133cb2c5f146..5128e20eaacd 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -20,7 +20,7 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import AutoModel, AutoTokenizer +from transformers import PreTrainedModel, PreTrainedTokenizerBase from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL @@ -143,13 +143,13 @@ class LuminaText2ImgPipeline(DiffusionPipeline): Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`AutoModel`]): + text_encoder ([`PreTrainedModel`]): Frozen text-encoder. Lumina-T2I uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. - tokenizer (`AutoModel`): + tokenizer (`AutoTokenizer`): Tokenizer of class - [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). + [AutoTokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). transformer ([`Transformer2DModel`]): A text conditioned `Transformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): @@ -180,8 +180,8 @@ def __init__( transformer: LuminaNextDiT2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, - text_encoder: AutoModel, - tokenizer: AutoTokenizer, + text_encoder: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, ): super().__init__() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 416b2f7c60f2..a64481f2f46a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import PreTrainedModel, PreTrainedTokenizerBase from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor @@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): def __init__( self, - tokenizer: AutoTokenizer, - text_encoder: AutoModelForCausalLM, + tokenizer: PreTrainedTokenizerBase, + text_encoder: PreTrainedModel, vae: AutoencoderDC, transformer: SanaTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0eae1ac75ba4..83fee4d419c0 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1001,7 +1001,7 @@ def is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bo obj_type = type(obj) # Classes with obj's type - class_or_tuple = {t for t in class_or_tuple if (get_origin(t) or t) is obj_type} + class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)} # Singular types (e.g. int, ControlNet, ...) # Untyped collections (e.g. List, but not List[int]) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index cca4dfe5e8ba..8dcbef5f99cf 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import PreTrainedModel, PreTrainedTokenizerBase from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor @@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): def __init__( self, - tokenizer: AutoTokenizer, - text_encoder: AutoModelForCausalLM, + tokenizer: PreTrainedTokenizerBase, + text_encoder: PreTrainedModel, vae: AutoencoderDC, transformer: SanaTransformer2DModel, scheduler: DPMSolverMultistepScheduler, diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index e3b9ec44005a..38f1c4314e4f 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -15,7 +15,7 @@ from typing import Callable, Dict, List, Optional, Union import torch -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler @@ -65,7 +65,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): Args: tokenizer (`CLIPTokenizer`): The CLIP tokenizer. - text_encoder (`CLIPTextModel`): + text_encoder (`CLIPTextModelWithProjection`): The CLIP text encoder. decoder ([`StableCascadeUNet`]): The Stable Cascade decoder unet. @@ -93,7 +93,7 @@ def __init__( self, decoder: StableCascadeUNet, tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, + text_encoder: CLIPTextModelWithProjection, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, latent_dim_scale: float = 10.67, diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py index 6724b60cc424..28a74ab83733 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py @@ -15,7 +15,7 @@ import PIL import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler @@ -52,7 +52,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): Args: tokenizer (`CLIPTokenizer`): The decoder tokenizer to be used for text inputs. - text_encoder (`CLIPTextModel`): + text_encoder (`CLIPTextModelWithProjection`): The decoder text encoder to be used for text inputs. decoder (`StableCascadeUNet`): The decoder model to be used for decoder image generation pipeline. @@ -60,14 +60,18 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): The scheduler to be used for decoder image generation pipeline. vqgan (`PaellaVQModel`): The VQGAN model to be used for decoder image generation pipeline. - feature_extractor ([`~transformers.CLIPImageProcessor`]): - Model that extracts features from generated images to be used as inputs for the `image_encoder`. - image_encoder ([`CLIPVisionModelWithProjection`]): - Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). prior_prior (`StableCascadeUNet`): The prior model to be used for prior pipeline. + prior_text_encoder (`CLIPTextModelWithProjection`): + The prior text encoder to be used for text inputs. + prior_tokenizer (`CLIPTokenizer`): + The prior tokenizer to be used for text inputs. prior_scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for prior pipeline. + prior_feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). """ _load_connected_pipes = True @@ -76,12 +80,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): def __init__( self, tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, + text_encoder: CLIPTextModelWithProjection, decoder: StableCascadeUNet, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, prior_prior: StableCascadeUNet, - prior_text_encoder: CLIPTextModel, + prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: DDPMWuerstchenScheduler, prior_feature_extractor: Optional[CLIPImageProcessor] = None, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 07d82251d4ba..be01e0acbf18 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -141,7 +141,7 @@ def __init__( image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModelWithProjection, + text_encoder: CLIPTextModel, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, # vae diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py index 601f51b1263e..e197cb6859fa 100644 --- a/tests/fixtures/custom_pipeline/pipeline.py +++ b/tests/fixtures/custom_pipeline/pipeline.py @@ -18,7 +18,7 @@ import torch -from diffusers import DiffusionPipeline, ImagePipelineOutput +from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel class CustomLocalPipeline(DiffusionPipeline): @@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline): [`DDPMScheduler`], or [`DDIMScheduler`]. """ - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/tests/fixtures/custom_pipeline/what_ever.py b/tests/fixtures/custom_pipeline/what_ever.py index 8ceeb4211e37..bbe7f4f16bd8 100644 --- a/tests/fixtures/custom_pipeline/what_ever.py +++ b/tests/fixtures/custom_pipeline/what_ever.py @@ -18,6 +18,7 @@ import torch +from diffusers import SchedulerMixin, UNet2DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -33,7 +34,7 @@ class CustomLocalPipeline(DiffusionPipeline): [`DDPMScheduler`], or [`DDIMScheduler`]. """ - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) From 12eb38faf49fd5bf6df1bad3b95ef23811067c1e Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Thu, 13 Feb 2025 10:15:19 +0000 Subject: [PATCH 4/9] More type corrections and skip tokenizer type checking --- .../pipeline_stable_diffusion_3_controlnet.py | 12 +++--- ...table_diffusion_3_controlnet_inpainting.py | 8 ++-- .../pipelines/lumina2/pipeline_lumina2.py | 10 ++--- .../pipelines/pag/pipeline_pag_sana.py | 6 +-- src/diffusers/pipelines/pipeline_utils.py | 18 ++++++--- src/diffusers/pipelines/sana/pipeline_sana.py | 6 +-- .../pipeline_stable_diffusion_3.py | 12 +++--- .../pipeline_stable_diffusion_3_img2img.py | 12 ++++-- .../pipeline_stable_diffusion_3_inpaint.py | 12 +++--- .../pipeline_stable_diffusion_k_diffusion.py | 38 +++++++++++++------ 10 files changed, 81 insertions(+), 53 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 7f85fcc1d90d..c353d0a8f2f7 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -17,10 +17,10 @@ import torch from transformers import ( - BaseImageProcessor, + SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipVisionModel, T5EncoderModel, T5TokenizerFast, ) @@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline( Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. - image_encoder (`PreTrainedModel`, *optional*): + image_encoder (`SiglipVisionModel`, *optional*): Pre-trained Vision Model for IP Adapter. - feature_extractor (`BaseImageProcessor`, *optional*): + feature_extractor (`SiglipImageProcessor`, *optional*): Image processor for IP Adapter. """ @@ -202,8 +202,8 @@ def __init__( controlnet: Union[ SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel ], - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: Optional[SiglipVisionModel] = None, + feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() if isinstance(controlnet, (list, tuple)): diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 35e47f4d650e..fbfac8f63ad3 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -17,10 +17,10 @@ import torch from transformers import ( - BaseImageProcessor, + SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipModel, T5EncoderModel, T5TokenizerFast, ) @@ -223,8 +223,8 @@ def __init__( controlnet: Union[ SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel ], - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: SiglipModel = None, + feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 801ed25093a3..242dd998284c 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -17,7 +17,7 @@ import numpy as np import torch -from transformers import AutoModel, AutoTokenizer +from transformers import PreTrainedModel, PreTrainedTokenizerBase from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL @@ -150,11 +150,11 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline): Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`AutoModel`]): + text_encoder ([`PreTrainedModel`]): Frozen text-encoder. Lumina-T2I uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. - tokenizer (`AutoModel`): + tokenizer (`PreTrainedTokenizerBase`): Tokenizer of class [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). transformer ([`Transformer2DModel`]): @@ -172,8 +172,8 @@ def __init__( transformer: Lumina2Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, - text_encoder: AutoModel, - tokenizer: AutoTokenizer, + text_encoder: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, ): super().__init__() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index a64481f2f46a..fd1cc1c537a7 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import Gemma2PreTrainedModel, GemmaTokenizerFast, GemmaTokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor @@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): def __init__( self, - tokenizer: PreTrainedTokenizerBase, - text_encoder: PreTrainedModel, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + text_encoder: Gemma2PreTrainedModel, vae: AutoencoderDC, transformer: SanaTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 83fee4d419c0..68eb0825778f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1047,12 +1047,20 @@ def get_detailed_type(obj: Any) -> Type: else: return obj_type - for key, class_obj in init_kwargs.items(): - if "scheduler" in key: + for kw, arg in init_kwargs.items(): + # Too complex to validate with type annotation alone + if "scheduler" in kw: continue - - if class_obj is not None and not is_valid_type(class_obj, expected_types[key]): - logger.warning(f"Expected types for {key}: {expected_types[key]}, got {get_detailed_type(class_obj)}.") + # Many tokenizer annotations don't include its "Fast" variant, so skip this + # e.g T5Tokenizer but not T5TokenizerFast + elif "tokenizer" in kw: + continue + elif ( + arg is not None + and expected_types[kw] is not inspect.Signature.empty # no type annotations + and not is_valid_type(arg, expected_types[kw]) + ): + logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {get_detailed_type(arg)}.") # 11. Instantiate the pipeline model = pipeline_class(**init_kwargs) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 8dcbef5f99cf..790946df0b09 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor @@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): def __init__( self, - tokenizer: PreTrainedTokenizerBase, - text_encoder: PreTrainedModel, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + text_encoder: Gemma2PreTrainedModel, vae: AutoencoderDC, transformer: SanaTransformer2DModel, scheduler: DPMSolverMultistepScheduler, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 23950f895aae..6468da00772f 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -17,10 +17,10 @@ import torch from transformers import ( - BaseImageProcessor, + SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipVisionModel, T5EncoderModel, T5TokenizerFast, ) @@ -176,9 +176,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - image_encoder (`PreTrainedModel`, *optional*): + image_encoder (`SiglipVisionModel`, *optional*): Pre-trained Vision Model for IP Adapter. - feature_extractor (`BaseImageProcessor`, *optional*): + feature_extractor (`SiglipImageProcessor`, *optional*): Image processor for IP Adapter. """ @@ -197,8 +197,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: SiglipVisionModel = None, + feature_extractor: SiglipImageProcessor = None, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 2fa63cf7ee81..b62a982a8747 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -18,10 +18,10 @@ import PIL.Image import torch from transformers import ( - BaseImageProcessor, + SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipVisionModel, T5EncoderModel, T5TokenizerFast, ) @@ -197,6 +197,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + image_encoder (`SiglipVisionModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`SiglipImageProcessor`, *optional*): + Image processor for IP Adapter. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" @@ -214,8 +218,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: Optional[SiglipVisionModel] = None, + feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index de9842913e98..0cf16891fe63 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -17,10 +17,10 @@ import torch from transformers import ( - BaseImageProcessor, + SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipVisionModel, T5EncoderModel, T5TokenizerFast, ) @@ -196,9 +196,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - image_encoder (`PreTrainedModel`, *optional*): + image_encoder (`SiglipVisionModel`, *optional*): Pre-trained Vision Model for IP Adapter. - feature_extractor (`BaseImageProcessor`, *optional*): + feature_extractor (`SiglipImageProcessor`, *optional*): Image processor for IP Adapter. """ @@ -217,8 +217,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: Optional[SiglipVisionModel] = None, + feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py index 24e11bff3052..1f29f577f8e0 100755 --- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -19,15 +19,31 @@ import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPTokenizerFast, +) from ...image_processor import VaeImageProcessor -from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder -from ...schedulers import LMSDiscreteScheduler -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -95,13 +111,13 @@ class StableDiffusionKDiffusionPipeline( def __init__( self, - vae, - text_encoder, - tokenizer, - unet, - scheduler, - safety_checker, - feature_extractor, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast], + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() From 03a8fcf16aceccc1bf20bd85b9de20edc2fa6adf Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Thu, 13 Feb 2025 11:31:56 +0000 Subject: [PATCH 5/9] make style && make quality --- .../controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py | 2 +- .../pipeline_stable_diffusion_3_controlnet_inpainting.py | 2 +- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 2 +- src/diffusers/pipelines/pipeline_utils.py | 2 +- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 2 +- .../stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py | 2 +- .../stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index c353d0a8f2f7..7f7acd882b59 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -17,9 +17,9 @@ import torch from transformers import ( - SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + SiglipImageProcessor, SiglipVisionModel, T5EncoderModel, T5TokenizerFast, diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index fbfac8f63ad3..cb35f67fa112 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -17,9 +17,9 @@ import torch from transformers import ( - SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + SiglipImageProcessor, SiglipModel, T5EncoderModel, T5TokenizerFast, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index fd1cc1c537a7..7dab0cd0c42a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import Gemma2PreTrainedModel, GemmaTokenizerFast, GemmaTokenizer +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 68eb0825778f..334f3f2402ae 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1057,7 +1057,7 @@ def get_detailed_type(obj: Any) -> Type: continue elif ( arg is not None - and expected_types[kw] is not inspect.Signature.empty # no type annotations + and expected_types[kw] is not inspect.Signature.empty # no type annotations and not is_valid_type(arg, expected_types[kw]) ): logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {get_detailed_type(arg)}.") diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 6468da00772f..5211894a40bd 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -17,9 +17,9 @@ import torch from transformers import ( - SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + SiglipImageProcessor, SiglipVisionModel, T5EncoderModel, T5TokenizerFast, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index b62a982a8747..0685345c1b76 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -18,9 +18,9 @@ import PIL.Image import torch from transformers import ( - SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + SiglipImageProcessor, SiglipVisionModel, T5EncoderModel, T5TokenizerFast, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index 0cf16891fe63..3979ae3b7069 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -17,9 +17,9 @@ import torch from transformers import ( - SiglipImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + SiglipImageProcessor, SiglipVisionModel, T5EncoderModel, T5TokenizerFast, From 0afbe6c06d97b44d8cd88ff17574179130dc12a9 Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Thu, 13 Feb 2025 11:37:09 +0000 Subject: [PATCH 6/9] Updated docs and types for Lumina pipelines --- .../pipelines/lumina/pipeline_lumina.py | 17 +++++++---------- .../pipelines/lumina2/pipeline_lumina2.py | 17 +++++++---------- .../pipelines/lumina2/test_pipeline_lumina2.py | 10 +++++----- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 5128e20eaacd..5b7b8830b00d 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -20,7 +20,7 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL @@ -143,13 +143,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline): Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`PreTrainedModel`]): - Frozen text-encoder. Lumina-T2I uses - [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the - [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. - tokenizer (`AutoTokenizer`): - Tokenizer of class - [AutoTokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). + text_encoder ([`GemmaPreTrainedModel`]): + Frozen Gemma text-encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Gemma tokenizer. transformer ([`Transformer2DModel`]): A text conditioned `Transformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): @@ -180,8 +177,8 @@ def __init__( transformer: LuminaNextDiT2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, - text_encoder: PreTrainedModel, - tokenizer: PreTrainedTokenizerBase, + text_encoder: GemmaPreTrainedModel, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], ): super().__init__() diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 242dd998284c..0828738fc339 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -17,7 +17,7 @@ import numpy as np import torch -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL @@ -150,13 +150,10 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline): Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`PreTrainedModel`]): - Frozen text-encoder. Lumina-T2I uses - [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the - [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. - tokenizer (`PreTrainedTokenizerBase`): - Tokenizer of class - [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). + text_encoder ([`Gemma2PreTrainedModel`]): + Frozen Gemma2 text-encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Gemma tokenizer. transformer ([`Transformer2DModel`]): A text conditioned `Transformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): @@ -172,8 +169,8 @@ def __init__( transformer: Lumina2Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, - text_encoder: PreTrainedModel, - tokenizer: PreTrainedTokenizerBase, + text_encoder: Gemma2PreTrainedModel, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], ): super().__init__() diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py index f8e0667ce1d2..646069a3df1f 100644 --- a/tests/pipelines/lumina2/test_pipeline_lumina2.py +++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py @@ -2,7 +2,7 @@ import numpy as np import torch -from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM +from transformers import AutoTokenizer, Gemma2Config, Gemma2ForCausalLM from diffusers import ( AutoencoderKL, @@ -81,7 +81,7 @@ def get_dummy_components(self): tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") torch.manual_seed(0) - config = GemmaConfig( + config = Gemma2Config( head_dim=2, hidden_size=8, intermediate_size=37, @@ -89,13 +89,13 @@ def get_dummy_components(self): num_hidden_layers=2, num_key_value_heads=4, ) - text_encoder = GemmaForCausalLM(config) + text_encoder = Gemma2ForCausalLM(config) components = { - "transformer": transformer.eval(), + "transformer": transformer, "vae": vae.eval(), "scheduler": scheduler, - "text_encoder": text_encoder.eval(), + "text_encoder": text_encoder, "tokenizer": tokenizer, } return components From e367fd327b911accfe1083d83679ec7c9ef9a434 Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Thu, 13 Feb 2025 12:11:58 +0000 Subject: [PATCH 7/9] Fixed check for empty signature --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 334f3f2402ae..237b8550487f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1057,7 +1057,7 @@ def get_detailed_type(obj: Any) -> Type: continue elif ( arg is not None - and expected_types[kw] is not inspect.Signature.empty # no type annotations + and not expected_types[kw] == (inspect.Signature.empty,) # no type annotations and not is_valid_type(arg, expected_types[kw]) ): logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {get_detailed_type(arg)}.") From b17fc6eabaae0a5637a5999b7e843b69e4f98a73 Mon Sep 17 00:00:00 2001 From: Daniel Regado Date: Fri, 21 Feb 2025 12:05:02 +0000 Subject: [PATCH 8/9] changed location of helper functions --- .../pipelines/pipeline_loading_utils.py | 75 +++++++++++++++++- src/diffusers/pipelines/pipeline_utils.py | 76 ++----------------- 2 files changed, 81 insertions(+), 70 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 9a9afa198b4c..0e2cbb32d3c1 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -17,7 +17,7 @@ import re import warnings from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin import requests import torch @@ -1059,3 +1059,76 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict): break if has_transformers_component and not is_transformers_version(">", "4.47.1"): raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.") + + +def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: + """ + Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of + the correct type as well. + """ + if not isinstance(class_or_tuple, tuple): + class_or_tuple = (class_or_tuple,) + + # Unpack unions + unpacked_class_or_tuple = [] + for t in class_or_tuple: + if get_origin(t) is Union: + unpacked_class_or_tuple.extend(get_args(t)) + else: + unpacked_class_or_tuple.append(t) + class_or_tuple = tuple(unpacked_class_or_tuple) + + if Any in class_or_tuple: + return True + + obj_type = type(obj) + # Classes with obj's type + class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)} + + # Singular types (e.g. int, ControlNet, ...) + # Untyped collections (e.g. List, but not List[int]) + elem_class_or_tuple = {get_args(t) for t in class_or_tuple} + if () in elem_class_or_tuple: + return True + # Typed lists or sets + elif obj_type in (list, set): + return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) + # Typed tuples + elif obj_type is tuple: + return any( + # Tuples with any length and single type (e.g. Tuple[int, ...]) + (len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj)) + or + # Tuples with fixed length and any types (e.g. Tuple[int, str]) + (len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t))) + for t in elem_class_or_tuple + ) + # Typed dicts + elif obj_type is dict: + return any( + all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items()) + for kt, vt in elem_class_or_tuple + ) + + else: + return False + + +def _get_detailed_type(obj: Any) -> Type: + """ + Gets a detailed type for an object, including nested types for collections. + """ + obj_type = type(obj) + + if obj_type in (list, set): + obj_origin_type = List if obj_type is list else Set + elems_type = Union[tuple({_get_detailed_type(x) for x in obj})] + return obj_origin_type[elems_type] + elif obj_type is tuple: + return Tuple[tuple(_get_detailed_type(x) for x in obj)] + elif obj_type is dict: + keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})] + values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})] + return Dict[keys_type, values_type] + else: + return obj_type diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1729f5747b1f..b4b342d64be7 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -21,7 +21,7 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin import numpy as np import PIL.Image @@ -78,10 +78,12 @@ _fetch_class_library_tuple, _get_custom_components_and_folders, _get_custom_pipeline_class, + _get_detailed_type, _get_final_device_map, _get_ignore_patterns, _get_pipeline_class, _identify_model_variants, + _is_valid_type, _maybe_raise_error_for_incorrect_transformers, _maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls, @@ -995,70 +997,6 @@ def load_module(name, value): ) # 10. Type checking init arguments - def is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: - if not isinstance(class_or_tuple, tuple): - class_or_tuple = (class_or_tuple,) - - # Unpack unions - unpacked_class_or_tuple = [] - for t in class_or_tuple: - if get_origin(t) is Union: - unpacked_class_or_tuple.extend(get_args(t)) - else: - unpacked_class_or_tuple.append(t) - class_or_tuple = tuple(unpacked_class_or_tuple) - - if Any in class_or_tuple: - return True - - obj_type = type(obj) - # Classes with obj's type - class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)} - - # Singular types (e.g. int, ControlNet, ...) - # Untyped collections (e.g. List, but not List[int]) - elem_class_or_tuple = {get_args(t) for t in class_or_tuple} - if () in elem_class_or_tuple: - return True - # Typed lists or sets - elif obj_type in (list, set): - return any(all(is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) - # Typed tuples - elif obj_type is tuple: - return any( - # Tuples with any length and single type (e.g. Tuple[int, ...]) - (len(t) == 2 and t[-1] is Ellipsis and all(is_valid_type(x, t[0]) for x in obj)) - or - # Tuples with fixed length and any types (e.g. Tuple[int, str]) - (len(obj) == len(t) and all(is_valid_type(x, tt) for x, tt in zip(obj, t))) - for t in elem_class_or_tuple - ) - # Typed dicts - elif obj_type is dict: - return any( - all(is_valid_type(k, kt) and is_valid_type(v, vt) for k, v in obj.items()) - for kt, vt in elem_class_or_tuple - ) - - else: - return False - - def get_detailed_type(obj: Any) -> Type: - obj_type = type(obj) - - if obj_type in (list, set): - obj_origin_type = List if obj_type is list else Set - elems_type = Union[tuple({get_detailed_type(x) for x in obj})] - return obj_origin_type[elems_type] - elif obj_type is tuple: - return Tuple[tuple(get_detailed_type(x) for x in obj)] - elif obj_type is dict: - keys_type = Union[tuple({get_detailed_type(k) for k in obj.keys()})] - values_type = Union[tuple({get_detailed_type(k) for k in obj.values()})] - return Dict[keys_type, values_type] - else: - return obj_type - for kw, arg in init_kwargs.items(): # Too complex to validate with type annotation alone if "scheduler" in kw: @@ -1068,11 +1006,11 @@ def get_detailed_type(obj: Any) -> Type: elif "tokenizer" in kw: continue elif ( - arg is not None - and not expected_types[kw] == (inspect.Signature.empty,) # no type annotations - and not is_valid_type(arg, expected_types[kw]) + arg is not None # Skip if None + and not expected_types[kw] == (inspect.Signature.empty,) # Skip if no type annotations + and not _is_valid_type(arg, expected_types[kw]) # Check type ): - logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {get_detailed_type(arg)}.") + logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {_get_detailed_type(arg)}.") # 11. Instantiate the pipeline model = pipeline_class(**init_kwargs) From 46d462972a9639a1a5d55de3c65b46aa94263605 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 22 Feb 2025 12:58:28 +0000 Subject: [PATCH 9/9] make style --- src/diffusers/pipelines/pipeline_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 9b680ef2dbcf..90a05e97f614 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1006,9 +1006,9 @@ def load_module(name, value): elif "tokenizer" in kw: continue elif ( - arg is not None # Skip if None + arg is not None # Skip if None and not expected_types[kw] == (inspect.Signature.empty,) # Skip if no type annotations - and not _is_valid_type(arg, expected_types[kw]) # Check type + and not _is_valid_type(arg, expected_types[kw]) # Check type ): logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {_get_detailed_type(arg)}.")