From b226d67d1d11419c9e63c564e8b006b5936db740 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 2 Aug 2024 12:53:12 +0200 Subject: [PATCH 1/3] update --- src/diffusers/loaders/single_file_model.py | 5 + src/diffusers/loaders/single_file_utils.py | 193 ++++++++++++++++++ .../models/transformers/transformer_flux.py | 4 +- 3 files changed, 200 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 92438620abd8..23d0b0ab2e7d 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -24,6 +24,7 @@ SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, convert_controlnet_checkpoint, + convert_flux_transformer_checkpoint_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_sd3_transformer_checkpoint_to_diffusers, @@ -74,6 +75,10 @@ "MotionAdapter": { "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, }, + "FluxTransformer2DModel": { + "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 483125f24825..3b2e62da5eb9 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -77,6 +77,7 @@ "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe", "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", + "flux": "double_blocks.0.img_attn.norm.key_norm.scale", } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -110,6 +111,8 @@ "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"}, "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"}, "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, + "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, + "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, } # Use to configure model sample size when original config is provided @@ -503,6 +506,11 @@ def infer_diffusers_model_type(checkpoint): else: model_type = "animatediff_v3" + elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint: + if "guidance_in.in_layer.bias" in checkpoint: + model_type = "flux-dev" + else: + model_type = "flux-schnell" else: model_type = "v1" @@ -1859,3 +1867,188 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): ] = v return converted_state_dict + + +def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + num_layers = sum([1 if key.startswith("double_blocks.") else 0 for key in checkpoint.keys()]) + num_single_layers = sum([1 if key.startswith("single_blocks.") else 0 for key in checkpoint.keys()]) + mlp_ratio = 4.0 + inner_dim = 3072 + + ## time_text_embed.timestep_embedder <- time_in + converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( + "time_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") + converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( + "time_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") + + ## time_text_embed.text_embedder <- vector_in + converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") + converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") + converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( + "vector_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") + + # guidance + has_guidance = any("guidance" in k for k in checkpoint) + if has_guidance: + converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( + "guidance_in.in_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( + "guidance_in.in_layer.bias" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( + "guidance_in.out_layer.weight" + ) + converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( + "guidance_in.out_layer.bias" + ) + + # context_embedder + converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") + converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") + + # x_embedder + converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") + converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + # norms. + ## norm1 + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_mod.lin.bias" + ) + ## norm1_context + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mod.lin.bias" + ) + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) + context_q, context_k, context_v = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.bias" + ) + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.bias" + ) + + # single transfomer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( + f"single_blocks.{i}.modulation.lin.bias" + ) + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) + q_bias, k_bias, v_bias, mlp_bias = torch.split( + checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) + converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + # output projections. + converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") + converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") + + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + checkpoint.pop("final_layer.adaLN_modulation.1.bias") + ) + + return converted_state_dict diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 364867275dc2..7d3c4e1552ea 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import FeedForward from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0 from ...models.modeling_utils import ModelMixin @@ -227,7 +227,7 @@ def forward( return encoder_hidden_states, hidden_states -class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ The Transformer model introduced in Flux. From 15924bc73bfd74c769f23c8d2636d6c7514163a0 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 5 Aug 2024 08:10:37 +0200 Subject: [PATCH 2/3] update --- src/diffusers/loaders/single_file_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 3b2e62da5eb9..0dce9d5c7aff 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -1872,11 +1872,18 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} - num_layers = sum([1 if key.startswith("double_blocks.") else 0 for key in checkpoint.keys()]) - num_single_layers = sum([1 if key.startswith("single_blocks.") else 0 for key in checkpoint.keys()]) + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 + num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 mlp_ratio = 4.0 inner_dim = 3072 + # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; + # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + ## time_text_embed.timestep_embedder <- time_in converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( "time_in.in_layer.weight" From 27910ee691e92890df0e534532e5a481d24d93bd Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 6 Aug 2024 10:25:23 +0000 Subject: [PATCH 3/3] update --- docs/source/en/api/pipelines/flux.md | 53 ++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md index 095bf76af37f..f7b3ee175057 100644 --- a/docs/source/en/api/pipelines/flux.md +++ b/docs/source/en/api/pipelines/flux.md @@ -77,6 +77,59 @@ out = pipe( out.save("image.png") ``` +## Single File Loading for the `FluxTransformer2DModel` + +The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community. + + +`FP8` inference can be brittle depending on the GPU type, CUDA version, and `torch` version that you are using. It is recommended that you use the `optimum-quanto` library in order to run FP8 inference on your machine. + + +The following example demonstrates how to run Flux with less than 16GB of VRAM. + +First install `optimum-quanto` + +```shell +pip install optimum-quanto +``` + +Then run the following example + +```python +import torch +from diffusers import FluxTransformer2DModel, FluxPipeline +from transformers import T5EncoderModel, CLIPTextModel +from optimum.quanto import freeze, qfloat8, quantize + +bfl_repo = "black-forest-labs/FLUX.1-dev" +dtype = torch.bfloat16 + +transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype) +quantize(transformer, weights=qfloat8) +freeze(transformer) + +text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype) +quantize(text_encoder_2, weights=qfloat8) +freeze(text_encoder_2) + +pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype) +pipe.transformer = transformer +pipe.text_encoder_2 = text_encoder_2 + +pipe.enable_model_cpu_offload() + +prompt = "A cat holding a sign that says hello world" +image = pipe( + prompt, + guidance_scale=3.5, + output_type="pil", + num_inference_steps=20, + generator=torch.Generator("cpu").manual_seed(0) +).images[0] + +image.save("flux-fp8-dev.png") +``` + ## FluxPipeline [[autodoc]] FluxPipeline