Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7d7fd83
Switch to peft and multi proj layers
fabiorigano Mar 2, 2024
b92c16f
Merge branch 'main' into faceidcore
fabiorigano Mar 3, 2024
4e56997
Move Face ID loading and inference to core
fabiorigano Mar 3, 2024
98a1aa4
Add support for Face ID XL
fabiorigano Mar 9, 2024
03d84fa
Add checks
fabiorigano Mar 10, 2024
fe35a42
Add test
fabiorigano Mar 10, 2024
bc016f4
Fix style
fabiorigano Mar 10, 2024
acf0a4e
Merge branch 'main' into faceidcore
fabiorigano Mar 10, 2024
92bfe50
Remove old pipeline
fabiorigano Mar 12, 2024
b8eed8d
Fix copies
fabiorigano Mar 12, 2024
ec539d0
Fix loading for full face
fabiorigano Mar 12, 2024
9521f89
Revert community pipeline delete
fabiorigano Mar 13, 2024
aad463c
Revert copies
fabiorigano Mar 13, 2024
e3c3518
Revert encode_image and loading warning
fabiorigano Mar 13, 2024
b8fe711
Add a separate loop to load lora weights
fabiorigano Mar 13, 2024
0d0baee
Merge branch 'main' into faceidcore
fabiorigano Mar 13, 2024
3c9382d
Fix style
fabiorigano Mar 13, 2024
c6f106e
Split Full and Face ID blocks
fabiorigano Mar 23, 2024
21ed0cc
Load Face ID Plus
fabiorigano Mar 27, 2024
50261bf
Add Face ID Plus proj layers
fabiorigano Mar 27, 2024
66f9117
Bugfixes + add shortcut
fabiorigano Apr 7, 2024
a47644d
Merge branch 'main' into faceidcore
fabiorigano Apr 11, 2024
b934489
Update docs
fabiorigano Apr 11, 2024
9b168ce
Merge branch 'main' into faceidcore
sayakpaul Apr 12, 2024
32d6943
Update docs/source/en/using-diffusers/loading_adapters.md
fabiorigano Apr 12, 2024
dc0a5fb
Fix test and docs
fabiorigano Apr 12, 2024
5a7ec1d
Merge branch 'faceidcore' of https://github.com/fabiorigano/diffusers…
fabiorigano Apr 12, 2024
4073be8
Fix style
fabiorigano Apr 12, 2024
2b9d5a5
Move lora loading to separate function
fabiorigano Apr 15, 2024
c14f077
Add IPAdapterPlusImageProjectionBlock
fabiorigano Apr 15, 2024
5cab6da
Fix style
fabiorigano Apr 15, 2024
8077c30
Merge branch 'main' into faceidcore
sayakpaul Apr 16, 2024
0e03446
Fix quality + add PEFT check
fabiorigano Apr 16, 2024
e6495a2
Fix names
fabiorigano Apr 16, 2024
d7772e9
Add fast test
fabiorigano Apr 16, 2024
f77739f
Fix style
fabiorigano Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions docs/source/en/using-diffusers/ip_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,12 @@ IP-Adapter's image prompting and compatibility with other adapters and models ma

### Face model

Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces:
Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces from the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repository:

* [ip-adapter-full-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-full-face_sd15.safetensors) is conditioned with images of cropped faces and removed backgrounds
* [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces

> [!TIP]
>
> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters.
Additionally, Diffusers supports all IP-Adapter checkpoints trained with face embeddings extracted by `insightface` face models. Supported models are from the [h94/IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) repository.

For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models.

Expand Down Expand Up @@ -411,6 +409,56 @@ image
</div>
</div>

To use IP-Adapter FaceID models, first extract face embeddings with `insightface`. Then pass the list of tensors to the pipeline as `ip_adapter_image_embeds`.

```py
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils import load_image
from insightface.app import FaceAnalysis

pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
).to("cuda")
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None)
pipeline.set_ip_adapter_scale(0.6)

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png")

ref_images_embeds = []
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
faces = app.get(image)
image = torch.from_numpy(faces[0].normed_embedding)
ref_images_embeds.append(image.unsqueeze(0))
ref_images_embeds = torch.stack(ref_images_embeds, dim=0).unsqueeze(0)
neg_ref_images_embeds = torch.zeros_like(ref_images_embeds)
id_embeds = torch.cat([neg_ref_images_embeds, ref_images_embeds]).to(dtype=torch.float16, device="cuda"))

generator = torch.Generator(device="cpu").manual_seed(42)

images = pipeline(
prompt="A photo of a girl",
ip_adapter_image_embeds=[id_embeds],
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=20, num_images_per_prompt=1,
generator=generator
).images
```

Both IP-Adapter FaceID Plus and Plus v2 models require CLIP image embeddings. You can prepare face embeddings as shown previously, then you can extract and pass CLIP embeddings to the hidden image projection layers.

```py
clip_embeds = pipeline.prepare_ip_adapter_image_embeds([ip_adapter_images], None, torch.device("cuda"), num_images, True)[0]

pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = clip_embeds.to(dtype=torch.float16)
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = False # True if Plus v2
```


### Multi IP-Adapter

More than one IP-Adapter can be used at the same time to generate specific images in more diverse styles. For example, you can use IP-Adapter-Face to generate consistent faces and characters, and IP-Adapter Plus to generate those faces in a specific style.
Expand Down
37 changes: 37 additions & 0 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,40 @@ pipeline = AutoPipelineForText2Image.from_pretrained(

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
```

### IP-Adapter Face ID models

The IP-Adapter FaceID models are experimental IP Adapters that use image embeddings generated by `insightface` instead of CLIP image embeddings. Some of these models also use LoRA to improve ID consistency.
You need to install `insightface` and all its requirements to use these models.

<Tip warning={true}>
As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and are not intended for commercial use.
</Tip>

```py
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin", image_encoder_folder=None)
```

If you want to use one of the two IP-Adapter FaceID Plus models, you must also load the CLIP image encoder, as this models use both `insightface` and CLIP image embeddings to achieve better photorealism.

```py
from transformers import CLIPVisionModelWithProjection

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5",
image_encoder=image_encoder,
torch_dtype=torch.float16
).to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plus_sd15.bin")
```
2 changes: 0 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3819,12 +3819,10 @@ export_to_gif(frames, "animation.gif")
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
You need to install `insightface` and all its requirements to use this model.
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
You have to disable PEFT BACKEND in order to load weights.
You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).

```py
import diffusers
diffusers.utils.USE_PEFT_BACKEND = False
import torch
from diffusers.utils import load_image
import cv2
Expand Down
Loading