Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ API and command-line option may change frequently.***
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
- [SD3/SD3.5](./docs/sd3.md)
- [Flux-dev/Flux-schnell](./docs/flux.md)
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md)
- [Chroma](./docs/chroma.md)
- [Chroma1-Radiance](./docs/chroma_radiance.md)
- [Qwen Image](./docs/qwen_image.md)
Expand Down Expand Up @@ -118,7 +119,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe

- [SD1.x/SD2.x/SDXL](./docs/sd.md)
- [SD3/SD3.5](./docs/sd3.md)
- [Flux-dev/Flux-schnell](./docs/flux.md)
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
- [FLUX.2-dev](./docs/flux2.md)
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- [Chroma](./docs/chroma.md)
- [🔥Qwen Image](./docs/qwen_image.md)
Expand Down
Binary file added assets/flux2/example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
140 changes: 92 additions & 48 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define __CONDITIONER_HPP__

#include "clip.hpp"
#include "qwenvl.hpp"
#include "llm.hpp"
#include "t5.hpp"

struct SDCondition {
Expand Down Expand Up @@ -1623,61 +1623,72 @@ struct T5CLIPEmbedder : public Conditioner {
}
};

struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
Qwen::Qwen2Tokenizer tokenizer;
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;

Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
const std::string prefix = "",
bool enable_vision = false) {
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend,
offload_params_to_cpu,
tensor_storage_map,
"text_encoders.qwen2vl",
enable_vision);
struct LLMEmbedder : public Conditioner {
SDVersion version;
std::shared_ptr<LLM::BPETokenizer> tokenizer;
std::shared_ptr<LLM::LLMRunner> llm;

LLMEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {},
SDVersion version = VERSION_QWEN_IMAGE,
const std::string prefix = "",
bool enable_vision = false)
: version(version) {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
tokenizer = std::make_shared<LLM::MistralTokenizer>();
} else {
tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
}
llm = std::make_shared<LLM::LLMRunner>(arch,
backend,
offload_params_to_cpu,
tensor_storage_map,
"text_encoders.llm",
enable_vision);
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
qwenvl->get_param_tensors(tensors, "text_encoders.qwen2vl");
llm->get_param_tensors(tensors, "text_encoders.llm");
}

void alloc_params_buffer() override {
qwenvl->alloc_params_buffer();
llm->alloc_params_buffer();
}

void free_params_buffer() override {
qwenvl->free_params_buffer();
llm->free_params_buffer();
}

size_t get_params_buffer_size() override {
size_t buffer_size = 0;
buffer_size += qwenvl->get_params_buffer_size();
buffer_size += llm->get_params_buffer_size();
return buffer_size;
}

void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (qwenvl) {
qwenvl->set_weight_adapter(adapter);
if (llm) {
llm->set_weight_adapter(adapter);
}
}

std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
size_t system_prompt_length = 0,
bool padding = false) {
std::pair<int, int> attn_range,
size_t max_length = 0,
bool padding = false) {
std::vector<std::pair<std::string, float>> parsed_attention;
if (system_prompt_length > 0) {
parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f);
auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length));
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
if (attn_range.second - attn_range.first > 0) {
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
parsed_attention.insert(parsed_attention.end(),
new_parsed_attention.begin(),
new_parsed_attention.end());
} else {
parsed_attention = parse_prompt_attention(text);
}

parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
{
std::stringstream ss;
ss << "[";
Expand All @@ -1693,12 +1704,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
std::vector<int> curr_tokens = tokenizer->tokenize(curr_text, nullptr);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}

tokenizer.pad_tokens(tokens, weights, max_length, padding);
tokenizer->pad_tokens(tokens, weights, max_length, padding);

// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
Expand All @@ -1713,9 +1724,10 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
const ConditionerParams& conditioner_params) override {
std::string prompt;
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
size_t system_prompt_length = 0;
std::pair<int, int> prompt_attn_range;
int prompt_template_encode_start_idx = 34;
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
std::set<int> out_layers;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6;
Expand All @@ -1727,7 +1739,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {

for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
double factor = qwenvl->params.vision.patch_size * qwenvl->params.vision.spatial_merge_size;
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
int height = image.height;
int width = image.width;
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
Expand Down Expand Up @@ -1757,7 +1769,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
resized_image.data = nullptr;

ggml_tensor* image_embed = nullptr;
qwenvl->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + image_embed->ne[1] + 6;

Expand All @@ -1771,17 +1783,37 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
}

prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";

system_prompt_length = prompt.size();

prompt += img_prompt;

prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();

prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};

prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";

prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();

prompt += "[/INST]";
} else {
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
prompt_template_encode_start_idx = 34;

prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";

prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();

prompt += "<|im_end|>\n<|im_start|>assistant\n";
}

auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false);
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);

Expand All @@ -1790,11 +1822,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {

auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);

qwenvl->compute(n_threads,
input_ids,
image_embeds,
&hidden_states,
work_ctx);
llm->compute(n_threads,
input_ids,
image_embeds,
out_layers,
&hidden_states,
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_ext_tensor_mean(tensor);
Expand All @@ -1813,14 +1846,25 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {

GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);

int64_t zero_pad_len = 0;
if (sd_version_is_flux2(version)) {
int64_t min_length = 512;
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
}
}

ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32,
hidden_states->ne[0],
hidden_states->ne[1] - prompt_template_encode_start_idx,
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
hidden_states->ne[2]);

ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
float value = 0.f;
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
}
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});

Expand Down
44 changes: 40 additions & 4 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ struct Denoiser {
virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;

virtual std::vector<float> get_sigmas(uint32_t n, scheduler_t scheduler_type, SDVersion version) {
virtual std::vector<float> get_sigmas(uint32_t n, int /*image_seq_len*/, scheduler_t scheduler_type, SDVersion version) {
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
std::shared_ptr<SigmaScheduler> scheduler;
switch (scheduler_type) {
Expand Down Expand Up @@ -582,10 +582,14 @@ struct FluxFlowDenoiser : public Denoiser {
set_parameters(shift);
}

void set_parameters(float shift = 1.15f) {
void set_shift(float shift) {
this->shift = shift;
for (int i = 1; i < TIMESTEPS + 1; i++) {
sigmas[i - 1] = t_to_sigma(i / TIMESTEPS * TIMESTEPS);
}

void set_parameters(float shift) {
set_shift(shift);
for (int i = 0; i < TIMESTEPS; i++) {
sigmas[i] = t_to_sigma(i);
}
}

Expand Down Expand Up @@ -627,6 +631,38 @@ struct FluxFlowDenoiser : public Denoiser {
}
};

struct Flux2FlowDenoiser : public FluxFlowDenoiser {
Flux2FlowDenoiser() = default;

float compute_empirical_mu(uint32_t n, int image_seq_len) {
const float a1 = 8.73809524e-05f;
const float b1 = 1.89833333f;
const float a2 = 0.00016927f;
const float b2 = 0.45666666f;

if (image_seq_len > 4300) {
float mu = a2 * image_seq_len + b2;
return mu;
}

float m_200 = a2 * image_seq_len + b2;
float m_10 = a1 * image_seq_len + b1;

float a = (m_200 - m_10) / 190.0f;
float b = m_200 - 200.0f * a;
float mu = a * n + b;

return mu;
}

std::vector<float> get_sigmas(uint32_t n, int image_seq_len, scheduler_t scheduler_type, SDVersion version) override {
float mu = compute_empirical_mu(n, image_seq_len);
LOG_DEBUG("Flux2FlowDenoiser: set shift to %.3f", mu);
set_shift(mu);
return Denoiser::get_sigmas(n, image_seq_len, scheduler_type, version);
}
};

typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;

// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
Expand Down
21 changes: 21 additions & 0 deletions docs/flux2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# How to Use

## Download weights

- Download FLUX.2-dev
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main

## Examples

```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
```

<img alt="flux2 example" src="../assets/flux2/example.png" />



2 changes: 1 addition & 1 deletion docs/qwen_image.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
## Examples

```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线: 探索视觉生成基础模型的极限,开创理解与生成一体化的未来。二、Qwen-Image的模型特色:1、复杂文字渲染。支持中英渲染、自动布局; 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景:赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3
```

<img alt="qwen example" src="../assets/qwen/example.png" />
Expand Down
Loading
Loading