Skip to content

Conversation

@leejet
Copy link
Owner

@leejet leejet commented Nov 29, 2025

.\bin\Release\sd.exe --diffusion-model  ..\..\ComfyUI\models\diffusion_models\z_image_turbo-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft  --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p 'a lovely cat' --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa
output

@leejet leejet changed the base branch from master to flux2 November 29, 2025 19:15
@wbruna
Copy link
Contributor

wbruna commented Nov 29, 2025

There seems to be an issue with the text rendering (or possibly long prompts). Comparing with rmatif's implementation from #1018 :

sd --diffusion-model z_image_turbo-Q8_0.gguf --llm qwen_3_4b-Q8_0.gguf -W 768 -H 1024 --vae ae.safetensors --vae-conv-direct --sampling-method euler --scheduler smoothstep --steps 8 --cfg-scale 1 -p 'A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: "THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR." -- moody, atmospheric, profound, dark academic' -s 42

7120eb7 #1018
zimg_1764446548 zimg_1764446609

(Vulkan+radv)

@leejet
Copy link
Owner Author

leejet commented Nov 29, 2025

Fixed.

 .\bin\Release\sd.exe --diffusion-model  ..\..\ComfyUI\models\diffusion_models\z_image_turbo-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft  --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu -H 1024 -W 512 --rng cpu --steps 9 --rng cpu --seed 1061061743296960 --scheduler simple
output

@stduhpf
Copy link
Contributor

stduhpf commented Nov 29, 2025

To support llama.cpp Qwen3 quants:

diff --git a/name_conversion.cpp b/name_conversion.cpp
index c4670df..b26e1a2 100644
--- a/name_conversion.cpp
+++ b/name_conversion.cpp
@@ -133,6 +133,8 @@ std::string convert_cond_stage_model_name(std::string name, std::string prefix)
         {"attn_q.", "self_attn.q_proj."},
         {"attn_k.", "self_attn.k_proj."},
         {"attn_v.", "self_attn.v_proj."},
+        {"attn_q_norm.", "self_attn.q_norm."},
+        {"attn_k_norm.", "self_attn.k_norm."},
         {"attn_output.", "self_attn.o_proj."},
         {"attn_norm.", "input_layernorm."},
         {"ffn_down.", "mlp.down_proj."},

@leejet
Copy link
Owner Author

leejet commented Nov 29, 2025

I’ll update it later, including compatibility with different kinds of LoRA models.

@wbruna
Copy link
Contributor

wbruna commented Nov 29, 2025

I'm getting fully black images on ROCm (Linux), even on the first image preview. My card:

[DEBUG] stable-diffusion.cpp:159  - Using CUDA backend
[INFO ] ggml_extend.hpp:69   - ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
[INFO ] ggml_extend.hpp:69   - ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
[INFO ] ggml_extend.hpp:69   - ggml_cuda_init: found 1 ROCm devices:
[INFO ] ggml_extend.hpp:69   -   Device 0: AMD Radeon RX 7600 XT, gfx1102 (0x1102), VMM: no, Wave Size: 32

@Green-Sky
Copy link
Contributor

Same with the other pr, in most cases I get black images after the second step. (cuda rtx2070)

@dsignarius
Copy link

z-image branch

DEBUG] stable-diffusion.cpp:167 - Using Vulkan backend
[DEBUG] ggml_extend.hpp:66 - ggml_vulkan: Found 1 Vulkan devices:
[DEBUG] ggml_extend.hpp:66 - ggml_vulkan: 0 = AMD RADV RENOIR (ACO) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 32768 | int dot: 0 | matrix cores: none

sd --diffusion-model z_image_turbo_bf16.safetensors --vae ae.safetensors --llm qwen_3_4b.safetensors --cfg-scale 1.0 --sampling-method euler --steps 8 --diffusion-fa -W 512 -H 512 -p 'a red apple on a wooden table. Text across bottom reads : sdcpp z-image' --seed 10 --type q8_0

keeper_153

@leejet
Copy link
Owner Author

leejet commented Nov 30, 2025

Same with the other pr, in most cases I get black images after the second step. (cuda rtx2070)

@Green-Sky Could you share the cmd?

@netrunnereve
Copy link

This seems to work fine on Vulkan using a q4_K model and a q8_0 Qwen with this prompt and others 😃.

output

@leejet leejet changed the base branch from flux2 to master November 30, 2025 04:28
@whoreson
Copy link

Works nicely on an MI50 with ROCm. Thanks!

@Green-Sky
Copy link
Contributor

Same with the other pr, in most cases I get black images after the second step. (cuda rtx2070)

@Green-Sky Could you share the cmd?

$ result/bin/sd --diffusion-model models/z-image/z_image_turbo-Q3_K_M.gguf --vae models/flux-extra/ae.safetensors --llm models/z-image/qwen_3_4b.safetensors --cfg-scale 1 --steps 8 -W 1024 -H 1024 --diffusion-fa --offload-to-cpu --clip-on-cpu --vae-tiling --vae-tile-size 64 -v --preview proj -p "a lovely cat"

Also black at step3 with lower resolution, without flash attention.

That specific quant mix is from https://huggingface.co/jayn7/Z-Image-Turbo-GGUF/blob/main/z_image_turbo-Q3_K_M.gguf.


q5_k/q4_k/q5_1 1024x1024, 768x768 starts off full black.
q5_k/q4_k/q5_1 512x512, 768x512 blackens at step3.

👉 But it works with q8_0 and q6_k.

I only tested up to 1024x1024


[INFO ] ggml_extend.hpp:69   - ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
[INFO ] ggml_extend.hpp:69   - ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
[INFO ] ggml_extend.hpp:69   - ggml_cuda_init: found 1 CUDA devices:
[INFO ] ggml_extend.hpp:69   -   Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5, VMM: yes

Driver Version: 575.51.02

@Green-Sky
Copy link
Contributor

With smoothstep and some mid level loss in quality you can go down to 4steps.

output

(q8_0 diffusion model)

@wbruna
Copy link
Contributor

wbruna commented Nov 30, 2025

I uploaded a few quantized models to https://huggingface.co/wbruna/Z-Image-Turbo-sdcpp-GGUF/tree/main , to make testing the lower quants a bit easier. Even legacy q4 can work very well: this is with 5 steps, q8_0 llm, euler+sgm_uniform, hitting ~4G peak VRAM with Vulkan + --offload-to-cpu:

q4_0 q4_1
test_1764505418 test_1764505598

@leejet
Copy link
Owner Author

leejet commented Nov 30, 2025

LoRA support

.\bin\Release\sd.exe --diffusion-model  ..\..\ComfyUI\models\diffusion_models\z_image_turbo-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft  --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "l3n0v0, raw unedited amateurish candid photo. amateur photo with bad quality, selfie, casual style. A young sweden woman with long hair and a white tank top holds up her old iphone to capture the image; there is a photo of woman on phone screen, she is 20 years old<lora:z-image/lenovo_z:1>" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 768 --lora-model-dir ..\..\ComfyUI\models\loras
without lora with lora
output output

@leejet
Copy link
Owner Author

leejet commented Nov 30, 2025

Same with the other pr, in most cases I get black images after the second step. (cuda rtx2070)

@Green-Sky Could you share the cmd?

$ result/bin/sd --diffusion-model models/z-image/z_image_turbo-Q3_K_M.gguf --vae models/flux-extra/ae.safetensors --llm models/z-image/qwen_3_4b.safetensors --cfg-scale 1 --steps 8 -W 1024 -H 1024 --diffusion-fa --offload-to-cpu --clip-on-cpu --vae-tiling --vae-tile-size 64 -v --preview proj -p "a lovely cat"

Also black at step3 with lower resolution, without flash attention.

That specific quant mix is from https://huggingface.co/jayn7/Z-Image-Turbo-GGUF/blob/main/z_image_turbo-Q3_K_M.gguf.

q5_k/q4_k/q5_1 1024x1024, 768x768 starts off full black. q5_k/q4_k/q5_1 512x512, 768x512 blackens at step3.

👉 But it works with q8_0 and q6_k.

I only tested up to 1024x1024

[INFO ] ggml_extend.hpp:69   - ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
[INFO ] ggml_extend.hpp:69   - ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
[INFO ] ggml_extend.hpp:69   - ggml_cuda_init: found 1 CUDA devices:
[INFO ] ggml_extend.hpp:69   -   Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5, VMM: yes

Driver Version: 575.51.02

@Green-Sky This should be fixed. Could you try again?

.\bin\Release\sd.exe --diffusion-model  ..\..\ComfyUI\models\diffusion_models\z_image_turbo-Q3_K_M.gguf --vae ..\..\ComfyUI\models\vae\ae.sft  
--llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "a lovly cat" --cfg-scale 1.0 -v
output

@leejet
Copy link
Owner Author

leejet commented Nov 30, 2025

@wbruna Fixed.

@leejet
Copy link
Owner Author

leejet commented Dec 1, 2025

Comparison of Different Quantization Types

GGUF: https://huggingface.co/leejet/Z-Image-Turbo-GGUF/tree/main

.\bin\Release\sd.exe --diffusion-model  z_image_turbo-Q3_K.gguf --vae ae.safetensors  --llm Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
bf16 q8_0 q6_K q5_0 q4_K q4_0 q3_K q2_K
bf16 q8_0 q6_K q5_0 q4_K q4_0 q3_K q2_K

@leejet
Copy link
Owner Author

leejet commented Dec 1, 2025

I think this PR can be merged now. Thank you everyone!

@leejet leejet merged commit 34a6fd4 into master Dec 1, 2025
9 checks passed
@kurnevsky
Copy link

It seems it has the same problem for ROCm as described in #968

@leejet
Copy link
Owner Author

leejet commented Dec 2, 2025

Since I don’t have a ROCm device to reproduce the issue, it’s difficult for me to locate and fix the problem.

@wbruna
Copy link
Contributor

wbruna commented Dec 2, 2025

@leejet , if you could guide me through it, I could try to debug it on my card.

I tried to run the ZImage::ZImageRunner::load_from_file_and_test function after setting the input tensors to 0.01 on ZImage::ZImageRunner::test, and got only NaNs in the output:

[DEBUG] ggml_extend.hpp:1688 - z_image compute buffer size: 47.83 MB(VRAM)
 (f32): shape(16, 16, 16, 1)
  [0, 0, 0, 0] = nan
  [0, 0, 0, 1] = nan
  [0, 0, 0, 2] = nan
  [0, 0, 0, 13] = nan
  [0, 0, 0, 14] = nan
  [0, 0, 0, 15] = nan
  [0, 0, 1, 0] = nan
  [0, 0, 1, 1] = nan
(...)

@whoreson
Copy link

whoreson commented Dec 2, 2025

@leejet I can setup SSH access to my ROCm PC for you, just drop me an e-mail or something.

@linuzel
Copy link

linuzel commented Dec 2, 2025

I can confirm that the issue from #968 is also happening with this model.

I also don't know how to help troubleshooting.

The nan issue is interesting because it is why I switched from python/pytorch, etc. to this to run Qwen, I kept running into this issue where all values where nan.

@mmlingyu
Copy link

mmlingyu commented Dec 2, 2025

./build/bin/sd --diffusion-model models/z_image_turbo-Q4_K_M.gguf --vae models/vae/diffusion_pytorch_model.safetensors --llm models/Qwen3-4B-Q8_0.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v -H 1024 -W 512 --steps 9 --seed 1061061743296960 --scheduler simple [DEBUG] ggml_extend.hpp:1688 - z_image compute buffer size: 704.48 MB(RAM)
|==================================================| 9/9 - 212.12s/it
[INFO ] stable-diffusion.cpp:3069 - sampling completed, taking 1909.39s
[INFO ] stable-diffusion.cpp:3077 - generating 1 latent images completed, taking 1910.28s
[INFO ] stable-diffusion.cpp:3080 - decoding 1 latents
[DEBUG] ggml_extend.hpp:1688 - vae compute buffer size: 3328.00 MB(RAM)
[DEBUG] stable-diffusion.cpp:2286 - computing vae decode graph completed, taking 79.63s
[INFO ] stable-diffusion.cpp:3090 - latent 1 decoded, taking 79.63s
[INFO ] stable-diffusion.cpp:3094 - decode_first_stage completed, taking 79.63s
[INFO ] stable-diffusion.cpp:3390 - generate_image completed in 1996.30s 4090上为什么要40分钟 ???

@stduhpf
Copy link
Contributor

stduhpf commented Dec 2, 2025

4090上为什么要40分钟 ???

z_image compute buffer size: 704.48 MB(RAM)

The model is in RAM, not VRAM, which indicates that the program is running on your CPU, not on the 4090. You should use the CUDA or Vulkan builds if you want to take advantage of the GPU.

@leejet
Copy link
Owner Author

leejet commented Dec 2, 2025

@leejet , if you could guide me through it, I could try to debug it on my card.

I tried to run the ZImage::ZImageRunner::load_from_file_and_test function after setting the input tensors to 0.01 on ZImage::ZImageRunner::test, and got only NaNs in the output:

[DEBUG] ggml_extend.hpp:1688 - z_image compute buffer size: 47.83 MB(VRAM)
 (f32): shape(16, 16, 16, 1)
  [0, 0, 0, 0] = nan
  [0, 0, 0, 1] = nan
  [0, 0, 0, 2] = nan
  [0, 0, 0, 13] = nan
  [0, 0, 0, 14] = nan
  [0, 0, 0, 15] = nan
  [0, 0, 1, 0] = nan
  [0, 0, 1, 1] = nan
(...)

@wbruna You can try changing the default value of the scale parameter in the Linear layer from 1.f to 1.f/256.f first, then recompile and check whether the issue is related to numerical precision. If it is, you can use a binary-search style approach to gradually identify which specific Linear layer requires the scale adjustment.

If the issue is not caused by precision in the Linear layers, then you may need to print out the outputs layer by layer to determine which operation is producing the nan values.

@leejet
Copy link
Owner Author

leejet commented Dec 2, 2025

@leejet I can setup SSH access to my ROCm PC for you, just drop me an e-mail or something.

@linuzel Thanks! I’ll reach out if I need your help.

@shakfu
Copy link

shakfu commented Dec 2, 2025

This worked for me:

sd \
	--diffusion-model models/z_image_turbo-Q6_K.gguf \
	--vae ae.safetensors \
	--llm models/Qwen3-4B-Q8_0.gguf \
	--cfg-scale 1.0 -v \
	--offload-to-cpu \
	--diffusion-fa \
	-H 1024 -W 512 \
	-p "a lovely plump cat"

But it took ages... I'm on an M1 macbook air (16GB memory)

Is there any way to make this quicker?

output

@wbruna
Copy link
Contributor

wbruna commented Dec 2, 2025

@leejet , found it!

diff --git a/z_image.hpp b/z_image.hpp
index b692a14..5d07f6f 100644
--- a/z_image.hpp
+++ b/z_image.hpp
@@ -30,7 +30,7 @@ namespace ZImage {
         JointAttention(int64_t hidden_size, int64_t head_dim, int64_t num_heads, int64_t num_kv_heads, bool qk_norm)
             : head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) {
             blocks["qkv"] = std::make_shared<Linear>(hidden_size, (num_heads + num_kv_heads * 2) * head_dim, false);
-            blocks["out"] = std::make_shared<Linear>(num_heads * head_dim, hidden_size, false);
+            blocks["out"] = std::make_shared<Linear>(num_heads * head_dim, hidden_size, false, false, false, 1.f / 8.f);
             if (qk_norm) {
                 blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim);
                 blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim);

That 1/8 was enough to generate images on my card. Setting force_prec_f32 worked too, while being around 30% slower.

I could make a PR, but I'm not sure how to enable it only when needed?

@stduhpf
Copy link
Contributor

stduhpf commented Dec 2, 2025

@shakfu Are you sure you have Metal enabled? If you run it with v (verbose) flag, you can see at the beginning of the log if it's using "Metal" backend or "CPU" backend. CPU is generally much slower.

@wbruna
Copy link
Contributor

wbruna commented Dec 3, 2025

People getting black images on ROCm, especially right at the beginning: please give #1034 a try.

@shakfu
Copy link

shakfu commented Dec 3, 2025

@stduhpf Thanks for your help. I'll run it again and check if metal is enabled.

@RealHacker
Copy link

RealHacker commented Dec 5, 2025

Why am I getting this error:

[INFO ] ggml_extend.hpp:1791 - qwen3 offload params (7672.62 MB, 398 tensors) to runtime backend (CUDA0), taking 1.46s
[DEBUG] ggml_extend.hpp:1691 - qwen3 compute buffer size: 1.20 MB(VRAM)
[ERROR] ggml_extend.hpp:75   - ggml_cuda_compute_forward: GET_ROWS failed
[ERROR] ggml_extend.hpp:75   - CUDA error: no kernel image is available for execution on the device
[ERROR] ggml_extend.hpp:75   -   current device: 0, in function ggml_cuda_compute_forward at D:\a\stable-diffusion.cpp\stable-diffusion.cpp\ggml\src\ggml-cuda\ggml-cuda.cu:2540
[ERROR] ggml_extend.hpp:75   -   err

My command line:
.\sd.exe --diffusion-model .\z_image_turbo-Q8_0.gguf --vae .\ae.safetensors --llm .\qwen_3_4b.safetensors -p 'a lovely cat' --cfg-scale 1.0 -v --offload-to-cpu --scheduler simple --steps 8 --diffusion-fa --verbose -o test.png

My device is Nvidia GTX 3060 with 12GB of vmem, OS Windows 11.

@popters
Copy link

popters commented Dec 7, 2025

.\build\bin\sd.exe --diffusion-model models/z-image/z_image_turbo-Q8_0.gguf --vae models/flux-extra/diffusion_pytorch_model.safetensors --llm models/z-image/qwen_3_4b.safetensors --preview proj -p "girl"
我生成的图片怎么是黑色的,纯黑,没有任何图像,谢谢

@leejet
Copy link
Owner Author

leejet commented Dec 8, 2025

Why am I getting this error:

[INFO ] ggml_extend.hpp:1791 - qwen3 offload params (7672.62 MB, 398 tensors) to runtime backend (CUDA0), taking 1.46s
[DEBUG] ggml_extend.hpp:1691 - qwen3 compute buffer size: 1.20 MB(VRAM)
[ERROR] ggml_extend.hpp:75   - ggml_cuda_compute_forward: GET_ROWS failed
[ERROR] ggml_extend.hpp:75   - CUDA error: no kernel image is available for execution on the device
[ERROR] ggml_extend.hpp:75   -   current device: 0, in function ggml_cuda_compute_forward at D:\a\stable-diffusion.cpp\stable-diffusion.cpp\ggml\src\ggml-cuda\ggml-cuda.cu:2540
[ERROR] ggml_extend.hpp:75   -   err

My command line: .\sd.exe --diffusion-model .\z_image_turbo-Q8_0.gguf --vae .\ae.safetensors --llm .\qwen_3_4b.safetensors -p 'a lovely cat' --cfg-scale 1.0 -v --offload-to-cpu --scheduler simple --steps 8 --diffusion-fa --verbose -o test.png

My device is Nvidia GTX 3060 with 12GB of vmem, OS Windows 11.

@RealHacker I believe the latest release has already addressed this issue. For details, please refer to this PR: #1062. You can give it another try.

@leejet
Copy link
Owner Author

leejet commented Dec 8, 2025

.\build\bin\sd.exe --diffusion-model models/z-image/z_image_turbo-Q8_0.gguf --vae models/flux-extra/diffusion_pytorch_model.safetensors --llm models/z-image/qwen_3_4b.safetensors --preview proj -p "girl" 我生成的图片怎么是黑色的,纯黑,没有任何图像,谢谢

@popters Please create an issue describing your environment, such as the commit you are using, the backend, etc., or search for existing issues and provide the relevant information.

@popters
Copy link

popters commented Dec 8, 2025

.\build\bin\sd.exe --diffusion-model models/z-image/z_image_turbo-Q8_0.gguf --vae models/flux-extra/diffusion_pytorch_model.safetensors --llm models/z-image/qwen_3_4b.safetensors --preview proj -p "girl" 我生成的图片怎么是黑色的,纯黑,没有任何图像,谢谢

@popters Please create an issue describing your environment, such as the commit you are using, the backend, etc., or search for existing issues and provide the relevant information.

你好,我用的是这种
(mkdir build && cd build
cmake .. -DSD_CUDA=ON
cmake --build . --config Release)方式编译的sd,生成的图片就是黑色,
但是用这种
(mkdir build && cd build
cmake ..
cmake --build . --config Release)
就可以正常生成,CPU可用GPU不行
这是用CPU编译后执行的日志信息:
(Q:\stable-diffusion.cpp>.\build\bin\Release\sd.exe --diffusion-model models/z-image/z_image_turbo-Q3_K.gguf --vae models/extra/ae.safetensors --llm models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf --cfg-scale 1 --steps 8 -W 512 -H 512 --vae-tile-size 64 -v --preview proj -p "a cat"
System Info:
SSE3 = 1 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | VSX = 0 | SDCliParams {
mode: img_gen,
output_path: "output.png",
verbose: true,
color: false,
canny_preprocess: false,
preview_method: proj,
preview_interval: 1,
preview_path: "preview.png",
preview_fps: 4,
taesd_preview: false,
preview_noisy: false
}
SDContextParams {
n_threads: 6,
model_path: "",
clip_l_path: "",
clip_g_path: "",
clip_vision_path: "",
t5xxl_path: "",
llm_path: "models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf",
llm_vision_path: "",
diffusion_model_path: "models/z-image/z_image_turbo-Q3_K.gguf",
high_noise_diffusion_model_path: "",
vae_path: "models/extra/ae.safetensors",
taesd_path: "",
esrgan_path: "",
control_net_path: "",
embedding_dir: "",
embeddings: {
}
wtype: NONE,
tensor_type_rules: "",
lora_model_dir: "",
photo_maker_path: "",
rng_type: cuda,
sampler_rng_type: NONE,
flow_shift: INF
offload_params_to_cpu: false,
control_net_cpu: false,
clip_on_cpu: false,
vae_on_cpu: false,
diffusion_flash_attn: false,
diffusion_conv_direct: false,
vae_conv_direct: false,
chroma_use_dit_mask: true,
chroma_use_t5_mask: false,
chroma_t5_mask_pad: 1,
prediction: NONE,
lora_apply_mode: auto,
vae_tiling_params: { 0, 64, 64, 0.5, 0, 0 },
force_sdxl_vae_conv_scale: false
}
SDGenerationParams {
prompt: "a cat",
negative_prompt: "",
clip_skip: -1,
width: 512,
height: 512,
batch_count: 1,
init_image_path: "",
end_image_path: "",
mask_image_path: "",
control_image_path: "",
ref_image_paths: [],
control_video_path: "",
auto_resize_ref_image: true,
increase_ref_index: false,
pm_id_images_dir: "",
pm_id_embed_path: "",
pm_style_strength: 20,
skip_layers: [7, 8, 9],
sample_params: (txt_cfg: 1.00, img_cfg: 1.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: NONE, sample_method: NONE, sample_steps: 8, eta: 0.00, shifted_timestep: 0),
high_noise_skip_layers: [7, 8, 9],
high_noise_sample_params: (txt_cfg: 7.00, img_cfg: 7.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: NONE, sample_method: NONE, sample_steps: 20, eta: 0.00, shifted_timestep: 0),
easycache_option: "",
easycache: disabled (threshold=0, start=1.4013e-45, end=0),
moe_boundary: 0.875,
video_frames: 1,
fps: 16,
vace_strength: 1,
strength: 0.75,
control_strength: 0.9,
seed: 42,
upscale_repeats: 1,
}
[DEBUG] stable-diffusion.cpp:190 - Using CPU backend
[INFO ] stable-diffusion.cpp:235 - loading diffusion model from 'models/z-image/z_image_turbo-Q3_K.gguf'
[INFO ] model.cpp:370 - load models/z-image/z_image_turbo-Q3_K.gguf using gguf format
[DEBUG] model.cpp:412 - init from 'models/z-image/z_image_turbo-Q3_K.gguf'
[INFO ] stable-diffusion.cpp:282 - loading llm from 'models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf'
[INFO ] model.cpp:370 - load models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf using gguf format
[DEBUG] model.cpp:412 - init from 'models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf'
[INFO ] stable-diffusion.cpp:296 - loading vae from 'models/extra/ae.safetensors'
[INFO ] model.cpp:373 - load models/extra/ae.safetensors using safetensors format
[DEBUG] model.cpp:503 - init from 'models/extra/ae.safetensors', prefix = 'vae.'
[INFO ] stable-diffusion.cpp:312 - Version: Z-Image
[INFO ] stable-diffusion.cpp:340 - Weight type stat: f32: 640 | q8_0: 22 | q3_K: 180 | q4_K: 216 | q6_K: 37
[INFO ] stable-diffusion.cpp:341 - Conditioner weight type stat: f32: 145 | q4_K: 216 | q6_K: 37
[INFO ] stable-diffusion.cpp:342 - Diffusion model weight type stat: f32: 251 | q8_0: 22 | q3_K: 180
[INFO ] stable-diffusion.cpp:343 - VAE weight type stat: f32: 244
[DEBUG] stable-diffusion.cpp:345 - ggml tensor size = 400 bytes
[DEBUG] llm.hpp:285 - merges size 151387
[DEBUG] llm.hpp:317 - vocab size: 151669
[DEBUG] llm.hpp:1134 - llm: num_layers = 36, vocab_size = 151936, hidden_size = 2560, intermediate_size = 9728
[DEBUG] ggml_extend.hpp:1883 - qwen3 params backend buffer size = 3555.38 MB(RAM) (398 tensors)
[DEBUG] ggml_extend.hpp:1883 - z_image params backend buffer size = 2997.90 MB(RAM) (453 tensors)
[DEBUG] ggml_extend.hpp:1883 - vae params backend buffer size = 94.57 MB(RAM) (138 tensors)
[DEBUG] stable-diffusion.cpp:688 - loading weights
[DEBUG] model.cpp:1351 - using 6 threads for model loading
[DEBUG] model.cpp:1373 - loading tensors from models/z-image/z_image_turbo-Q3_K.gguf
|====================> | 453/1095 - 728.30it/s�[K
[DEBUG] model.cpp:1373 - loading tensors from models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf
|======================================> | 851/1095 - 520.81it/s�[K
[DEBUG] model.cpp:1373 - loading tensors from models/extra/ae.safetensors
|==================================================| 1095/1095 - 596.41it/s�[K
[INFO ] model.cpp:1580 - loading tensors completed, taking 1.84s (process: 0.00s, read: 1.11s, memcpy: 0.00s, convert: 0.13s, copy_to_backend: 0.00s)
[DEBUG] stable-diffusion.cpp:720 - finished loaded file
[INFO ] stable-diffusion.cpp:792 - total params memory size = 6647.85MB (VRAM 0.00MB, RAM 6647.85MB): text_encoders 3555.38MB(RAM), diffusion_model 2997.90MB(RAM), vae 94.57MB(RAM), controlnet 0.00MB(VRAM), pmid 0.00MB(RAM)
[INFO ] stable-diffusion.cpp:860 - running in FLOW mode
[DEBUG] stable-diffusion.cpp:3154 - generate_image 512x512
[INFO ] stable-diffusion.cpp:3185 - sampling using Euler method
[INFO ] denoiser.hpp:364 - get_sigmas with discrete scheduler
[INFO ] stable-diffusion.cpp:3298 - TXT2IMG
[DEBUG] conditioner.hpp:1671 - parse '<|im_start|>user
a cat<|im_end|>
<|im_start|>assistant
' to [['<|im_start|>user
', 1], ['a cat', 1], ['<|im_end|>
<|im_start|>assistant
', 1], ]
[DEBUG] llm.hpp:259 - split prompt "<|im_start|>user
" to tokens ["<|im_start|>", "user", "Ċ", ]
[DEBUG] llm.hpp:259 - split prompt "a cat" to tokens ["a", "Ġcat", ]
[DEBUG] llm.hpp:259 - split prompt "<|im_end|>
<|im_start|>assistant
" to tokens ["<|im_end|>", "Ċ", "<|im_start|>", "assistant", "Ċ", ]
[DEBUG] ggml_extend.hpp:1697 - qwen3 compute buffer size: 1.09 MB(RAM)
[DEBUG] conditioner.hpp:1884 - computing condition graph completed, taking 265 ms
[INFO ] stable-diffusion.cpp:2929 - get_learned_condition completed, taking 266 ms
[INFO ] stable-diffusion.cpp:3040 - generating image: 1/1 - seed 42
[DEBUG] ggml_extend.hpp:1697 - z_image compute buffer size: 221.56 MB(RAM)
|==================================================| 8/8 - 68.54s/it�[K
[INFO ] stable-diffusion.cpp:3082 - sampling completed, taking 548.42s
[INFO ] stable-diffusion.cpp:3093 - generating 1 latent images completed, taking 548.76s
[INFO ] stable-diffusion.cpp:3096 - decoding 1 latents
[DEBUG] ggml_extend.hpp:1697 - vae compute buffer size: 1664.00 MB(RAM)
[DEBUG] stable-diffusion.cpp:2301 - computing vae decode graph completed, taking 23.93s
[INFO ] stable-diffusion.cpp:3106 - latent 1 decoded, taking 23.93s
[INFO ] stable-diffusion.cpp:3110 - decode_first_stage completed, taking 23.94s
[INFO ] stable-diffusion.cpp:3406 - generate_image completed in 572.97s
save result PNG image to 'output.png' (success))
可以正常生成图片,说明模型和参数没有问题。

### 这是用GPU编译后生成的日志信息:
(Q:\2stable-diffusion.cpp>.\build\bin\Release\sd.exe --diffusion-model models/z-image/z_image_turbo-Q3_K.gguf --vae models/extra/ae.safetensors --llm models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf --cfg-scale 1 --steps 8 -W 512 -H 512 --vae-tile-size 64 -v --preview proj -p "a cat"
System Info:
SSE3 = 1 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | VSX = 0 | SDCliParams {
mode: img_gen,
output_path: "output.png",
verbose: true,
color: false,
canny_preprocess: false,
preview_method: proj,
preview_interval: 1,
preview_path: "preview.png",
preview_fps: 4,
taesd_preview: false,
preview_noisy: false
}
SDContextParams {
n_threads: 6,
model_path: "",
clip_l_path: "",
clip_g_path: "",
clip_vision_path: "",
t5xxl_path: "",
llm_path: "models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf",
llm_vision_path: "",
diffusion_model_path: "models/z-image/z_image_turbo-Q3_K.gguf",
high_noise_diffusion_model_path: "",
vae_path: "models/extra/ae.safetensors",
taesd_path: "",
esrgan_path: "",
control_net_path: "",
embedding_dir: "",
wtype: NONE,
tensor_type_rules: "",
lora_model_dir: "",
photo_maker_path: "",
rng_type: cuda,
sampler_rng_type: NONE,
flow_shift: INF
offload_params_to_cpu: false,
control_net_cpu: false,
clip_on_cpu: false,
vae_on_cpu: false,
diffusion_flash_attn: false,
diffusion_conv_direct: false,
vae_conv_direct: false,
chroma_use_dit_mask: true,
chroma_use_t5_mask: false,
chroma_t5_mask_pad: 1,
prediction: NONE,
lora_apply_mode: auto,
vae_tiling_params: { 0, 64, 64, 0.5, 0, 0 },
force_sdxl_vae_conv_scale: false
}
SDGenerationParams {
prompt: "a cat",
negative_prompt: "",
clip_skip: -1,
width: 512,
height: 512,
batch_count: 1,
init_image_path: "",
end_image_path: "",
mask_image_path: "",
control_image_path: "",
ref_image_paths: [],
control_video_path: "",
auto_resize_ref_image: true,
increase_ref_index: false,
pm_id_images_dir: "",
pm_id_embed_path: "",
pm_style_strength: 20,
skip_layers: [7, 8, 9],
sample_params: (txt_cfg: 1.00, img_cfg: 1.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: NONE, sample_method: NONE, sample_steps: 8, eta: 0.00, shifted_timestep: 0),
high_noise_skip_layers: [7, 8, 9],
high_noise_sample_params: (txt_cfg: 7.00, img_cfg: 7.00, distilled_guidance: 3.50, slg.layer_count: 3, slg.layer_start: 0.01, slg.layer_end: 0.20, slg.scale: 0.00, scheduler: NONE, sample_method: NONE, sample_steps: 20, eta: 0.00, shifted_timestep: 0),
easycache_option: "",
easycache: disabled (threshold=0, start=0, end=0),
moe_boundary: 0.875,
video_frames: 1,
fps: 16,
vace_strength: 1,
strength: 0.75,
control_strength: 0.9,
seed: 42,
upscale_repeats: 1,
}
[DEBUG] stable-diffusion.cpp:160 - Using CUDA backend
[INFO ] ggml_extend.hpp:69 - ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
[INFO ] ggml_extend.hpp:69 - ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
[INFO ] ggml_extend.hpp:69 - ggml_cuda_init: found 1 CUDA devices:
[INFO ] ggml_extend.hpp:69 - Device 0: Tesla V100-SXM2-16GB, compute capability 7.0, VMM: yes
[INFO ] stable-diffusion.cpp:235 - loading diffusion model from 'models/z-image/z_image_turbo-Q3_K.gguf'
[INFO ] model.cpp:370 - load models/z-image/z_image_turbo-Q3_K.gguf using gguf format
[DEBUG] model.cpp:412 - init from 'models/z-image/z_image_turbo-Q3_K.gguf'
[INFO ] stable-diffusion.cpp:282 - loading llm from 'models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf'
[INFO ] model.cpp:370 - load models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf using gguf format
[DEBUG] model.cpp:412 - init from 'models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf'
[INFO ] stable-diffusion.cpp:296 - loading vae from 'models/extra/ae.safetensors'
[INFO ] model.cpp:373 - load models/extra/ae.safetensors using safetensors format
[DEBUG] model.cpp:503 - init from 'models/extra/ae.safetensors', prefix = 'vae.'
[INFO ] stable-diffusion.cpp:312 - Version: Z-Image
[INFO ] stable-diffusion.cpp:340 - Weight type stat: f32: 640 | q8_0: 22 | q3_K: 180 | q4_K: 216 | q6_K: 37
[INFO ] stable-diffusion.cpp:341 - Conditioner weight type stat: f32: 145 | q4_K: 216 | q6_K: 37
[INFO ] stable-diffusion.cpp:342 - Diffusion model weight type stat: f32: 251 | q8_0: 22 | q3_K: 180
[INFO ] stable-diffusion.cpp:343 - VAE weight type stat: f32: 244
[DEBUG] stable-diffusion.cpp:345 - ggml tensor size = 400 bytes
[DEBUG] llm.hpp:285 - merges size 151387
[DEBUG] llm.hpp:317 - vocab size: 151669
[DEBUG] llm.hpp:1134 - llm: num_layers = 36, vocab_size = 151936, hidden_size = 2560, intermediate_size = 9728
[DEBUG] ggml_extend.hpp:1883 - qwen3 params backend buffer size = 3555.38 MB(VRAM) (398 tensors)
[DEBUG] ggml_extend.hpp:1883 - z_image params backend buffer size = 2997.93 MB(VRAM) (453 tensors)
[DEBUG] ggml_extend.hpp:1883 - vae params backend buffer size = 94.57 MB(VRAM) (138 tensors)
[DEBUG] stable-diffusion.cpp:684 - loading weights
[DEBUG] model.cpp:1351 - using 6 threads for model loading
[DEBUG] model.cpp:1373 - loading tensors from models/z-image/z_image_turbo-Q3_K.gguf
|====================> | 453/1095 - 351.16it/s�[K
[DEBUG] model.cpp:1373 - loading tensors from models/z-image/Qwen3-4B-Instruct-2507-Q4_K_M.gguf
|======================================> | 851/1095 - 288.57it/s�[K
[DEBUG] model.cpp:1373 - loading tensors from models/extra/ae.safetensors
|==================================================| 1095/1095 - 347.62it/s�[K
[INFO ] model.cpp:1580 - loading tensors completed, taking 3.15s (process: 0.00s, read: 1.06s, memcpy: 0.00s, convert: 0.10s, copy_to_backend: 0.73s)
[DEBUG] stable-diffusion.cpp:716 - finished loaded file
[INFO ] stable-diffusion.cpp:788 - total params memory size = 6647.88MB (VRAM 6647.88MB, RAM 0.00MB): text_encoders 3555.38MB(VRAM), diffusion_model 2997.93MB(VRAM), vae 94.57MB(VRAM), controlnet 0.00MB(VRAM), pmid 0.00MB(VRAM)
[INFO ] stable-diffusion.cpp:856 - running in FLOW mode
[DEBUG] stable-diffusion.cpp:3152 - generate_image 512x512
[INFO ] stable-diffusion.cpp:3183 - sampling using Euler method
[INFO ] denoiser.hpp:364 - get_sigmas with discrete scheduler
[INFO ] stable-diffusion.cpp:3296 - TXT2IMG
[DEBUG] conditioner.hpp:1701 - parse '<|im_start|>user
a cat<|im_end|>
<|im_start|>assistant
' to [['<|im_start|>user
', 1], ['a cat', 1], ['<|im_end|>
<|im_start|>assistant
', 1], ]
[DEBUG] llm.hpp:259 - split prompt "<|im_start|>user
" to tokens ["<|im_start|>", "user", "Ċ", ]
[DEBUG] llm.hpp:259 - split prompt "a cat" to tokens ["a", "Ġcat", ]
[DEBUG] llm.hpp:259 - split prompt "<|im_end|>
<|im_start|>assistant
" to tokens ["<|im_end|>", "Ċ", "<|im_start|>", "assistant", "Ċ", ]
[DEBUG] ggml_extend.hpp:1697 - qwen3 compute buffer size: 1.09 MB(VRAM)
[DEBUG] conditioner.hpp:1914 - computing condition graph completed, taking 82 ms
[INFO ] stable-diffusion.cpp:2927 - get_learned_condition completed, taking 83 ms
[INFO ] stable-diffusion.cpp:3038 - generating image: 1/1 - seed 42
[DEBUG] ggml_extend.hpp:1697 - z_image compute buffer size: 206.34 MB(VRAM)
|==================================================| 8/8 - 2.46it/s�[K
[INFO ] stable-diffusion.cpp:3080 - sampling completed, taking 3.32s
[INFO ] stable-diffusion.cpp:3091 - generating 1 latent images completed, taking 3.36s
[INFO ] stable-diffusion.cpp:3094 - decoding 1 latents
[DEBUG] ggml_extend.hpp:1697 - vae compute buffer size: 1664.25 MB(VRAM)
[DEBUG] stable-diffusion.cpp:2297 - computing vae decode graph completed, taking 0.25s
[INFO ] stable-diffusion.cpp:3104 - latent 1 decoded, taking 0.25s
[INFO ] stable-diffusion.cpp:3108 - decode_first_stage completed, taking 0.25s
[INFO ] stable-diffusion.cpp:3404 - generate_image completed in 3.71s
save result PNG image to 'output.png' (success))

@popters
Copy link

popters commented Dec 8, 2025

.\build\bin\sd.exe --diffusion-model models/z-image/z_image_turbo-Q8_0.gguf --vae models/flux-extra/diffusion_pytorch_model.safetensors --llm models/z-image/qwen_3_4b.safetensors --preview proj -p "girl" 我生成的图片怎么是黑色的,纯黑,没有任何图像,谢谢

@popters Please create an issue describing your environment, such as the commit you are using, the backend, etc., or search for existing issues and provide the relevant information.

目前最新的情况换成(z_image_turbo_bf16.safetensors)这个模型可以生成图片,换成其他的量化gguf模型都是黑图

@wbruna
Copy link
Contributor

wbruna commented Dec 8, 2025

@popters , please create a separate issue, so your problem can be properly tracked.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.