Skip to content

Conversation

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Oct 20, 2023

What does this PR do?

This PR adds code that allows to run remote models, schedulers, and pipelines.
You can try it out by looking at the tests and how these two example repos are structured:

@patrickvonplaten patrickvonplaten changed the title upload custom remote poc [Remote code] Add functionality to run remote models, schedulers, pipelines Oct 20, 2023
Comment on lines 339 to 343
hub_repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we keep these as **kwargs. Instead of hub_repo_id or hub_revision, we could also opt for repo_id or revision, respectively following what we do across the library for Hub related utilities.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep hub_revision as the other revision argument is already used for Git revisions


# Check that only loading custom componets "my_unet", "my_scheduler" and explicit custom pipeline works
pipeline = DiffusionPipeline.from_pretrained(
"/home/patrick/tiny-stable-diffusion-xl-pipe", custom_pipeline="my_pipeline"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ckpt path needs to be changed.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, that was pretty clean!

Can we also add a nice doc as I believe this will be a very powerful feature. We could readily test it out with Show-1: https://github.com/showlab/Show-1.

Also, can we add a test to see if a custom pipeline (loaded with trust_remote_code=True can work seamlessly with a legacy component from the library?

For example:

from diffusers import DiffusionPipeline
from diffusers import UniPCMultistepScheduler

pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-components", trust_remote_code=True)

pipeline.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]

@sayakpaul
Copy link
Member

Hmm, I guess it's okay for the first shipment but it won't support loading of a custom UNet in isolation I think.

What if I only have a model repo on the Hub similar to the ones you have for Transformers. I guess for that we need Autoclasses.

@sayakpaul
Copy link
Member

Facing an issue and proposed a solution here: #5491.

@sayakpaul
Copy link
Member

sayakpaul commented Oct 23, 2023

Have been playing with this PR to support a bit more complicated pipelines such as: https://huggingface.co/showlab/show-1-base. Took a while to understand how the pipeline repository should be structured. So, documenting everything here.

If your pipeline has custom components that diffusers don't support already, you need to accompany the Python modules that implement them. These custom components could be VAE, UNet, scheduler, etc. For the text encoder, we rely on transformers anyway. So, that should be handled separately (more info here). The pipeline code itself can be custom as well.

In case of "showlab/show-1-base", we have a custom UNet and a custom pipeline (TextToVideoIFPipeline). For convenience, let's call the UNet ShowOneUNet3DConditionModel.

"showlab/show-1-base" already provides the checkpoints in the diffusers format, which is great. So, let's start loading up the components which are already well-supported:

  1. Text encoder:
from transformers import T5Tokenizer, T5EncoderModel

pipe_id = "showlab/show-1-base"
tokenizer = T5Tokenizer.from_pretrained(pipe_id, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(pipe_id, subfolder="text_encoder")
  1. Scheduler:
from diffusers import DPMSolverMultistepScheduler

scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="scheduler")
  1. Image feature extractor:
from transformers import CLIPFeatureExtractor

feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor")

Now, we need to implement the custom UNet. It's already available here: https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py. So, we create a Python script called showone_unet_3d_condition.py and copy over the implementation, changing the UNet3DConditionModel classname to ShowOneUNet3DConditionModel to avoid any conflicts with diffusers. This is because diffusers already has one UNet3DConditionModel. We put all the components needed to implement the class in showone_unet_3d_condition.py only. You can find the entire file here.

Once this is done, we can initialize the UNet:

from showone_unet_3d_condition import ShowOneUNet3DConditionModel

unet = ShowOneUNet3DConditionModel.from_pretrained(pipe_id, subfolder="unet")

And then we implement the custom TextToVideoIFPipeline in another Python script: pipeline_t2v_base_pixel.py.

Now that we have all the components, we can fully initialize the TextToVideoIFPipeline:

from pipeline_t2v_base_pixel import TextToVideoIFPipeline
import torch

pipeline = TextToVideoIFPipeline(
    unet=unet, 
    text_encoder=text_encoder, 
    tokenizer=tokenizer, 
    scheduler=scheduler, 
    feature_extractor=feature_extractor
)
pipeline = pipeline.to(device="cuda")
pipeline.torch_dtype = torch.float16

For sharing with others, we can push this pipeline to the Hub:

pipeline.push_to_hub("custom-t2v-pipeline")

After the pipeline is successfully pushed, we need to perform a couple of changes:

  1. In model_index.json file, we need to change the _class_name attribute. It should be like so: https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2.
  2. Upload showone_unet_3d_condition.py to the unet directory (example).
  3. Upload pipeline_t2v_base_pixel.py to the pipeline base directory (example).

Then we're ready for inference:

from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained(
    "sayakpaul/show-1-base-with-code", trust_remote_code=True, torch_dtype=torch.float16
).to("cuda")

prompt = "hello"

# Text embeds
prompt_embeds, negative_embeds = pipeline.encode_prompt(prompt)

# Keyframes generation (8x64x40, 2fps)
video_frames = pipeline(
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_embeds,
    num_frames=8,
    height=40,
    width=64,
    num_inference_steps=2,
    guidance_scale=9.0,
    output_type="pt"
).frames

@patrickvonplaten
Copy link
Contributor Author

Have been playing with this PR to support a bit more complicated pipelines such as: https://huggingface.co/showlab/show-1-base. Took a while to understand how the pipeline repository should be structured. So, documenting everything here.

If your pipeline has custom components that diffusers don't support already, you need to accompany the Python modules that implement them. These custom components could be VAE, UNet, scheduler, etc. For the text encoder, we rely on transformers anyway. So, that should be handled separately (more info here). The pipeline code itself can be custom as well.

In case of "showlab/show-1-base", we have a custom UNet and a custom pipeline (TextToVideoIFPipeline). For convenience, let's call the UNet ShowOneUNet3DConditionModel.

"showlab/show-1-base" already provides the checkpoints in the diffusers format, which is great. So, let's start loading up the components which are already well-supported:

  1. Text encoder:
from transformers import T5Tokenizer, T5EncoderModel

pipe_id = "showlab/show-1-base"
tokenizer = T5Tokenizer.from_pretrained(pipe_id, subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(pipe_id, subfolder="text_encoder")
  1. Scheduler:
from diffusers import DPMSolverMultistepScheduler

scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="scheduler")
  1. Image feature extractor:
from transformers import CLIPFeatureExtractor

feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor")

Now, we need to implement the custom UNet. It's already available here: https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py. So, we create a Python script called showone_unet_3d_condition.py and copy over the implementation, changing the UNet3DConditionModel classname to ShowOneUNet3DConditionModel to avoid any conflicts with diffusers. This is because diffusers already has one UNet3DConditionModel. We put all the components needed to implement the class in showone_unet_3d_condition.py only. You can find the entire file here.

Once this is done, we can initialize the UNet:

from showone_unet_3d_condition import ShowOneUNet3DConditionModel

unet = ShowOneUNet3DConditionModel.from_pretrained(pipe_id, subfolder="unet")

And then we implement the custom TextToVideoIFPipeline in another Python script: pipeline_t2v_base_pixel.py.

Now that we have all the components, we can fully initialize the TextToVideoIFPipeline:

from pipeline_t2v_base_pixel import TextToVideoIFPipeline
import torch

pipeline = TextToVideoIFPipeline(
    unet=unet, 
    text_encoder=text_encoder, 
    tokenizer=tokenizer, 
    scheduler=scheduler, 
    feature_extractor=feature_extractor
)
pipeline = pipeline.to(device="cuda")
pipeline.torch_dtype = torch.float16

For sharing with others, we can push this pipeline to the Hub:

pipeline.push_to_hub("custom-t2v-pipeline")

After the pipeline is successfully pushed, we need to perform a couple of changes:

  1. In model_index.json file, we need to change the _class_name attribute. It should be like so: https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2.
  2. Upload showone_unet_3d_condition.py to the unet directory (example).
  3. Upload pipeline_t2v_base_pixel.py to the pipeline base directory (example).

Then we're ready for inference:

from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained(
    "sayakpaul/show-1-base-with-code", trust_remote_code=True, torch_dtype=torch.float16
).to("cuda")

prompt = "hello"

# Text embeds
prompt_embeds, negative_embeds = pipeline.encode_prompt(prompt)

# Keyframes generation (8x64x40, 2fps)
video_frames = pipeline(
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_embeds,
    num_frames=8,
    height=40,
    width=64,
    num_inference_steps=2,
    guidance_scale=9.0,
    output_type="pt"
).frames

That's a great summary! Maybe we could make this a doc page in a follow-up PR? :-)

@sayakpaul
Copy link
Member

@patrickvonplaten if you could look into #5491 before merging. Happy to create a doc after that.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 25, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor Author

@sayakpaul Think it's ready for a final review :-)

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay for me to merge.

I will do a follow up after merge.

@patrickvonplaten patrickvonplaten merged commit cee1cd6 into main Oct 26, 2023
@patrickvonplaten patrickvonplaten deleted the add_custom_remote_pipelines branch October 26, 2023 16:02
kashif pushed a commit to kashif/diffusers that referenced this pull request Nov 11, 2023
…elines (huggingface#5472)

* upload custom remote poc

* up

* make style

* finish

* better name

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines.py

* more fixes

* remove ipdb

* more fixes

* fix more

* finish tests

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…elines (huggingface#5472)

* upload custom remote poc

* up

* make style

* finish

* better name

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines.py

* more fixes

* remove ipdb

* more fixes

* fix more

* finish tests

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…elines (huggingface#5472)

* upload custom remote poc

* up

* make style

* finish

* better name

* Apply suggestions from code review

* Update tests/pipelines/test_pipelines.py

* more fixes

* remove ipdb

* more fixes

* fix more

* finish tests

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants