diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index cef8f474c00e..f6ddcc3a1df1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -19,6 +19,8 @@ title: Train a diffusion model - local: tutorials/using_peft_for_inference title: Inference with PEFT + - local: tutorials/custom_pipelines_components + title: Working with fully custom pipelines and components title: Tutorials - sections: - sections: diff --git a/docs/source/en/tutorials/custom_pipelines_components.md b/docs/source/en/tutorials/custom_pipelines_components.md new file mode 100644 index 000000000000..8f19dd2abd90 --- /dev/null +++ b/docs/source/en/tutorials/custom_pipelines_components.md @@ -0,0 +1,135 @@ + + +# Working with fully custom pipelines and components + +Diffusers supports the use [custom pipelines](../using-diffusers/contribute_pipeline) letting the users add any additional features on top of the [`DiffusionPipeline`]. However, it can get cumbersome if you're dealing with a custom pipeline where its components (such as the UNet, VAE, scheduler) are also custom. + +We allow loading of such pipelines by exposing a `trust_remote_code` argument inside [`DiffusionPipeline`]. The advantage of `trust_remote_code` lies in its flexibility. You can have different levels of customizations for a pipeline. Following are a few examples: + +* Only UNet is custom +* UNet and VAE both are custom +* Pipeline is custom +* UNet, VAE, scheduler, and pipeline are custom + +With `trust_remote_code=True`, you can achieve perform of the above! + +This tutorial covers how to author your pipeline repository so that it becomes compatible with `trust_remote_code`. You'll use a custom UNet, a custom scheduler, and a custom pipeline for this purpose. + + + +You should use `trust_remote_code=True` _only_ when you fully trust the code and have verified its usage. + + + +## Pipeline components + +In the interest of brevity, you'll use the custom UNet, scheduler, and pipeline classes that we've already authored: + +```bash +# Custom UNet +wget https://huggingface.co/sayakpaul/custom_pipeline_remote_code/raw/main/unet/my_unet_model.py +# Custom scheduler +wget https://huggingface.co/sayakpaul/custom_pipeline_remote_code/raw/main/scheduler/my_scheduler.py +# Custom pipeline +wget https://huggingface.co/sayakpaul/custom_pipeline_remote_code/raw/main/my_pipeline.py +``` + + + +The above classes are just for references. We encourage you to experiment with these classes for desired customizations. + + + +Load the individual components, starting with the UNet: + +```python +from my_unet_model import MyUNetModel + +pretrained_id = "hf-internal-testing/tiny-sdxl-custom-all" +unet = MyUNetModel.from_pretrained(pretrained_id, subfolder="unet") +``` + +Then go for the scheduler: + +```python +from my_scheduler import MyUNetModel + +scheduler = MyScheduler.from_pretrained(pretrained_id, subfolder="scheduler") +``` + +Finally, the VAE and the text encoders: + +```python +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from diffusers import AutoencoderKL + +text_encoder = CLIPTextModel.from_pretrained(pretrained_id, subfolder="text_encoder") +text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(pretrained_id, subfolder="text_encoder_2") +tokenizer = CLIPTokenizer.from_pretrained(pretrained_id, subfolder="tokenizer") +tokenizer_2 = CLIPTokenizer.from_pretrained(pretrained_id, subfolder="tokenizer_2") + +vae = AutoencoderKL.from_pretrained(pretrained_id, subfolder="vae") +``` + +`MyUNetModel`, `MyScheduler`, and `MyPipeline` use blocks that are already supported by Diffusers. If you are using any custom blocks make sure to put them in the module files themselves. + +## Pipeline initialization and serialization + +With all the components, you can now initialize the custom pipeline: + +```python +pipeline = MyPipeline( + vae=vae, + unet=unet, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + scheduler=scheduler, +) +``` + +Now, push the pipeline to the Hub: + +```python +pipeline.push_to_hub("custom_pipeline_remote_code") +``` + +Since the `pipeline` itself is a custom pipeline, its corresponding Python module will also be pushed ([example](https://huggingface.co/sayakpaul/custom_pipeline_remote_code/blob/main/my_pipeline.py)). If the pipeline has any other custom components, they will be pushed as well ([UNet](https://huggingface.co/sayakpaul/custom_pipeline_remote_code/blob/main/unet/my_unet_model.py), [scheduler](https://huggingface.co/sayakpaul/custom_pipeline_remote_code/blob/main/scheduler/my_scheduler.py)). + +If you want to keep the pipeline local, then use the [`PushToHubMixin.save_pretrained`] method. + +## Pipeline loading + +You can load this pipeline from the Hub by specifying `trust_remote_code=True`: + +```python +from diffusers import DiffusionPipeline + +reloaded_pipeline = DiffusionPipeline.from_pretrained( + "sayakpaul/custom_pipeline_remote_code", + torch_dtype=torch.float16, + trust_remote_code=True, +).to("cuda") +``` + +And then perform inference: + +```python +prompt = "hey" +num_inference_steps = 2 + +_ = reloaded_pipeline(prompt=prompt, num_inference_steps=num_inference_steps)[0] +``` + +For more complex pipelines, readers are welcome to check out [this comment](https://github.com/huggingface/diffusers/pull/5472#issuecomment-1775034461) on GitHub. \ No newline at end of file diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index a67fa9d41ca5..b48279135323 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -21,6 +21,7 @@ import json import os import re +import sys from collections import OrderedDict from pathlib import PosixPath from typing import Any, Dict, Tuple, Union @@ -162,6 +163,30 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool self.to_json_file(output_config_file) logger.info(f"Configuration saved in {output_config_file}") + # Additionally, save the implementation file too. It can happen for a pipeline, for a model, and + # for a scheduler. + + # To avoid circular import problems. + from .models import _import_structure as model_modules + from .pipelines import _import_structure as pipeline_modules + from .schedulers import _import_structure as scheduler_modules + + _all_available_pipelines_schedulers_model_classes = sum( + (list(model_modules.values()) + list(scheduler_modules.values()) + list(pipeline_modules.values())), [] + ) + if self.__class__.__name__ not in _all_available_pipelines_schedulers_model_classes: + module_to_save = self.__class__.__module__ + absolute_module_path = sys.modules[module_to_save].__file__ + try: + with open(absolute_module_path, "r") as original_file: + content = original_file.read() + path_to_write = os.path.join(save_directory, f"{module_to_save}.py") + with open(path_to_write, "w") as new_file: + new_file.write(content) + logger.info(f"{module_to_save}.py saved in {save_directory}") + except Exception as e: + logger.error(e) + if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", False) @@ -567,7 +592,24 @@ def to_json_string(self) -> str: String containing all the attributes that make up the configuration instance in JSON format. """ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} - config_dict["_class_name"] = self.__class__.__name__ + cls_name = self.__class__.__name__ + + # Additionally, save the implementation file too. It can happen for a pipeline, for a model, and + # for a scheduler. + + # To avoid circular import problems. + from .models import _import_structure as model_modules + from .pipelines import _import_structure as pipeline_modules + from .schedulers import _import_structure as scheduler_modules + + _all_available_pipelines_schedulers_model_classes = sum( + (list(model_modules.values()) + list(scheduler_modules.values()) + list(pipeline_modules.values())), [] + ) + + if cls_name not in _all_available_pipelines_schedulers_model_classes: + config_dict["_class_name"] = [str(self.__class__.__module__), cls_name] + else: + config_dict["_class_name"] = cls_name config_dict["_diffusers_version"] = __version__ def to_json_saveable(value):