Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
- local: using-diffusers/schedulers
title: Load and compare different schedulers
- local: using-diffusers/custom_pipeline_overview
title: Load community pipelines
title: Load community pipelines and components
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not adding a separate doc page here.

See: https://github.com/huggingface/diffusers/pull/5556/files#r1376500960/

- local: using-diffusers/using_safetensors
title: Load safetensors
- local: using-diffusers/other-formats
Expand Down
113 changes: 111 additions & 2 deletions docs/source/en/using-diffusers/custom_pipeline_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->

# Load community pipelines
# Load community pipelines and components

[[open-in-colab]]

## Community pipelines

Community pipelines are any [`DiffusionPipeline`] class that are different from the original implementation as specified in their paper (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://arxiv.org/abs/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline.

There are many cool community pipelines like [Speech to Image](https://github.com/huggingface/diffusers/tree/main/examples/community#speech-to-image) or [Composable Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#composable-stable-diffusion), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community).
Expand Down Expand Up @@ -54,4 +56,111 @@ pipeline = DiffusionPipeline.from_pretrained(
)
```

For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!

## Community components

If your pipeline has custom components that Diffusers doesn't support already, you need to accompany the Python modules that implement them. These customized components could be VAE, UNet, scheduler, etc. For the text encoder, we rely on `transformers` anyway. So, that should be handled separately (more info here). The pipeline code itself can be customized as well.

Community components allow users to build pipelines that may have customized components that are not part of Diffusers. This section shows how users should use community components to build a community pipeline.

You'll use the [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) pipeline checkpoint as an example here. Here, you have a custom UNet and a customized pipeline (`TextToVideoIFPipeline`). For convenience, let's call the UNet `ShowOneUNet3DConditionModel`.

"showlab/show-1-base" already provides the checkpoints in the Diffusers format, which is a great starting point. So, let's start loading up the components which are already well-supported:

1. **Text encoder**

```python
from transformers import T5Tokenizer, T5EncoderModel

pipe_id = "showlab/show-1-base"
tokenizer = T5Tokenizer.from_pretrained(pipe_id, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(pipe_id, subfolder="text_encoder")
```

2. **Scheduler**

```python
from diffusers import DPMSolverMultistepScheduler

scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="scheduler")
```

3. **Image processor**

```python
from transformers import CLIPFeatureExtractor

feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor")
```

Now, you need to implement the custom UNet. The implementation is available [here](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py). So, let's create a Python script called `showone_unet_3d_condition.py` and copy over the implementation, changing the `UNet3DConditionModel` classname to `ShowOneUNet3DConditionModel` to avoid any conflicts with Diffusers. This is because Diffusers already has one `UNet3DConditionModel`. We put all the components needed to implement the class in `showone_unet_3d_condition.py` only. You can find the entire file [here](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py).

Once this is done, we can initialize the UNet:

```python
from showone_unet_3d_condition import ShowOneUNet3DConditionModel

unet = ShowOneUNet3DConditionModel.from_pretrained(pipe_id, subfolder="unet")
```

Then implement the custom `TextToVideoIFPipeline` in another Python script: `pipeline_t2v_base_pixel.py`. This is already available [here](https://github.com/showlab/Show-1/blob/main/showone/pipelines/pipeline_t2v_base_pixel.py).

Now that you have all the components, initialize the `TextToVideoIFPipeline`:

```python
from pipeline_t2v_base_pixel import TextToVideoIFPipeline
import torch

pipeline = TextToVideoIFPipeline(
unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
feature_extractor=feature_extractor
)
pipeline = pipeline.to(device="cuda")
pipeline.torch_dtype = torch.float16
```

Push to the pipeline to the Hub to share with the community:

```python
pipeline.push_to_hub("custom-t2v-pipeline")
```

After the pipeline is successfully pushed, you need a couple of changes:

1. In `model_index.json` file, change the `_class_name` attribute. It should be like [so](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2).
2. Upload `showone_unet_3d_condition.py` to the `unet` directory ([example](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py)).
3. Upload `pipeline_t2v_base_pixel.py` to the pipeline base directory ([example](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py)).

To run inference, just do:

```python
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained(
"<change-username>/<change-id>", trust_remote_code=True, torch_dtype=torch.float16
).to("cuda")

prompt = "hello"

# Text embeds
prompt_embeds, negative_embeds = pipeline.encode_prompt(prompt)

# Keyframes generation (8x64x40, 2fps)
video_frames = pipeline(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
num_frames=8,
height=40,
width=64,
num_inference_steps=2,
guidance_scale=9.0,
output_type="pt"
).frames
```

Here, notice the use of the `trust_remote_code` argument while initializing the pipeline. It is responsible for handling all the "magic" behind the scenes.