diff --git a/README.md b/README.md index c1636c967..026695aab 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ API and command-line option may change frequently.*** - [Chroma](./docs/chroma.md) - [Chroma1-Radiance](./docs/chroma_radiance.md) - [Qwen Image](./docs/qwen_image.md) + - [Z-Image](./docs/z_image.md) - Image Edit Models - [FLUX.1-Kontext-dev](./docs/kontext.md) - [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md) @@ -129,6 +130,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe - [🔥Qwen Image](./docs/qwen_image.md) - [🔥Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md) - [🔥Wan2.1/Wan2.2](./docs/wan.md) +- [🔥Z-Image](./docs/z_image.md) - [LoRA](./docs/lora.md) - [LCM/LCM-LoRA](./docs/lcm.md) - [Using PhotoMaker to personalize image generation](./docs/photo_maker.md) diff --git a/assets/z_image/bf16.png b/assets/z_image/bf16.png new file mode 100644 index 000000000..5bb7a955f Binary files /dev/null and b/assets/z_image/bf16.png differ diff --git a/assets/z_image/q2_K.png b/assets/z_image/q2_K.png new file mode 100644 index 000000000..20aff17ed Binary files /dev/null and b/assets/z_image/q2_K.png differ diff --git a/assets/z_image/q3_K.png b/assets/z_image/q3_K.png new file mode 100644 index 000000000..727b8e3e5 Binary files /dev/null and b/assets/z_image/q3_K.png differ diff --git a/assets/z_image/q4_0.png b/assets/z_image/q4_0.png new file mode 100644 index 000000000..5136b2ac2 Binary files /dev/null and b/assets/z_image/q4_0.png differ diff --git a/assets/z_image/q4_K.png b/assets/z_image/q4_K.png new file mode 100644 index 000000000..511104240 Binary files /dev/null and b/assets/z_image/q4_K.png differ diff --git a/assets/z_image/q5_0.png b/assets/z_image/q5_0.png new file mode 100644 index 000000000..a89081ec7 Binary files /dev/null and b/assets/z_image/q5_0.png differ diff --git a/assets/z_image/q6_K.png b/assets/z_image/q6_K.png new file mode 100644 index 000000000..d9f6ac9bf Binary files /dev/null and b/assets/z_image/q6_K.png differ diff --git a/assets/z_image/q8_0.png b/assets/z_image/q8_0.png new file mode 100644 index 000000000..38687a3ea Binary files /dev/null and b/assets/z_image/q8_0.png differ diff --git a/conditioner.hpp b/conditioner.hpp index bce625a2c..e28e6e158 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1638,6 +1638,8 @@ struct LLMEmbedder : public Conditioner { LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL; if (sd_version_is_flux2(version)) { arch = LLM::LLMArch::MISTRAL_SMALL_3_2; + } else if (sd_version_is_z_image(version)) { + arch = LLM::LLMArch::QWEN3; } if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) { tokenizer = std::make_shared(); @@ -1785,9 +1787,31 @@ struct LLMEmbedder : public Conditioner { prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; prompt += img_prompt; - prompt_attn_range.first = prompt.size(); + prompt_attn_range.first = static_cast(prompt.size()); prompt += conditioner_params.text; - prompt_attn_range.second = prompt.size(); + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } else if (sd_version_is_flux2(version)) { + prompt_template_encode_start_idx = 0; + out_layers = {10, 20, 30}; + + prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "[/INST]"; + } else if (sd_version_is_z_image(version)) { + prompt_template_encode_start_idx = 0; + out_layers = {35}; // -2 + + prompt = "<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n"; } else if (sd_version_is_flux2(version)) { @@ -1806,9 +1830,9 @@ struct LLMEmbedder : public Conditioner { prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; - prompt_attn_range.first = prompt.size(); + prompt_attn_range.first = static_cast(prompt.size()); prompt += conditioner_params.text; - prompt_attn_range.second = prompt.size(); + prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n"; } diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 0a3914edc..5a311f57e 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -6,6 +6,7 @@ #include "qwen_image.hpp" #include "unet.hpp" #include "wan.hpp" +#include "z_image.hpp" struct DiffusionParams { struct ggml_tensor* x = nullptr; @@ -357,4 +358,67 @@ struct QwenImageModel : public DiffusionModel { } }; +struct ZImageModel : public DiffusionModel { + std::string prefix; + ZImage::ZImageRunner z_image; + + ZImageModel(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model", + SDVersion version = VERSION_Z_IMAGE) + : prefix(prefix), z_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) { + } + + std::string get_desc() override { + return z_image.get_desc(); + } + + void alloc_params_buffer() override { + z_image.alloc_params_buffer(); + } + + void free_params_buffer() override { + z_image.free_params_buffer(); + } + + void free_compute_buffer() override { + z_image.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) override { + z_image.get_param_tensors(tensors, prefix); + } + + size_t get_params_buffer_size() override { + return z_image.get_params_buffer_size(); + } + + void set_weight_adapter(const std::shared_ptr& adapter) override { + z_image.set_weight_adapter(adapter); + } + + int64_t get_adm_in_channels() override { + return 768; + } + + void set_flash_attn_enabled(bool enabled) { + z_image.set_flash_attention_enabled(enabled); + } + + void compute(int n_threads, + DiffusionParams diffusion_params, + struct ggml_tensor** output = nullptr, + struct ggml_context* output_ctx = nullptr) override { + return z_image.compute(n_threads, + diffusion_params.x, + diffusion_params.timesteps, + diffusion_params.context, + diffusion_params.ref_latents, + true, // increase_ref_index + output, + output_ctx); + } +}; + #endif diff --git a/docs/z_image.md b/docs/z_image.md new file mode 100644 index 000000000..73eacffa8 --- /dev/null +++ b/docs/z_image.md @@ -0,0 +1,28 @@ +# How to Use + +You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or even less. + +## Download weights + +- Download Z-Image-Turbo + - safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/diffusion_models + - gguf: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main +- Download vae + - safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main +- Download Qwen3 4b + - safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/text_encoders + - gguf: https://huggingface.co/unsloth/Qwen3-4B-Instruct-2507-GGUF/tree/main + +## Examples + +``` +.\bin\Release\sd.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512 +``` + +z-image example + +## Comparison of Different Quantization Types + +| bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K| +|---|---|---|---|---|---|---|---| +| bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K | \ No newline at end of file diff --git a/llm.hpp b/llm.hpp index c96ba0f41..d1dd3a663 100644 --- a/llm.hpp +++ b/llm.hpp @@ -1,5 +1,5 @@ -#ifndef __QWENVL_HPP__ -#define __QWENVL_HPP__ +#ifndef __LLM_HPP__ +#define __LLM_HPP__ #include #include @@ -256,7 +256,7 @@ namespace LLM { ss << "\"" << token << "\", "; } ss << "]"; - // LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); + LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); return bpe_tokens; } @@ -469,12 +469,14 @@ namespace LLM { enum class LLMArch { QWEN2_5_VL, + QWEN3, MISTRAL_SMALL_3_2, ARCH_COUNT, }; static const char* llm_arch_to_str[] = { "qwen2.5vl", + "qwen3", "mistral_small3.2", }; @@ -501,6 +503,7 @@ namespace LLM { int64_t num_kv_heads = 4; int64_t head_dim = 128; bool qkv_bias = true; + bool qk_norm = false; int64_t vocab_size = 152064; float rms_norm_eps = 1e-06f; LLMVisionParams vision; @@ -813,14 +816,19 @@ namespace LLM { int64_t head_dim; int64_t num_heads; int64_t num_kv_heads; + bool qk_norm; public: Attention(const LLMParams& params) - : num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), arch(params.arch) { + : arch(params.arch), num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), qk_norm(params.qk_norm) { blocks["q_proj"] = std::make_shared(params.hidden_size, num_heads * head_dim, params.qkv_bias); blocks["k_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); blocks["v_proj"] = std::make_shared(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias); blocks["o_proj"] = std::make_shared(num_heads * head_dim, params.hidden_size, false); + if (params.qk_norm) { + blocks["q_norm"] = std::make_shared(head_dim, params.rms_norm_eps); + blocks["k_norm"] = std::make_shared(head_dim, params.rms_norm_eps); + } } struct ggml_tensor* forward(GGMLRunnerContext* ctx, @@ -842,9 +850,20 @@ namespace LLM { k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] + if (qk_norm) { + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + + q = q_norm->forward(ctx, q); + k = k_norm->forward(ctx, k); + } + if (arch == LLMArch::MISTRAL_SMALL_3_2) { q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + } else if (arch == LLMArch::QWEN3) { + q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); } else { int sections[4] = {16, 24, 24, 0}; q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); @@ -1063,6 +1082,17 @@ namespace LLM { params.qkv_bias = false; params.vocab_size = 131072; params.rms_norm_eps = 1e-5f; + } else if (arch == LLMArch::QWEN3) { + params.num_layers = 36; + params.hidden_size = 2560; + params.intermediate_size = 9728; + params.head_dim = 128; + params.num_heads = 32; + params.num_kv_heads = 8; + params.qkv_bias = false; + params.qk_norm = true; + params.vocab_size = 151936; + params.rms_norm_eps = 1e-6f; } bool have_vision_weight = false; bool llama_cpp_style = false; @@ -1132,7 +1162,7 @@ namespace LLM { } int64_t n_tokens = input_ids->ne[0]; - if (params.arch == LLMArch::MISTRAL_SMALL_3_2) { + if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::QWEN3) { input_pos_vec.resize(n_tokens); for (int i = 0; i < n_tokens; ++i) { input_pos_vec[i] = i; @@ -1420,7 +1450,8 @@ namespace LLM { struct ggml_context* work_ctx = ggml_init(params); GGML_ASSERT(work_ctx != nullptr); - bool test_mistral = true; + bool test_mistral = false; + bool test_qwen3 = true; bool test_vit = false; bool test_decoder_with_vit = false; @@ -1455,9 +1486,9 @@ namespace LLM { std::pair prompt_attn_range; std::string text = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; text += img_prompt; - prompt_attn_range.first = text.size(); + prompt_attn_range.first = static_cast(text.size()); text += "change 'flux.cpp' to 'edit.cpp'"; - prompt_attn_range.second = text.size(); + prompt_attn_range.second = static_cast(text.size()); text += "<|im_end|>\n<|im_start|>assistant\n"; auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false); @@ -1496,9 +1527,9 @@ namespace LLM { } else if (test_mistral) { std::pair prompt_attn_range; std::string text = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; - prompt_attn_range.first = text.size(); + prompt_attn_range.first = static_cast(text.size()); text += "a lovely cat"; - prompt_attn_range.second = text.size(); + prompt_attn_range.second = static_cast(text.size()); text += "[/INST]"; auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false); std::vector& tokens = std::get<0>(tokens_and_weights); @@ -1514,14 +1545,37 @@ namespace LLM { model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx); int t1 = ggml_time_ms(); + print_ggml_tensor(out); + LOG_DEBUG("llm test done in %dms", t1 - t0); + } else if (test_qwen3) { + std::pair prompt_attn_range; + std::string text = "<|im_start|>user\n"; + prompt_attn_range.first = static_cast(text.size()); + text += "a lovely cat"; + prompt_attn_range.second = static_cast(text.size()); + text += "<|im_end|>\n<|im_start|>assistant\n"; + auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false); + std::vector& tokens = std::get<0>(tokens_and_weights); + std::vector& weights = std::get<1>(tokens_and_weights); + for (auto token : tokens) { + printf("%d ", token); + } + printf("\n"); + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); + struct ggml_tensor* out = nullptr; + + int t0 = ggml_time_ms(); + model.compute(8, input_ids, {}, {35}, &out, work_ctx); + int t1 = ggml_time_ms(); + print_ggml_tensor(out); LOG_DEBUG("llm test done in %dms", t1 - t0); } else { std::pair prompt_attn_range; std::string text = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; - prompt_attn_range.first = text.size(); + prompt_attn_range.first = static_cast(text.size()); text += "a lovely cat"; - prompt_attn_range.second = text.size(); + prompt_attn_range.second = static_cast(text.size()); text += "<|im_end|>\n<|im_start|>assistant\n"; auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false); std::vector& tokens = std::get<0>(tokens_and_weights); @@ -1563,7 +1617,7 @@ namespace LLM { } } - LLMArch arch = LLMArch::MISTRAL_SMALL_3_2; + LLMArch arch = LLMArch::QWEN3; std::shared_ptr llm = std::make_shared(arch, backend, @@ -1587,6 +1641,6 @@ namespace LLM { llm->test(); } }; -}; // Qwen +}; // LLM -#endif // __QWENVL_HPP__ +#endif // __LLM_HPP__ diff --git a/mmdit.hpp b/mmdit.hpp index c243e034a..247c8f6d1 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -101,10 +101,14 @@ struct TimestepEmbedder : public GGMLBlock { public: TimestepEmbedder(int64_t hidden_size, - int64_t frequency_embedding_size = 256) + int64_t frequency_embedding_size = 256, + int64_t out_channels = 0) : frequency_embedding_size(frequency_embedding_size) { + if (out_channels <= 0) { + out_channels = hidden_size; + } blocks["mlp.0"] = std::shared_ptr(new Linear(frequency_embedding_size, hidden_size, true, true)); - blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true, true)); + blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, out_channels, true, true)); } struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* t) { diff --git a/model.cpp b/model.cpp index 05afde947..1f0ad828f 100644 --- a/model.cpp +++ b/model.cpp @@ -1067,6 +1067,9 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) { return VERSION_FLUX2; } + if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) { + return VERSION_Z_IMAGE; + } if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { is_wan = true; } diff --git a/model.h b/model.h index 64050541c..e2ff26c49 100644 --- a/model.h +++ b/model.h @@ -44,6 +44,7 @@ enum SDVersion { VERSION_WAN2_2_TI2V, VERSION_QWEN_IMAGE, VERSION_FLUX2, + VERSION_Z_IMAGE, VERSION_COUNT, }; @@ -116,6 +117,13 @@ static inline bool sd_version_is_qwen_image(SDVersion version) { return false; } +static inline bool sd_version_is_z_image(SDVersion version) { + if (version == VERSION_Z_IMAGE) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || @@ -132,7 +140,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_flux2(version) || sd_version_is_sd3(version) || sd_version_is_wan(version) || - sd_version_is_qwen_image(version)) { + sd_version_is_qwen_image(version) || + sd_version_is_z_image(version)) { return true; } return false; diff --git a/name_conversion.cpp b/name_conversion.cpp index c4670dfc7..8b521486d 100644 --- a/name_conversion.cpp +++ b/name_conversion.cpp @@ -133,6 +133,8 @@ std::string convert_cond_stage_model_name(std::string name, std::string prefix) {"attn_q.", "self_attn.q_proj."}, {"attn_k.", "self_attn.k_proj."}, {"attn_v.", "self_attn.v_proj."}, + {"attn_q_norm.", "self_attn.q_norm."}, + {"attn_k_norm.", "self_attn.k_norm."}, {"attn_output.", "self_attn.o_proj."}, {"attn_norm.", "input_layernorm."}, {"ffn_down.", "mlp.down_proj."}, @@ -613,6 +615,44 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { return name; } +std::string convert_diffusers_dit_to_original_lumina2(std::string name) { + int num_layers = 30; + int num_refiner_layers = 2; + static std::unordered_map z_image_name_map; + + if (z_image_name_map.empty()) { + z_image_name_map["all_x_embedder.2-1."] = "x_embedder."; + z_image_name_map["all_final_layer.2-1."] = "final_layer."; + + // --- transformer blocks --- + auto add_attention_map = [&](const std::string& prefix, int num) { + for (int i = 0; i < num; ++i) { + std::string block_prefix = prefix + std::to_string(i) + "."; + std::string dst_prefix = prefix + std::to_string(i) + "."; + + z_image_name_map[block_prefix + "attention.norm_q."] = dst_prefix + "attention.q_norm."; + z_image_name_map[block_prefix + "attention.norm_k."] = dst_prefix + "attention.k_norm."; + z_image_name_map[block_prefix + "attention.to_out.0."] = dst_prefix + "attention.out."; + + z_image_name_map[block_prefix + "attention.to_q.weight"] = dst_prefix + "attention.qkv.weight"; + z_image_name_map[block_prefix + "attention.to_q.bias"] = dst_prefix + "attention.qkv.bias"; + z_image_name_map[block_prefix + "attention.to_k.weight"] = dst_prefix + "attention.qkv.weight.1"; + z_image_name_map[block_prefix + "attention.to_k.bias"] = dst_prefix + "attention.qkv.bias.1"; + z_image_name_map[block_prefix + "attention.to_v.weight"] = dst_prefix + "attention.qkv.weight.2"; + z_image_name_map[block_prefix + "attention.to_v.bias"] = dst_prefix + "attention.qkv.bias.2"; + } + }; + + add_attention_map("noise_refiner.", num_refiner_layers); + add_attention_map("context_refiner.", num_refiner_layers); + add_attention_map("layers.", num_layers); + } + + replace_with_prefix_map(name, z_image_name_map); + + return name; +} + std::string convert_diffusion_model_name(std::string name, std::string prefix, SDVersion version) { if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) { name = convert_diffusers_unet_to_original_sd1(name); @@ -622,6 +662,8 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_dit_to_original_sd3(name); } else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) { name = convert_diffusers_dit_to_original_flux(name); + } else if (sd_version_is_z_image(version)) { + name = convert_diffusers_dit_to_original_lumina2(name); } return name; } diff --git a/rope.hpp b/rope.hpp index 964dcd8ba..7a35926eb 100644 --- a/rope.hpp +++ b/rope.hpp @@ -379,6 +379,55 @@ namespace Rope { return embed_nd(ids, 1, theta, axes_dim); } + __STATIC_INLINE__ int bound_mod(int a, int m) { + return (m - (a % m)) % m; + } + + __STATIC_INLINE__ std::vector> gen_z_image_ids(int h, + int w, + int patch_size, + int bs, + int context_len, + int seq_multi_of, + const std::vector& ref_latents, + bool increase_ref_index) { + int padded_context_len = context_len + bound_mod(context_len, seq_multi_of); + auto txt_ids = std::vector>(bs * padded_context_len, std::vector(3, 0.0f)); + for (int i = 0; i < bs * padded_context_len; i++) { + txt_ids[i][0] = (i % padded_context_len) + 1.f; + } + + int axes_dim_num = 3; + int index = padded_context_len + 1; + auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, index); + + int img_pad_len = bound_mod(static_cast(img_ids.size() / bs), seq_multi_of); + if (img_pad_len > 0) { + std::vector> img_pad_ids(bs * img_pad_len, std::vector(3, 0.f)); + img_ids = concat_ids(img_ids, img_pad_ids, bs); + } + + auto ids = concat_ids(txt_ids, img_ids, bs); + + // ignore ref_latents for now + return ids; + } + + // Generate z_image positional embeddings + __STATIC_INLINE__ std::vector gen_z_image_pe(int h, + int w, + int patch_size, + int bs, + int context_len, + int seq_multi_of, + const std::vector& ref_latents, + bool increase_ref_index, + int theta, + const std::vector& axes_dim) { + std::vector> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index); + return embed_nd(ids, bs, theta, axes_dim); + } + __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index ef798c2d8..2e873c2df 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -45,6 +45,7 @@ const char* model_version_to_str[] = { "Wan 2.2 TI2V", "Qwen Image", "Flux.2", + "Z-Image", }; const char* sampling_methods_str[] = { @@ -377,7 +378,7 @@ class StableDiffusionGGML { } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; shift_factor = 0.0609f; - } else if (sd_version_is_flux(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { scale_factor = 0.3611f; shift_factor = 0.1159f; } else if (sd_version_is_wan(version) || @@ -495,6 +496,16 @@ class StableDiffusionGGML { tensor_storage_map, "model.diffusion_model", version); + } else if (sd_version_is_z_image(version)) { + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + version); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else { // SD1.x SD2.x SDXL if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, @@ -868,6 +879,13 @@ class StableDiffusionGGML { shift = 3.0; } denoiser = std::make_shared(shift); + } else if (sd_version_is_z_image(version)) { + LOG_INFO("running in FLOW mode"); + float shift = sd_ctx_params->flow_shift; + if (shift == INFINITY) { + shift = 3.0f; + } + denoiser = std::make_shared(shift); } else if (is_using_v_parameterization) { LOG_INFO("running in v-prediction mode"); denoiser = std::make_shared(); @@ -1334,7 +1352,7 @@ class StableDiffusionGGML { if (sd_version_is_sd3(version)) { latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_bias = sd3_latent_rgb_bias; - } else if (sd_version_is_flux(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_bias = flux_latent_rgb_bias; } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { @@ -1635,6 +1653,8 @@ class StableDiffusionGGML { shifted_t = std::max((int64_t)0, std::min((int64_t)(TIMESTEPS - 1), shifted_t)); LOG_DEBUG("shifting timestep from %.2f to %" PRId64 " (sigma: %.4f)", t, shifted_t, sigma); timesteps_vec.assign(1, (float)shifted_t); + } else if (sd_version_is_z_image(version)) { + timesteps_vec.assign(1, 1000.f - t); } else { timesteps_vec.assign(1, t); } diff --git a/z_image.hpp b/z_image.hpp new file mode 100644 index 000000000..b692a14a5 --- /dev/null +++ b/z_image.hpp @@ -0,0 +1,670 @@ +#ifndef __Z_IMAGE_HPP__ +#define __Z_IMAGE_HPP__ + +#include + +#include "flux.hpp" +#include "ggml_extend.hpp" +#include "mmdit.hpp" + +// Ref: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py +// Ref: https://github.com/huggingface/diffusers/pull/12703 + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +namespace ZImage { + constexpr int Z_IMAGE_GRAPH_SIZE = 20480; + constexpr int ADALN_EMBED_DIM = 256; + constexpr int SEQ_MULTI_OF = 32; + + struct JointAttention : public GGMLBlock { + protected: + int64_t head_dim; + int64_t num_heads; + int64_t num_kv_heads; + bool qk_norm; + + public: + JointAttention(int64_t hidden_size, int64_t head_dim, int64_t num_heads, int64_t num_kv_heads, bool qk_norm) + : head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) { + blocks["qkv"] = std::make_shared(hidden_size, (num_heads + num_kv_heads * 2) * head_dim, false); + blocks["out"] = std::make_shared(num_heads * head_dim, hidden_size, false); + if (qk_norm) { + blocks["q_norm"] = std::make_shared(head_dim); + blocks["k_norm"] = std::make_shared(head_dim); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* pe, + struct ggml_tensor* mask = nullptr) { + // x: [N, n_token, hidden_size] + int64_t n_token = x->ne[1]; + int64_t N = x->ne[2]; + auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); + auto out_proj = std::dynamic_pointer_cast(blocks["out"]); + + auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim] + qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim] + qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim] + + auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim] + auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim] + auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim] + + q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim] + k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] + v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] + + if (qk_norm) { + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + + q = q_norm->forward(ctx, q); + k = k_norm->forward(ctx, k); + } + + x = Rope::attention(ctx, q, k, v, pe, mask, 1.f / 128.f); // [N, n_token, num_heads * head_dim] + + x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] + return x; + } + }; + + class FeedForward : public GGMLBlock { + public: + FeedForward(int64_t dim, + int64_t hidden_dim, + int64_t multiple_of, + float ffn_dim_multiplier = 0.f) { + if (ffn_dim_multiplier > 0.f) { + hidden_dim = static_cast(ffn_dim_multiplier * hidden_dim); + } + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) / multiple_of); + blocks["w1"] = std::make_shared(dim, hidden_dim, false); + + bool force_prec_f32 = false; + float scale = 1.f / 128.f; +#ifdef SD_USE_VULKAN + force_prec_f32 = true; +#endif + // The purpose of the scale here is to prevent NaN issues in certain situations. + // For example, when using CUDA but the weights are k-quants. + blocks["w2"] = std::make_shared(hidden_dim, dim, false, false, force_prec_f32, 1.f / 128.f); + blocks["w3"] = std::make_shared(dim, hidden_dim, false); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + auto w1 = std::dynamic_pointer_cast(blocks["w1"]); + auto w2 = std::dynamic_pointer_cast(blocks["w2"]); + auto w3 = std::dynamic_pointer_cast(blocks["w3"]); + + auto x1 = w1->forward(ctx, x); + auto x3 = w3->forward(ctx, x); + x = ggml_mul(ctx->ggml_ctx, ggml_silu(ctx->ggml_ctx, x1), x3); + x = w2->forward(ctx, x); + + return x; + } + }; + + __STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* scale) { + // x: [N, L, C] + // scale: [N, C] + scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + return x; + } + + struct JointTransformerBlock : public GGMLBlock { + protected: + bool modulation; + + public: + JointTransformerBlock(int layer_id, + int64_t hidden_size, + int64_t head_dim, + int64_t num_heads, + int64_t num_kv_heads, + int64_t multiple_of, + float ffn_dim_multiplier, + float norm_eps, + bool qk_norm, + bool modulation = true) + : modulation(modulation) { + blocks["attention"] = std::make_shared(hidden_size, head_dim, num_heads, num_kv_heads, qk_norm); + blocks["feed_forward"] = std::make_shared(hidden_size, hidden_size, multiple_of, ffn_dim_multiplier); + blocks["attention_norm1"] = std::make_shared(hidden_size, norm_eps); + blocks["ffn_norm1"] = std::make_shared(hidden_size, norm_eps); + blocks["attention_norm2"] = std::make_shared(hidden_size, norm_eps); + blocks["ffn_norm2"] = std::make_shared(hidden_size, norm_eps); + if (modulation) { + blocks["adaLN_modulation.0"] = std::make_shared(MIN(hidden_size, ADALN_EMBED_DIM), 4 * hidden_size); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* pe, + struct ggml_tensor* mask = nullptr, + struct ggml_tensor* adaln_input = nullptr) { + auto attention = std::dynamic_pointer_cast(blocks["attention"]); + auto feed_forward = std::dynamic_pointer_cast(blocks["feed_forward"]); + auto attention_norm1 = std::dynamic_pointer_cast(blocks["attention_norm1"]); + auto ffn_norm1 = std::dynamic_pointer_cast(blocks["ffn_norm1"]); + auto attention_norm2 = std::dynamic_pointer_cast(blocks["attention_norm2"]); + auto ffn_norm2 = std::dynamic_pointer_cast(blocks["ffn_norm2"]); + + if (modulation) { + GGML_ASSERT(adaln_input != nullptr); + auto adaLN_modulation_0 = std::dynamic_pointer_cast(blocks["adaLN_modulation.0"]); + + auto m = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size] + auto mods = ggml_ext_chunk(ctx->ggml_ctx, m, 4, 0); + auto scale_msa = mods[0]; + auto gate_msa = mods[1]; + auto scale_mlp = mods[2]; + auto gate_mlp = mods[3]; + + auto residual = x; + x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa); + x = attention->forward(ctx, x, pe, mask); + x = attention_norm2->forward(ctx, x); + x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_msa)); + x = ggml_add(ctx->ggml_ctx, x, residual); + + residual = x; + x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp); + x = feed_forward->forward(ctx, x); + x = ffn_norm2->forward(ctx, x); + x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_mlp)); + x = ggml_add(ctx->ggml_ctx, x, residual); + } else { + GGML_ASSERT(adaln_input == nullptr); + + auto residual = x; + x = attention_norm1->forward(ctx, x); + x = attention->forward(ctx, x, pe, mask); + x = attention_norm2->forward(ctx, x); + x = ggml_add(ctx->ggml_ctx, x, residual); + + residual = x; + x = ffn_norm1->forward(ctx, x); + x = feed_forward->forward(ctx, x); + x = ffn_norm2->forward(ctx, x); + x = ggml_add(ctx->ggml_ctx, x, residual); + } + + return x; + } + }; + + struct FinalLayer : public GGMLBlock { + public: + FinalLayer(int64_t hidden_size, + int64_t patch_size, + int64_t out_channels) { + blocks["norm_final"] = std::make_shared(hidden_size, 1e-06f, false); + blocks["linear"] = std::make_shared(hidden_size, patch_size * patch_size * out_channels, true, true); + blocks["adaLN_modulation.1"] = std::make_shared(MIN(hidden_size, ADALN_EMBED_DIM), hidden_size); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { + // x: [N, n_token, hidden_size] + // c: [N, hidden_size] + // return: [N, n_token, patch_size * patch_size * out_channels] + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + auto scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size] + x = norm_final->forward(ctx, x); + x = modulate(ctx->ggml_ctx, x, scale); + x = linear->forward(ctx, x); + + return x; + } + }; + + struct ZImageParams { + int64_t patch_size = 2; + int64_t hidden_size = 3840; + int64_t in_channels = 16; + int64_t out_channels = 16; + int64_t num_layers = 30; + int64_t num_refiner_layers = 2; + int64_t head_dim = 128; + int64_t num_heads = 30; + int64_t num_kv_heads = 30; + int64_t multiple_of = 256; + float ffn_dim_multiplier = 8.0 / 3.0f; + float norm_eps = 1e-5f; + bool qk_norm = true; + int64_t cap_feat_dim = 2560; + float theta = 256.f; + std::vector axes_dim = {32, 48, 48}; + int64_t axes_dim_sum = 128; + }; + + class ZImageModel : public GGMLBlock { + protected: + ZImageParams z_image_params; + + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); + params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); + } + + public: + ZImageModel() = default; + ZImageModel(ZImageParams z_image_params) + : z_image_params(z_image_params) { + blocks["x_embedder"] = std::make_shared(z_image_params.patch_size * z_image_params.patch_size * z_image_params.in_channels, z_image_params.hidden_size); + blocks["t_embedder"] = std::make_shared(MIN(z_image_params.hidden_size, 1024), 256, 256); + blocks["cap_embedder.0"] = std::make_shared(z_image_params.cap_feat_dim, z_image_params.norm_eps); + blocks["cap_embedder.1"] = std::make_shared(z_image_params.cap_feat_dim, z_image_params.hidden_size); + + for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + auto block = std::make_shared(i, + z_image_params.hidden_size, + z_image_params.head_dim, + z_image_params.num_heads, + z_image_params.num_kv_heads, + z_image_params.multiple_of, + z_image_params.ffn_dim_multiplier, + z_image_params.norm_eps, + z_image_params.qk_norm, + true); + + blocks["noise_refiner." + std::to_string(i)] = block; + } + + for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + auto block = std::make_shared(i, + z_image_params.hidden_size, + z_image_params.head_dim, + z_image_params.num_heads, + z_image_params.num_kv_heads, + z_image_params.multiple_of, + z_image_params.ffn_dim_multiplier, + z_image_params.norm_eps, + z_image_params.qk_norm, + false); + + blocks["context_refiner." + std::to_string(i)] = block; + } + + for (int i = 0; i < z_image_params.num_layers; i++) { + auto block = std::make_shared(i, + z_image_params.hidden_size, + z_image_params.head_dim, + z_image_params.num_heads, + z_image_params.num_kv_heads, + z_image_params.multiple_of, + z_image_params.ffn_dim_multiplier, + z_image_params.norm_eps, + z_image_params.qk_norm, + true); + + blocks["layers." + std::to_string(i)] = block; + } + + blocks["final_layer"] = std::make_shared(z_image_params.hidden_size, z_image_params.patch_size, z_image_params.out_channels); + } + + struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, + struct ggml_tensor* x) { + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + + int pad_h = (z_image_params.patch_size - H % z_image_params.patch_size) % z_image_params.patch_size; + int pad_w = (z_image_params.patch_size - W % z_image_params.patch_size) % z_image_params.patch_size; + x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + return x; + } + + struct ggml_tensor* patchify(struct ggml_context* ctx, + struct ggml_tensor* x) { + // x: [N, C, H, W] + // return: [N, h*w, patch_size*patch_size*C] + int64_t N = x->ne[3]; + int64_t C = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + int64_t p = z_image_params.patch_size; + int64_t h = H / z_image_params.patch_size; + int64_t w = W / z_image_params.patch_size; + + GGML_ASSERT(h * p == H && w * p == W); + + x = ggml_reshape_4d(ctx, x, p, w, p, h * C * N); // [N*C*h, p, w, p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p] + x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p] + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, p*p] + x = ggml_reshape_3d(ctx, x, C * p * p, w * h, N); // [N, h*w, p*p*C] + return x; + } + + struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* x) { + x = pad_to_patch_size(ctx, x); + x = patchify(ctx, x); + return x; + } + + struct ggml_tensor* unpatchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t h, + int64_t w) { + // x: [N, h*w, patch_size*patch_size*C] + // return: [N, C, H, W] + int64_t N = x->ne[2]; + int64_t C = x->ne[0] / z_image_params.patch_size / z_image_params.patch_size; + int64_t H = h * z_image_params.patch_size; + int64_t W = w * z_image_params.patch_size; + int64_t p = z_image_params.patch_size; + + GGML_ASSERT(C * p * p == x->ne[0]); + + x = ggml_reshape_4d(ctx, x, C, p * p, w * h, N); // [N, h*w, p*p, C] + x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, h*w, p*p] + x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p] + + return x; + } + + struct ggml_tensor* forward_core(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* context, + struct ggml_tensor* pe) { + auto x_embedder = std::dynamic_pointer_cast(blocks["x_embedder"]); + auto t_embedder = std::dynamic_pointer_cast(blocks["t_embedder"]); + auto cap_embedder_0 = std::dynamic_pointer_cast(blocks["cap_embedder.0"]); + auto cap_embedder_1 = std::dynamic_pointer_cast(blocks["cap_embedder.1"]); + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); + + auto txt_pad_token = params["cap_pad_token"]; + auto img_pad_token = params["x_pad_token"]; + + int64_t N = x->ne[2]; + int64_t n_img_token = x->ne[1]; + int64_t n_txt_token = context->ne[1]; + + auto t_emb = t_embedder->forward(ctx, timestep); + + auto txt = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, context)); // [N, n_txt_token, hidden_size] + auto img = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size] + + int64_t n_txt_pad_token = Rope::bound_mod(n_txt_token, SEQ_MULTI_OF); + if (n_txt_pad_token > 0) { + auto txt_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, txt_pad_token, txt_pad_token->ne[0], n_txt_pad_token, N, 1); + txt = ggml_concat(ctx->ggml_ctx, txt, txt_pad_tokens, 1); // [N, n_txt_token + n_txt_pad_token, hidden_size] + } + + int64_t n_img_pad_token = Rope::bound_mod(n_img_token, SEQ_MULTI_OF); + if (n_img_pad_token > 0) { + auto img_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, img_pad_token, img_pad_token->ne[0], n_img_pad_token, N, 1); + img = ggml_concat(ctx->ggml_ctx, img, img_pad_tokens, 1); // [N, n_img_token + n_img_pad_token, hidden_size] + } + + GGML_ASSERT(txt->ne[1] + img->ne[1] == pe->ne[3]); + + auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]); + auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], pe->ne[3]); + + for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["context_refiner." + std::to_string(i)]); + + txt = block->forward(ctx, txt, txt_pe, nullptr, nullptr); + } + + for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["noise_refiner." + std::to_string(i)]); + + img = block->forward(ctx, img, img_pe, nullptr, t_emb); + } + + auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, hidden_size] + + for (int i = 0; i < z_image_params.num_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); + + txt_img = block->forward(ctx, txt_img, pe, nullptr, t_emb); + } + + txt_img = final_layer->forward(ctx, txt_img, t_emb); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, ph*pw*C] + + img = ggml_ext_slice(ctx->ggml_ctx, txt_img, 1, n_txt_token + n_txt_pad_token, n_txt_token + n_txt_pad_token + n_img_token); // [N, n_img_token, ph*pw*C] + + return img; + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* context, + struct ggml_tensor* pe, + std::vector ref_latents = {}) { + // Forward pass of DiT. + // x: [N, C, H, W] + // timestep: [N,] + // context: [N, L, D] + // pe: [L, d_head/2, 2, 2] + // return: [N, C, H, W] + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t C = x->ne[2]; + int64_t N = x->ne[3]; + + auto img = process_img(ctx->ggml_ctx, x); + uint64_t n_img_token = img->ne[1]; + + if (ref_latents.size() > 0) { + for (ggml_tensor* ref : ref_latents) { + ref = process_img(ctx->ggml_ctx, ref); + img = ggml_concat(ctx->ggml_ctx, img, ref, 1); + } + } + + int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size); + int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size); + + auto out = forward_core(ctx, img, timestep, context, pe); + + out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C] + out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] + + // slice + out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] + out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W] + + out = ggml_scale(ctx->ggml_ctx, out, -1.f); + + return out; + } + }; + + struct ZImageRunner : public GGMLRunner { + public: + ZImageParams z_image_params; + ZImageModel z_image; + std::vector pe_vec; + std::vector timestep_vec; + SDVersion version; + + ZImageRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "", + SDVersion version = VERSION_Z_IMAGE) + : GGMLRunner(backend, offload_params_to_cpu) { + z_image = ZImageModel(z_image_params); + z_image.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "z_image"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + z_image.get_param_tensors(tensors, prefix); + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + std::vector ref_latents = {}, + bool increase_ref_index = false) { + GGML_ASSERT(x->ne[3] == 1); + struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE); + + x = to_backend(x); + context = to_backend(context); + timesteps = to_backend(timesteps); + + for (int i = 0; i < ref_latents.size(); i++) { + ref_latents[i] = to_backend(ref_latents[i]); + } + + pe_vec = Rope::gen_z_image_pe(x->ne[1], + x->ne[0], + z_image_params.patch_size, + x->ne[3], + context->ne[1], + SEQ_MULTI_OF, + ref_latents, + increase_ref_index, + z_image_params.theta, + z_image_params.axes_dim); + int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2; + // LOG_DEBUG("pos_len %d", pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len); + // pe->data = pe_vec.data(); + // print_ggml_tensor(pe, true, "pe"); + // pe->data = nullptr; + set_backend_tensor_data(pe, pe_vec.data()); + auto runner_ctx = get_context(); + + struct ggml_tensor* out = z_image.forward(&runner_ctx, + x, + timesteps, + context, + pe, + ref_latents); + + ggml_build_forward_expand(gf, out); + + return gf; + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + std::vector ref_latents = {}, + bool increase_ref_index = false, + struct ggml_tensor** output = nullptr, + struct ggml_context* output_ctx = nullptr) { + // x: [N, in_channels, h, w] + // timesteps: [N, ] + // context: [N, max_position, hidden_size] + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, timesteps, context, ref_latents, increase_ref_index); + }; + + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } + + void test() { + struct ggml_init_params params; + params.mem_size = static_cast(1024 * 1024) * 1024; // 1GB + params.mem_buffer = nullptr; + params.no_alloc = false; + + struct ggml_context* work_ctx = ggml_init(params); + GGML_ASSERT(work_ctx != nullptr); + + { + // auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1); + // ggml_set_f32(x, 0.01f); + auto x = load_tensor_from_file(work_ctx, "./z_image_x.bin"); + print_ggml_tensor(x); + + std::vector timesteps_vec(1, 0.f); + auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + + // auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 2560, 256, 1); + // ggml_set_f32(context, 0.01f); + auto context = load_tensor_from_file(work_ctx, "./z_image_context.bin"); + print_ggml_tensor(context); + + struct ggml_tensor* out = nullptr; + + int t0 = ggml_time_ms(); + compute(8, x, timesteps, context, {}, false, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("z_image test done in %dms", t1 - t0); + } + } + + static void load_from_file_and_test(const std::string& file_path) { + // cuda q8: pass + // cuda q8 fa: pass + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_type model_data_type = GGML_TYPE_Q8_0; + + ModelLoader model_loader; + if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) { + LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); + return; + } + + auto& tensor_storage_map = model_loader.get_tensor_storage_map(); + if (model_data_type != GGML_TYPE_COUNT) { + for (auto& [name, tensor_storage] : tensor_storage_map) { + if (ends_with(name, "weight")) { + tensor_storage.expected_type = model_data_type; + } + } + } + + std::shared_ptr z_image = std::make_shared(backend, + false, + tensor_storage_map, + "model.diffusion_model", + VERSION_QWEN_IMAGE); + + z_image->alloc_params_buffer(); + std::map tensors; + z_image->get_param_tensors(tensors, "model.diffusion_model"); + + bool success = model_loader.load_tensors(tensors); + + if (!success) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("z_image model loaded"); + z_image->test(); + } + }; + +} // namespace ZImage + +#endif // __Z_IMAGE_HPP__ \ No newline at end of file