diff --git a/conditioner.hpp b/conditioner.hpp index 94e98a511..89bd33f0b 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -2,6 +2,7 @@ #define __CONDITIONER_HPP__ #include "clip.hpp" +#include "qwen3.hpp" #include "qwenvl.hpp" #include "t5.hpp" @@ -1830,4 +1831,123 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { } }; +struct ZImageConditioner : public Conditioner { + Qwen::Qwen2Tokenizer tokenizer; + std::shared_ptr qwen3; + std::string chat_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"; + int64_t skip_token_count = 0; // Use full sequence for Z-Image + + ZImageConditioner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string& prefix = "text_encoders.qwen3") { + qwen3 = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + prefix); + } + + void get_param_tensors(std::map& tensors) override { + qwen3->get_param_tensors(tensors, "text_encoders.qwen3"); + } + + void alloc_params_buffer() override { + qwen3->alloc_params_buffer(); + } + + void free_params_buffer() override { + qwen3->free_params_buffer(); + } + + size_t get_params_buffer_size() override { + return qwen3->get_params_buffer_size(); + } + + std::string apply_chat_template(const std::string& text) { + std::string result = chat_template; + size_t pos = result.find("{}"); + if (pos != std::string::npos) { + result.replace(pos, 2, text); + } + return result; + } + + std::tuple, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + std::vector tokens; + std::vector weights; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + std::vector curr_tokens = tokenizer.tokenize(curr_text, nullptr); + tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); + weights.insert(weights.end(), curr_tokens.size(), curr_weight); + } + tokenizer.pad_tokens(tokens, weights, max_length, padding); + return {tokens, weights}; + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const ConditionerParams& conditioner_params) override { + std::string prompt = apply_chat_template(conditioner_params.text); + auto tokens_and_weights = tokenize(prompt, 0, false); + auto& tokens = std::get<0>(tokens_and_weights); + auto& weights = std::get<1>(tokens_and_weights); + + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = nullptr; + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); + + qwen3->compute(n_threads, + input_ids, + &hidden_states, + work_ctx); + { + auto tensor = hidden_states; + float original_mean = ggml_ext_tensor_mean(tensor); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2); + value *= weights[i1]; + ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2); + } + } + } + float new_mean = ggml_ext_tensor_mean(tensor); + if (new_mean > 0) { + ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean)); + } + } + + int64_t skip_count = skip_token_count; + int64_t output_seq_len = hidden_states->ne[1] - skip_count; + if (output_seq_len <= 0) { + LOG_WARN("ZImageConditioner: output sequence length would be %lld (hidden_states seq=%lld, skip=%lld), using full sequence", + output_seq_len, hidden_states->ne[1], skip_count); + output_seq_len = hidden_states->ne[1]; + skip_count = 0; + } + + ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx, + GGML_TYPE_F32, + hidden_states->ne[0], + output_seq_len, + hidden_states->ne[2]); + + ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + skip_count, i2, i3); + ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); + }); + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing Z-Image condition graph completed, taking %" PRId64 " ms", t1 - t0); + return {new_hidden_states, nullptr, nullptr}; + } +}; + #endif diff --git a/denoiser.hpp b/denoiser.hpp index 12ba8a77a..9dd887db0 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -627,6 +627,55 @@ struct FluxFlowDenoiser : public Denoiser { } }; +// Z-Image flow matching denoiser +struct ZImageFlowDenoiser : public Denoiser { + float sigmas[TIMESTEPS]; + float shift = 3.0f; + + ZImageFlowDenoiser(float shift = 3.0f) { + this->shift = shift; + for (int i = 0; i < TIMESTEPS; i++) { + sigmas[i] = t_to_sigma(i); + } + } + + float sigma_min() override { + return sigmas[0]; + } + + float sigma_max() override { + return sigmas[TIMESTEPS - 1]; + } + + float sigma_to_t(float sigma) override { + return 1.0f - sigma; + } + + float t_to_sigma(float t) override { + float sigma_raw = (t + 1) / TIMESTEPS; + return shift * sigma_raw / (1.0f + (shift - 1.0f) * sigma_raw); + } + + std::vector get_scalings(float sigma) override { + float c_skip = 1.0f; + float c_out = sigma; + float c_in = 1.0f; + return {c_skip, c_out, c_in}; + } + + ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override { + ggml_ext_tensor_scale_inplace(noise, sigma); + ggml_ext_tensor_scale_inplace(latent, 1.0f - sigma); + ggml_ext_tensor_add_inplace(latent, noise); + return latent; + } + + ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override { + ggml_ext_tensor_scale_inplace(latent, 1.0f / (1.0f - sigma)); + return latent; + } +}; + typedef std::function denoise_cb_t; // k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 0a3914edc..2a005b61c 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -6,6 +6,7 @@ #include "qwen_image.hpp" #include "unet.hpp" #include "wan.hpp" +#include "zimage.hpp" struct DiffusionParams { struct ggml_tensor* x = nullptr; @@ -357,4 +358,65 @@ struct QwenImageModel : public DiffusionModel { } }; +struct ZImageDiffusionModel : public DiffusionModel { + std::string prefix; + ZImage::ZImageRunner zimage; + + ZImageDiffusionModel(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map = {}, + const std::string prefix = "model.diffusion_model") + : prefix(prefix), zimage(backend, offload_params_to_cpu, tensor_storage_map, prefix) { + } + + std::string get_desc() override { + return zimage.get_desc(); + } + + void alloc_params_buffer() override { + zimage.alloc_params_buffer(); + } + + void free_params_buffer() override { + zimage.free_params_buffer(); + } + + void free_compute_buffer() override { + zimage.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) override { + zimage.get_param_tensors(tensors, prefix); + } + + size_t get_params_buffer_size() override { + return zimage.get_params_buffer_size(); + } + + int64_t get_adm_in_channels() override { + return 0; + } + + void set_flash_attn_enabled(bool enabled) { + zimage.set_flash_attention_enabled(enabled); + } + + void compute(int n_threads, + DiffusionParams diffusion_params, + struct ggml_tensor** output = nullptr, + struct ggml_context* output_ctx = nullptr) override { + int height = diffusion_params.x->ne[1] * 8; + int width = diffusion_params.x->ne[0] * 8; + + zimage.compute(n_threads, + diffusion_params.x, + diffusion_params.timesteps, + diffusion_params.context, + height, + width, + output, + output_ctx); + } +}; + #endif diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 427364a46..0585d1927 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -72,6 +72,7 @@ struct SDParams { std::string t5xxl_path; std::string qwen2vl_path; std::string qwen2vl_vision_path; + std::string qwen3_path; std::string diffusion_model_path; std::string high_noise_diffusion_model_path; std::string vae_path; @@ -176,6 +177,7 @@ void print_params(SDParams params) { printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str()); printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str()); + printf(" qwen3_path: %s\n", params.qwen3_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str()); @@ -540,6 +542,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { "--qwen2vl_vision", "path to the qwen2vl vit", ¶ms.qwen2vl_vision_path}, + {"", + "--qwen3", + "path to the qwen3 text encoder (for Z-Image)", + ¶ms.qwen3_path}, {"", "--diffusion-model", "path to the standalone diffusion model", @@ -1428,7 +1434,7 @@ std::string get_image_params(SDParams params, int64_t seed) { parameter_string += " " + std::string(sd_scheduler_name(params.sample_params.scheduler)); } parameter_string += ", "; - for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) { + for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path, params.qwen3_path}) { if (!te.empty()) { parameter_string += "TE: " + sd_basename(te) + ", "; } @@ -1847,6 +1853,7 @@ int main(int argc, const char* argv[]) { params.t5xxl_path.c_str(), params.qwen2vl_path.c_str(), params.qwen2vl_vision_path.c_str(), + params.qwen3_path.c_str(), params.diffusion_model_path.c_str(), params.high_noise_diffusion_model_path.c_str(), params.vae_path.c_str(), diff --git a/model.cpp b/model.cpp index dac6e88f5..bd57d9fac 100644 --- a/model.cpp +++ b/model.cpp @@ -1062,6 +1062,13 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) { return VERSION_QWEN_IMAGE; } + if (tensor_storage.name.find("model.diffusion_model.context_refiner.") != std::string::npos || + tensor_storage.name.find("model.diffusion_model.noise_refiner.") != std::string::npos || + // Also check without prefix for safetensors files exported directly + tensor_storage.name.find("context_refiner.") == 0 || + tensor_storage.name.find("noise_refiner.") == 0) { + return VERSION_ZIMAGE; + } if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { is_wan = true; } diff --git a/model.h b/model.h index 2ac079fb5..e5f61a0bf 100644 --- a/model.h +++ b/model.h @@ -43,6 +43,7 @@ enum SDVersion { VERSION_WAN2_2_I2V, VERSION_WAN2_2_TI2V, VERSION_QWEN_IMAGE, + VERSION_ZIMAGE, VERSION_COUNT, }; @@ -108,6 +109,13 @@ static inline bool sd_version_is_qwen_image(SDVersion version) { return false; } +static inline bool sd_version_is_zimage(SDVersion version) { + if (version == VERSION_ZIMAGE) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || @@ -123,7 +131,8 @@ static inline bool sd_version_is_dit(SDVersion version) { if (sd_version_is_flux(version) || sd_version_is_sd3(version) || sd_version_is_wan(version) || - sd_version_is_qwen_image(version)) { + sd_version_is_qwen_image(version) || + sd_version_is_zimage(version)) { return true; } return false; diff --git a/qwen3.hpp b/qwen3.hpp new file mode 100644 index 000000000..be1643844 --- /dev/null +++ b/qwen3.hpp @@ -0,0 +1,303 @@ +#ifndef __QWEN3_HPP__ +#define __QWEN3_HPP__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml_extend.hpp" +#include "json.hpp" + +namespace Qwen3 { + constexpr int QWEN3_GRAPH_SIZE = 10240; + + struct Qwen3MLP : public GGMLBlock { + public: + Qwen3MLP(int64_t hidden_size, int64_t intermediate_size) { + blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, false)); + blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, false)); + blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + auto gate_proj = std::dynamic_pointer_cast(blocks["gate_proj"]); + auto up_proj = std::dynamic_pointer_cast(blocks["up_proj"]); + auto down_proj = std::dynamic_pointer_cast(blocks["down_proj"]); + + auto h = gate_proj->forward(ctx, x); + h = ggml_silu_inplace(ctx->ggml_ctx, h); + h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x)); + h = down_proj->forward(ctx, h); + return h; + } + }; + + struct Qwen3Attention : public GGMLBlock { + protected: + int64_t head_dim; + int64_t num_heads; + int64_t num_kv_heads; + float rope_theta; + + public: + Qwen3Attention(int64_t hidden_size, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_dim = 128, + float rope_theta = 1000000.f) + : head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), rope_theta(rope_theta) { + blocks["q_proj"] = std::shared_ptr(new Linear(hidden_size, num_heads * head_dim, false)); + blocks["k_proj"] = std::shared_ptr(new Linear(hidden_size, num_kv_heads * head_dim, false)); + blocks["v_proj"] = std::shared_ptr(new Linear(hidden_size, num_kv_heads * head_dim, false)); + blocks["o_proj"] = std::shared_ptr(new Linear(num_heads * head_dim, hidden_size, false)); + blocks["q_norm"] = std::shared_ptr(new RMSNorm(head_dim, 1e-6f)); + blocks["k_norm"] = std::shared_ptr(new RMSNorm(head_dim, 1e-6f)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* input_pos) { + int64_t n_token = x->ne[1]; + int64_t N = x->ne[2]; + auto q_proj = std::dynamic_pointer_cast(blocks["q_proj"]); + auto k_proj = std::dynamic_pointer_cast(blocks["k_proj"]); + auto v_proj = std::dynamic_pointer_cast(blocks["v_proj"]); + auto out_proj = std::dynamic_pointer_cast(blocks["o_proj"]); + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + + auto q = q_proj->forward(ctx, x); + auto k = k_proj->forward(ctx, x); + auto v = v_proj->forward(ctx, x); + + q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, n_token, N); + k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_kv_heads, n_token, N); + v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); + + q = q_norm->forward(ctx, q); + k = k_norm->forward(ctx, k); + + q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 128000, rope_theta, 1.f, 0.f, 1.f, 32.f, 1.f); + k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 128000, rope_theta, 1.f, 0.f, 1.f, 32.f, 1.f); + + q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); + q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); + + k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); + k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); + + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); + + x = out_proj->forward(ctx, x); + return x; + } + }; + + struct Qwen3Block : public GGMLBlock { + public: + Qwen3Block(int64_t hidden_size, + int64_t intermediate_size, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_dim = 128, + float rope_theta = 1000000.f, + float eps = 1e-6f) { + blocks["self_attn"] = std::shared_ptr(new Qwen3Attention(hidden_size, num_heads, num_kv_heads, head_dim, rope_theta)); + blocks["mlp"] = std::shared_ptr(new Qwen3MLP(hidden_size, intermediate_size)); + blocks["input_layernorm"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + blocks["post_attention_layernorm"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* input_pos) { + auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + auto input_layernorm = std::dynamic_pointer_cast(blocks["input_layernorm"]); + auto post_attention_layernorm = std::dynamic_pointer_cast(blocks["post_attention_layernorm"]); + + auto residual = x; + x = input_layernorm->forward(ctx, x); + x = self_attn->forward(ctx, x, input_pos); + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); + + residual = x; + x = post_attention_layernorm->forward(ctx, x); + x = mlp->forward(ctx, x); + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); + + return x; + } + }; + + struct Qwen3TextModel : public GGMLBlock { + protected: + int64_t num_layers; + + public: + Qwen3TextModel(int64_t num_layers, + int64_t vocab_size, + int64_t hidden_size, + int64_t intermediate_size, + int64_t num_heads, + int64_t num_kv_heads, + int64_t head_dim = 128, + float rope_theta = 1000000.f, + float eps = 1e-6f) + : num_layers(num_layers) { + blocks["embed_tokens"] = std::shared_ptr(new Embedding(vocab_size, hidden_size)); + for (int i = 0; i < num_layers; i++) { + blocks["layers." + std::to_string(i)] = std::shared_ptr(new Qwen3Block(hidden_size, + intermediate_size, + num_heads, + num_kv_heads, + head_dim, + rope_theta, + eps)); + } + blocks["norm"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + } + + // This matches the diffusers implementation which uses .hidden_states[-2] + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* input_ids, + struct ggml_tensor* input_pos, + bool return_second_to_last = false) { + auto embed_tokens = std::dynamic_pointer_cast(blocks["embed_tokens"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + + auto x = embed_tokens->forward(ctx, input_ids); + + struct ggml_tensor* second_to_last = nullptr; + for (int i = 0; i < num_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); + x = block->forward(ctx, x, input_pos); + if (i == num_layers - 2) { + second_to_last = x; // Save output from second-to-last layer + } + } + + if (return_second_to_last && second_to_last != nullptr) { + return second_to_last; // Return without final norm + } + + x = norm->forward(ctx, x); + return x; + } + }; + + struct Qwen3Params { + int64_t num_layers = 36; + int64_t hidden_size = 2560; + int64_t intermediate_size = 9728; + int64_t num_heads = 32; + int64_t num_kv_heads = 8; + int64_t head_dim = 128; + int64_t vocab_size = 151936; + float rope_theta = 1000000.f; + float rms_norm_eps = 1e-06f; + }; + + struct Qwen3 : public GGMLBlock { + Qwen3Params params; + + public: + Qwen3() {} + Qwen3(Qwen3Params params) + : params(params) { + blocks["model"] = std::shared_ptr(new Qwen3TextModel(params.num_layers, + params.vocab_size, + params.hidden_size, + params.intermediate_size, + params.num_heads, + params.num_kv_heads, + params.head_dim, + params.rope_theta, + params.rms_norm_eps)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* input_ids, + struct ggml_tensor* input_pos, + bool return_second_to_last = false) { + auto model = std::dynamic_pointer_cast(blocks["model"]); + auto x = model->forward(ctx, input_ids, input_pos, return_second_to_last); + return x; + } + }; + + struct Qwen3Runner : public GGMLRunner { + Qwen3Params params; + Qwen3 model; + std::vector input_pos_vec; + bool return_second_to_last = true; // For Z-Image, return hidden_states[-2] + + Qwen3Runner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string prefix) + : GGMLRunner(backend, offload_params_to_cpu) { + model = Qwen3(params); + model.init(params_ctx, tensor_storage_map, prefix); + } + + std::string get_desc() override { + return "qwen3"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + model.get_param_tensors(tensors, prefix); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* input_ids, + struct ggml_tensor* input_pos) { + auto hidden_states = model.forward(ctx, input_ids, input_pos, return_second_to_last); + return hidden_states; + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { + struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); + + input_ids = to_backend(input_ids); + + int64_t n_tokens = input_ids->ne[0]; + input_pos_vec.resize(n_tokens); + for (int i = 0; i < n_tokens; ++i) { + input_pos_vec[i] = i; + } + + auto input_pos = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, n_tokens); + set_backend_tensor_data(input_pos, input_pos_vec.data()); + + auto runner_ctx = get_context(); + + struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos); + + ggml_build_forward_expand(gf, hidden_states); + + return gf; + } + + void compute(const int n_threads, + struct ggml_tensor* input_ids, + ggml_tensor** output, + ggml_context* output_ctx = nullptr) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(input_ids); + }; + GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + } + }; + +}; + +#endif diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index b129d53d4..a32be49bd 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -44,6 +44,7 @@ const char* model_version_to_str[] = { "Wan 2.2 I2V", "Wan 2.2 TI2V", "Qwen Image", + "Z-Image", }; const char* sampling_methods_str[] = { @@ -282,6 +283,13 @@ class StableDiffusionGGML { } } + if (strlen(SAFE_STR(sd_ctx_params->qwen3_path)) > 0) { + LOG_INFO("loading qwen3 from '%s'", sd_ctx_params->qwen3_path); + if (!model_loader.init_from_file(sd_ctx_params->qwen3_path, "text_encoders.qwen3.")) { + LOG_WARN("loading qwen3 from '%s' failed", sd_ctx_params->qwen3_path); + } + } + if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_vision_path)) > 0) { LOG_INFO("loading qwen2vl vision from '%s'", sd_ctx_params->qwen2vl_vision_path); if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_vision_path, "text_encoders.qwen2vl.visual.")) { @@ -381,6 +389,9 @@ class StableDiffusionGGML { shift_factor = 0.1159f; } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { scale_factor = 1.0f; + } else if (sd_version_is_zimage(version)) { + scale_factor = 0.3611f; + shift_factor = 0.1159f; } if (sd_version_is_control(version)) { @@ -479,6 +490,13 @@ class StableDiffusionGGML { tensor_storage_map, "model.diffusion_model", version); + } else if (sd_version_is_zimage(version)) { + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map); } else { // SD1.x SD2.x SDXL if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, @@ -844,6 +862,13 @@ class StableDiffusionGGML { shift = 3.0; } denoiser = std::make_shared(shift); + } else if (sd_version_is_zimage(version)) { + LOG_INFO("running in Z-Image FLOW mode"); + float shift = sd_ctx_params->flow_shift; + if (shift == INFINITY) { + shift = 3.0; + } + denoiser = std::make_shared(shift); } else if (is_using_v_parameterization) { LOG_INFO("running in v-prediction mode"); denoiser = std::make_shared(); @@ -2398,6 +2423,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "t5xxl_path: %s\n" "qwen2vl_path: %s\n" "qwen2vl_vision_path: %s\n" + "qwen3_path: %s\n" "diffusion_model_path: %s\n" "high_noise_diffusion_model_path: %s\n" "vae_path: %s\n" @@ -2429,6 +2455,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->qwen2vl_path), SAFE_STR(sd_ctx_params->qwen2vl_vision_path), + SAFE_STR(sd_ctx_params->qwen3_path), SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), SAFE_STR(sd_ctx_params->vae_path), diff --git a/stable-diffusion.h b/stable-diffusion.h index 309da9b1a..282c12b17 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -158,6 +158,7 @@ typedef struct { const char* t5xxl_path; const char* qwen2vl_path; const char* qwen2vl_vision_path; + const char* qwen3_path; const char* diffusion_model_path; const char* high_noise_diffusion_model_path; const char* vae_path; diff --git a/zimage.hpp b/zimage.hpp new file mode 100644 index 000000000..9d06ad810 --- /dev/null +++ b/zimage.hpp @@ -0,0 +1,730 @@ +#ifndef __ZIMAGE_HPP__ +#define __ZIMAGE_HPP__ + +#include "ggml_extend.hpp" +#include "rope.hpp" + +namespace ZImage { + constexpr int ZIMAGE_GRAPH_SIZE = 40960; + constexpr int ADALN_EMBED_DIM = 256; + constexpr int SEQ_MULTI_OF = 32; + + struct TimestepEmbedder : public UnaryBlock { + protected: + int64_t out_size; + int64_t frequency_embedding_size; + + public: + TimestepEmbedder(int64_t out_size, int64_t mid_size = 1024, int64_t frequency_embedding_size = 256) + : out_size(out_size), frequency_embedding_size(frequency_embedding_size) { + blocks["mlp.0"] = std::shared_ptr(new Linear(frequency_embedding_size, mid_size, true)); + blocks["mlp.2"] = std::shared_ptr(new Linear(mid_size, out_size, true)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* t) override { + auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); + auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); + + auto t_freq = ggml_ext_timestep_embedding(ctx->ggml_ctx, t, (int)frequency_embedding_size, 10000, 1.0f); + auto t_emb = mlp_0->forward(ctx, t_freq); + t_emb = ggml_silu_inplace(ctx->ggml_ctx, t_emb); + t_emb = mlp_2->forward(ctx, t_emb); + return t_emb; + } + }; + + struct ZImageFeedForward : public UnaryBlock { + protected: + int64_t dim; + int64_t hidden_dim; + + public: + ZImageFeedForward(int64_t dim) + : dim(dim) { + hidden_dim = (int64_t)(dim / 3 * 8); + blocks["w1"] = std::shared_ptr(new Linear(dim, hidden_dim, false)); + blocks["w2"] = std::shared_ptr(new Linear(hidden_dim, dim, false)); + blocks["w3"] = std::shared_ptr(new Linear(dim, hidden_dim, false)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { + auto w1 = std::dynamic_pointer_cast(blocks["w1"]); + auto w2 = std::dynamic_pointer_cast(blocks["w2"]); + auto w3 = std::dynamic_pointer_cast(blocks["w3"]); + + auto h1 = w1->forward(ctx, x); + h1 = ggml_silu_inplace(ctx->ggml_ctx, h1); + auto h3 = w3->forward(ctx, x); + h1 = ggml_mul_inplace(ctx->ggml_ctx, h1, h3); + return w2->forward(ctx, h1); + } + }; + + struct ZImageSelfAttention : public GGMLBlock { + protected: + int64_t dim; + int64_t num_heads; + int64_t head_dim; + bool qk_norm; + + public: + ZImageSelfAttention(int64_t dim, int64_t num_heads, bool qk_norm = true) + : dim(dim), num_heads(num_heads), qk_norm(qk_norm) { + head_dim = dim / num_heads; + blocks["qkv"] = std::shared_ptr(new Linear(dim, 3 * dim, false)); + blocks["out"] = std::shared_ptr(new Linear(dim, dim, false)); + if (qk_norm) { + blocks["q_norm"] = std::shared_ptr(new RMSNorm(head_dim, 1e-5f)); + blocks["k_norm"] = std::shared_ptr(new RMSNorm(head_dim, 1e-5f)); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* pe, + struct ggml_tensor* attn_mask = nullptr) { + auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); + auto out = std::dynamic_pointer_cast(blocks["out"]); + + int64_t n_token = x->ne[1]; + int64_t N = x->ne[2]; + + auto qkv = qkv_proj->forward(ctx, x); + size_t elem_size = ggml_element_size(qkv); + auto q_view = ggml_view_3d(ctx->ggml_ctx, qkv, dim, n_token, N, qkv->nb[1], qkv->nb[2], 0 * dim * elem_size); + auto k_view = ggml_view_3d(ctx->ggml_ctx, qkv, dim, n_token, N, qkv->nb[1], qkv->nb[2], 1 * dim * elem_size); + auto v_view = ggml_view_3d(ctx->ggml_ctx, qkv, dim, n_token, N, qkv->nb[1], qkv->nb[2], 2 * dim * elem_size); + q_view = ggml_cont(ctx->ggml_ctx, q_view); + k_view = ggml_cont(ctx->ggml_ctx, k_view); + v_view = ggml_cont(ctx->ggml_ctx, v_view); + + auto q = ggml_reshape_4d(ctx->ggml_ctx, q_view, head_dim, num_heads, n_token, N); + auto k = ggml_reshape_4d(ctx->ggml_ctx, k_view, head_dim, num_heads, n_token, N); + auto v = ggml_reshape_4d(ctx->ggml_ctx, v_view, head_dim, num_heads, n_token, N); + + if (qk_norm) { + auto norm_q = std::dynamic_pointer_cast(blocks["q_norm"]); + auto norm_k = std::dynamic_pointer_cast(blocks["k_norm"]); + q = norm_q->forward(ctx, q); + k = norm_k->forward(ctx, k); + } + + if (pe != nullptr) { + float kv_scale = ctx->flash_attn_enabled ? (1.0f / 256.0f) : 1.0f; + x = Rope::attention(ctx, q, k, v, pe, attn_mask, kv_scale, true); + } else { + float kv_scale = ctx->flash_attn_enabled ? (1.0f / 256.0f) : 1.0f; + q = ggml_reshape_3d(ctx->ggml_ctx, q, dim, n_token, N); + k = ggml_reshape_3d(ctx->ggml_ctx, k, dim, n_token, N); + v = ggml_reshape_3d(ctx->ggml_ctx, v, dim, n_token, N); + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attn_mask, false, false, ctx->flash_attn_enabled, kv_scale); + } + x = out->forward(ctx, x); + return x; + } + }; + + struct ZImageTransformerBlock : public GGMLBlock { + protected: + int64_t dim; + bool modulation; + + public: + ZImageTransformerBlock(int64_t dim, + int64_t n_heads, + float norm_eps = 1e-5f, + bool qk_norm = true, + bool modulation = true) + : dim(dim), modulation(modulation) { + blocks["attention"] = std::shared_ptr(new ZImageSelfAttention(dim, n_heads, qk_norm)); + blocks["feed_forward"] = std::shared_ptr(new ZImageFeedForward(dim)); + blocks["attention_norm1"] = std::shared_ptr(new RMSNorm(dim, norm_eps)); + blocks["ffn_norm1"] = std::shared_ptr(new RMSNorm(dim, norm_eps)); + blocks["attention_norm2"] = std::shared_ptr(new RMSNorm(dim, norm_eps)); + blocks["ffn_norm2"] = std::shared_ptr(new RMSNorm(dim, norm_eps)); + if (modulation) { + int64_t adaln_in = dim < ADALN_EMBED_DIM ? dim : ADALN_EMBED_DIM; + blocks["adaLN_modulation.0"] = std::shared_ptr(new Linear(adaln_in, 4 * dim, true)); + } + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* attn_mask, + struct ggml_tensor* freqs_cis, + struct ggml_tensor* adaln_input = nullptr) { + auto attention = std::dynamic_pointer_cast(blocks["attention"]); + auto feed_forward = std::dynamic_pointer_cast(blocks["feed_forward"]); + auto attention_norm1 = std::dynamic_pointer_cast(blocks["attention_norm1"]); + auto ffn_norm1 = std::dynamic_pointer_cast(blocks["ffn_norm1"]); + auto attention_norm2 = std::dynamic_pointer_cast(blocks["attention_norm2"]); + auto ffn_norm2 = std::dynamic_pointer_cast(blocks["ffn_norm2"]); + + if (modulation && adaln_input != nullptr) { + auto adaLN = std::dynamic_pointer_cast(blocks["adaLN_modulation.0"]); + auto mod = adaLN->forward(ctx, adaln_input); + int64_t B = mod->ne[1]; + size_t elem_size = ggml_element_size(mod); + int64_t stride_B = dim * 4 * elem_size; + + auto scale_msa = ggml_view_2d(ctx->ggml_ctx, mod, dim, B, stride_B, 0 * dim * elem_size); + auto gate_msa = ggml_view_2d(ctx->ggml_ctx, mod, dim, B, stride_B, 1 * dim * elem_size); + auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, mod, dim, B, stride_B, 2 * dim * elem_size); + auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, mod, dim, B, stride_B, 3 * dim * elem_size); + + scale_msa = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, scale_msa), dim, 1, B); + gate_msa = ggml_tanh(ctx->ggml_ctx, ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, gate_msa), dim, 1, B)); + scale_mlp = ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, scale_mlp), dim, 1, B); + gate_mlp = ggml_tanh(ctx->ggml_ctx, ggml_reshape_3d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, gate_mlp), dim, 1, B)); + + auto normed = attention_norm1->forward(ctx, x); + normed = ggml_add(ctx->ggml_ctx, normed, ggml_mul(ctx->ggml_ctx, normed, scale_msa)); + auto attn_out = attention->forward(ctx, normed, freqs_cis, attn_mask); + attn_out = attention_norm2->forward(ctx, attn_out); + attn_out = ggml_mul_inplace(ctx->ggml_ctx, attn_out, gate_msa); + x = ggml_add_inplace(ctx->ggml_ctx, x, attn_out); + + normed = ffn_norm1->forward(ctx, x); + normed = ggml_add(ctx->ggml_ctx, normed, ggml_mul(ctx->ggml_ctx, normed, scale_mlp)); + auto ffn_out = feed_forward->forward(ctx, normed); + ffn_out = ffn_norm2->forward(ctx, ffn_out); + ffn_out = ggml_mul_inplace(ctx->ggml_ctx, ffn_out, gate_mlp); + x = ggml_add_inplace(ctx->ggml_ctx, x, ffn_out); + } else { + auto normed = attention_norm1->forward(ctx, x); + auto attn_out = attention->forward(ctx, normed, freqs_cis, attn_mask); + attn_out = attention_norm2->forward(ctx, attn_out); + x = ggml_add_inplace(ctx->ggml_ctx, x, attn_out); + + normed = ffn_norm1->forward(ctx, x); + auto ffn_out = feed_forward->forward(ctx, normed); + ffn_out = ffn_norm2->forward(ctx, ffn_out); + x = ggml_add_inplace(ctx->ggml_ctx, x, ffn_out); + } + return x; + } + }; + + struct FinalLayer : public GGMLBlock { + protected: + int64_t hidden_size; + + public: + FinalLayer(int64_t hidden_size, int64_t out_channels) + : hidden_size(hidden_size) { + blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, out_channels, true)); + int64_t adaln_in = hidden_size < ADALN_EMBED_DIM ? hidden_size : ADALN_EMBED_DIM; + blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(adaln_in, hidden_size, true)); + } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + auto adaLN = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + auto scale = ggml_silu(ctx->ggml_ctx, c); + scale = adaLN->forward(ctx, scale); + scale = ggml_reshape_3d(ctx->ggml_ctx, scale, scale->ne[0], 1, scale->ne[1]); + + x = norm_final->forward(ctx, x); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, x, scale)); + x = linear->forward(ctx, x); + return x; + } + }; + + struct ZImageTransformer2DModel : public GGMLBlock { + protected: + int64_t in_channels; + int64_t dim; + int64_t n_layers; + int64_t n_refiner_layers; + int64_t n_heads; + int64_t patch_size; + float t_scale; + int theta; + std::vector axes_dims; + std::vector axes_lens; + + public: + ZImageTransformer2DModel(int64_t in_channels = 16, + int64_t dim = 3840, + int64_t n_layers = 30, + int64_t n_refiner_layers = 2, + int64_t n_heads = 30, + float norm_eps = 1e-5f, + bool qk_norm = true, + int64_t cap_feat_dim = 2560, + int64_t patch_size = 2, + float t_scale = 1000.f, + int theta_ = 256, + std::vector axes_dims = {32, 48, 48}, + std::vector axes_lens = {1536, 512, 512}) + : in_channels(in_channels), dim(dim), n_layers(n_layers), n_refiner_layers(n_refiner_layers), + n_heads(n_heads), patch_size(patch_size), t_scale(t_scale), theta(theta_), + axes_dims(axes_dims), axes_lens(axes_lens) { + blocks["x_embedder"] = std::shared_ptr( + new Linear(patch_size * patch_size * in_channels, dim, true)); + blocks["final_layer"] = std::shared_ptr( + new FinalLayer(dim, patch_size * patch_size * in_channels)); + + for (int i = 0; i < n_refiner_layers; i++) { + blocks["noise_refiner." + std::to_string(i)] = std::shared_ptr( + new ZImageTransformerBlock(dim, n_heads, norm_eps, qk_norm, true)); + } + + for (int i = 0; i < n_refiner_layers; i++) { + blocks["context_refiner." + std::to_string(i)] = std::shared_ptr( + new ZImageTransformerBlock(dim, n_heads, norm_eps, qk_norm, false)); + } + + int64_t adaln_size = dim < ADALN_EMBED_DIM ? dim : ADALN_EMBED_DIM; + blocks["t_embedder"] = std::shared_ptr(new TimestepEmbedder(adaln_size, 1024, 256)); + + blocks["cap_embedder.0"] = std::shared_ptr(new RMSNorm(cap_feat_dim, norm_eps)); + blocks["cap_embedder.1"] = std::shared_ptr(new Linear(cap_feat_dim, dim, true)); + + for (int i = 0; i < n_layers; i++) { + blocks["layers." + std::to_string(i)] = std::shared_ptr( + new ZImageTransformerBlock(dim, n_heads, norm_eps, qk_norm, true)); + } + } + + + int64_t get_dim() const { return dim; } + int64_t get_patch_size() const { return patch_size; } + int64_t get_in_channels() const { return in_channels; } + float get_t_scale() const { return t_scale; } + int get_theta() const { return theta; } + const std::vector& get_axes_dims() const { return axes_dims; } + const std::vector& get_axes_lens() const { return axes_lens; } + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* t, + struct ggml_tensor* cap_feats, + struct ggml_tensor* x_freqs_cis, + struct ggml_tensor* cap_freqs_cis, + struct ggml_tensor* unified_freqs_cis, + int64_t x_seq_len, + int64_t cap_seq_len) { + auto x_embedder = std::dynamic_pointer_cast(blocks["x_embedder"]); + auto cap_norm = std::dynamic_pointer_cast(blocks["cap_embedder.0"]); + auto cap_proj = std::dynamic_pointer_cast(blocks["cap_embedder.1"]); + auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); + auto t_embedder = std::dynamic_pointer_cast(blocks["t_embedder"]); + + x = x_embedder->forward(ctx, x); + auto t_scaled = ggml_scale(ctx->ggml_ctx, t, t_scale); + auto c = t_embedder->forward(ctx, t_scaled); + cap_feats = cap_norm->forward(ctx, cap_feats); + cap_feats = cap_proj->forward(ctx, cap_feats); + + for (int i = 0; i < n_refiner_layers; i++) { + auto block = std::dynamic_pointer_cast( + blocks["noise_refiner." + std::to_string(i)]); + x = block->forward(ctx, x, nullptr, x_freqs_cis, c); + } + + for (int i = 0; i < n_refiner_layers; i++) { + auto block = std::dynamic_pointer_cast( + blocks["context_refiner." + std::to_string(i)]); + cap_feats = block->forward(ctx, cap_feats, nullptr, cap_freqs_cis, nullptr); + } + + x = ggml_concat(ctx->ggml_ctx, x, cap_feats, 1); + int64_t total_seq_len = x->ne[1]; + + for (int i = 0; i < n_layers; i++) { + auto block = std::dynamic_pointer_cast( + blocks["layers." + std::to_string(i)]); + x = block->forward(ctx, x, nullptr, unified_freqs_cis, c); + } + + x = ggml_view_3d(ctx->ggml_ctx, x, x->ne[0], x_seq_len, x->ne[2], + x->nb[1], x->nb[2], 0); + x = ggml_cont(ctx->ggml_ctx, x); + + x = final_layer->forward(ctx, x, c); + + return x; + } + }; + + __STATIC_INLINE__ std::vector gen_zimage_pe( + const std::vector>& pos_ids, + int theta, + const std::vector& axes_dims) { + std::vector> ids(pos_ids.size(), std::vector(3)); + for (size_t i = 0; i < pos_ids.size(); i++) { + ids[i][0] = (float)pos_ids[i][0]; + ids[i][1] = (float)pos_ids[i][1]; + ids[i][2] = (float)pos_ids[i][2]; + } + return Rope::embed_nd(ids, 1, theta, axes_dims); + } + + __STATIC_INLINE__ std::vector> gen_zimage_img_ids(int h, int w, int patch_size, int bs, int axis0_offset = 0) { + int h_len = h / patch_size; + int w_len = w / patch_size; + std::vector> ids(bs * h_len * w_len, std::vector(3)); + for (int b = 0; b < bs; b++) { + for (int i = 0; i < h_len; i++) { + for (int j = 0; j < w_len; j++) { + int idx = b * h_len * w_len + i * w_len + j; + ids[idx][0] = (float)axis0_offset; + ids[idx][1] = (float)i; + ids[idx][2] = (float)j; + } + } + } + return ids; + } + + __STATIC_INLINE__ std::vector> gen_zimage_cap_ids(int cap_seq_len, int bs, int axis0_start = 1) { + std::vector> ids(bs * cap_seq_len, std::vector(3)); + for (int b = 0; b < bs; b++) { + for (int i = 0; i < cap_seq_len; i++) { + int idx = b * cap_seq_len + i; + ids[idx][0] = (float)(axis0_start + i); + ids[idx][1] = 0.0f; + ids[idx][2] = 0.0f; + } + } + return ids; + } + + __STATIC_INLINE__ std::vector cpu_patchify( + const float* src, int64_t W, int64_t H, int64_t C, int64_t B, int64_t patch_size) { + + int64_t pW = patch_size, pH = patch_size; + int64_t W_tok = W / pW; + int64_t H_tok = H / pH; + int64_t inner_dim = pH * pW * C; + int64_t seq_len = H_tok * W_tok; + + std::vector dst(B * seq_len * inner_dim); + + for (int64_t b = 0; b < B; b++) { + for (int64_t h_tok = 0; h_tok < H_tok; h_tok++) { + for (int64_t w_tok = 0; w_tok < W_tok; w_tok++) { + int64_t seq_idx = h_tok * W_tok + w_tok; + for (int64_t ph = 0; ph < pH; ph++) { + for (int64_t pw = 0; pw < pW; pw++) { + for (int64_t c = 0; c < C; c++) { + int64_t inner_idx = c + C * (pw + pW * ph); + + int64_t src_w = pw + pW * w_tok; + int64_t src_h = ph + pH * h_tok; + int64_t src_idx = src_w + W * (src_h + H * (c + C * b)); + + int64_t dst_idx = inner_idx + inner_dim * (seq_idx + seq_len * b); + + dst[dst_idx] = src[src_idx]; + } + } + } + } + } + } + return dst; + } + + __STATIC_INLINE__ std::vector cpu_unpatchify( + const float* src, int64_t W, int64_t H, int64_t C, int64_t B, int64_t patch_size) { + + int64_t pW = patch_size, pH = patch_size; + int64_t W_tok = W / pW; + int64_t H_tok = H / pH; + int64_t inner_dim = pH * pW * C; + int64_t seq_len = H_tok * W_tok; + + std::vector dst(B * C * H * W); + + for (int64_t b = 0; b < B; b++) { + for (int64_t h_tok = 0; h_tok < H_tok; h_tok++) { + for (int64_t w_tok = 0; w_tok < W_tok; w_tok++) { + int64_t seq_idx = h_tok * W_tok + w_tok; + for (int64_t ph = 0; ph < pH; ph++) { + for (int64_t pw = 0; pw < pW; pw++) { + for (int64_t c = 0; c < C; c++) { + int64_t inner_idx = c + C * (pw + pW * ph); + + int64_t dst_w = pw + pW * w_tok; + int64_t dst_h = ph + pH * h_tok; + int64_t dst_idx = dst_w + W * (dst_h + H * (c + C * b)); + + int64_t src_idx = inner_idx + inner_dim * (seq_idx + seq_len * b); + + dst[dst_idx] = src[src_idx]; + } + } + } + } + } + } + return dst; + } + + __STATIC_INLINE__ int64_t pad_cap_len(int64_t cap_len) { + int64_t padding = (-cap_len) % SEQ_MULTI_OF; + if (padding < 0) padding += SEQ_MULTI_OF; + return cap_len + padding; + } + + __STATIC_INLINE__ std::vector gen_zimage_unified_pe(int h, int w, int patch_size, int cap_seq_len, int bs, + int theta, const std::vector& axes_dims) { + int64_t cap_padded_len = pad_cap_len(cap_seq_len); + + auto img_ids = gen_zimage_img_ids(h, w, patch_size, bs, cap_padded_len + 1); + auto cap_ids = gen_zimage_cap_ids(cap_padded_len, bs, 1); + + auto ids = Rope::concat_ids(img_ids, cap_ids, bs); + return Rope::embed_nd(ids, bs, theta, axes_dims); + } + + struct ZImageRunner : public GGMLRunner { + ZImageTransformer2DModel model; + + std::vector unified_pe_vec; + std::vector img_pe_vec; + std::vector cap_pe_vec; + std::vector timestep_vec; + + ZImageRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2TensorStorage& tensor_storage_map, + const std::string& prefix = "model.diffusion_model.") + : GGMLRunner(backend, offload_params_to_cpu) { + model = ZImageTransformer2DModel(); + model.init(params_ctx, tensor_storage_map, prefix); + + std::map model_tensors; + model.get_param_tensors(model_tensors, prefix); + } + + std::string get_desc() override { + return "Z-Image"; + } + + void get_param_tensors(std::map& tensors, const std::string& prefix = "model.diffusion_model.") { + model.get_param_tensors(tensors, prefix); + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* cap_feats, + int height, + int width) { + struct ggml_cgraph* gf = new_graph_custom(ZIMAGE_GRAPH_SIZE); + + x = to_backend(x); + timestep = to_backend(timestep); + cap_feats = to_backend(cap_feats); + + int64_t C = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + int64_t B = x->ne[3]; + int64_t patch_size = model.get_patch_size(); + + int64_t H_patches = H / patch_size; + int64_t W_patches = W / patch_size; + int64_t x_seq_len = H_patches * W_patches; + int64_t cap_seq_len = cap_feats->ne[1]; + + auto x_patchified = ggml_reshape_4d(compute_ctx, x, + patch_size, W_patches, + patch_size, H_patches * C * B); + x_patchified = ggml_cont(compute_ctx, ggml_ext_torch_permute(compute_ctx, x_patchified, 0, 2, 1, 3)); + x_patchified = ggml_reshape_3d(compute_ctx, x_patchified, + patch_size * patch_size * C, + H_patches * W_patches, + B); + + cap_feats = ggml_reshape_3d(compute_ctx, cap_feats, cap_feats->ne[0], cap_feats->ne[1], B); + + bool use_rope = true; + + struct ggml_tensor* x_freqs_cis = nullptr; + struct ggml_tensor* cap_freqs_cis = nullptr; + struct ggml_tensor* unified_freqs_cis = nullptr; + + if (use_rope) { + auto axes_dims = model.get_axes_dims(); + int theta = model.get_theta(); + std::vector axes_dims_int(axes_dims.begin(), axes_dims.end()); + int emb_dim = 0; + for (int d : axes_dims_int) emb_dim += d / 2; + + int64_t cap_padded_len = pad_cap_len(cap_seq_len); + int img_axis0_offset = (int)(cap_padded_len + 1); + auto img_ids = gen_zimage_img_ids(H, W, patch_size, B, img_axis0_offset); + img_pe_vec = Rope::embed_nd(img_ids, B, theta, axes_dims_int); + + x_freqs_cis = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, emb_dim, x_seq_len); + set_backend_tensor_data(x_freqs_cis, img_pe_vec.data()); + + auto cap_ids = gen_zimage_cap_ids(cap_seq_len, B, 1); + cap_pe_vec = Rope::embed_nd(cap_ids, B, theta, axes_dims_int); + cap_freqs_cis = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, emb_dim, cap_seq_len); + set_backend_tensor_data(cap_freqs_cis, cap_pe_vec.data()); + + auto unified_ids = Rope::concat_ids(img_ids, cap_ids, B); + unified_pe_vec = Rope::embed_nd(unified_ids, B, theta, axes_dims_int); + int64_t unified_seq_len = x_seq_len + cap_seq_len; + unified_freqs_cis = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, emb_dim, unified_seq_len); + set_backend_tensor_data(unified_freqs_cis, unified_pe_vec.data()); + } + + auto runner_ctx = get_context(); + struct ggml_tensor* out = model.forward(&runner_ctx, x_patchified, timestep, cap_feats, + x_freqs_cis, cap_freqs_cis, unified_freqs_cis, + x_seq_len, cap_seq_len); + + out = ggml_reshape_4d(compute_ctx, out, patch_size, patch_size, C, H_patches * W_patches * B); + out = ggml_cont(compute_ctx, ggml_ext_torch_permute(compute_ctx, out, 0, 2, 1, 3)); + out = ggml_reshape_4d(compute_ctx, out, patch_size * W_patches, patch_size * H_patches, C, B); + + ggml_build_forward_expand(gf, out); + + return gf; + } + + void compute(const int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* cap_feats, + int height, + int width, + ggml_tensor** output, + ggml_context* output_ctx = nullptr) { + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t C = x->ne[2]; + int64_t B = x->ne[3]; + int64_t patch_size = model.get_patch_size(); + int64_t H_patches = H / patch_size; + int64_t W_patches = W / patch_size; + int64_t inner_dim = patch_size * patch_size * C; + int64_t seq_len = H_patches * W_patches; + + std::vector x_cpu(W * H * C * B); + if (x->buffer != nullptr) { + ggml_backend_tensor_get(x, x_cpu.data(), 0, x_cpu.size() * sizeof(float)); + } else if (x->data != nullptr) { + memcpy(x_cpu.data(), x->data, x_cpu.size() * sizeof(float)); + } else { + LOG_ERROR("ZImage compute: x tensor has no data!"); + return; + } + + std::vector x_patchified_data = cpu_patchify( + x_cpu.data(), W, H, C, B, patch_size); + + size_t temp_ctx_size = ggml_tensor_overhead() * 4 + 2 * inner_dim * seq_len * B * sizeof(float); + struct ggml_init_params temp_params = { + /*.mem_size =*/ temp_ctx_size, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ false, + }; + struct ggml_context* temp_ctx = ggml_init(temp_params); + + struct ggml_tensor* x_patchified = ggml_new_tensor_3d(temp_ctx, GGML_TYPE_F32, inner_dim, seq_len, B); + memcpy(x_patchified->data, x_patchified_data.data(), x_patchified_data.size() * sizeof(float)); + + struct ggml_tensor* patchified_output = ggml_new_tensor_3d(temp_ctx, GGML_TYPE_F32, inner_dim, seq_len, B); + + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph_patchified(x_patchified, timestep, cap_feats, height, width, H_patches, W_patches); + }; + bool free_compute_buffer_immediately = !flash_attn_enabled; + GGMLRunner::compute(get_graph, n_threads, free_compute_buffer_immediately, &patchified_output, temp_ctx); + + if (patchified_output != nullptr && patchified_output->data != nullptr) { + std::vector patchified_output_data(inner_dim * seq_len * B); + memcpy(patchified_output_data.data(), patchified_output->data, + patchified_output_data.size() * sizeof(float)); + + std::vector unpatchified_data = cpu_unpatchify( + patchified_output_data.data(), W, H, C, B, patch_size); + + if (output != nullptr && *output != nullptr) { + memcpy((*output)->data, unpatchified_data.data(), unpatchified_data.size() * sizeof(float)); + } + } + + ggml_free(temp_ctx); + } + + struct ggml_cgraph* build_graph_patchified(struct ggml_tensor* x_patchified, + struct ggml_tensor* timestep, + struct ggml_tensor* cap_feats, + int height, + int width, + int64_t H_patches, + int64_t W_patches) { + struct ggml_cgraph* gf = new_graph_custom(ZIMAGE_GRAPH_SIZE); + + int64_t patch_size = model.get_patch_size(); + int64_t C = model.get_in_channels(); + int64_t B = x_patchified->ne[2]; + int64_t x_seq_len = H_patches * W_patches; + int64_t cap_seq_len = cap_feats->ne[1]; + int64_t H = H_patches * patch_size; + int64_t W = W_patches * patch_size; + + x_patchified = to_backend(x_patchified); + timestep = to_backend(timestep); + cap_feats = to_backend(cap_feats); + + cap_feats = ggml_reshape_3d(compute_ctx, cap_feats, cap_feats->ne[0], cap_feats->ne[1], B); + + struct ggml_tensor* x_freqs_cis = nullptr; + struct ggml_tensor* cap_freqs_cis = nullptr; + struct ggml_tensor* unified_freqs_cis = nullptr; + + auto axes_dims = model.get_axes_dims(); + int theta = model.get_theta(); + std::vector axes_dims_int(axes_dims.begin(), axes_dims.end()); + int emb_dim = 0; + for (int d : axes_dims_int) emb_dim += d / 2; + + int64_t cap_padded_len = pad_cap_len(cap_seq_len); + int img_axis0_offset = (int)(cap_padded_len + 1); + + auto img_ids = gen_zimage_img_ids(H, W, patch_size, B, img_axis0_offset); + img_pe_vec = Rope::embed_nd(img_ids, B, theta, axes_dims_int); + x_freqs_cis = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, emb_dim, x_seq_len); + set_backend_tensor_data(x_freqs_cis, img_pe_vec.data()); + + auto cap_ids = gen_zimage_cap_ids(cap_seq_len, B, 1); + cap_pe_vec = Rope::embed_nd(cap_ids, B, theta, axes_dims_int); + cap_freqs_cis = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, emb_dim, cap_seq_len); + set_backend_tensor_data(cap_freqs_cis, cap_pe_vec.data()); + + auto unified_ids = Rope::concat_ids(img_ids, cap_ids, B); + unified_pe_vec = Rope::embed_nd(unified_ids, B, theta, axes_dims_int); + int64_t unified_seq_len = x_seq_len + cap_seq_len; + unified_freqs_cis = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, emb_dim, unified_seq_len); + set_backend_tensor_data(unified_freqs_cis, unified_pe_vec.data()); + + auto runner_ctx = get_context(); + struct ggml_tensor* out = model.forward(&runner_ctx, x_patchified, timestep, cap_feats, + x_freqs_cis, cap_freqs_cis, unified_freqs_cis, + x_seq_len, cap_seq_len); + + ggml_set_name(out, "ggml_runner_final_result_tensor"); + + ggml_build_forward_expand(gf, out); + + return gf; + } + }; + +}; + +#endif