Skip to content
27 changes: 27 additions & 0 deletions docs/source/en/api/pipelines/aura_flow.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,33 @@ image = pipeline(prompt).images[0]
image.save("auraflow.png")
```

Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported:

```py
import torch
from diffusers import (
AuraFlowPipeline,
GGUFQuantizationConfig,
AuraFlowTransformer2DModel,
)

transformer = AuraFlowTransformer2DModel.from_single_file(
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)

pipeline = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
transformer=transformer,
torch_dtype=torch.bfloat16,
)

prompt = "a cute pony in a field of flowers"
image = pipeline(prompt).images[0]
image.save("auraflow.png")
```

## AuraFlowPipeline

[[autodoc]] AuraFlowPipeline
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
Expand Down Expand Up @@ -106,6 +107,10 @@
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AuraFlowTransformer2DModel": {
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}


Expand Down
103 changes: 103 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
"auraflow": [
"double_layers.0.attn.w2q.weight",
"double_layers.0.attn.w1q.weight",
"cond_seq_linear.weight",
"t_embedder.mlp.0.weight",
],
"flux": [
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
Expand Down Expand Up @@ -154,6 +160,7 @@
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
Expand Down Expand Up @@ -635,6 +642,9 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"

elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
model_type = "auraflow"

elif (
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
Expand Down Expand Up @@ -2090,6 +2100,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())

for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
Expand Down Expand Up @@ -2689,3 +2700,95 @@ def update_state_dict_(state_dict, old_key, new_key):
handler_fn_inplace(key, checkpoint)

return checkpoint


def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
state_dict_keys = list(checkpoint.keys())

# Handle register tokens and positional embeddings
converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)

# Handle time step projection
converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)

# Handle context embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)

# Calculate the number of layers
def calculate_layers(keys, key_prefix):
layers = set()
for k in keys:
if key_prefix in k:
layer_num = int(k.split(".")[1]) # get the layer number
layers.add(layer_num)
return len(layers)

mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")

# MMDiT blocks
for i in range(mmdit_layers):
# Feed-forward
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for orig_k, diffuser_k in path_mapping.items():
for k, v in weight_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
f"double_layers.{i}.{orig_k}.{k}.weight", None
)

# Norms
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
for orig_k, diffuser_k in path_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
f"double_layers.{i}.{orig_k}.1.weight", None
)

# Attentions
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
for k, v in attn_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
f"double_layers.{i}.attn.{k}.weight", None
)

# Single-DiT blocks
for i in range(single_dit_layers):
# Feed-forward
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for k, v in mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
f"single_layers.{i}.mlp.{k}.weight", None
)

# Norms
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
f"single_layers.{i}.modCX.1.weight", None
)

# Attentions
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
for k, v in x_attn_mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
f"single_layers.{i}.attn.{k}.weight", None
)
# Final blocks
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)

# Handle the final norm layer
norm_weight = checkpoint.pop("modF.1.weight", None)
if norm_weight is not None:
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
else:
converted_state_dict["norm_out.linear.weight"] = None

converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")

return converted_state_dict
3 changes: 2 additions & 1 deletion src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
Expand Down Expand Up @@ -253,7 +254,7 @@ def forward(
return encoder_hidden_states, hidden_states


class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/quantizers/gguf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def __init__(
def forward(self, inputs):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None

output = torch.nn.functional.linear(inputs, weight, bias)
return output
81 changes: 80 additions & 1 deletion tests/quantization/gguf/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch.nn as nn

from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
Expand Down Expand Up @@ -54,7 +56,8 @@ def test_gguf_linear_layers(self):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
assert module.weight.dtype == torch.uint8
assert module.bias.dtype == torch.float32
if module.bias is not None:
assert module.bias.dtype == torch.float32

def test_gguf_memory_usage(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
Expand Down Expand Up @@ -377,3 +380,79 @@ def test_pipeline_inference(self):
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4


class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf"
torch_dtype = torch.bfloat16
model_cls = AuraFlowTransformer2DModel
expected_memory_use_in_gb = 4

def setUp(self):
gc.collect()
torch.cuda.empty_cache()

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 4, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 2048),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}

def test_pipeline_inference(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
transformer = self.model_cls.from_single_file(
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
)
pipe = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3", transformer=transformer, torch_dtype=self.torch_dtype
)
pipe.enable_model_cpu_offload()

prompt = "a pony holding a sign that says hello"
output = pipe(
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
).images[0]
output_slice = output[:3, :3, :].flatten()
expected_slice = np.array(
[
0.46484375,
0.546875,
0.64453125,
0.48242188,
0.53515625,
0.59765625,
0.47070312,
0.5078125,
0.5703125,
0.42773438,
0.50390625,
0.5703125,
0.47070312,
0.515625,
0.57421875,
0.45898438,
0.48632812,
0.53515625,
0.4453125,
0.5078125,
0.56640625,
0.47851562,
0.5234375,
0.57421875,
0.48632812,
0.5234375,
0.56640625,
]
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
Loading