Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/en/api/models/hidream_image_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
```

## Loading GGUF quantized checkpoints for HiDream-I1

GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file`

```python
import torch
from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel

ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
transformer = HiDreamImageTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16
)
```

## HiDreamImageTransformer2DModel

[[autodoc]] HiDreamImageTransformer2DModel
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
Expand Down Expand Up @@ -133,6 +134,10 @@
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
"HiDreamImageTransformer2DModel": {
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
}


Expand Down
13 changes: 13 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -190,6 +191,7 @@
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
}

# Use to configure model sample size when original config is provided
Expand Down Expand Up @@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
model_type = "hidream"
else:
model_type = "v1"

Expand Down Expand Up @@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
converted_state_dict[key] = value

return converted_state_dict


def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

return checkpoint
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
Expand Down Expand Up @@ -602,7 +602,7 @@ def forward(
)


class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
Examples:
```py
>>> import torch
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from diffusers import HiDreamImagePipeline


>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
... output_hidden_states=True,
Expand Down
28 changes: 28 additions & 0 deletions tests/quantization/gguf/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
Expand Down Expand Up @@ -549,3 +550,30 @@ def test_lora_loading(self):

max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)


class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
torch_dtype = torch.bfloat16
model_cls = HiDreamImageTransformer2DModel
expected_memory_use_in_gb = 8

def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 16, 128, 128), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states_t5": torch.randn(
(1, 128, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"encoder_hidden_states_llama3": torch.randn(
(32, 1, 128, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"pooled_embeds": torch.randn(
(1, 2048),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
}