Skip to content

Conversation

fabiorigano
Copy link
Contributor

What does this PR do?

Makes it possible to load a pipeline with an IP Adapter into an AnimateDiff pipeline with from_pipe()

Fixes #7661

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean!

Could we also see some results?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Apr 27, 2024

hi @sayakpaul thank you

I used YiYi's code to test:

from diffusers import DiffusionPipeline, AnimateDiffPipeline, MotionAdapter, DDIMScheduler
from diffusers.utils import export_to_gif
import torch
from diffusers.utils import load_image

base_repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
num_inference_steps = 20
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
prompt="bear eats pizza"
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"

pipe_sd = DiffusionPipeline.from_pretrained(base_repo, torch_dtype=torch.float16)
pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe_sd.set_ip_adapter_scale(0.6)   
pipe_sd.to("cuda")

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)

pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")

pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
pipe_animate.to("cuda")
pipe_animate.enable_vae_slicing()
pipe_animate.enable_model_cpu_offload()

generator = torch.Generator(device="cpu").manual_seed(33)
pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
out = pipe_animate(
    prompt= prompt,
    num_frames=8,
    num_inference_steps=num_inference_steps,
    ip_adapter_image = image,
    generator=generator,
).frames[0]
export_to_gif(out, "out_animate.gif")

out_animate

Output is the same of

pipe_sd = DiffusionPipeline.from_pretrained(base_repo, torch_dtype=torch.float16)
pipe_sd.to("cuda")
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")

pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") 
pipe_animate.set_ip_adapter_scale(0.6)    
pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")

and code doesn't break during loading

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽

Comment on lines +461 to +483
attn_procs = {}
for name, processor in unet.attn_processors.items():
if name.endswith("attn1.processor"):
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=processor.hidden_size,
cross_attention_dim=processor.cross_attention_dim,
scale=processor.scale,
num_tokens=processor.num_tokens,
)
for name, processor in model.attn_processors.items():
if name not in attn_procs:
attn_procs[name] = processor.__class__()
model.set_attn_processor(attn_procs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the same as this?

model.set_attn_processor(unet.attn_processors)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually no, because the UNetMotion used in AnimateDiff has motion modules, that the original pipeline does not have

if you do something like this:

if any(
    isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
    for proc in unet.attn_processors.values()
    ):
    model.set_attn_processor(unet.attn_processors)
    model.config.encoder_hid_dim_type = "ip_image_proj"
    model.encoder_hid_proj = unet.encoder_hid_proj

you will end up (in the particular case of the code snippet above) with a ValueError, because the number of attention processors does not match
ValueError: A dict of processors was passed, but the number of processors 32 does not match the number of attention layers: 74. Please make sure to pass 74 processor classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu I don't know if you read it, I'm pinging you because the issue has gone stale
thank you

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey, yes, indeed, I missed this one!
thanks for pining!

@yiyixuxu yiyixuxu merged commit 44aa9e5 into huggingface:main May 13, 2024
@yiyixuxu
Copy link
Collaborator

merged! sorry for the delay!
thanks again @fabiorigano for the great work:)

@fabiorigano fabiorigano deleted the unetmotionloadsipadapter branch May 13, 2024 18:21
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* Fix loading from_pipe

* Fix style

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

make it possible to create an animate diff pipeline with unet loaded with ip-adapter
5 participants