Skip to content
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ API and command-line option may change frequently.***
- [Chroma](./docs/chroma.md)
- [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md)
- [Z-Image](./docs/z_image.md)
- Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
Expand Down Expand Up @@ -129,6 +130,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
- [🔥Qwen Image](./docs/qwen_image.md)
- [🔥Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
- [🔥Wan2.1/Wan2.2](./docs/wan.md)
- [🔥Z-Image](./docs/z_image.md)
- [LoRA](./docs/lora.md)
- [LCM/LCM-LoRA](./docs/lcm.md)
- [Using PhotoMaker to personalize image generation](./docs/photo_maker.md)
Expand Down
Binary file added assets/z_image/bf16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/z_image/q2_K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/z_image/q3_K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/z_image/q4_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/z_image/q4_K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/z_image/q5_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/z_image/q6_K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/z_image/q8_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
32 changes: 28 additions & 4 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,8 @@ struct LLMEmbedder : public Conditioner {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_z_image(version)) {
arch = LLM::LLMArch::QWEN3;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
tokenizer = std::make_shared<LLM::MistralTokenizer>();
Expand Down Expand Up @@ -1785,9 +1787,31 @@ struct LLMEmbedder : public Conditioner {
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
prompt += img_prompt;

prompt_attn_range.first = prompt.size();
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};

prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";

prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "[/INST]";
} else if (sd_version_is_z_image(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {35}; // -2

prompt = "<|im_start|>user\n";

prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
Expand All @@ -1806,9 +1830,9 @@ struct LLMEmbedder : public Conditioner {

prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";

prompt_attn_range.first = prompt.size();
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
Expand Down
64 changes: 64 additions & 0 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "qwen_image.hpp"
#include "unet.hpp"
#include "wan.hpp"
#include "z_image.hpp"

struct DiffusionParams {
struct ggml_tensor* x = nullptr;
Expand Down Expand Up @@ -357,4 +358,67 @@ struct QwenImageModel : public DiffusionModel {
}
};

struct ZImageModel : public DiffusionModel {
std::string prefix;
ZImage::ZImageRunner z_image;

ZImageModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_Z_IMAGE)
: prefix(prefix), z_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) {
}

std::string get_desc() override {
return z_image.get_desc();
}

void alloc_params_buffer() override {
z_image.alloc_params_buffer();
}

void free_params_buffer() override {
z_image.free_params_buffer();
}

void free_compute_buffer() override {
z_image.free_compute_buffer();
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
z_image.get_param_tensors(tensors, prefix);
}

size_t get_params_buffer_size() override {
return z_image.get_params_buffer_size();
}

void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
z_image.set_weight_adapter(adapter);
}

int64_t get_adm_in_channels() override {
return 768;
}

void set_flash_attn_enabled(bool enabled) {
z_image.set_flash_attention_enabled(enabled);
}

void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {
return z_image.compute(n_threads,
diffusion_params.x,
diffusion_params.timesteps,
diffusion_params.context,
diffusion_params.ref_latents,
true, // increase_ref_index
output,
output_ctx);
}
};

#endif
28 changes: 28 additions & 0 deletions docs/z_image.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# How to Use

You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or even less.

## Download weights

- Download Z-Image-Turbo
- safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.1-schnell/tree/main
- Download Qwen3 4b
- safetensors: https://huggingface.co/Comfy-Org/z_image_turbo/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/unsloth/Qwen3-4B-Instruct-2507-GGUF/tree/main

## Examples

```
.\bin\Release\sd.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
```

<img width="256" alt="z-image example" src="../assets/z_image/q3_K.png" />

## Comparison of Different Quantization Types

| bf16 | q8_0 | q6_K | q5_0 | q4_K | q4_0 | q3_K | q2_K|
|---|---|---|---|---|---|---|---|
| <img width="256" alt="bf16" src="../assets/z_image/bf16.png" /> | <img width="256" alt="q8_0" src="../assets/z_image/q8_0.png" /> | <img width="256" alt="q6_K" src="../assets/z_image/q6_K.png" /> | <img width="256" alt="q5_0" src="../assets/z_image/q5_0.png" /> | <img width="256" alt="q4_K" src="../assets/z_image/q4_K.png" /> | <img width="256" alt="q4_0" src="../assets/z_image/q4_0.png" /> | <img width="256" alt="q3_K" src="../assets/z_image/q3_K.png" /> | <img width="256" alt="q2_K" src="../assets/z_image/q2_K.png" /> |
84 changes: 69 additions & 15 deletions llm.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef __QWENVL_HPP__
#define __QWENVL_HPP__
#ifndef __LLM_HPP__
#define __LLM_HPP__

#include <algorithm>
#include <fstream>
Expand Down Expand Up @@ -256,7 +256,7 @@ namespace LLM {
ss << "\"" << token << "\", ";
}
ss << "]";
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
return bpe_tokens;
}
Expand Down Expand Up @@ -469,12 +469,14 @@ namespace LLM {

enum class LLMArch {
QWEN2_5_VL,
QWEN3,
MISTRAL_SMALL_3_2,
ARCH_COUNT,
};

static const char* llm_arch_to_str[] = {
"qwen2.5vl",
"qwen3",
"mistral_small3.2",
};

Expand All @@ -501,6 +503,7 @@ namespace LLM {
int64_t num_kv_heads = 4;
int64_t head_dim = 128;
bool qkv_bias = true;
bool qk_norm = false;
int64_t vocab_size = 152064;
float rms_norm_eps = 1e-06f;
LLMVisionParams vision;
Expand Down Expand Up @@ -813,14 +816,19 @@ namespace LLM {
int64_t head_dim;
int64_t num_heads;
int64_t num_kv_heads;
bool qk_norm;

public:
Attention(const LLMParams& params)
: num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), arch(params.arch) {
: arch(params.arch), num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), qk_norm(params.qk_norm) {
blocks["q_proj"] = std::make_shared<Linear>(params.hidden_size, num_heads * head_dim, params.qkv_bias);
blocks["k_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
blocks["v_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
blocks["o_proj"] = std::make_shared<Linear>(num_heads * head_dim, params.hidden_size, false);
if (params.qk_norm) {
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim, params.rms_norm_eps);
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim, params.rms_norm_eps);
}
}

struct ggml_tensor* forward(GGMLRunnerContext* ctx,
Expand All @@ -842,9 +850,20 @@ namespace LLM {
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]

if (qk_norm) {
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
auto k_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm"]);

q = q_norm->forward(ctx, q);
k = k_norm->forward(ctx, k);
}

if (arch == LLMArch::MISTRAL_SMALL_3_2) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else if (arch == LLMArch::QWEN3) {
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NEOX, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
} else {
int sections[4] = {16, 24, 24, 0};
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
Expand Down Expand Up @@ -1063,6 +1082,17 @@ namespace LLM {
params.qkv_bias = false;
params.vocab_size = 131072;
params.rms_norm_eps = 1e-5f;
} else if (arch == LLMArch::QWEN3) {
params.num_layers = 36;
params.hidden_size = 2560;
params.intermediate_size = 9728;
params.head_dim = 128;
params.num_heads = 32;
params.num_kv_heads = 8;
params.qkv_bias = false;
params.qk_norm = true;
params.vocab_size = 151936;
params.rms_norm_eps = 1e-6f;
}
bool have_vision_weight = false;
bool llama_cpp_style = false;
Expand Down Expand Up @@ -1132,7 +1162,7 @@ namespace LLM {
}

int64_t n_tokens = input_ids->ne[0];
if (params.arch == LLMArch::MISTRAL_SMALL_3_2) {
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::QWEN3) {
input_pos_vec.resize(n_tokens);
for (int i = 0; i < n_tokens; ++i) {
input_pos_vec[i] = i;
Expand Down Expand Up @@ -1420,7 +1450,8 @@ namespace LLM {

struct ggml_context* work_ctx = ggml_init(params);
GGML_ASSERT(work_ctx != nullptr);
bool test_mistral = true;
bool test_mistral = false;
bool test_qwen3 = true;
bool test_vit = false;
bool test_decoder_with_vit = false;

Expand Down Expand Up @@ -1455,9 +1486,9 @@ namespace LLM {
std::pair<int, int> prompt_attn_range;
std::string text = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
text += img_prompt;
prompt_attn_range.first = text.size();
prompt_attn_range.first = static_cast<int>(text.size());
text += "change 'flux.cpp' to 'edit.cpp'";
prompt_attn_range.second = text.size();
prompt_attn_range.second = static_cast<int>(text.size());
text += "<|im_end|>\n<|im_start|>assistant\n";

auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false);
Expand Down Expand Up @@ -1496,9 +1527,9 @@ namespace LLM {
} else if (test_mistral) {
std::pair<int, int> prompt_attn_range;
std::string text = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
prompt_attn_range.first = text.size();
prompt_attn_range.first = static_cast<int>(text.size());
text += "a lovely cat";
prompt_attn_range.second = text.size();
prompt_attn_range.second = static_cast<int>(text.size());
text += "[/INST]";
auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false);
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
Expand All @@ -1514,14 +1545,37 @@ namespace LLM {
model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx);
int t1 = ggml_time_ms();

print_ggml_tensor(out);
LOG_DEBUG("llm test done in %dms", t1 - t0);
} else if (test_qwen3) {
std::pair<int, int> prompt_attn_range;
std::string text = "<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(text.size());
text += "a lovely cat";
prompt_attn_range.second = static_cast<int>(text.size());
text += "<|im_end|>\n<|im_start|>assistant\n";
auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false);
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
std::vector<float>& weights = std::get<1>(tokens_and_weights);
for (auto token : tokens) {
printf("%d ", token);
}
printf("\n");
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
struct ggml_tensor* out = nullptr;

int t0 = ggml_time_ms();
model.compute(8, input_ids, {}, {35}, &out, work_ctx);
int t1 = ggml_time_ms();

print_ggml_tensor(out);
LOG_DEBUG("llm test done in %dms", t1 - t0);
} else {
std::pair<int, int> prompt_attn_range;
std::string text = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = text.size();
prompt_attn_range.first = static_cast<int>(text.size());
text += "a lovely cat";
prompt_attn_range.second = text.size();
prompt_attn_range.second = static_cast<int>(text.size());
text += "<|im_end|>\n<|im_start|>assistant\n";
auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false);
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
Expand Down Expand Up @@ -1563,7 +1617,7 @@ namespace LLM {
}
}

LLMArch arch = LLMArch::MISTRAL_SMALL_3_2;
LLMArch arch = LLMArch::QWEN3;

std::shared_ptr<LLMEmbedder> llm = std::make_shared<LLMEmbedder>(arch,
backend,
Expand All @@ -1587,6 +1641,6 @@ namespace LLM {
llm->test();
}
};
}; // Qwen
}; // LLM

#endif // __QWENVL_HPP__
#endif // __LLM_HPP__
8 changes: 6 additions & 2 deletions mmdit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,14 @@ struct TimestepEmbedder : public GGMLBlock {

public:
TimestepEmbedder(int64_t hidden_size,
int64_t frequency_embedding_size = 256)
int64_t frequency_embedding_size = 256,
int64_t out_channels = 0)
: frequency_embedding_size(frequency_embedding_size) {
if (out_channels <= 0) {
out_channels = hidden_size;
}
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(frequency_embedding_size, hidden_size, true, true));
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, out_channels, true, true));
}

struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* t) {
Expand Down
Loading
Loading