diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c6f5ba6a04c54..f988ed3f12540 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3577,15 +3577,17 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] -@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration", "Eagle2_5_VLForConditionalGeneration") class Qwen2VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.hparams_vision is not None self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) - # rename config.json values - self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") - self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") + # rename config.json values for Qwen models + if self.hparams_vision.get("num_heads") is not None: + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") + if self.hparams_vision.get("depth") is not None: + self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") if "embed_dim" in self.hparams_vision: # qwen2vl self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size") self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim") @@ -3612,6 +3614,44 @@ def set_gguf_parameters(self): if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern: raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}") self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern) + elif model_type in ['eagle_2_5_vl', 'eagle2_vl', 'eagle2_5_vl']: + # Eagle2-VL uses MLP projector with 2x2 patch merge + # Structure: Vision encoder → 2x2 patch merge → LayerNorm → Linear → GELU → Linear + self.gguf_writer.add_clip_projector_type("mlp") + + # Add spatial_merge_size for patch merge (stored as n_merge in hparams) + self.gguf_writer.add_vision_spatial_merge_size(2) + + # Add grid dimensions for runtime to calculate merge + image_size = self.find_vparam(["image_size"]) + patch_size = self.find_vparam(["patch_size"]) + grid_h = grid_w = image_size // patch_size + self.gguf_writer.add_key_value("clip.vision.grid_h", grid_h, gguf.GGUFValueType.INT32) + self.gguf_writer.add_key_value("clip.vision.grid_w", grid_w, gguf.GGUFValueType.INT32) + + # Eagle2-VL uses window attention similar to Qwen2.5-VL but doesn't have fullatt_block_indexes + # Set a reasonable default window attention pattern (every 4th layer uses full attention) + n_wa_pattern = 4 # Default value for Eagle2-VL based on similar models + self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern) + + # --- BEGIN: Eagle2 fallback for required vision metadata --- + assert self.hparams_vision is not None + hv = self.hparams_vision + # block_count (num of vision layers) fallback - check original vision_config first + blk = hv.get('num_hidden_layers') or hv.get('num_layers') or hv.get('n_layers') + if blk is None: + # Try to get from original vision_config before any transformations + original_vision_config = self.global_config.get('vision_config', {}) + blk = original_vision_config.get('num_hidden_layers') or original_vision_config.get('num_layers') or original_vision_config.get('n_layers') + if blk is None: + # As a last resort, try to infer from config layout if present + # (keep it simple: raise with a clear message if still missing) + raise ValueError("Eagle2: missing vision block count (num_hidden_layers/num_layers/n_layers) in vision_config") + self.gguf_writer.add_vision_block_count(int(blk)) + # (Optional) You can add other explicit fallbacks here only if they also turn out None later: + # head_count = hv.get('num_attention_heads', hv.get('num_heads')) + # if head_count is not None: self.gguf_writer.add_vision_head_count(int(head_count)) + # --- END: Eagle2 fallback --- else: raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}") # default values below are taken from HF tranformers code @@ -3624,8 +3664,47 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - if name.startswith("visual."): + if name.startswith("visual.") or name.startswith("vision_model.") or name.startswith("mlp1."): + # Skip all vision model head layers - not needed for mmproj + if ".head." in name: + return [] + + # Handle projector tensors (Eagle2-VL uses mlp1.N.weight/bias pattern) + if name.startswith("mlp1."): + # Eagle2-VL has: LayerNorm(0) → Linear(1) → GELU(2) → Linear(3) + # QWEN2VL projector expects: Linear(0) → GELU → Linear(2) + # So we need to remap: mlp1.1 → mm.0, mlp1.3 → mm.2 + # Skip mlp1.0 (LayerNorm) as it's not used by QWEN2VL projector type + if ".0." in name: + # Skip LayerNorm layer + return [] + elif ".1." in name: + # Map first Linear layer (mlp1.1) to mm.0 + # Original: [896, 4608] -> Need to transpose for GGML: [4608, 896] + if ".weight" in name: + new_name = name.replace("mlp1.1.", "mm.0.") + return [(new_name, data_torch.T)] # Transpose the weight + else: + new_name = name.replace("mlp1.1.", "mm.0.") + return [(new_name, data_torch)] + elif ".3." in name: + # Map second Linear layer (mlp1.3) to mm.2 + # Original: [896, 896] -> Need to transpose for GGML: [896, 896] (square matrix) + if ".weight" in name: + new_name = name.replace("mlp1.3.", "mm.2.") + return [(new_name, data_torch.T)] # Transpose the weight + else: + new_name = name.replace("mlp1.3.", "mm.2.") + return [(new_name, data_torch)] + else: + # Unknown mlp1 layer + return [] + # process visual tensors + # Handle Eagle2-VL specific naming: vision_model.vision_model.* -> model.vision_model.* + if name.startswith("vision_model.vision_model."): + name = name.replace("vision_model.vision_model.", "model.vision_model.") + # split QKV tensors if needed if ".qkv." in name: if data_torch.ndim == 2: # weight @@ -3653,6 +3732,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter ] else: return [(self.map_tensor_name(name), data_torch)] + elif name.startswith("multi_modal_projector."): + # Handle projector tensors (for other Qwen2.5-VL models that use multi_modal_projector prefix) + # Convert mm.model.mlp.N.weight/bias to mm.N.weight/bias pattern + new_name = name.replace("multi_modal_projector.", "") + if "mm.model.mlp." in new_name: + new_name = new_name.replace("mm.model.mlp.", "mm.") + return [(new_name, data_torch)] return [] # skip other tensors diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 99775cb3e351c..f14eafa5589be 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1,7 +1,4 @@ -// NOTE: This is modified from clip.cpp only for LLaVA, -// so there might be still unnecessary artifacts hanging around -// I'll gradually clean and extend it -// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch +#include #include "clip.h" #include "clip-impl.h" #include "ggml.h" @@ -12,6 +9,7 @@ #include #include +#include #include #include #include @@ -185,6 +183,11 @@ struct clip_hparams { patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT; + int32_t patch_merge_factor = 1; + std::string patch_merge_mode = "flat"; + int32_t grid_h = 0; + int32_t grid_w = 0; + float eps = 1e-6; float rope_theta = 0.0; @@ -667,9 +670,28 @@ struct clip_graph { // LlavaMultiModalProjector (always using GELU activation) { - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - if (model.mm_1_b) { - cur = ggml_add(ctx0, cur, model.mm_1_b); + // Eagle2-VL: Apply patch merge before MLP projection if n_merge > 1 + // Prefer clip.vision.spatial_merge_size; treat n_merge==1 as no-merge + if (hparams.n_merge > 1 && + (model.proj_type == PROJECTOR_TYPE_MLP || model.proj_type == PROJECTOR_TYPE_MLP_NORM)) { + const int scale_factor = hparams.n_merge; + cur = build_patch_merge_permute(cur, scale_factor); + } + + // Use mm_0_w/mm_0_b if available (Eagle2-VL), otherwise mm_1_w/mm_1_b (standard LLaVA) + ggml_tensor * first_w = model.mm_0_w ? model.mm_0_w : model.mm_1_w; + ggml_tensor * first_b = model.mm_0_b ? model.mm_0_b : model.mm_1_b; + + // Ensure 2D and correct orientation for matmul: first_w[out,in] x cur[in, tokens] + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1]); + if (first_w && first_w->ne[1] != cur->ne[0]) { + // transpose to match expected [in, tokens] + cur = ggml_transpose(ctx0, cur); + cur = ggml_cont(ctx0, cur); + } + cur = ggml_mul_mat(ctx0, first_w, cur); + if (first_b) { + cur = ggml_add(ctx0, cur, first_b); } cur = ggml_gelu(ctx0, cur); @@ -686,8 +708,8 @@ struct clip_graph { // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows] - const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y; - const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x; + const int p_y = n_merge > 1 ? n_patches_y / n_merge : n_patches_y; + const int p_x = n_merge > 1 ? n_patches_x / n_merge : n_patches_x; const int p_total = p_x * p_y; const int n_embd_text = cur->ne[0]; const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row @@ -710,7 +732,7 @@ struct clip_graph { // Qwen2VL and Qwen2.5VL use M-RoPE ggml_cgraph * build_qwen2vl() { - GGML_ASSERT(model.patch_bias == nullptr); + // Eagle2-VL and some variants may have patch bias GGML_ASSERT(model.class_embedding == nullptr); const int batch_size = 1; @@ -749,6 +771,12 @@ struct clip_graph { n_embd, n_patches_x * n_patches_y, batch_size); } + // add patch bias if present (Eagle2-VL has patch bias) + if (model.patch_bias != nullptr) { + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + } + ggml_tensor * inpL = inp; ggml_tensor * window_mask = nullptr; ggml_tensor * window_idx = nullptr; @@ -867,10 +895,30 @@ struct clip_graph { inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); } + // Apply patch merge based on metadata + // Preferred: clip.vision.spatial_merge_size -> hparams.n_merge + // Fallback: legacy keys clip.vision.patch_merge_factor/mode + bool did_spatial_merge = false; + if (hparams.n_merge > 1) { + // e.g. [1152, 1024] -> [1152 * n_merge^2, 1024 / n_merge^2] + inpL = build_patch_merge_permute(inpL, hparams.n_merge); + did_spatial_merge = true; + } else if (hparams.patch_merge_factor > 1 && + (hparams.patch_merge_mode == "concat2x2" || hparams.patch_merge_mode == "concat")) { + // legacy fallback (kept for backward compatibility) + inpL = build_patch_merge_permute(inpL, hparams.patch_merge_factor); + did_spatial_merge = true; + } + // multimodal projection ggml_tensor * embeddings = inpL; - embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + // Conditional reshape based on whether patch merge was applied + if (!did_spatial_merge) { + // Standard Qwen2VL path assumes 2x2 merge semantics without explicit permute + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + } + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); @@ -1537,26 +1585,34 @@ struct clip_graph { // llava projector (also used by granite) if (ctx->model.hparams.has_llava_projector) { + // consume the full post-merge sequence directly; no row selection via patches embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); - - ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); - ggml_set_name(patches, "patches"); - ggml_set_input(patches); - - // shape [1, 576, 1024] - // ne is whcn, ne = [1024, 576, 1, 1] - embeddings = ggml_get_rows(ctx0, embeddings, patches); - - // print_tensor_info(embeddings, "embeddings"); + // llava projector if (ctx->proj_type() == PROJECTOR_TYPE_MLP) { - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + // apply 2x2 patch merge on [C, T] layout directly when n_merge > 1 + if (hparams.n_merge > 1) { + // ensure contiguous before reshape/permutation in patch merge + embeddings = ggml_cont(ctx0, embeddings); + const int scale_factor = hparams.n_merge; + embeddings = build_patch_merge_permute(embeddings, scale_factor); + } + ggml_tensor * w0 = model.mm_0_w; + // ensure projector weight orientation matches embeddings + if (w0->ne[0] != embeddings->ne[0] && w0->ne[1] == embeddings->ne[0]) { + w0 = ggml_cont(ctx0, ggml_transpose(ctx0, w0)); + } + embeddings = ggml_mul_mat(ctx0, w0, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_gelu(ctx0, embeddings); if (model.mm_2_w) { - embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + ggml_tensor * w2 = model.mm_2_w; + if (w2->ne[0] != embeddings->ne[0] && w2->ne[1] == embeddings->ne[0]) { + w2 = ggml_cont(ctx0, ggml_transpose(ctx0, w2)); + } + embeddings = ggml_mul_mat(ctx0, w2, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); } } @@ -2706,6 +2762,12 @@ struct clip_model_loader { if (mm_patch_merge_type == "spatial_unpad") { hparams.mm_patch_merge_type = PATCH_MERGE_SPATIAL_UNPAD; } + + // Load Eagle2-VL specific patch merge metadata + get_i32("clip.vision.patch_merge_factor", hparams.patch_merge_factor, false); + get_string("clip.vision.patch_merge_mode", hparams.patch_merge_mode, false); + get_i32("clip.vision.grid_h", hparams.grid_h, false); + get_i32("clip.vision.grid_w", hparams.grid_w, false); } if (is_vision) { @@ -2735,6 +2797,14 @@ struct clip_model_loader { // model-specific params switch (model.proj_type) { + case PROJECTOR_TYPE_MLP: + case PROJECTOR_TYPE_MLP_NORM: + { + // Eagle2-VL: Load spatial merge size for patch merge + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + (void)hparams.n_merge; // keep variable referenced even if unused + } + break; case PROJECTOR_TYPE_MINICPMV: { if (hparams.minicpmv_version == 0) { @@ -3587,13 +3657,25 @@ void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) { dst.nx = src.nx; dst.ny = src.ny; - dst.buf.resize(src.buf.size()); - - // TODO @ngxson : seems like this could be done more efficiently on cgraph - for (size_t i = 0; i < src.buf.size(); ++i) { - int c = i % 3; // rgb - dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; + const size_t plane_sz = (size_t) dst.nx * (size_t) dst.ny; + dst.buf.resize(3 * plane_sz); // planar RGB + + + + for (int y = 0; y < dst.ny; ++y) { + for (int x = 0; x < dst.nx; ++x) { + size_t base = (size_t) y * (size_t) dst.nx + (size_t) x; + for (int c = 0; c < 3; ++c) { + size_t src_idx = 3ull * base + (size_t) c; // interleaved in src + float raw = static_cast(src.buf[src_idx]) / 255.0f; + float v = (raw - mean[c]) / std[c]; + size_t dst_idx = (size_t) c * plane_sz + base; // planar in dst + dst.buf[dst_idx] = v; + + } + } } + } // set of tools to manupulate images @@ -4462,7 +4544,12 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_MLP_NORM: case PROJECTOR_TYPE_JANUS_PRO: { - // do nothing + // account for spatial patch merge when present (e.g., Eagle2-VL) + // both X and Y are downscaled by the merge factor + const int scale_factor = ctx->model.hparams.n_merge; + if (scale_factor > 0) { + n_patches /= (scale_factor * scale_factor); + } } break; case PROJECTOR_TYPE_LDP: case PROJECTOR_TYPE_LDPV2: @@ -4962,7 +5049,19 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (int i = 0; i < num_patches; i++) { patches[i] = i + patch_offset; } - set_input_i32("patches", patches); + // Make patches optional: if the graph doesn't contain an input named "patches" + // (Eagle2-VL full-sequence path), skip without aborting. + ggml_tensor * patches_tensor = ggml_graph_get_tensor(gf, "patches"); + if (patches_tensor && (patches_tensor->flags & GGML_TENSOR_FLAG_INPUT)) { + GGML_ASSERT(patches_tensor->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_nelements(patches_tensor) == (int64_t)patches.size()); + ggml_backend_tensor_set(patches_tensor, patches.data(), 0, ggml_nbytes(patches_tensor)); + } else { + // Only log in verbose contexts (llava projector present) to avoid spam for other models. + if (ctx->model.hparams.has_llava_projector) { + // no 'patches' tensor in graph (full-sequence path) + } + } } break; case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: @@ -5013,6 +5112,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; } + + // print debug nodes if (ctx->debug_graph) { LOG_INF("\n\n---\n\n"); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 325f7ff995e36..e599137769963 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -101,16 +101,17 @@ static clip_flash_attn_type mtmd_get_clip_flash_attn_type(enum llama_flash_attn_ } mtmd_context_params mtmd_context_params_default() { - mtmd_context_params params; - params.use_gpu = true; - params.print_timings = true; - params.n_threads = 4; - params.verbosity = GGML_LOG_LEVEL_INFO; - params.image_marker = MTMD_DEFAULT_IMAGE_MARKER; - params.media_marker = mtmd_default_marker(); - params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; - params.image_min_tokens = -1; - params.image_max_tokens = -1; + mtmd_context_params params { + /* use_gpu */ true, + /* print_timings */ true, + /* n_threads */ 4, + /* verbosity */ GGML_LOG_LEVEL_INFO, + /* image_marker */ MTMD_DEFAULT_IMAGE_MARKER, + /* media_marker */ mtmd_default_marker(), + /* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO, + /* image_min_tokens */ -1, + /* image_max_tokens */ -1, + }; return params; } @@ -162,7 +163,7 @@ struct mtmd_context { print_timings(ctx_params.print_timings), n_threads (ctx_params.n_threads), media_marker (ctx_params.media_marker), - n_embd_text (llama_model_n_embd(text_model)) + n_embd_text (llama_model_n_embd_inp(text_model)) { if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) { throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead"); @@ -172,13 +173,13 @@ struct mtmd_context { throw std::runtime_error("media_marker must not be empty"); } - clip_context_params ctx_clip_params; - ctx_clip_params.use_gpu = ctx_params.use_gpu; - ctx_clip_params.verbosity = ctx_params.verbosity; - ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type); - // custom image token limits - ctx_clip_params.image_min_tokens = ctx_params.image_min_tokens; - ctx_clip_params.image_max_tokens = ctx_params.image_max_tokens; + clip_context_params ctx_clip_params { + /* use_gpu */ ctx_params.use_gpu, + /* verbosity */ ctx_params.verbosity, + /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, + /* image_min_tokens */ ctx_params.image_min_tokens, + /* image_max_tokens */ ctx_params.image_max_tokens, + }; auto res = clip_init(mmproj_fname, ctx_clip_params); ctx_v = res.ctx_v;