diff --git a/README.md b/README.md index 0d2da62c..0a27bc1c 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ API and command-line option may change frequently.*** - [Qwen Image](./docs/qwen_image.md) - Image Edit Models - [FLUX.1-Kontext-dev](./docs/kontext.md) + - [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md) - Video Models - [Wan2.1/Wan2.2](./docs/wan.md) - [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support. @@ -298,6 +299,7 @@ arguments: --clip_vision path to the clip-vision encoder --t5xxl path to the t5xxl text encoder --qwen2vl path to the qwen2vl text encoder + --qwen2vl_vision path to the qwen2vl vit --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --control-net [CONTROL_PATH] path to control net model diff --git a/assets/qwen/qwen_image_edit.png b/assets/qwen/qwen_image_edit.png new file mode 100644 index 00000000..c2a31eda Binary files /dev/null and b/assets/qwen/qwen_image_edit.png differ diff --git a/assets/qwen/qwen_image_edit_2509.png b/assets/qwen/qwen_image_edit_2509.png new file mode 100644 index 00000000..442ba9b3 Binary files /dev/null and b/assets/qwen/qwen_image_edit_2509.png differ diff --git a/conditioner.hpp b/conditioner.hpp index b25ef84f..abd6dbc3 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -15,28 +15,28 @@ struct SDCondition { : c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {} }; +struct ConditionerParams { + std::string text; + int clip_skip = -1; + int width = -1; + int height = -1; + int adm_in_channels = -1; + bool zero_out_masked = false; + int num_input_imgs = 0; // for photomaker + std::vector ref_images = {}; // for qwen image edit +}; + struct Conditioner { virtual SDCondition get_learned_condition(ggml_context* work_ctx, int n_threads, - const std::string& text, - int clip_skip, - int width, - int height, - int adm_in_channels = -1, - bool zero_out_masked = false) = 0; - virtual void alloc_params_buffer() = 0; - virtual void free_params_buffer() = 0; - virtual void get_param_tensors(std::map& tensors) = 0; - virtual size_t get_params_buffer_size() = 0; + const ConditionerParams& conditioner_params) = 0; + virtual void alloc_params_buffer() = 0; + virtual void free_params_buffer() = 0; + virtual void get_param_tensors(std::map& tensors) = 0; + virtual size_t get_params_buffer_size() = 0; virtual std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, int n_threads, - const std::string& text, - int clip_skip, - int width, - int height, - int num_input_imgs, - int adm_in_channels = -1, - bool zero_out_masked = false) { + const ConditionerParams& conditioner_params) { GGML_ABORT("Not implemented yet!"); } virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx, @@ -555,20 +555,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, int n_threads, - const std::string& text, - int clip_skip, - int width, - int height, - int num_input_imgs, - int adm_in_channels = -1, - bool zero_out_masked = false) { + const ConditionerParams& conditioner_params) { auto image_tokens = convert_token_to_id(trigger_word); // if(image_tokens.size() == 1){ // printf(" image token id is: %d \n", image_tokens[0]); // } GGML_ASSERT(image_tokens.size() == 1); - auto tokens_and_weights = tokenize_with_trigger_token(text, - num_input_imgs, + auto tokens_and_weights = tokenize_with_trigger_token(conditioner_params.text, + conditioner_params.num_input_imgs, image_tokens[0], true); std::vector& tokens = std::get<0>(tokens_and_weights); @@ -582,7 +576,15 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { // for(int i = 0; i < clsm.size(); ++i) // printf("%d ", clsm[i]?1:0); // printf("\n"); - auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked); + auto cond = get_learned_condition_common(work_ctx, + n_threads, + tokens, + weights, + conditioner_params.clip_skip, + conditioner_params.width, + conditioner_params.height, + conditioner_params.adm_in_channels, + conditioner_params.zero_out_masked); return std::make_tuple(cond, clsm); } @@ -600,16 +602,19 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { SDCondition get_learned_condition(ggml_context* work_ctx, int n_threads, - const std::string& text, - int clip_skip, - int width, - int height, - int adm_in_channels = -1, - bool zero_out_masked = false) { - auto tokens_and_weights = tokenize(text, true); + const ConditionerParams& conditioner_params) { + auto tokens_and_weights = tokenize(conditioner_params.text, true); std::vector& tokens = tokens_and_weights.first; std::vector& weights = tokens_and_weights.second; - return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, zero_out_masked); + return get_learned_condition_common(work_ctx, + n_threads, + tokens, + weights, + conditioner_params.clip_skip, + conditioner_params.width, + conditioner_params.height, + conditioner_params.adm_in_channels, + conditioner_params.zero_out_masked); } }; @@ -974,14 +979,13 @@ struct SD3CLIPEmbedder : public Conditioner { SDCondition get_learned_condition(ggml_context* work_ctx, int n_threads, - const std::string& text, - int clip_skip, - int width, - int height, - int adm_in_channels = -1, - bool zero_out_masked = false) { - auto tokens_and_weights = tokenize(text, 77, true); - return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked); + const ConditionerParams& conditioner_params) { + auto tokens_and_weights = tokenize(conditioner_params.text, 77, true); + return get_learned_condition_common(work_ctx, + n_threads, + tokens_and_weights, + conditioner_params.clip_skip, + conditioner_params.zero_out_masked); } }; @@ -1174,14 +1178,13 @@ struct FluxCLIPEmbedder : public Conditioner { SDCondition get_learned_condition(ggml_context* work_ctx, int n_threads, - const std::string& text, - int clip_skip, - int width, - int height, - int adm_in_channels = -1, - bool zero_out_masked = false) { - auto tokens_and_weights = tokenize(text, chunk_len, true); - return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked); + const ConditionerParams& conditioner_params) { + auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true); + return get_learned_condition_common(work_ctx, + n_threads, + tokens_and_weights, + conditioner_params.clip_skip, + conditioner_params.zero_out_masked); } }; @@ -1360,27 +1363,30 @@ struct T5CLIPEmbedder : public Conditioner { SDCondition get_learned_condition(ggml_context* work_ctx, int n_threads, - const std::string& text, - int clip_skip, - int width, - int height, - int adm_in_channels = -1, - bool zero_out_masked = false) { - auto tokens_and_weights = tokenize(text, chunk_len, true); - return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked); + const ConditionerParams& conditioner_params) { + auto tokens_and_weights = tokenize(conditioner_params.text, chunk_len, true); + return get_learned_condition_common(work_ctx, + n_threads, + tokens_and_weights, + conditioner_params.clip_skip, + conditioner_params.zero_out_masked); } }; struct Qwen2_5_VLCLIPEmbedder : public Conditioner { Qwen::Qwen2Tokenizer tokenizer; std::shared_ptr qwenvl; - int prompt_template_encode_start_idx = 34; Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, - const std::string prefix = "") { - qwenvl = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.qwen2vl"); + const std::string prefix = "", + bool enable_vision = false) { + qwenvl = std::make_shared(backend, + offload_params_to_cpu, + tensor_types, + "text_encoders.qwen2vl", + enable_vision); } void get_param_tensors(std::map& tensors) { @@ -1402,9 +1408,19 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { } std::tuple, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { - auto parsed_attention = parse_prompt_attention(text); + size_t max_length = 0, + size_t system_prompt_length = 0, + bool padding = false) { + std::vector> parsed_attention; + if (system_prompt_length > 0) { + parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f); + auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length)); + parsed_attention.insert(parsed_attention.end(), + new_parsed_attention.begin(), + new_parsed_attention.end()); + } else { + parsed_attention = parse_prompt_attention(text); + } { std::stringstream ss; @@ -1429,20 +1445,89 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { tokenizer.pad_tokens(tokens, weights, max_length, padding); // for (int i = 0; i < tokens.size(); i++) { - // std::cout << tokens[i] << ":" << weights[i] << ", "; + // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl; // } // std::cout << std::endl; return {tokens, weights}; } - SDCondition get_learned_condition_common(ggml_context* work_ctx, - int n_threads, - std::tuple, std::vector> token_and_weights, - int clip_skip, - bool zero_out_masked = false) { - auto& tokens = std::get<0>(token_and_weights); - auto& weights = std::get<1>(token_and_weights); + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const ConditionerParams& conditioner_params) { + std::string prompt; + std::vector> image_embeds; + size_t system_prompt_length = 0; + int prompt_template_encode_start_idx = 34; + if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) { + LOG_INFO("QwenImageEditPlusPipeline"); + prompt_template_encode_start_idx = 64; + int image_embed_idx = 64 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = qwenvl->params.vision.patch_size * qwenvl->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor)) * factor; + int w_bar = static_cast(std::round(width / factor)) * factor; + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; + + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; + + ggml_tensor* image_embed = nullptr; + qwenvl->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + image_embed->ne[1] + 6; + + img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; + } + + prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + + system_prompt_length = prompt.size(); + + prompt += img_prompt; + prompt += conditioner_params.text; + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } else { + prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n"; + } + + auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false); + auto& tokens = std::get<0>(tokens_and_weights); + auto& weights = std::get<1>(tokens_and_weights); int64_t t0 = ggml_time_ms(); struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584] @@ -1451,6 +1536,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { qwenvl->compute(n_threads, input_ids, + image_embeds, &hidden_states, work_ctx); { @@ -1486,19 +1572,6 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); return SDCondition(new_hidden_states, nullptr, nullptr); } - - SDCondition get_learned_condition(ggml_context* work_ctx, - int n_threads, - const std::string& text, - int clip_skip, - int width, - int height, - int adm_in_channels = -1, - bool zero_out_masked = false) { - std::string prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + text + "<|im_end|>\n<|im_start|>assistant\n"; - auto tokens_and_weights = tokenize(prompt, 0, false); - return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked); - } }; #endif diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 69cd5748..6c38b58a 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel { diffusion_params.x, diffusion_params.timesteps, diffusion_params.context, + diffusion_params.ref_latents, + true, // increase_ref_index output, output_ctx); } diff --git a/docs/qwen_image_edit.md b/docs/qwen_image_edit.md new file mode 100644 index 00000000..3a5242f2 --- /dev/null +++ b/docs/qwen_image_edit.md @@ -0,0 +1,35 @@ +# How to Use + +## Download weights + +- Download Qwen Image + - Qwen Image Edit + - safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models + - gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-GGUF/tree/main + - Qwen Image Edit 2509 + - safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models + - gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-2509-GGUF/tree/main +- Download vae + - safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/vae +- Download qwen_2.5_vl 7b + - safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/text_encoders + - gguf: https://huggingface.co/mradermacher/Qwen2.5-VL-7B-Instruct-GGUF/tree/main + +## Examples + +### Qwen Image Edit + +``` +.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453 +``` + +qwen_image_edit + + +### Qwen Image Edit 2509 + +``` +.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --qwen2vl_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'" +``` + +qwen_image_edit_2509 \ No newline at end of file diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 5229876f..b1d83a06 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -62,6 +62,7 @@ struct SDParams { std::string clip_vision_path; std::string t5xxl_path; std::string qwen2vl_path; + std::string qwen2vl_vision_path; std::string diffusion_model_path; std::string high_noise_diffusion_model_path; std::string vae_path; @@ -148,6 +149,7 @@ void print_params(SDParams params) { printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str()); + printf(" qwen2vl_vision_path: %s\n", params.qwen2vl_vision_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str()); @@ -220,6 +222,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --clip_vision path to the clip-vision encoder\n"); printf(" --t5xxl path to the t5xxl text encoder\n"); printf(" --qwen2vl path to the qwen2vl text encoder\n"); + printf(" --qwen2vl_vision path to the qwen2vl vit\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); @@ -490,6 +493,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--clip_vision", "", ¶ms.clip_vision_path}, {"", "--t5xxl", "", ¶ms.t5xxl_path}, {"", "--qwen2vl", "", ¶ms.qwen2vl_path}, + {"", "--qwen2vl_vision", "", ¶ms.qwen2vl_vision_path}, {"", "--diffusion-model", "", ¶ms.diffusion_model_path}, {"", "--high-noise-diffusion-model", "", ¶ms.high_noise_diffusion_model_path}, {"", "--vae", "", ¶ms.vae_path}, @@ -952,7 +956,7 @@ std::string get_image_params(SDParams params, int64_t seed) { parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler)); } parameter_string += ", "; - for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) { + for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path, params.qwen2vl_vision_path}) { if (!te.empty()) { parameter_string += "TE: " + sd_basename(te) + ", "; } @@ -1336,6 +1340,7 @@ int main(int argc, const char* argv[]) { params.clip_vision_path.c_str(), params.t5xxl_path.c_str(), params.qwen2vl_path.c_str(), + params.qwen2vl_vision_path.c_str(), params.diffusion_model_path.c_str(), params.high_noise_diffusion_model_path.c_str(), params.vae_path.c_str(), diff --git a/flux.hpp b/flux.hpp index 39543721..2ed41041 100644 --- a/flux.hpp +++ b/flux.hpp @@ -81,57 +81,6 @@ namespace Flux { } }; - __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, - struct ggml_tensor* x, - struct ggml_tensor* pe) { - // x: [N, L, n_head, d_head] - // pe: [L, d_head/2, 2, 2] - int64_t d_head = x->ne[0]; - int64_t n_head = x->ne[1]; - int64_t L = x->ne[2]; - int64_t N = x->ne[3]; - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head] - x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2] - x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2] - - int64_t offset = x->nb[2] * x->ne[2]; - auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2] - auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2] - x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1] - x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1] - auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]); - x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2] - x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2] - - pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2] - offset = pe->nb[2] * pe->ne[2]; - auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2] - auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2] - - auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2] - x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head] - return x_out; - } - - __STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, - ggml_backend_t backend, - struct ggml_tensor* q, - struct ggml_tensor* k, - struct ggml_tensor* v, - struct ggml_tensor* pe, - struct ggml_tensor* mask, - bool flash_attn, - float kv_scale = 1.0f) { - // q,k,v: [N, L, n_head, d_head] - // pe: [L, d_head/2, 2, 2] - // return: [N, L, n_head*d_head] - q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] - k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - - auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head] - return x; - } - struct SelfAttention : public GGMLBlock { public: int64_t num_heads; @@ -179,9 +128,9 @@ namespace Flux { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -369,8 +318,8 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] + attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, attn->ne[0], @@ -504,7 +453,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] + auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] diff --git a/ggml_extend.hpp b/ggml_extend.hpp index a125357b..b1b9e8c8 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -197,8 +197,11 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i return value; } -__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) { +__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic, bool scale = true) { float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic); + if (scale) { + value /= 255.f; + } return value; } @@ -456,24 +459,18 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, } } -__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data, - struct ggml_tensor* output, +__STATIC_INLINE__ void sd_image_f32_to_tensor(sd_image_f32_t image, + ggml_tensor* tensor, bool scale = true) { - int64_t width = output->ne[0]; - int64_t height = output->ne[1]; - int64_t channels = output->ne[2]; - GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { - for (int k = 0; k < channels; k++) { - int value = *(image_data + iy * width * channels + ix * channels + k); - if (scale) { - value /= 255.f; - } - ggml_tensor_set_f32(output, value, ix, iy, k); - } - } - } + GGML_ASSERT(image.width == tensor->ne[0]); + GGML_ASSERT(image.height == tensor->ne[1]); + GGML_ASSERT(image.channel == tensor->ne[2]); + GGML_ASSERT(1 == tensor->ne[3]); + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = sd_image_get_f32(image, i0, i1, i2, scale); + ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3); + }); } __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input, diff --git a/model.cpp b/model.cpp index 55b1abca..b45493cc 100644 --- a/model.cpp +++ b/model.cpp @@ -113,7 +113,6 @@ const char* unused_tensors[] = { "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training "text_encoders.qwen2vl.output.weight", "text_encoders.qwen2vl.lm_head.", - "text_encoders.qwen2vl.visual.", }; bool is_unused_tensor(std::string name) { @@ -212,6 +211,24 @@ std::unordered_map qwenvl_name_map{ {"output_norm.", "model.norm."}, }; +std::unordered_map qwenvl_vision_name_map{ + {"mm.", "merger.mlp."}, + {"v.post_ln.", "merger.ln_q."}, + {"v.patch_embd.weight", "patch_embed.proj.0.weight"}, + {"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"}, + {"v.patch_embd.weight.1", "patch_embed.proj.1.weight"}, + {"v.blk.", "blocks."}, + {"attn_q.", "attn.q_proj."}, + {"attn_k.", "attn.k_proj."}, + {"attn_v.", "attn.v_proj."}, + {"attn_out.", "attn.proj."}, + {"ffn_down.", "mlp.down_proj."}, + {"ffn_gate.", "mlp.gate_proj."}, + {"ffn_up.", "mlp.up_proj."}, + {"ln1.", "norm1."}, + {"ln2.", "norm2."}, +}; + std::string convert_cond_model_name(const std::string& name) { std::string new_name = name; std::string prefix; @@ -270,10 +287,19 @@ std::string convert_cond_model_name(const std::string& name) { new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias."); } } else if (contains(name, "qwen2vl")) { - for (auto kv : qwenvl_name_map) { - size_t pos = new_name.find(kv.first); - if (pos != std::string::npos) { - new_name.replace(pos, kv.first.size(), kv.second); + if (contains(name, "qwen2vl.visual")) { + for (auto kv : qwenvl_vision_name_map) { + size_t pos = new_name.find(kv.first); + if (pos != std::string::npos) { + new_name.replace(pos, kv.first.size(), kv.second); + } + } + } else { + for (auto kv : qwenvl_name_map) { + size_t pos = new_name.find(kv.first); + if (pos != std::string::npos) { + new_name.replace(pos, kv.first.size(), kv.second); + } } } } else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") { diff --git a/qwen_image.hpp b/qwen_image.hpp index 90357afc..630e5536 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -94,12 +94,12 @@ namespace Qwen { blocks["norm_added_q"] = std::shared_ptr(new RMSNorm(dim_head, eps)); blocks["norm_added_k"] = std::shared_ptr(new RMSNorm(dim_head, eps)); - blocks["to_out.0"] = std::shared_ptr(new Linear(inner_dim, out_dim, out_bias)); - // to_out.1 is nn.Dropout - float scale = 1.f / 32.f; // The purpose of the scale here is to prevent NaN issues in certain situations. // For example when using CUDA but the weights are k-quants (not all prompts). + blocks["to_out.0"] = std::shared_ptr(new Linear(inner_dim, out_dim, out_bias, false, false, scale)); + // to_out.1 is nn.Dropout + blocks["to_add_out"] = std::shared_ptr(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale)); } @@ -159,7 +159,7 @@ namespace Qwen { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, @@ -389,6 +389,13 @@ namespace Qwen { return x; } + struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* x) { + x = pad_to_patch_size(ctx, x); + x = patchify(ctx, x); + return x; + } + struct ggml_tensor* unpatchify(struct ggml_context* ctx, struct ggml_tensor* x, int64_t h, @@ -449,7 +456,8 @@ namespace Qwen { struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + std::vector ref_latents = {}) { // Forward pass of DiT. // x: [N, C, H, W] // timestep: [N,] @@ -462,13 +470,26 @@ namespace Qwen { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - x = pad_to_patch_size(ctx, x); - x = patchify(ctx, x); + auto img = process_img(ctx, x); + uint64_t img_tokens = img->ne[1]; + + if (ref_latents.size() > 0) { + for (ggml_tensor* ref : ref_latents) { + ref = process_img(ctx, ref); + img = ggml_concat(ctx, img, ref, 1); + } + } int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size); int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size); - auto out = forward_orig(ctx, backend, x, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] + auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] + + if (out->ne[1] > img_tokens) { + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] + out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] + } out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] @@ -495,6 +516,25 @@ namespace Qwen { bool flash_attn = false) : GGMLRunner(backend, offload_params_to_cpu) { qwen_image_params.flash_attn = flash_attn; + qwen_image_params.num_layers = 0; + for (auto pair : tensor_types) { + std::string tensor_name = pair.first; + if (tensor_name.find(prefix) == std::string::npos) + continue; + size_t pos = tensor_name.find("transformer_blocks."); + if (pos != std::string::npos) { + tensor_name = tensor_name.substr(pos); // remove prefix + auto items = split_string(tensor_name, '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > qwen_image_params.num_layers) { + qwen_image_params.num_layers = block_index + 1; + } + } + continue; + } + } + LOG_ERROR("qwen_image_params.num_layers: %ld", qwen_image_params.num_layers); qwen_image = QwenImageModel(qwen_image_params); qwen_image.init(params_ctx, tensor_types, prefix); } @@ -509,7 +549,9 @@ namespace Qwen { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, - struct ggml_tensor* context) { + struct ggml_tensor* context, + std::vector ref_latents = {}, + bool increase_ref_index = false) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false); @@ -517,18 +559,24 @@ namespace Qwen { context = to_backend(context); timesteps = to_backend(timesteps); + for (int i = 0; i < ref_latents.size(); i++) { + ref_latents[i] = to_backend(ref_latents[i]); + } + pe_vec = Rope::gen_qwen_image_pe(x->ne[1], x->ne[0], qwen_image_params.patch_size, x->ne[3], context->ne[1], + ref_latents, + increase_ref_index, qwen_image_params.theta, qwen_image_params.axes_dim); int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); - // print_ggml_tensor(pe); + // print_ggml_tensor(pe, true, "pe"); // pe->data = NULL; set_backend_tensor_data(pe, pe_vec.data()); @@ -537,7 +585,8 @@ namespace Qwen { x, timesteps, context, - pe); + pe, + ref_latents); ggml_build_forward_expand(gf, out); @@ -548,13 +597,15 @@ namespace Qwen { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + std::vector ref_latents = {}, + bool increase_ref_index = false, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context); + return build_graph(x, timesteps, context, ref_latents, increase_ref_index); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -586,7 +637,7 @@ namespace Qwen { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, &out, work_ctx); + compute(8, x, timesteps, context, {}, false, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/qwenvl.hpp b/qwenvl.hpp index 228452de..881f54d7 100644 --- a/qwenvl.hpp +++ b/qwenvl.hpp @@ -15,9 +15,11 @@ #include "clip.hpp" #include "ggml_extend.hpp" #include "json.hpp" +#include "rope.hpp" #include "tokenize_util.h" namespace Qwen { + constexpr int QWENVL_GRAPH_SIZE = 10240; class Qwen2Tokenizer { private: @@ -340,9 +342,9 @@ namespace Qwen { struct Qwen2_5_VLMLP : public GGMLBlock { public: Qwen2_5_VLMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) { - blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, false)); - blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, false)); - blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, false)); + blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); + blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, bias)); + blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, bias)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { @@ -359,6 +361,288 @@ namespace Qwen { } }; + struct Qwen2_5_VisionPatchEmbed : public GGMLBlock { + protected: + bool llama_cpp_style; + int patch_size; + int temporal_patch_size; + int64_t in_channels; + int64_t embed_dim; + + public: + Qwen2_5_VisionPatchEmbed(bool llama_cpp_style, + int patch_size = 14, + int temporal_patch_size = 2, + int64_t in_channels = 3, + int64_t embed_dim = 1152) + : llama_cpp_style(llama_cpp_style), + patch_size(patch_size), + temporal_patch_size(temporal_patch_size), + in_channels(in_channels), + embed_dim(embed_dim) { + if (llama_cpp_style) { + blocks["proj.0"] = std::shared_ptr(new Conv2d(in_channels, + embed_dim, + {patch_size, patch_size}, + {patch_size, patch_size}, // stride + {0, 0}, // padding + {1, 1}, // dilation + false)); + blocks["proj.1"] = std::shared_ptr(new Conv2d(in_channels, + embed_dim, + {patch_size, patch_size}, + {patch_size, patch_size}, // stride + {0, 0}, // padding + {1, 1}, // dilation + false)); + } else { + std::tuple kernel_size = {(int)temporal_patch_size, (int)patch_size, (int)patch_size}; + blocks["proj"] = std::shared_ptr(new Conv3d(in_channels, + embed_dim, + kernel_size, + kernel_size, // stride + {0, 0, 0}, // padding + {1, 1, 1}, // dilation + false)); + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size] + // return: [N*grid_t*grid_h*grid_w, embed_dim] + x = ggml_reshape_4d(ctx, + x, + patch_size, + patch_size, + temporal_patch_size, + ggml_nelements(x) / (temporal_patch_size * patch_size * patch_size)); + + if (llama_cpp_style) { + auto proj_0 = std::dynamic_pointer_cast(blocks["proj.0"]); + auto proj_1 = std::dynamic_pointer_cast(blocks["proj.1"]); + + auto x0 = ggml_slice(ctx, x, 2, 0, 1); + x0 = ggml_reshape_4d(ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels); + x0 = proj_0->forward(ctx, x0); + + auto x1 = ggml_slice(ctx, x, 2, 1, 2); + x1 = ggml_reshape_4d(ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels); + x1 = proj_1->forward(ctx, x1); + + x = ggml_add(ctx, x0, x1); + } else { + auto proj = std::dynamic_pointer_cast(blocks["proj"]); + + x = proj->forward(ctx, x); + } + + x = ggml_reshape_2d(ctx, x, embed_dim, ggml_nelements(x) / embed_dim); + return x; + } + }; + + struct Qwen2_5_VLPatchMerger : public GGMLBlock { + protected: + int64_t hidden_size; + + public: + Qwen2_5_VLPatchMerger(int64_t dim, + int64_t context_dim, + int64_t spatial_merge_size) { + hidden_size = context_dim * spatial_merge_size * spatial_merge_size; + blocks["ln_q"] = std::shared_ptr(new RMSNorm(context_dim, 1e-6f)); + blocks["mlp.0"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + // mlp.1 is nn.GELU() + blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, dim)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + auto ln_q = std::dynamic_pointer_cast(blocks["ln_q"]); + auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); + auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); + + x = ln_q->forward(ctx, x); + x = ggml_reshape_2d(ctx, x, hidden_size, ggml_nelements(x) / hidden_size); + x = mlp_0->forward(ctx, x); + x = ggml_gelu(ctx, x); + x = mlp_2->forward(ctx, x); + return x; + } + }; + + struct Qwen2_5_VLVisionAttention : public GGMLBlock { + protected: + bool llama_cpp_style; + int64_t head_dim; + int64_t num_heads; + + public: + Qwen2_5_VLVisionAttention(bool llama_cpp_style, + int64_t hidden_size, + int64_t num_heads) + : llama_cpp_style(llama_cpp_style), num_heads(num_heads) { + head_dim = hidden_size / num_heads; + GGML_ASSERT(num_heads * head_dim == hidden_size); + if (llama_cpp_style) { + blocks["q_proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + blocks["k_proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + blocks["v_proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + } else { + blocks["qkv"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3)); + } + blocks["proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* pe, + struct ggml_tensor* mask = nullptr) { + // x: [N, n_token, hidden_size] + int64_t n_token = x->ne[1]; + int64_t N = x->ne[2]; + auto proj = std::dynamic_pointer_cast(blocks["proj"]); + + std::vector qkv_vec; + if (llama_cpp_style) { + auto q_proj = std::dynamic_pointer_cast(blocks["q_proj"]); + auto k_proj = std::dynamic_pointer_cast(blocks["k_proj"]); + auto v_proj = std::dynamic_pointer_cast(blocks["v_proj"]); + + auto q = q_proj->forward(ctx, x); + auto k = k_proj->forward(ctx, x); + auto v = v_proj->forward(ctx, x); + + qkv_vec = {q, k, v}; + } else { + auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); + auto qkv = qkv_proj->forward(ctx, x); + qkv_vec = split_qkv(ctx, qkv); + } + + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] + + x = Rope::attention(ctx, backend, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size] + + x = proj->forward(ctx, x); // [N, n_token, hidden_size] + return x; + } + }; + + struct Qwen2_5_VLVisionBlock : public GGMLBlock { + public: + Qwen2_5_VLVisionBlock(bool llama_cpp_style, + int64_t hidden_size, + int64_t intermediate_size, + int64_t num_heads, + float eps = 1e-6f) { + blocks["attn"] = std::shared_ptr(new Qwen2_5_VLVisionAttention(llama_cpp_style, hidden_size, num_heads)); + blocks["mlp"] = std::shared_ptr(new Qwen2_5_VLMLP(hidden_size, intermediate_size, true)); + blocks["norm1"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + blocks["norm2"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* pe, + struct ggml_tensor* mask = nullptr) { + // x: [N, n_token, hidden_size] + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + + auto residual = x; + x = norm1->forward(ctx, x); + x = attn->forward(ctx, backend, x, pe, mask); + x = ggml_add_inplace(ctx, x, residual); + + residual = x; + x = norm2->forward(ctx, x); + x = mlp->forward(ctx, x); + x = ggml_add_inplace(ctx, x, residual); + + return x; + } + }; + + struct Qwen2_5_VLVisionModel : public GGMLBlock { + protected: + int64_t num_layers; + int64_t spatial_merge_size; + std::set fullatt_block_indexes; + + public: + Qwen2_5_VLVisionModel(bool llama_cpp_style, + int64_t num_layers, + int64_t in_channels, + int64_t hidden_size, + int64_t out_hidden_size, + int64_t intermediate_size, + int64_t num_heads, + int64_t spatial_merge_size, + int64_t patch_size, + int64_t temporal_patch_size, + int64_t window_size, + std::set fullatt_block_indexes = {7, 15, 23, 31}, + float eps = 1e-6f) + : num_layers(num_layers), fullatt_block_indexes(fullatt_block_indexes), spatial_merge_size(spatial_merge_size) { + blocks["patch_embed"] = std::shared_ptr(new Qwen2_5_VisionPatchEmbed(llama_cpp_style, + patch_size, + temporal_patch_size, + in_channels, + hidden_size)); + for (int i = 0; i < num_layers; i++) { + blocks["blocks." + std::to_string(i)] = std::shared_ptr(new Qwen2_5_VLVisionBlock(llama_cpp_style, + hidden_size, + intermediate_size, + num_heads, + eps)); + } + blocks["merger"] = std::shared_ptr(new Qwen2_5_VLPatchMerger(out_hidden_size, hidden_size, spatial_merge_size)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* pixel_values, + struct ggml_tensor* pe, + struct ggml_tensor* window_index, + struct ggml_tensor* window_inverse_index, + struct ggml_tensor* window_mask) { + // pixel_values: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw] + // window_index: [grid_t*(H/mh/ph)*(W/mw/pw)] + // window_inverse_index: [grid_t*(H/mh/ph)*(W/mw/pw)] + // window_mask: [grid_h*grid_w, grid_h*grid_w] + auto patch_embed = std::dynamic_pointer_cast(blocks["patch_embed"]); + auto merger = std::dynamic_pointer_cast(blocks["merger"]); + + auto x = patch_embed->forward(ctx, pixel_values); + + x = ggml_reshape_4d(ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]); + x = ggml_get_rows(ctx, x, window_index); + x = ggml_reshape_4d(ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]); + + for (int i = 0; i < num_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); + + auto mask = window_mask; + if (fullatt_block_indexes.find(i) != fullatt_block_indexes.end()) { + mask = nullptr; + } + x = block->forward(ctx, backend, x, pe, mask); + } + + x = merger->forward(ctx, x); + + x = ggml_get_rows(ctx, x, window_inverse_index); + + return x; + } + }; + struct Qwen2_5_VLAttention : public GGMLBlock { protected: int64_t head_dim; @@ -478,7 +762,8 @@ namespace Qwen { struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* input_ids, - struct ggml_tensor* input_pos) { + struct ggml_tensor* input_pos, + std::vector> image_embeds) { // input_ids: [N, n_token] // return: [N, n_token, hidden_size] @@ -487,6 +772,45 @@ namespace Qwen { auto x = embed_tokens->forward(ctx, input_ids); + if (image_embeds.size() > 0) { + GGML_ASSERT(x->ne[2] == 1); // N == 1 + + auto raw_x = ggml_cast(ctx, x, image_embeds[0].second->type); + int64_t txt_token_start = 0; + int64_t txt_token_end = 0; + + ggml_tensor* input_embed = nullptr; + + for (int i = 0; i < image_embeds.size(); i++) { + if (i == 0) { + txt_token_start = 0; + } else { + txt_token_start = image_embeds[i - 1].first + image_embeds[i - 1].second->ne[1]; + } + txt_token_end = image_embeds[i].first; + + auto txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end); + if (input_embed == nullptr) { + input_embed = txt_embed; + } else { + input_embed = ggml_concat(ctx, input_embed, txt_embed, 1); + } + + auto image_embed = image_embeds[i].second; + input_embed = ggml_concat(ctx, input_embed, image_embed, 1); + } + + txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1]; + txt_token_end = raw_x->ne[1]; + + auto final_txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end); + + input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1); + GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]); + + x = input_embed; + } + for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); @@ -498,6 +822,20 @@ namespace Qwen { } }; + struct Qwen2_5_VLVisionParams { + int64_t num_layers = 32; + int64_t hidden_size = 1280; + int64_t intermediate_size = 3420; + int64_t num_heads = 16; + int64_t in_channels = 3; + int64_t out_hidden_size = 3584; + int64_t temporal_patch_size = 2; + int64_t patch_size = 14; + int64_t spatial_merge_size = 2; + int64_t window_size = 112; + std::set fullatt_block_indexes = {7, 15, 23, 31}; + }; + struct Qwen2_5_VLParams { int64_t num_layers = 28; int64_t hidden_size = 3584; @@ -506,15 +844,17 @@ namespace Qwen { int64_t num_kv_heads = 4; int64_t vocab_size = 152064; float rms_norm_eps = 1e-06f; + Qwen2_5_VLVisionParams vision; }; struct Qwen2_5_VL : public GGMLBlock { + bool enable_vision; Qwen2_5_VLParams params; public: Qwen2_5_VL() {} - Qwen2_5_VL(Qwen2_5_VLParams params) - : params(params) { + Qwen2_5_VL(Qwen2_5_VLParams params, bool enable_vision = false, bool llama_cpp_style = false) + : enable_vision(enable_vision), params(params) { blocks["model"] = std::shared_ptr(new Qwen2_5_VLTextModel(params.num_layers, params.vocab_size, params.hidden_size, @@ -522,32 +862,90 @@ namespace Qwen { params.num_heads, params.num_kv_heads, params.rms_norm_eps)); + if (enable_vision) { + blocks["visual"] = std::shared_ptr(new Qwen2_5_VLVisionModel(llama_cpp_style, + params.vision.num_layers, + params.vision.in_channels, + params.vision.hidden_size, + params.vision.out_hidden_size, + params.vision.intermediate_size, + params.vision.num_heads, + params.vision.spatial_merge_size, + params.vision.patch_size, + params.vision.temporal_patch_size, + params.vision.window_size, + params.vision.fullatt_block_indexes)); + } } struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* input_ids, - struct ggml_tensor* input_pos) { + struct ggml_tensor* input_pos, + std::vector> image_embeds) { // input_ids: [N, n_token] auto model = std::dynamic_pointer_cast(blocks["model"]); - auto x = model->forward(ctx, backend, input_ids, input_pos); + auto x = model->forward(ctx, backend, input_ids, input_pos, image_embeds); return x; } + + struct ggml_tensor* vision_forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* pixel_values, + struct ggml_tensor* pe, + struct ggml_tensor* window_index, + struct ggml_tensor* window_inverse_index, + struct ggml_tensor* window_mask) { + GGML_ASSERT(enable_vision); + auto vision_model = std::dynamic_pointer_cast(blocks["visual"]); + return vision_model->forward(ctx, backend, pixel_values, pe, window_index, window_inverse_index, window_mask); + } }; struct Qwen2_5_VLRunner : public GGMLRunner { Qwen2_5_VLParams params; + bool enable_vision; Qwen2_5_VL model; std::vector input_pos_vec; + std::vector window_mask_vec; + std::vector window_index_vec; + std::vector window_inverse_index_vec; + std::vector pe_vec; Qwen2_5_VLRunner(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types, - const std::string prefix) - : GGMLRunner(backend, offload_params_to_cpu) { - model = Qwen2_5_VL(params); + const std::string prefix, + bool enable_vision_ = false) + : GGMLRunner(backend, offload_params_to_cpu), enable_vision(enable_vision_) { + bool have_vision_weight = false; + bool llama_cpp_style = false; + for (auto pair : tensor_types) { + std::string tensor_name = pair.first; + if (tensor_name.find(prefix) == std::string::npos) + continue; + size_t pos = tensor_name.find("visual."); + if (pos != std::string::npos) { + have_vision_weight = true; + if (contains(tensor_name, "attn.q_proj")) { + llama_cpp_style = true; + break; + } + } + } + if (enable_vision && !have_vision_weight) { + LOG_WARN("no vision weights detected, vision disabled"); + enable_vision = false; + } + if (enable_vision) { + LOG_DEBUG("enable qwen2vl vision"); + if (llama_cpp_style) { + LOG_DEBUG("llama.cpp style vision weight"); + } + } + model = Qwen2_5_VL(params, enable_vision, llama_cpp_style); model.init(params_ctx, tensor_types, prefix); } @@ -562,16 +960,32 @@ namespace Qwen { struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* input_ids, - struct ggml_tensor* input_pos) { - auto hidden_states = model.forward(ctx, backend, input_ids, input_pos); // [N, n_token, hidden_size] + struct ggml_tensor* input_pos, + std::vector> image_embeds) { + auto hidden_states = model.forward(ctx, backend, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size] return hidden_states; } - struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { + struct ggml_tensor* vision_forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* pixel_values, + struct ggml_tensor* input_pos, + struct ggml_tensor* window_index, + struct ggml_tensor* window_inverse_index, + struct ggml_tensor* window_mask) { + auto hidden_states = model.vision_forward(ctx, backend, pixel_values, input_pos, window_index, window_inverse_index, window_mask); + return hidden_states; + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, std::vector> image_embeds) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); input_ids = to_backend(input_ids); + for (auto& image_embed : image_embeds) { + image_embed.second = to_backend(image_embed.second); + } + int64_t n_tokens = input_ids->ne[0]; input_pos_vec.resize(n_tokens * 4); for (int i = 0; i < n_tokens; ++i) { @@ -586,7 +1000,7 @@ namespace Qwen { n_tokens * 4); set_backend_tensor_data(input_pos, input_pos_vec.data()); - struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos); + struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos, image_embeds); ggml_build_forward_expand(gf, hidden_states); @@ -595,13 +1009,183 @@ namespace Qwen { void compute(const int n_threads, struct ggml_tensor* input_ids, + std::vector> image_embeds, ggml_tensor** output, ggml_context* output_ctx = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(input_ids); + return build_graph(input_ids, image_embeds); }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } + + int64_t get_num_image_tokens(int64_t t, int64_t h, int64_t w) { + int grid_t = 1; + int grid_h = h / params.vision.patch_size; + int grid_w = w / params.vision.patch_size; + int llm_grid_h = grid_h / params.vision.spatial_merge_size; + int llm_grid_w = grid_w / params.vision.spatial_merge_size; + return grid_t * grid_h * grid_w; + } + + struct ggml_tensor* process_image(struct ggml_context* ctx, struct ggml_tensor* image) { + // image: [C, H, W] + // return: [grid_t*(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw], grid_t == 1 + int64_t C = image->ne[2]; + int64_t H = image->ne[1]; + int64_t W = image->ne[0]; + int64_t mh = params.vision.spatial_merge_size; + int64_t mw = params.vision.spatial_merge_size; + int64_t pt = params.vision.temporal_patch_size; + int64_t ph = params.vision.patch_size; + int64_t pw = params.vision.patch_size; + + image = ggml_reshape_4d(ctx, image, pw, mw, (W / mw / pw), H * C); // [C*H, (W/mw/pw), mw, pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [mw, C*H, (W/mw/pw), pw] + image = ggml_reshape_4d(ctx, image, pw * (W / mw / pw), H, C, mw); // [mw, C, H, (W/mw/pw)*pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 3, 1)); // [H, mw, C, (W/mw/pw)*pw] + image = ggml_reshape_4d(ctx, image, pw, (W / mw / pw) * C * mw, ph, mh * (H / mh / ph)); // [(H/mh/ph)*mh, ph, mw*C*(W/mw/pw), pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh, mw*C*(W/mw/pw), ph, pw] + image = ggml_reshape_4d(ctx, image, pw * ph, (W / mw / pw), C, mw * mh * (H / mh / ph)); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), ph*pw] + image = ggml_concat(ctx, image, image, 0); // [(H/mh/ph)*mh*mw, C, (W/mw/pw), pt*ph*pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph)*mh*mw, (W/mw/pw), C, pt*ph*pw] + image = ggml_reshape_4d(ctx, image, pw * ph * pt * C, (W / mw / pw), mw * mh, (H / mh / ph)); // [(H/mh/ph), mh*mw, (W/mw/pw), C*pt*ph*pw] + image = ggml_cont(ctx, ggml_torch_permute(ctx, image, 0, 2, 1, 3)); // [(H/mh/ph), (W/mw/pw), mh*mw, C*pt*ph*pw] + image = ggml_reshape_2d(ctx, image, pw * ph * pt * C, mw * mh * (W / mw / pw) * (H / mh / ph)); // [(H/mh/ph)*(W/mw/pw)*mh*mw, C*pt*ph*pw] + return image; + } + + struct ggml_cgraph* build_encode_image_graph(struct ggml_tensor* image) { + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWENVL_GRAPH_SIZE, false); + + GGML_ASSERT(image->ne[1] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0); + GGML_ASSERT(image->ne[0] % (params.vision.patch_size * params.vision.spatial_merge_size) == 0); + + int grid_t = 1; + int grid_h = image->ne[1] / params.vision.patch_size; + int grid_w = image->ne[0] / params.vision.patch_size; + int llm_grid_h = grid_h / params.vision.spatial_merge_size; + int llm_grid_w = grid_w / params.vision.spatial_merge_size; + int vit_merger_window_size = params.vision.window_size / params.vision.patch_size / params.vision.spatial_merge_size; + + image = to_backend(image); + + auto pixel_values = process_image(compute_ctx, image); + + // window index + int inverse_index = 0; + window_index_vec.resize(llm_grid_h * llm_grid_w); + window_inverse_index_vec.resize(llm_grid_h * llm_grid_w); + std::vector seqlens; + for (int ih = 0; ih < llm_grid_h; ih += vit_merger_window_size) { + for (int iw = 0; iw < llm_grid_w; iw += vit_merger_window_size) { + int win_h = std::min(vit_merger_window_size, llm_grid_h - ih); + int win_w = std::min(vit_merger_window_size, llm_grid_w - iw); + for (int iy = 0; iy < win_h; iy++) { + for (int ix = 0; ix < win_w; ix++) { + int index = (ih + iy) * llm_grid_w + iw + ix; + window_index_vec[inverse_index] = index; + window_inverse_index_vec[index] = inverse_index; + inverse_index++; + } + } + seqlens.push_back(win_h * win_w * params.vision.spatial_merge_size * params.vision.spatial_merge_size); + } + } + // printf("window_index: "); + // for (int i : window_index_vec) { + // printf("%d ", i); + // } + // printf("\n"); + // printf("window_inverse_index: "); + // for (int i : window_inverse_index_vec) { + // printf("%d ", i); + // } + // printf("\n"); + // printf("seqlens: "); + // for (int i : seqlens) { + // printf("%d ", i); + // } + // printf("\n"); + auto window_index = ggml_new_tensor_1d(compute_ctx, + GGML_TYPE_I32, + llm_grid_h * llm_grid_w); + auto window_inverse_index = ggml_new_tensor_1d(compute_ctx, + GGML_TYPE_I32, + llm_grid_h * llm_grid_w); + set_backend_tensor_data(window_index, window_index_vec.data()); + set_backend_tensor_data(window_inverse_index, window_inverse_index_vec.data()); + + // window mask + int seq_window_size = (vit_merger_window_size * params.vision.spatial_merge_size) * (vit_merger_window_size * params.vision.spatial_merge_size); + window_mask_vec.resize((grid_h * grid_w) * (grid_h * grid_w)); + int window_start_index = 0; + for (int seq_index = 0; seq_index < seqlens.size(); seq_index++) { + int window_end_index = window_start_index + seqlens[seq_index]; + // LOG_DEBUG("%d %d", window_start_index, window_end_index); + GGML_ASSERT(window_end_index <= grid_h * grid_w); + for (int i = window_start_index; i < window_end_index; i++) { + for (int j = 0; j < grid_h * grid_w; j++) { + float mask_value = -INFINITY; + if (j >= window_start_index && j < window_end_index) { + mask_value = 0; + } + GGML_ASSERT((i * (grid_h * grid_w) + j) < window_mask_vec.size()); + window_mask_vec[i * (grid_h * grid_w) + j] = mask_value; + } + } + window_start_index = window_end_index; + // printf("\n"); + } + // printf("window_mask: \n"); + // for (int i = 0; i < grid_h*grid_w; i++) { + // for (int j = 0; j < grid_h*grid_w; j++) { + // printf("%f ", window_mask_vec[i * (grid_h * grid_w) + j]); + // } + // printf("\n"); + // } + auto window_mask = ggml_new_tensor_2d(compute_ctx, + GGML_TYPE_F32, + grid_h * grid_w, + grid_h * grid_w); + set_backend_tensor_data(window_mask, window_mask_vec.data()); + + // pe + int head_dim = params.vision.hidden_size / params.vision.num_heads; + pe_vec = Rope::gen_qwen2vl_pe(grid_h, + grid_w, + params.vision.spatial_merge_size, + window_inverse_index_vec, + 10000.f, + {head_dim / 2, head_dim / 2}); + int pos_len = pe_vec.size() / head_dim / 2; + // LOG_DEBUG("pos_len %d", pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, head_dim / 2, pos_len); + // pe->data = pe_vec.data(); + // print_ggml_tensor(pe); + // pe->data = NULL; + set_backend_tensor_data(pe, pe_vec.data()); + + struct ggml_tensor* hidden_states = vision_forward(compute_ctx, + runtime_backend, + pixel_values, + pe, + window_index, + window_inverse_index, + window_mask); + ggml_build_forward_expand(gf, hidden_states); + + return gf; + } + + void encode_image(const int n_threads, + struct ggml_tensor* image, + ggml_tensor** output, + ggml_context* output_ctx = NULL) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_encode_image_graph(image); + }; + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } }; struct Qwen2_5_VLEmbedder { @@ -611,8 +1195,9 @@ namespace Qwen { Qwen2_5_VLEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, - const std::string prefix = "") - : model(backend, offload_params_to_cpu, tensor_types, prefix) { + const std::string prefix = "", + bool enable_vision = false) + : model(backend, offload_params_to_cpu, tensor_types, prefix, enable_vision) { } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -666,8 +1251,76 @@ namespace Qwen { struct ggml_context* work_ctx = ggml_init(params); GGML_ASSERT(work_ctx != NULL); + bool test_vit = true; + bool test_decoder_with_vit = true; + + if (test_decoder_with_vit) { + ggml_tensor* image_embed = nullptr; + { + auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin"); + print_ggml_tensor(image, false, "image"); + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + model.encode_image(8, image, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out, false, "image_embed"); + image_embed = out; + LOG_DEBUG("qwen2vl encode_image test done in %dms", t1 - t0); + } - { + std::string placeholder = "<|image_pad|>"; + std::string img_prompt = "Picture 1: <|vision_start|>"; // [24669, 220, 16, 25, 220, 151652] + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int i = 0; i < num_image_tokens; i++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; + + std::vector> image_embeds; + image_embeds.emplace_back(64, image_embed); + + std::string text = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + text += img_prompt; + text += "change 'flux.cpp' to 'edit.cpp'"; + text += "<|im_end|>\n<|im_start|>assistant\n"; + + auto tokens_and_weights = tokenize(text, 0, false); + std::vector& tokens = std::get<0>(tokens_and_weights); + std::vector& weights = std::get<1>(tokens_and_weights); + for (auto token : tokens) { + printf("%d ", token); + } + printf("\n"); + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + model.compute(8, input_ids, image_embeds, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("qwen2vl test done in %dms", t1 - t0); + } else if (test_vit) { + // auto image = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 280, 280, 3); + // ggml_set_f32(image, 0.f); + auto image = load_tensor_from_file(work_ctx, "qwen2vl_normalized.bin"); + print_ggml_tensor(image, false, "image"); + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + model.encode_image(8, image, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out, false, "out"); + + // auto ref_out = load_tensor_from_file(work_ctx, "qwen2vl.bin"); + // ggml_tensor_diff(ref_out, out, 0.01f); + + LOG_DEBUG("qwen2vl test done in %dms", t1 - t0); + } else { std::string text("<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\na lovely cat<|im_end|>\n<|im_start|>assistant\n"); auto tokens_and_weights = tokenize(text, 0, false); std::vector& tokens = std::get<0>(tokens_and_weights); @@ -680,7 +1333,7 @@ namespace Qwen { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - model.compute(8, input_ids, &out, work_ctx); + model.compute(8, input_ids, {}, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -692,7 +1345,7 @@ namespace Qwen { // cpu f16: pass // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); - ggml_type model_data_type = GGML_TYPE_Q8_0; + ggml_type model_data_type = GGML_TYPE_F16; ModelLoader model_loader; if (!model_loader.init_from_file(file_path, "qwen2vl.")) { @@ -708,7 +1361,11 @@ namespace Qwen { } } - std::shared_ptr qwenvl = std::shared_ptr(new Qwen2_5_VLEmbedder(backend, false, tensor_types, "qwen2vl")); + std::shared_ptr qwenvl = std::shared_ptr(new Qwen2_5_VLEmbedder(backend, + false, + tensor_types, + "qwen2vl", + true)); qwenvl->alloc_params_buffer(); std::map tensors; diff --git a/rope.hpp b/rope.hpp index 5e3aaf93..295c9a21 100644 --- a/rope.hpp +++ b/rope.hpp @@ -4,9 +4,9 @@ #include #include "ggml_extend.hpp" -struct Rope { +namespace Rope { template - static std::vector linspace(T start, T end, int num) { + __STATIC_INLINE__ std::vector linspace(T start, T end, int num) { std::vector result(num); if (num == 1) { result[0] = start; @@ -19,7 +19,7 @@ struct Rope { return result; } - static std::vector> transpose(const std::vector>& mat) { + __STATIC_INLINE__ std::vector> transpose(const std::vector>& mat) { int rows = mat.size(); int cols = mat[0].size(); std::vector> transposed(cols, std::vector(rows)); @@ -31,7 +31,7 @@ struct Rope { return transposed; } - static std::vector flatten(const std::vector>& vec) { + __STATIC_INLINE__ std::vector flatten(const std::vector>& vec) { std::vector flat_vec; for (const auto& sub_vec : vec) { flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); @@ -39,7 +39,7 @@ struct Rope { return flat_vec; } - static std::vector> rope(const std::vector& pos, int dim, int theta) { + __STATIC_INLINE__ std::vector> rope(const std::vector& pos, int dim, int theta) { assert(dim % 2 == 0); int half_dim = dim / 2; @@ -72,11 +72,11 @@ struct Rope { } // Generate IDs for image patches and text - static std::vector> gen_txt_ids(int bs, int context_len) { + __STATIC_INLINE__ std::vector> gen_txt_ids(int bs, int context_len) { return std::vector>(bs * context_len, std::vector(3, 0.0)); } - static std::vector> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { + __STATIC_INLINE__ std::vector> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; @@ -102,9 +102,9 @@ struct Rope { return img_ids_repeated; } - static std::vector> concat_ids(const std::vector>& a, - const std::vector>& b, - int bs) { + __STATIC_INLINE__ std::vector> concat_ids(const std::vector>& a, + const std::vector>& b, + int bs) { size_t a_len = a.size() / bs; size_t b_len = b.size() / bs; std::vector> ids(a.size() + b.size(), std::vector(3)); @@ -119,10 +119,10 @@ struct Rope { return ids; } - static std::vector embed_nd(const std::vector>& ids, - int bs, - int theta, - const std::vector& axes_dim) { + __STATIC_INLINE__ std::vector embed_nd(const std::vector>& ids, + int bs, + int theta, + const std::vector& axes_dim) { std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size() / bs; int num_axes = axes_dim.size(); @@ -151,17 +151,11 @@ struct Rope { return flatten(emb); } - static std::vector> gen_flux_ids(int h, - int w, - int patch_size, - int bs, - int context_len, - std::vector ref_latents, - bool increase_ref_index) { - auto txt_ids = gen_txt_ids(bs, context_len); - auto img_ids = gen_img_ids(h, w, patch_size, bs); - - auto ids = concat_ids(txt_ids, img_ids, bs); + __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, + int bs, + const std::vector& ref_latents, + bool increase_ref_index) { + std::vector> ids; uint64_t curr_h_offset = 0; uint64_t curr_w_offset = 0; int index = 1; @@ -189,25 +183,45 @@ struct Rope { return ids; } + __STATIC_INLINE__ std::vector> gen_flux_ids(int h, + int w, + int patch_size, + int bs, + int context_len, + const std::vector& ref_latents, + bool increase_ref_index) { + auto txt_ids = gen_txt_ids(bs, context_len); + auto img_ids = gen_img_ids(h, w, patch_size, bs); + + auto ids = concat_ids(txt_ids, img_ids, bs); + if (ref_latents.size() > 0) { + auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index); + ids = concat_ids(ids, refs_ids, bs); + } + return ids; + } + // Generate flux positional embeddings - static std::vector gen_flux_pe(int h, - int w, - int patch_size, - int bs, - int context_len, - std::vector ref_latents, - bool increase_ref_index, - int theta, - const std::vector& axes_dim) { + __STATIC_INLINE__ std::vector gen_flux_pe(int h, + int w, + int patch_size, + int bs, + int context_len, + const std::vector& ref_latents, + bool increase_ref_index, + int theta, + const std::vector& axes_dim) { std::vector> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); return embed_nd(ids, bs, theta, axes_dim); } - static std::vector> gen_qwen_image_ids(int h, - int w, - int patch_size, - int bs, - int context_len) { + __STATIC_INLINE__ std::vector> gen_qwen_image_ids(int h, + int w, + int patch_size, + int bs, + int context_len, + const std::vector& ref_latents, + bool increase_ref_index) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; int txt_id_start = std::max(h_len, w_len); @@ -220,31 +234,37 @@ struct Rope { } auto img_ids = gen_img_ids(h, w, patch_size, bs); auto ids = concat_ids(txt_ids_repeated, img_ids, bs); + if (ref_latents.size() > 0) { + auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index); + ids = concat_ids(ids, refs_ids, bs); + } return ids; } // Generate qwen_image positional embeddings - static std::vector gen_qwen_image_pe(int h, - int w, - int patch_size, - int bs, - int context_len, - int theta, - const std::vector& axes_dim) { - std::vector> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len); + __STATIC_INLINE__ std::vector gen_qwen_image_pe(int h, + int w, + int patch_size, + int bs, + int context_len, + const std::vector& ref_latents, + bool increase_ref_index, + int theta, + const std::vector& axes_dim) { + std::vector> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); return embed_nd(ids, bs, theta, axes_dim); } - static std::vector> gen_vid_ids(int t, - int h, - int w, - int pt, - int ph, - int pw, - int bs, - int t_offset = 0, - int h_offset = 0, - int w_offset = 0) { + __STATIC_INLINE__ std::vector> gen_vid_ids(int t, + int h, + int w, + int pt, + int ph, + int pw, + int bs, + int t_offset = 0, + int h_offset = 0, + int w_offset = 0) { int t_len = (t + (pt / 2)) / pt; int h_len = (h + (ph / 2)) / ph; int w_len = (w + (pw / 2)) / pw; @@ -276,18 +296,115 @@ struct Rope { } // Generate wan positional embeddings - static std::vector gen_wan_pe(int t, - int h, - int w, - int pt, - int ph, - int pw, - int bs, - int theta, - const std::vector& axes_dim) { + __STATIC_INLINE__ std::vector gen_wan_pe(int t, + int h, + int w, + int pt, + int ph, + int pw, + int bs, + int theta, + const std::vector& axes_dim) { std::vector> ids = gen_vid_ids(t, h, w, pt, ph, pw, bs); return embed_nd(ids, bs, theta, axes_dim); } -}; // struct Rope + + __STATIC_INLINE__ std::vector> gen_qwen2vl_ids(int grid_h, + int grid_w, + int merge_size, + const std::vector& window_index) { + std::vector> ids(grid_h * grid_w, std::vector(2, 0.0)); + int index = 0; + for (int ih = 0; ih < grid_h; ih += merge_size) { + for (int iw = 0; iw < grid_w; iw += merge_size) { + for (int iy = 0; iy < merge_size; iy++) { + for (int ix = 0; ix < merge_size; ix++) { + int inverse_index = window_index[index / (merge_size * merge_size)]; + int i = inverse_index * (merge_size * merge_size) + index % (merge_size * merge_size); + + GGML_ASSERT(i < grid_h * grid_w); + + ids[i][0] = ih + iy; + ids[i][1] = iw + ix; + index++; + } + } + } + } + return ids; + } + + // Generate qwen2vl positional embeddings + __STATIC_INLINE__ std::vector gen_qwen2vl_pe(int grid_h, + int grid_w, + int merge_size, + const std::vector& window_index, + int theta, + const std::vector& axes_dim) { + std::vector> ids = gen_qwen2vl_ids(grid_h, grid_w, merge_size, window_index); + return embed_nd(ids, 1, theta, axes_dim); + } + + __STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* pe, + bool rope_interleaved = true) { + // x: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2], [[cos, -sin], [sin, cos]] + int64_t d_head = x->ne[0]; + int64_t n_head = x->ne[1]; + int64_t L = x->ne[2]; + int64_t N = x->ne[3]; + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head] + if (rope_interleaved) { + x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); // [N * n_head, L, d_head/2, 2] + x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2] + } else { + x = ggml_reshape_4d(ctx, x, d_head / 2, 2, L, n_head * N); // [N * n_head, L, 2, d_head/2] + x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 0, 2, 3, 1)); // [2, N * n_head, L, d_head/2] + } + + int64_t offset = x->nb[2] * x->ne[2]; + auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2] + auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2] + x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1] + x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1] + auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]); + x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2] + x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2] + + pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2] + offset = pe->nb[2] * pe->ne[2]; + auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2] + auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2] + + auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2] + if (!rope_interleaved) { + x_out = ggml_cont(ctx, ggml_permute(ctx, x_out, 1, 0, 2, 3)); // [N * n_head, L, x, d_head/2] + } + x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); // [N*n_head, L, d_head] + return x_out; + } + + __STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* q, + struct ggml_tensor* k, + struct ggml_tensor* v, + struct ggml_tensor* pe, + struct ggml_tensor* mask, + bool flash_attn, + float kv_scale = 1.0f, + bool rope_interleaved = true) { + // q,k,v: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2] + // return: [N, L, n_head*d_head] + q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head] + k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head] + + auto x = ggml_nn_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head] + return x; + } +}; // namespace Rope #endif // __ROPE_HPP__ diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 51f8cbe0..9667ba15 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -261,6 +261,13 @@ class StableDiffusionGGML { } } + if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_vision_path)) > 0) { + LOG_INFO("loading qwen2vl vision from '%s'", sd_ctx_params->qwen2vl_vision_path); + if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_vision_path, "text_encoders.qwen2vl.visual.")) { + LOG_WARN("loading qwen2vl vision from '%s' failed", sd_ctx_params->qwen2vl_vision_path); + } + } + if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) { LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path); if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) { @@ -274,6 +281,15 @@ class StableDiffusionGGML { return false; } + auto& tensor_types = model_loader.tensor_storages_types; + for (auto& item : tensor_types) { + // LOG_DEBUG("%s %u", item.first.c_str(), item.second); + if (contains(item.first, "qwen2vl") && ends_with(item.first, "weight") && (item.second == GGML_TYPE_F32 || item.second == GGML_TYPE_BF16)) { + item.second = GGML_TYPE_F16; + // LOG_DEBUG(" change %s %u", item.first.c_str(), item.second); + } + } + LOG_INFO("Version: %s ", model_version_to_str[version]); ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) ? (ggml_type)sd_ctx_params->wtype @@ -338,17 +354,7 @@ class StableDiffusionGGML { bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; { - clip_backend = backend; - bool use_t5xxl = false; - if (sd_version_is_dit(version) && !sd_version_is_qwen_image(version)) { - use_t5xxl = true; - } - if (!clip_on_cpu && !ggml_backend_is_cpu(backend) && use_t5xxl) { - LOG_WARN( - "!!!It appears that you are using the T5 model. Some backends may encounter issues with it." - "If you notice that the generated images are completely black," - "try running the T5 model on the CPU using the --clip-on-cpu parameter."); - } + clip_backend = backend; if (clip_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); @@ -427,9 +433,15 @@ class StableDiffusionGGML { clip_vision->get_param_tensors(tensors); } } else if (sd_version_is_qwen_image(version)) { + bool enable_vision = false; + if (!vae_decode_only) { + enable_vision = true; + } cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, - model_loader.tensor_storages_types); + model_loader.tensor_storages_types, + "", + enable_vision); diffusion_model = std::make_shared(backend, offload_params_to_cpu, model_loader.tensor_storages_types, @@ -600,7 +612,9 @@ class StableDiffusionGGML { if (vae_decode_only) { ignore_tensors.insert("first_stage_model.encoder"); + ignore_tensors.insert("first_stage_model.conv1"); ignore_tensors.insert("first_stage_model.quant"); + ignore_tensors.insert("text_encoders.qwen2vl.visual."); } if (version == VERSION_SVD) { ignore_tensors.insert("conditioner.embedders.3"); @@ -959,12 +973,12 @@ class StableDiffusionGGML { ggml_set_f32(output, 0.f); } else { sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image); - sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size); + sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size, clip_vision->vision_model.image_size); free(image.data); image.data = NULL; ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_tensor(resized_image.data, pixel_values, false); + sd_image_f32_to_tensor(resized_image, pixel_values, false); free(resized_image.data); resized_image.data = NULL; @@ -1001,7 +1015,7 @@ class StableDiffusionGGML { sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height); free(image.data); image.data = NULL; - sd_image_f32_to_tensor(resized_image.data, init_img, false); + sd_image_f32_to_tensor(resized_image, init_img, false); free(resized_image.data); resized_image.data = NULL; } else { @@ -1757,6 +1771,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "clip_vision_path: %s\n" "t5xxl_path: %s\n" "qwen2vl_path: %s\n" + "qwen2vl_vision_path: %s\n" "diffusion_model_path: %s\n" "high_noise_diffusion_model_path: %s\n" "vae_path: %s\n" @@ -1785,6 +1800,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->clip_vision_path), SAFE_STR(sd_ctx_params->t5xxl_path), SAFE_STR(sd_ctx_params->qwen2vl_path), + SAFE_STR(sd_ctx_params->qwen2vl_vision_path), SAFE_STR(sd_ctx_params->diffusion_model_path), SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path), SAFE_STR(sd_ctx_params->vae_path), @@ -1995,6 +2011,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, sd_image_t control_image, float control_strength, sd_pm_params_t pm_params, + std::vector ref_images, std::vector ref_latents, bool increase_ref_index, ggml_tensor* concat_latent = NULL, @@ -2027,6 +2044,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, ggml_tensor* init_img = NULL; SDCondition id_cond; std::vector class_tokens_mask; + + ConditionerParams condition_params; + condition_params.clip_skip = clip_skip; + condition_params.width = width; + condition_params.height = height; + condition_params.ref_images = ref_images; + condition_params.adm_in_channels = sd_ctx->sd->diffusion_model->get_adm_in_channels(); + if (sd_ctx->sd->stacked_id) { if (!sd_ctx->sd->pmid_lora->applied) { int64_t t0 = ggml_time_ms(); @@ -2049,7 +2074,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, std::vector processed_id_images; for (int i = 0; i < pm_params.id_images_count; i++) { sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]); - sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size); + sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size, clip_image_size); free(id_image.data); id_image.data = NULL; processed_id_images.push_back(processed_id_image); @@ -2066,17 +2091,15 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } processed_id_images.clear(); - int64_t t0 = ggml_time_ms(); - auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, - sd_ctx->sd->n_threads, prompt, - clip_skip, - width, - height, - pm_params.id_images_count, - sd_ctx->sd->diffusion_model->get_adm_in_channels()); - id_cond = std::get<0>(cond_tup); - class_tokens_mask = std::get<1>(cond_tup); // - struct ggml_tensor* id_embeds = NULL; + int64_t t0 = ggml_time_ms(); + condition_params.text = prompt; + condition_params.num_input_imgs = pm_params.id_images_count; + auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, + sd_ctx->sd->n_threads, + condition_params); + id_cond = std::get<0>(cond_tup); + class_tokens_mask = std::get<1>(cond_tup); // + struct ggml_tensor* id_embeds = NULL; if (pmv2 && pm_params.id_embed_path != nullptr) { id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path); // print_ggml_tensor(id_embeds, true, "id_embeds:"); @@ -2102,14 +2125,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } // Get learned condition - t0 = ggml_time_ms(); - SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, - sd_ctx->sd->n_threads, - prompt, - clip_skip, - width, - height, - sd_ctx->sd->diffusion_model->get_adm_in_channels()); + t0 = ggml_time_ms(); + condition_params.text = prompt; + condition_params.zero_out_masked = false; + SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, + sd_ctx->sd->n_threads, + condition_params); SDCondition uncond; if (guidance.txt_cfg != 1.0 || @@ -2118,14 +2139,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) { zero_out_masked = true; } - uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, - sd_ctx->sd->n_threads, - negative_prompt, - clip_skip, - width, - height, - sd_ctx->sd->diffusion_model->get_adm_in_channels(), - zero_out_masked); + condition_params.text = negative_prompt; + condition_params.zero_out_masked = zero_out_masked; + uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, + sd_ctx->sd->n_threads, + condition_params); } int64_t t1 = ggml_time_ms(); LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0); @@ -2546,13 +2564,42 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g std::vector ref_latents; for (int i = 0; i < ref_images.size(); i++) { - ggml_tensor* img = ggml_new_tensor_4d(work_ctx, - GGML_TYPE_F32, - ref_images[i]->width, - ref_images[i]->height, - 3, - 1); - sd_image_to_tensor(*ref_images[i], img); + ggml_tensor* img; + if (sd_version_is_qwen_image(sd_ctx->sd->version)) { + sd_image_f32_t ref_image = sd_image_t_to_sd_image_f32_t(*ref_images[i]); + int VAE_IMAGE_SIZE = std::min(1024 * 1024, width * height); + double vae_width = sqrt(VAE_IMAGE_SIZE * ref_image.width / ref_image.height); + double vae_height = vae_width * ref_image.height / ref_image.width; + + vae_height = round(vae_height / 32) * 32; + vae_width = round(vae_width / 32) * 32; + + sd_image_f32_t resized_image = resize_sd_image_f32_t(ref_image, static_cast(vae_width), static_cast(vae_height)); + free(ref_image.data); + ref_image.data = nullptr; + + LOG_DEBUG("resize vae ref image %d from %dx%d to %dx%d", i, ref_image.height, ref_image.width, resized_image.height, resized_image.width); + + img = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + resized_image.width, + resized_image.height, + 3, + 1); + sd_image_f32_to_tensor(resized_image, img); + free(resized_image.data); + resized_image.data = nullptr; + } else { + img = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + ref_images[i]->width, + ref_images[i]->height, + 3, + 1); + sd_image_to_tensor(*ref_images[i], img); + } + + // print_ggml_tensor(img, false, "img"); ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img); ref_latents.push_back(latent); @@ -2586,6 +2633,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->control_image, sd_img_gen_params->control_strength, sd_img_gen_params->pm_params, + ref_images, ref_latents, sd_img_gen_params->increase_ref_index, concat_latent, @@ -2843,30 +2891,25 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } // Get learned condition - bool zero_out_masked = true; - int64_t t1 = ggml_time_ms(); - SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, - sd_ctx->sd->n_threads, - prompt, - sd_vid_gen_params->clip_skip, - width, - height, - sd_ctx->sd->diffusion_model->get_adm_in_channels(), - zero_out_masked); - cond.c_concat = concat_latent; - cond.c_vector = clip_vision_output; + ConditionerParams condition_params; + condition_params.clip_skip = sd_vid_gen_params->clip_skip; + condition_params.zero_out_masked = true; + condition_params.text = prompt; + + int64_t t1 = ggml_time_ms(); + SDCondition cond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, + sd_ctx->sd->n_threads, + condition_params); + cond.c_concat = concat_latent; + cond.c_vector = clip_vision_output; SDCondition uncond; if (sd_vid_gen_params->sample_params.guidance.txt_cfg != 1.0 || sd_vid_gen_params->high_noise_sample_params.guidance.txt_cfg != 1.0) { - uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, - sd_ctx->sd->n_threads, - negative_prompt, - sd_vid_gen_params->clip_skip, - width, - height, - sd_ctx->sd->diffusion_model->get_adm_in_channels(), - zero_out_masked); - uncond.c_concat = concat_latent; - uncond.c_vector = clip_vision_output; + condition_params.text = negative_prompt; + uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, + sd_ctx->sd->n_threads, + condition_params); + uncond.c_concat = concat_latent; + uncond.c_vector = clip_vision_output; } int64_t t2 = ggml_time_ms(); LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1); diff --git a/stable-diffusion.h b/stable-diffusion.h index 4711b45a..1d3ed85c 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -132,6 +132,7 @@ typedef struct { const char* clip_vision_path; const char* t5xxl_path; const char* qwen2vl_path; + const char* qwen2vl_vision_path; const char* diffusion_model_path; const char* high_noise_diffusion_model_path; const char* vae_path; diff --git a/t5.hpp b/t5.hpp index 062e37bb..15f7af80 100644 --- a/t5.hpp +++ b/t5.hpp @@ -504,7 +504,9 @@ struct T5DenseGatedActDense : public UnaryBlock { T5DenseGatedActDense(int64_t model_dim, int64_t ff_dim) { blocks["wi_0"] = std::shared_ptr(new Linear(model_dim, ff_dim, false)); blocks["wi_1"] = std::shared_ptr(new Linear(model_dim, ff_dim, false)); - blocks["wo"] = std::shared_ptr(new Linear(ff_dim, model_dim, false)); + float scale = 1.f / 32.f; + // The purpose of the scale here is to prevent NaN issues on some backends(CUDA, ...). + blocks["wo"] = std::shared_ptr(new Linear(ff_dim, model_dim, false, false, false, scale)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { diff --git a/util.cpp b/util.cpp index 5af6b1ec..1d0bbd2b 100644 --- a/util.cpp +++ b/util.cpp @@ -84,6 +84,7 @@ int round_up_to(int value, int base) { } #ifdef _WIN32 // code for windows +#define NOMINMAX #include bool file_exists(const std::string& filename) { @@ -298,7 +299,7 @@ std::string trim(const std::string& s) { static sd_log_cb_t sd_log_cb = NULL; void* sd_log_cb_data = NULL; -#define LOG_BUFFER_SIZE 1024 +#define LOG_BUFFER_SIZE 4096 void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) { va_list args; @@ -387,10 +388,10 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int float original_x = (float)x * image.width / target_width; float original_y = (float)y * image.height / target_height; - int x1 = (int)original_x; - int y1 = (int)original_y; - int x2 = x1 + 1; - int y2 = y1 + 1; + uint32_t x1 = (uint32_t)original_x; + uint32_t y1 = (uint32_t)original_y; + uint32_t x2 = std::min(x1 + 1, image.width - 1); + uint32_t y2 = std::min(y1 + 1, image.height - 1); for (int k = 0; k < image.channel; k++) { float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); @@ -427,23 +428,26 @@ float means[3] = {0.48145466, 0.4578275, 0.40821073}; float stds[3] = {0.26862954, 0.26130258, 0.27577711}; // Function to clip and preprocess sd_image_f32_t -sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) { - float scale = (float)size / fmin(image.width, image.height); +sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height) { + float width_scale = (float)target_width / image.width; + float height_scale = (float)target_height / image.height; + + float scale = std::fmax(width_scale, height_scale); // Interpolation - int new_width = (int)(scale * image.width); - int new_height = (int)(scale * image.height); - float* resized_data = (float*)malloc(new_width * new_height * image.channel * sizeof(float)); + int resized_width = (int)(scale * image.width); + int resized_height = (int)(scale * image.height); + float* resized_data = (float*)malloc(resized_width * resized_height * image.channel * sizeof(float)); - for (int y = 0; y < new_height; y++) { - for (int x = 0; x < new_width; x++) { - float original_x = (float)x * image.width / new_width; - float original_y = (float)y * image.height / new_height; + for (int y = 0; y < resized_height; y++) { + for (int x = 0; x < resized_width; x++) { + float original_x = (float)x * image.width / resized_width; + float original_y = (float)y * image.height / resized_height; - int x1 = (int)original_x; - int y1 = (int)original_y; - int x2 = x1 + 1; - int y2 = y1 + 1; + uint32_t x1 = (uint32_t)original_x; + uint32_t y1 = (uint32_t)original_y; + uint32_t x2 = std::min(x1 + 1, image.width - 1); + uint32_t y2 = std::min(y1 + 1, image.height - 1); for (int k = 0; k < image.channel; k++) { float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); @@ -456,26 +460,28 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) { float value = interpolate(v1, v2, v3, v4, x_ratio, y_ratio); - *(resized_data + y * new_width * image.channel + x * image.channel + k) = value; + *(resized_data + y * resized_width * image.channel + x * image.channel + k) = value; } } } // Clip and preprocess - int h = (new_height - size) / 2; - int w = (new_width - size) / 2; + int h_offset = std::max((int)(resized_height - target_height) / 2, 0); + int w_offset = std::max((int)(resized_width - target_width) / 2, 0); sd_image_f32_t result; - result.width = size; - result.height = size; + result.width = target_width; + result.height = target_height; result.channel = image.channel; - result.data = (float*)malloc(size * size * image.channel * sizeof(float)); + result.data = (float*)malloc(target_height * target_width * image.channel * sizeof(float)); for (int k = 0; k < image.channel; k++) { - for (int i = 0; i < size; i++) { - for (int j = 0; j < size; j++) { - *(result.data + i * size * image.channel + j * image.channel + k) = - fmin(fmax(*(resized_data + (i + h) * new_width * image.channel + (j + w) * image.channel + k), 0.0f), 255.0f) / 255.0f; + for (int i = 0; i < result.height; i++) { + for (int j = 0; j < result.width; j++) { + int src_y = std::min(i + h_offset, resized_height - 1); + int src_x = std::min(j + w_offset, resized_width - 1); + *(result.data + i * result.width * image.channel + j * image.channel + k) = + fmin(fmax(*(resized_data + src_y * resized_width * image.channel + src_x * image.channel + k), 0.0f), 255.0f) / 255.0f; } } } @@ -485,10 +491,10 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) { // Normalize for (int k = 0; k < image.channel; k++) { - for (int i = 0; i < size; i++) { - for (int j = 0; j < size; j++) { + for (int i = 0; i < result.height; i++) { + for (int j = 0; j < result.width; j++) { // *(result.data + i * size * image.channel + j * image.channel + k) = 0.5f; - int offset = i * size * image.channel + j * image.channel + k; + int offset = i * result.width * image.channel + j * image.channel + k; float value = *(result.data + offset); value = (value - means[k]) / stds[k]; // value = 0.5f; diff --git a/util.h b/util.h index 1e8db6e3..17bcd1d3 100644 --- a/util.h +++ b/util.h @@ -42,7 +42,7 @@ sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image); sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int target_height); -sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size); +sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int target_height); std::string path_join(const std::string& p1, const std::string& p2); std::vector split_string(const std::string& str, char delimiter); diff --git a/wan.hpp b/wan.hpp index af829b1a..31fa90b3 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1333,7 +1333,7 @@ namespace WAN { k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] - x = Flux::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] + x = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x;