Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define __CONDITIONER_HPP__

#include "clip.hpp"
#include "qwen3.hpp"
#include "qwenvl.hpp"
#include "t5.hpp"

Expand Down Expand Up @@ -1830,4 +1831,123 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
}
};

struct ZImageConditioner : public Conditioner {
Qwen::Qwen2Tokenizer tokenizer;
std::shared_ptr<Qwen3::Qwen3Runner> qwen3;
std::string chat_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n";
int64_t skip_token_count = 0; // Use full sequence for Z-Image

ZImageConditioner(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string& prefix = "text_encoders.qwen3") {
qwen3 = std::make_shared<Qwen3::Qwen3Runner>(backend,
offload_params_to_cpu,
tensor_storage_map,
prefix);
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
qwen3->get_param_tensors(tensors, "text_encoders.qwen3");
}

void alloc_params_buffer() override {
qwen3->alloc_params_buffer();
}

void free_params_buffer() override {
qwen3->free_params_buffer();
}

size_t get_params_buffer_size() override {
return qwen3->get_params_buffer_size();
}

std::string apply_chat_template(const std::string& text) {
std::string result = chat_template;
size_t pos = result.find("{}");
if (pos != std::string::npos) {
result.replace(pos, 2, text);
}
return result;
}

std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
auto parsed_attention = parse_prompt_attention(text);
std::vector<int> tokens;
std::vector<float> weights;
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}
tokenizer.pad_tokens(tokens, weights, max_length, padding);
return {tokens, weights};
}

SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const ConditionerParams& conditioner_params) override {
std::string prompt = apply_chat_template(conditioner_params.text);
auto tokens_and_weights = tokenize(prompt, 0, false);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);

int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = nullptr;

auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);

qwen3->compute(n_threads,
input_ids,
&hidden_states,
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= weights[i1];
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_ext_tensor_mean(tensor);
if (new_mean > 0) {
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}
}

int64_t skip_count = skip_token_count;
int64_t output_seq_len = hidden_states->ne[1] - skip_count;
if (output_seq_len <= 0) {
LOG_WARN("ZImageConditioner: output sequence length would be %lld (hidden_states seq=%lld, skip=%lld), using full sequence",
output_seq_len, hidden_states->ne[1], skip_count);
output_seq_len = hidden_states->ne[1];
skip_count = 0;
}

ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32,
hidden_states->ne[0],
output_seq_len,
hidden_states->ne[2]);

ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + skip_count, i2, i3);
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});

int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing Z-Image condition graph completed, taking %" PRId64 " ms", t1 - t0);
return {new_hidden_states, nullptr, nullptr};
}
};

#endif
49 changes: 49 additions & 0 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,55 @@ struct FluxFlowDenoiser : public Denoiser {
}
};

// Z-Image flow matching denoiser
struct ZImageFlowDenoiser : public Denoiser {
float sigmas[TIMESTEPS];
float shift = 3.0f;

ZImageFlowDenoiser(float shift = 3.0f) {
this->shift = shift;
for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(i);
}
}

float sigma_min() override {
return sigmas[0];
}

float sigma_max() override {
return sigmas[TIMESTEPS - 1];
}

float sigma_to_t(float sigma) override {
return 1.0f - sigma;
}

float t_to_sigma(float t) override {
float sigma_raw = (t + 1) / TIMESTEPS;
return shift * sigma_raw / (1.0f + (shift - 1.0f) * sigma_raw);
}

std::vector<float> get_scalings(float sigma) override {
float c_skip = 1.0f;
float c_out = sigma;
float c_in = 1.0f;
return {c_skip, c_out, c_in};
}

ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override {
ggml_ext_tensor_scale_inplace(noise, sigma);
ggml_ext_tensor_scale_inplace(latent, 1.0f - sigma);
ggml_ext_tensor_add_inplace(latent, noise);
return latent;
}

ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override {
ggml_ext_tensor_scale_inplace(latent, 1.0f / (1.0f - sigma));
return latent;
}
};

typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;

// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
Expand Down
62 changes: 62 additions & 0 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "qwen_image.hpp"
#include "unet.hpp"
#include "wan.hpp"
#include "zimage.hpp"

struct DiffusionParams {
struct ggml_tensor* x = nullptr;
Expand Down Expand Up @@ -357,4 +358,65 @@ struct QwenImageModel : public DiffusionModel {
}
};

struct ZImageDiffusionModel : public DiffusionModel {
std::string prefix;
ZImage::ZImageRunner zimage;

ZImageDiffusionModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "model.diffusion_model")
: prefix(prefix), zimage(backend, offload_params_to_cpu, tensor_storage_map, prefix) {
}

std::string get_desc() override {
return zimage.get_desc();
}

void alloc_params_buffer() override {
zimage.alloc_params_buffer();
}

void free_params_buffer() override {
zimage.free_params_buffer();
}

void free_compute_buffer() override {
zimage.free_compute_buffer();
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
zimage.get_param_tensors(tensors, prefix);
}

size_t get_params_buffer_size() override {
return zimage.get_params_buffer_size();
}

int64_t get_adm_in_channels() override {
return 0;
}

void set_flash_attn_enabled(bool enabled) {
zimage.set_flash_attention_enabled(enabled);
}

void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override {
int height = diffusion_params.x->ne[1] * 8;
int width = diffusion_params.x->ne[0] * 8;

zimage.compute(n_threads,
diffusion_params.x,
diffusion_params.timesteps,
diffusion_params.context,
height,
width,
output,
output_ctx);
}
};

#endif
9 changes: 8 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ struct SDParams {
std::string t5xxl_path;
std::string qwen2vl_path;
std::string qwen2vl_vision_path;
std::string qwen3_path;
std::string diffusion_model_path;
std::string high_noise_diffusion_model_path;
std::string vae_path;
Expand Down Expand Up @@ -176,6 +177,7 @@ void print_params(SDParams params) {
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str());
printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str());
printf(" qwen3_path: %s\n", params.qwen3_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str());
Expand Down Expand Up @@ -540,6 +542,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"--qwen2vl_vision",
"path to the qwen2vl vit",
&params.qwen2vl_vision_path},
{"",
"--qwen3",
"path to the qwen3 text encoder (for Z-Image)",
&params.qwen3_path},
{"",
"--diffusion-model",
"path to the standalone diffusion model",
Expand Down Expand Up @@ -1428,7 +1434,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += " " + std::string(sd_scheduler_name(params.sample_params.scheduler));
}
parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path, params.qwen3_path}) {
if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", ";
}
Expand Down Expand Up @@ -1847,6 +1853,7 @@ int main(int argc, const char* argv[]) {
params.t5xxl_path.c_str(),
params.qwen2vl_path.c_str(),
params.qwen2vl_vision_path.c_str(),
params.qwen3_path.c_str(),
params.diffusion_model_path.c_str(),
params.high_noise_diffusion_model_path.c_str(),
params.vae_path.c_str(),
Expand Down
7 changes: 7 additions & 0 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,13 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) {
return VERSION_QWEN_IMAGE;
}
if (tensor_storage.name.find("model.diffusion_model.context_refiner.") != std::string::npos ||
tensor_storage.name.find("model.diffusion_model.noise_refiner.") != std::string::npos ||
// Also check without prefix for safetensors files exported directly
tensor_storage.name.find("context_refiner.") == 0 ||
tensor_storage.name.find("noise_refiner.") == 0) {
return VERSION_ZIMAGE;
}
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
is_wan = true;
}
Expand Down
11 changes: 10 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ enum SDVersion {
VERSION_WAN2_2_I2V,
VERSION_WAN2_2_TI2V,
VERSION_QWEN_IMAGE,
VERSION_ZIMAGE,
VERSION_COUNT,
};

Expand Down Expand Up @@ -108,6 +109,13 @@ static inline bool sd_version_is_qwen_image(SDVersion version) {
return false;
}

static inline bool sd_version_is_zimage(SDVersion version) {
if (version == VERSION_ZIMAGE) {
return true;
}
return false;
}

static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT ||
version == VERSION_SD2_INPAINT ||
Expand All @@ -123,7 +131,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
if (sd_version_is_flux(version) ||
sd_version_is_sd3(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version)) {
sd_version_is_qwen_image(version) ||
sd_version_is_zimage(version)) {
return true;
}
return false;
Expand Down
Loading
Loading