Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f88daa5
add qwen tokenizer
leejet Sep 20, 2025
fe4e731
add qwen2.5 vl support
leejet Sep 20, 2025
d8d4c26
mv qwen.hpp -> qwenvl.hpp
leejet Sep 21, 2025
d232509
add qwen image model
leejet Sep 21, 2025
cf19c6e
add qwen image t2i pipeline
leejet Sep 22, 2025
477911f
fix qwen image flash attn
leejet Sep 22, 2025
feb0279
add qwen image i2i pipline
leejet Sep 22, 2025
5af0bb0
change encoding of vocab_qwen.hpp to utf8
leejet Sep 22, 2025
a8d3aa0
Merge branch 'master' into qwen_image
leejet Sep 24, 2025
a3a2b2d
fix get_first_stage_encoding
leejet Sep 24, 2025
178a415
Merge branch 'master' into qwen_image
leejet Sep 25, 2025
94f4f29
Merge branch 'master' into qwen_image
leejet Sep 25, 2025
4e48e6b
add ref latent support for qwen image
leejet Sep 23, 2025
95cae28
optimize clip_preprocess and fix get_first_stage_encoding
leejet Sep 24, 2025
58e81ad
add qwen2vl vit support
leejet Sep 29, 2025
40752b6
add qwen image edit support
leejet Oct 8, 2025
887055e
fix qwen image edit pipeline
leejet Oct 8, 2025
9fa817f
add mmproj file support
leejet Oct 9, 2025
a123e25
support dynamic number of Qwen image transformer blocks
leejet Oct 10, 2025
70654d0
revert Rope::gen_qwen_image_ids
leejet Oct 10, 2025
d19d4a5
Merge branch 'master' into qwen_image
leejet Oct 10, 2025
6ea2a75
apply jeffbolz f32 patch
leejet Oct 10, 2025
b769da2
Merge branch 'qwen_image' into qwen_image_edit
leejet Oct 10, 2025
47c0f8e
set prompt_template_encode_start_idx every time
leejet Oct 11, 2025
98d6e71
fix the issue that occurs when using CUDA with k-quants weights
leejet Oct 12, 2025
cc064a0
optimize the handling of the FeedForward precision fix
leejet Oct 12, 2025
0741f14
Merge branch 'qwen_image' into qwen_image_edit
leejet Oct 12, 2025
7519e2f
to_add_out precision fix
leejet Oct 12, 2025
b4b5b4c
Merge branch 'qwen_image' into qwen_image_edit
leejet Oct 12, 2025
d21d1aa
update docs
leejet Oct 12, 2025
ca14940
T5DenseGatedActDense precision fix
leejet Oct 12, 2025
74e020e
Merge branch 'master' into t5_fix
leejet Oct 12, 2025
17f0125
remove dup line
leejet Oct 12, 2025
162d5ce
Merge branch 't5_fix' into qwen_image_edit
leejet Oct 12, 2025
4edc3ad
to_out.0 precision fix
leejet Oct 13, 2025
c47affc
update docs
leejet Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ API and command-line option may change frequently.***
- [Qwen Image](./docs/qwen_image.md)
- Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
- Video Models
- [Wan2.1/Wan2.2](./docs/wan.md)
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
Expand Down Expand Up @@ -298,6 +299,7 @@ arguments:
--clip_vision path to the clip-vision encoder
--t5xxl path to the t5xxl text encoder
--qwen2vl path to the qwen2vl text encoder
--qwen2vl_vision path to the qwen2vl vit
--vae [VAE] path to vae
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
--control-net [CONTROL_PATH] path to control net model
Expand Down
Binary file added assets/qwen/qwen_image_edit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/qwen/qwen_image_edit_2509.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
245 changes: 159 additions & 86 deletions conditioner.hpp

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel {
diffusion_params.x,
diffusion_params.timesteps,
diffusion_params.context,
diffusion_params.ref_latents,
true, // increase_ref_index
output,
output_ctx);
}
Expand Down
35 changes: 35 additions & 0 deletions docs/qwen_image_edit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# How to Use

## Download weights

- Download Qwen Image
- Qwen Image Edit
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-GGUF/tree/main
- Qwen Image Edit 2509
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-2509-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/vae
- Download qwen_2.5_vl 7b
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/mradermacher/Qwen2.5-VL-7B-Instruct-GGUF/tree/main

## Examples

### Qwen Image Edit

```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453
```

<img alt="qwen_image_edit" src="../assets/qwen/qwen_image_edit.png" />


### Qwen Image Edit 2509

```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --qwen2vl_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'"
```

<img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2509.png" />
7 changes: 6 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct SDParams {
std::string clip_vision_path;
std::string t5xxl_path;
std::string qwen2vl_path;
std::string qwen2vl_vision_path;
std::string diffusion_model_path;
std::string high_noise_diffusion_model_path;
std::string vae_path;
Expand Down Expand Up @@ -148,6 +149,7 @@ void print_params(SDParams params) {
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str());
printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str());
Expand Down Expand Up @@ -220,6 +222,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --clip_vision path to the clip-vision encoder\n");
printf(" --t5xxl path to the t5xxl text encoder\n");
printf(" --qwen2vl path to the qwen2vl text encoder\n");
printf(" --qwen2vl_vision path to the qwen2vl vit\n");
printf(" --vae [VAE] path to vae\n");
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --control-net [CONTROL_PATH] path to control net model\n");
Expand Down Expand Up @@ -490,6 +493,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--clip_vision", "", &params.clip_vision_path},
{"", "--t5xxl", "", &params.t5xxl_path},
{"", "--qwen2vl", "", &params.qwen2vl_path},
{"", "--qwen2vl_vision", "", &params.qwen2vl_vision_path},
{"", "--diffusion-model", "", &params.diffusion_model_path},
{"", "--high-noise-diffusion-model", "", &params.high_noise_diffusion_model_path},
{"", "--vae", "", &params.vae_path},
Expand Down Expand Up @@ -952,7 +956,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
}
parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) {
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) {
if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", ";
}
Expand Down Expand Up @@ -1336,6 +1340,7 @@ int main(int argc, const char* argv[]) {
params.clip_vision_path.c_str(),
params.t5xxl_path.c_str(),
params.qwen2vl_path.c_str(),
params.qwen2vl_vision_path.c_str(),
params.diffusion_model_path.c_str(),
params.high_noise_diffusion_model_path.c_str(),
params.vae_path.c_str(),
Expand Down
63 changes: 6 additions & 57 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,57 +81,6 @@ namespace Flux {
}
};

__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* pe) {
// x: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
int64_t d_head = x->ne[0];
int64_t n_head = x->ne[1];
int64_t L = x->ne[2];
int64_t N = x->ne[3];
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head]
x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2]
x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2]

int64_t offset = x->nb[2] * x->ne[2];
auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2]
auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2]
x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1]
x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1]
auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]);
x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2]
x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2]

pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2]
offset = pe->nb[2] * pe->ne[2];
auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2]
auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2]

auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2]
x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head]
return x_out;
}

__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx,
ggml_backend_t backend,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe,
struct ggml_tensor* mask,
bool flash_attn,
float kv_scale = 1.0f) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
// return: [N, L, n_head*d_head]
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]

auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head]
return x;
}

struct SelfAttention : public GGMLBlock {
public:
int64_t num_heads;
Expand Down Expand Up @@ -179,9 +128,9 @@ namespace Flux {
// x: [N, n_token, dim]
// pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
};
Expand Down Expand Up @@ -369,8 +318,8 @@ namespace Flux {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]

auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx,
attn,
attn->ne[0],
Expand Down Expand Up @@ -504,7 +453,7 @@ namespace Flux {
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]

auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
Expand Down
33 changes: 15 additions & 18 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,11 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i
return value;
}

__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) {
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic, bool scale = true) {
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
if (scale) {
value /= 255.f;
}
return value;
}

Expand Down Expand Up @@ -456,24 +459,18 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
}
}

__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data,
struct ggml_tensor* output,
__STATIC_INLINE__ void sd_image_f32_to_tensor(sd_image_f32_t image,
ggml_tensor* tensor,
bool scale = true) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
int value = *(image_data + iy * width * channels + ix * channels + k);
if (scale) {
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
GGML_ASSERT(image.width == tensor->ne[0]);
GGML_ASSERT(image.height == tensor->ne[1]);
GGML_ASSERT(image.channel == tensor->ne[2]);
GGML_ASSERT(1 == tensor->ne[3]);
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = sd_image_get_f32(image, i0, i1, i2, scale);
ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3);
});
}

__STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,
Expand Down
36 changes: 31 additions & 5 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ const char* unused_tensors[] = {
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
"text_encoders.qwen2vl.output.weight",
"text_encoders.qwen2vl.lm_head.",
"text_encoders.qwen2vl.visual.",
};

bool is_unused_tensor(std::string name) {
Expand Down Expand Up @@ -212,6 +211,24 @@ std::unordered_map<std::string, std::string> qwenvl_name_map{
{"output_norm.", "model.norm."},
};

std::unordered_map<std::string, std::string> qwenvl_vision_name_map{
{"mm.", "merger.mlp."},
{"v.post_ln.", "merger.ln_q."},
{"v.patch_embd.weight", "patch_embed.proj.0.weight"},
{"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"},
{"v.patch_embd.weight.1", "patch_embed.proj.1.weight"},
{"v.blk.", "blocks."},
{"attn_q.", "attn.q_proj."},
{"attn_k.", "attn.k_proj."},
{"attn_v.", "attn.v_proj."},
{"attn_out.", "attn.proj."},
{"ffn_down.", "mlp.down_proj."},
{"ffn_gate.", "mlp.gate_proj."},
{"ffn_up.", "mlp.up_proj."},
{"ln1.", "norm1."},
{"ln2.", "norm2."},
};

std::string convert_cond_model_name(const std::string& name) {
std::string new_name = name;
std::string prefix;
Expand Down Expand Up @@ -270,10 +287,19 @@ std::string convert_cond_model_name(const std::string& name) {
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
}
} else if (contains(name, "qwen2vl")) {
for (auto kv : qwenvl_name_map) {
size_t pos = new_name.find(kv.first);
if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
if (contains(name, "qwen2vl.visual")) {
for (auto kv : qwenvl_vision_name_map) {
size_t pos = new_name.find(kv.first);
if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
}
}
} else {
for (auto kv : qwenvl_name_map) {
size_t pos = new_name.find(kv.first);
if (pos != std::string::npos) {
new_name.replace(pos, kv.first.size(), kv.second);
}
}
}
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {
Expand Down
Loading
Loading