From 43a130b4d0af45034a5abb2db6884c8fd08fee8b Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:40:20 +0100 Subject: [PATCH 01/37] mtmd: llama.cpp DeepSeekOCR support init commit --- convert_hf_to_gguf.py | 107 ++++++- gguf-py/gguf/constants.py | 32 +++ gguf-py/gguf/tensor_mapping.py | 54 ++++ tools/mtmd/clip-impl.h | 20 ++ tools/mtmd/clip.cpp | 504 ++++++++++++++++++++++++++++++++- 5 files changed, 696 insertions(+), 21 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 222f6ed6dc4..82a6c95bdd0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -620,6 +620,9 @@ def load_hparams(dir_model: Path, is_mistral_format: bool): if "thinker_config" in config: # rename for Qwen2.5-Omni config["text_config"] = config["thinker_config"]["text_config"] + if "language_config" in config: + # rename for DeepSeekOCR + config["text_config"] = config["language_config"] return config @classmethod @@ -1442,7 +1445,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "width.clip-l-14-224.layers", "sam_vit_b.layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1488,13 +1491,31 @@ def __init__(self, *args, **kwargs): # TODO @ngxson : this is a hack to support both vision and audio encoders have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) + # FIXME: DeepseekOCRVisionModel specific hack + if self.block_count is None: + if isinstance(self, DeepseekOCRVisionModel): + clip_block_count = self.hparams['width']['clip-l-14-224']['layers'] + sam_block_count = self.hparams['width']['sam_vit_b']['layers'] + if clip_block_count is not None: + self.block_count = clip_block_count + if sam_block_count is not None: + self.block_count = sam_block_count if self.block_count is None else self.block_count + sam_block_count + if self.block_count is None: + raise KeyError(f"could not find block count using any of: {self.n_block_keys}") self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) # load preprocessor config self.preprocessor_config = {} if not self.is_mistral_format: - with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: - self.preprocessor_config = json.load(f) + # check if preprocessor_config.json exists + if (self.dir_model / "preprocessor_config.json").is_file(): + with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: + self.preprocessor_config = json.load(f) + else: + # try "processing_config" file if exists + if (self.dir_model / "processing_config.json").is_file(): + with open(self.dir_model / "processing_config.json", "r", encoding="utf-8") as f: + self.preprocessor_config = json.load(f) def get_vision_config(self) -> dict[str, Any] | None: config_name = "vision_config" if not self.is_mistral_format else "vision_encoder" @@ -5770,6 +5791,61 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("DeepseekOCRForCausalLM") +class DeepseekOCRVisionModel(MmprojModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DEEPSEEKOCR) + # default values below are taken from HF tranformers code + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) + self.gguf_writer.add_vision_use_gelu(True) + # calculate proj_scale_factor (used by tinygemma3 test model) + image_seq_length = self.preprocessor_config.get("image_seq_length", 256) + n_per_side = int(image_seq_length ** 0.5) + image_size = self.hparams["image_size"] + patch_size = self.hparams["patch_size"] + proj_scale_factor = (image_size // patch_size) // n_per_side + if proj_scale_factor > 0 and proj_scale_factor != 4: + # we only need to write this if it's not the default value + # in this case, we are converting a test model + self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor) + + def get_vision_config(self) -> dict[str, Any]: + orig_vision_config = self.global_config.get("vision_config") + + super().get_vision_config() + + def tensor_force_quant(self, name, new_name, bid, n_dims): + # related to https://github.com/ggml-org/llama.cpp/issues/13025 + if "input_projection" in name: + return gguf.GGMLQuantizationType.F16 + if ".embeddings." in name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if "vision_model.head." in name: + return [] # skip redundant tensors for tinygemma3 + + if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ + or name.startswith("multimodal_projector.") or name.startswith("vision_model."): + # process vision tensors + name = name.replace("_weight", ".weight") + + # correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector + # the other norm values are part of SigLIP model, and they are already correct + # ref code: Gemma3RMSNorm + if "soft_emb_norm.weight" in name: + logger.info(f"Correcting norm value for '{name}'") + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + return [] # skip other tensors + @ModelBase.register("Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): @@ -6943,6 +7019,7 @@ def prepare_tensors(self): @ModelBase.register( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", + "DeepseekOCRForCausalLM", "KimiVLForConditionalGeneration", ) class DeepseekV2Model(TextModel): @@ -7009,31 +7086,35 @@ def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams + kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 + routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0) + norm_topk_prob = hparams.get("norm_topk_prob", False) + scoring_func = hparams.get("scoring_func", "softmax") self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) - self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(kv_lora_rank) # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(kv_lora_rank) self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) - self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) - if hparams["scoring_func"] == "sigmoid": + if scoring_func == "sigmoid": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) - elif hparams["scoring_func"] == "softmax": + elif scoring_func == "softmax": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) else: - raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + raise ValueError(f"Unsupported scoring_func value: {scoring_func}") self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) @@ -7043,12 +7124,14 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # skip vision tensors and remove "language_model." for Kimi-VL - if "vision_tower" in name or "multi_modal_projector" in name: + if "vision_" in name or "multi_modal_projector" in name \ + or "image_newline" in name or "model.projector" in name or "sam_model" in name or "view_seperator" in name: return [] if name.startswith("language_model."): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6b4b6c5ab07..dfd947083a2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -664,6 +664,21 @@ class MODEL_TENSOR(IntEnum): V_MM_GATE = auto() # cogvlm V_TOK_BOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm + # DeepSeek-OCR sam_model + V_SAM_POS_EMBD = auto() + V_SAM_PATCH_EMBD = auto() + V_SAM_PRE_NORM = auto() + V_SAM_POST_NORM = auto() + V_SAM_ATTN_POS_H = auto() + V_SAM_ATTN_POS_W = auto() + V_SAM_ATTN_QKV = auto() + V_SAM_ATTN_OUT = auto() + V_SAM_MLP_LIN_1 = auto() + V_SAM_MLP_LIN_2 = auto() + V_SAM_NECK = auto() + V_SAM_NET_2 = auto() + V_SAM_NET_3 = auto() + # audio (mtmd) A_ENC_EMBD_POS = auto() A_ENC_CONV1D = auto() @@ -1030,6 +1045,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_GATE: "mm.gate", MODEL_TENSOR.V_TOK_BOI: "v.boi", MODEL_TENSOR.V_TOK_EOI: "v.eoi", + # DeepSeek-OCR sam_model + MODEL_TENSOR.V_SAM_POS_EMBD: "v.sam.pos_embd", + MODEL_TENSOR.V_SAM_PATCH_EMBD: "v.sam.patch_embd", + MODEL_TENSOR.V_SAM_PRE_NORM: "v.sam.blk.{bid}.pre_ln", + MODEL_TENSOR.V_SAM_POST_NORM: "v.sam.blk.{bid}.post_ln", + MODEL_TENSOR.V_SAM_ATTN_POS_H: "v.sam.blk.{bid}.attn.pos_h", + MODEL_TENSOR.V_SAM_ATTN_POS_W: "v.sam.blk.{bid}.attn.pos_w", + MODEL_TENSOR.V_SAM_ATTN_QKV: "v.sam.blk.{bid}.attn.qkv", + MODEL_TENSOR.V_SAM_ATTN_OUT: "v.sam.blk.{bid}.attn.out", + MODEL_TENSOR.V_SAM_MLP_LIN_1: "v.sam.blk.{bid}.mlp.lin1", + MODEL_TENSOR.V_SAM_MLP_LIN_2: "v.sam.blk.{bid}.mlp.lin2", + MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}", + MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2", + MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3", # audio (mtmd) MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", @@ -2247,7 +2276,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, @@ -3207,6 +3238,7 @@ class VisionProjectorType: LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" + DEEPSEEKOCR = "deepseekocr" # Items here are (block size, type size) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 92940668761..f15ea0a02a0 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -2,6 +2,8 @@ from typing import Sequence +from numpy.f2py.auxfuncs import throw_error + from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES @@ -1457,6 +1459,58 @@ class TensorNameMap: "model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl ), + MODEL_TENSOR.V_SAM_POS_EMBD: ( + "model.sam_model.pos_embed" + ), + + MODEL_TENSOR.V_SAM_PATCH_EMBD: ( + "model.sam_model.patch_embed.proj" + ), + + MODEL_TENSOR.V_SAM_PRE_NORM: ( + "model.sam_model.blocks.{bid}.norm1", # deepstack in qwen3vl + ), + + MODEL_TENSOR.V_SAM_POST_NORM: ( + "model.sam_model.blocks.{bid}.norm2", # deepstack in qwen3vl + ), + + MODEL_TENSOR.V_SAM_ATTN_POS_H: ( + "model.sam_model.blocks.{bid}.attn.rel_pos_h" + ), + + MODEL_TENSOR.V_SAM_ATTN_POS_W: ( + "model.sam_model.blocks.{bid}.attn.rel_pos_w" + ), + + MODEL_TENSOR.V_SAM_ATTN_QKV: ( + "model.sam_model.blocks.{bid}.attn.qkv" + ), + + MODEL_TENSOR.V_SAM_ATTN_OUT: ( + "model.sam_model.blocks.{bid}.attn.proj" + ), + + MODEL_TENSOR.V_SAM_MLP_LIN_1: ( + "model.sam_model.blocks.{bid}.mlp.lin1", + ), + + MODEL_TENSOR.V_SAM_MLP_LIN_2: ( + "model.sam_model.blocks.{bid}.mlp.lin2", + ), + + MODEL_TENSOR.V_SAM_NECK: ( + "model.sam_model.neck.{bid}" + ), + + MODEL_TENSOR.V_SAM_NET_2: ( + "model.sam_model.net_2" + ), + + MODEL_TENSOR.V_SAM_NET_3: ( + "model.sam_model.net_3" + ), + MODEL_TENSOR.V_MM_POST_FC_NORM: ( "model.vision.linear_proj.norm1", # cogvlm ), diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 722b1a4948d..8d1c7d0dff1 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -129,6 +129,24 @@ #define TN_TOK_BOI "v.boi" #define TN_TOK_EOI "v.eoi" +// deepseek-ocr +#define TN_SAM_POS_EMBD "sam.pos_embd" +#define TN_SAM_PATCH_EMBD "sam.patch_embd" +#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln" +#define TN_SAM_POST_NORM "sam.blk.%d.post_ln" +#define TN_SAM_ATTN_POS_H "sam.blk.%d.attn.pos_h" +#define TN_SAM_ATTN_POS_W "sam.blk.%d.attn.pos_w" +#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv" +#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out" +#define TN_SAM_MLP_LIN_1 "sam.blk.%d.mlp.lin1" +#define TN_SAM_MLP_LIN_2 "sam.blk.%d.mlp.lin2" +#define TN_SAM_NECK "sam.neck.%d" +#define TN_SAM_NET_2 "sam.net_2" +#define TN_SAM_NET_3 "sam.net_3" + + +#define TN_SAM_ATTN_OUT "sam.blk.%d.attn_out" + // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -156,6 +174,7 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, + PROJECTOR_TYPE_DEEPSEEK_OCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -182,6 +201,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, + { PROJECTOR_TYPE_DEEPSEEK_OCR,"deepseek_orc"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d1423b67f98..0961b96fd6f 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -222,6 +222,33 @@ struct clip_hparams { warmup_image_size = n_tok_per_side * patch_size * cur_merge; // TODO: support warmup size for custom token numbers } + + // sam vit deepseek-ocr + std::vector global_attn_indices() const { + switch (n_embd) { + case 768: return { 2, 5, 8, 11 }; + case 1024: return { 5, 11, 17, 23 }; + case 1280: return { 7, 15, 23, 31 }; + default: + { + fprintf(stderr, "%s: unsupported n_enc_state = %d\n", __func__, n_embd); + } break; + }; + + return {}; + } + + bool is_global_attn(int32_t layer) const { + const auto indices = global_attn_indices(); + + for (const auto & idx : indices) { + if (layer == idx) { + return true; + } + } + + return false; + } }; struct clip_layer { @@ -271,6 +298,10 @@ struct clip_layer { bool has_deepstack() const { return deepstack_fc1_w != nullptr; } + + // sam rel_pos + ggml_tensor * rel_pos_w = nullptr; + ggml_tensor * rel_pos_h = nullptr; }; struct clip_model { @@ -308,6 +339,7 @@ struct clip_model { ggml_tensor * mm_2_b = nullptr; ggml_tensor * image_newline = nullptr; + ggml_tensor * view_seperator = nullptr; // Yi type models with mlp+normalization projection ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 @@ -400,6 +432,11 @@ struct clip_model { ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_eoi = nullptr; + // deepseek ocr sam + ggml_tensor * patch_embed_proj_w = nullptr; + ggml_tensor * patch_embed_proj_b = nullptr; + ggml_tensor * pos_embed = nullptr; + bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A || proj_type == PROJECTOR_TYPE_VOXTRAL; @@ -409,6 +446,15 @@ struct clip_model { return proj_type == PROJECTOR_TYPE_ULTRAVOX || proj_type == PROJECTOR_TYPE_VOXTRAL; } + ggml_tensor * neck_conv_0; + ggml_tensor * neck_norm_0_w; + ggml_tensor * neck_norm_0_b; + ggml_tensor * neck_conv_1; + ggml_tensor * neck_norm_1_w; + ggml_tensor * neck_norm_1_b; + + std::vector enc_layers; + }; struct clip_ctx { @@ -521,9 +567,9 @@ struct clip_graph { hparams(model.hparams), img(img), patch_size(hparams.patch_size), - n_patches_x(img.nx / patch_size), - n_patches_y(img.ny / patch_size), - n_patches(n_patches_x * n_patches_y), + n_patches_x(img.nx / patch_size), // sam 1024 / 16 = 64 + n_patches_y(img.ny / patch_size), // sam 1024 / 16 = 64 + n_patches(n_patches_x * n_patches_y), // sam 64 * 64 = 4096 n_embd(hparams.n_embd), n_head(hparams.n_head), d_head(n_embd / n_head), @@ -619,6 +665,244 @@ struct clip_graph { return gf; } + ggml_tensor * build_sam_enc(ggml_tensor * inp_raw, + const int enc_image_size = 1024 + ) { + constexpr int enc_n_embd = 768; + constexpr int _depth = 12; + constexpr int enc_n_heads = 12; + constexpr int enc_d_heads = enc_n_embd / enc_n_heads; + constexpr int _prompt_n_embd = 256; + constexpr int enc_patch_size = 16; + constexpr int _window_size = 14; + + const int enc_n_patches = enc_image_size / enc_patch_size; // 64 + + ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_image_size, enc_n_embd); + ggml_tensor * cur = ggml_add(ctx0, inpL, model.position_embeddings); + + // loop over layers + for (int il = 0; il < _depth; il++) { + auto & layer = model.enc_layers[il]; + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "enc_layer_inp_normed", il); + + const int64_t w0 = cur->ne[1]; + const int64_t h0 = cur->ne[2]; + + if (hparams.is_global_attn(il) == false) { + // local attention layer - apply window partition + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 + cur = ggml_win_part(ctx0, cur, 14); + } + + const int64_t W = cur->ne[1]; + const int64_t H = cur->ne[2]; + + // self-attention + { + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + cur = ggml_add(ctx0, cur, layer.qkv_b); + const int B = cur->ne[3]; + + cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W * H, B); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2)); + + ggml_tensor * Qcur = + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 0); + Qcur = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, enc_n_heads, W * H, B); + Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); + Qcur = ggml_reshape_3d(ctx0, Qcur, enc_d_heads, W * H, B * enc_n_heads); + + ggml_tensor * Kcur = + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 1 * cur->nb[3]); + Kcur = ggml_reshape_4d(ctx0, Kcur, enc_d_heads, enc_n_heads, W * H, B); + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + Kcur = ggml_reshape_3d(ctx0, Kcur, enc_d_heads, W * H, B * enc_n_heads); + + ggml_tensor * Vcur = + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 2 * cur->nb[3]); + Vcur = ggml_reshape_4d(ctx0, Vcur, enc_d_heads, enc_n_heads, W * H, B); + Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); // transposed + Vcur = ggml_reshape_3d(ctx0, Vcur, W * H, enc_d_heads, B * enc_n_heads); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcur, Qcur); + + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_n_heads)); + + struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W); + struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H); + + struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_n_heads, W, H, B * enc_n_embd); + + struct ggml_tensor * rel_w = ggml_cont( + ctx0, + ggml_permute(ctx0, ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), 0, + 2, 1, 3)); + struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); + + struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcur, KQ_soft_max); + + cur = ggml_reshape_4d( + ctx0, + ggml_cont(ctx0, ggml_permute(ctx0, ggml_reshape_4d(ctx0, KQV, enc_d_heads, W * H, enc_n_heads, B), + 0, 2, 1, 3)), + n_embd, W, H, B); + + cur = ggml_mul_mat(ctx0, layer.o_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.o_b); + } + + if (hparams.is_global_attn(il) == false) { + // local attention layer - reverse window partition + cur = ggml_win_unpart(ctx0, cur, w0, h0, 14); + } + + if (layer.ls_1_w) { + cur = ggml_mul(ctx0, cur, layer.ls_1_w); + cb(cur, "attn_out_scaled", il); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + cb(cur, "ffn_inp", il); + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "ffn_inp_normed", il); + + // ffn + cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, layer.ff_gate_w, layer.ff_gate_b, layer.ff_down_w, + layer.ff_down_b, hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + if (layer.ls_2_w) { + cur = ggml_mul(ctx0, cur, layer.ls_2_w); + cb(cur, "ffn_out_scaled", il); + } + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + return cur; // B, 1024, 16, 16 + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); + + cur = ggml_conv_2d_sk_p0(ctx0, model.neck_conv_0, cur); + + cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_0_w, model.neck_norm_0_b, hparams.eps); + + cur = ggml_conv_2d_s1_ph(ctx0, model.neck_conv_1, cur); + + cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_1_w, model.neck_norm_1_b, hparams.eps); + + //cur = ggml_cpy(ctx0, cur, state.embd_img); + + ggml_build_forward_expand(gf, cur); + return cur; + } + + ggml_tensor * sam_layer_norm_2d(ggml_context * ctx0, + ggml_tensor * layer, + int n_channels, + ggml_tensor * w, + ggml_tensor * b, + float eps) { + // LayerNorm2d + // normalize along channel dimmension + // TODO: better implementation + layer = ggml_permute(ctx0, ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps), 2, 0, + 1, 3); + + layer = + ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer), layer), + ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer)); + + return layer; + } + + ggml_cgraph * build_deepseek_ocr() { + //patch embedding + ggml_tensor * inp_raw = build_inp_raw(); + + + ggml_tensor * global_features_1 = build_sam_enc(inp_raw); + + ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1); + + // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) + ggml_tensor * global_features = ggml_concat(ctx0, global_features_1, global_features_2, 0); + global_features = build_global_local_features( + ctx0, + global_features, + n_patches_y, + n_patches_x, + n_embd + ); + + return gf; + } + + // global_features: [n_dim, h*w] + // image_newline: [n_dim] + // view_separator: [n_dim] + + ggml_tensor * build_global_local_features(ggml_context * ctx0, + ggml_tensor * global_features, + int h, + int w, + int n_dim) { + GGML_ASSERT(model.image_newline != nullptr); + GGML_ASSERT(model.view_seperator != nullptr); + GGML_ASSERT(global_features->ne[0] == (int64_t) n_dim); + GGML_ASSERT(global_features->ne[1] == (int64_t) (h * w)); + + // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] + ggml_tensor * t = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); // (n_dim, w, h) + t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim) + + // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] + ggml_tensor * nl = ggml_reshape_3d(ctx0, model.image_newline, 1, 1, n_dim); // (1, 1, n_dim) + + ggml_tensor * nl_target_shape = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, h, n_dim); // (1, h, n_dim) + nl = ggml_repeat(ctx0, nl, nl_target_shape); // (1, h, n_dim) + nl = ggml_permute(ctx0, nl, 1, 0, 2, 3); // (h, 1, n_dim) + + // 3) concat along width dimension (dim=1): (h, w, n_dim) + (h, 1, n_dim) -> (h, w+1, n_dim) + t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim) + + // 4) flatten back to token axis: (h, w+1, n_dim) -> (n_dim, h*(w+1)) + t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (n_dim, w+1, h) + t = ggml_cont_2d(ctx0, t, n_dim, (w + 1) * h); // (n_dim, h*(w+1)) + + // 5) append view_separator as an extra "token": + // view_separator: [n_dim] -> [n_dim, 1] + ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + + // concat along token dimension (dim=1): + ggml_tensor * global_local_features = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) + + return global_local_features; + } + + + ggml_cgraph * build_pixtral() { const int n_merge = hparams.n_merge; @@ -1215,7 +1499,7 @@ struct clip_graph { norm_t, hparams.ffn_op, model.position_embeddings, - nullptr); + nullptr); // shape [1024, 16, 16] // remove CLS token cur = ggml_view_2d(ctx0, cur, @@ -1261,6 +1545,65 @@ struct clip_graph { return gf; } + ggml_tensor * build_dp_ocr_clip(ggml_tensor * inpL, ggml_tensor * patch_embeds) { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + auto n_embd_vit_clip = 1024; + + const int n_pos = n_patches + 1; + ggml_tensor * inp = + ggml_cont_3d(ctx0, ggml_dup_tensor(ctx0, patch_embeds), patch_embeds->ne[0], n_patches_x, n_patches_y); + //ggml_tensor * inp = ggml_cpy(ctx0, inpL, ggml_dup_tensor(ctx0, inpL)); + + // add CLS token + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + // The larger models use a different ViT, which uses RMS norm instead of layer norm + // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188 + norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45) ? + NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B) + : + NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models) + + ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, model.position_embeddings, + nullptr); // shape [1024, 16, 16] + + // remove CLS token + cur = ggml_view_2d(ctx0, cur, n_embd, n_patches, ggml_row_size(cur->type, n_embd), 0); + + // pixel shuffle + { + const int scale_factor = model.hparams.n_merge; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = n_patches_y; + const int width = n_patches_x; + GGML_ASSERT(scale_factor > 0); + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, + width / scale_factor, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // flatten to 2D + cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, cur->ne[1] * cur->ne[2]); + } + + // projector (always using GELU activation) + { + // projector LayerNorm uses pytorch's default eps = 1e-5 + // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79 + cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1); + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_3_w, cur); + cur = ggml_add(ctx0, cur, model.mm_3_b); + } + + // build the graph + + return cur; + } + ggml_cgraph * build_llama4() { GGML_ASSERT(model.class_embedding != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); @@ -2164,18 +2507,41 @@ struct clip_graph { return inpL; } + // build the input after conv2d (inp_raw --> patches) + // returns tensor with shape [n_embd, n_patches] + ggml_tensor * build_enc_inp(ggml_tensor * inp_raw, + const int enc_patch_size, + const int enc_n_patches, + const int enc_n_embd) { + GGML_ASSERT(model.patch_embed_proj_w != nullptr); + GGML_ASSERT(model.patch_embed_proj_b != nullptr); + // Image to Patch Embedding. + // ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3] + // patch_embed_proj_w shape = [768, 3, 16, 16] + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embed_proj_w, inp_raw, enc_patch_size, enc_patch_size, 0, 0, + 1, 1); // [64, 64, 768] + inp = ggml_reshape_2d(ctx0, inp, enc_n_patches, enc_n_embd); // [4096, 768] + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [768, 4096] + inp = ggml_add(ctx0, inp, model.patch_embed_proj_b); + cb(inp, "enc_patch_bias", -1); + return inp; + } + // build the input after conv2d (inp_raw --> patches) // returns tensor with shape [n_embd, n_patches] ggml_tensor * build_inp() { - ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + // Image to Patch Embedding. + ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3] + // sam patch_embeddings_0 shape = [768, 3, 16, 16] + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); // sam shape = [64, 64, 768] + inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); // sam shape = [4096, 768] + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // sam shape = [768, 4096] if (model.patch_bias) { + // sam patch_bias shape = [768] inp = ggml_add(ctx0, inp, model.patch_bias); cb(inp, "patch_bias", -1); } - return inp; + return inp; // shape = [n_embd, n_patches] same as [768, 4096] } ggml_tensor * build_inp_raw(int channels = 3) { @@ -3236,6 +3602,10 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; + case PROJECTOR_TYPE_DEEPSEEK_OCR: + { + } + break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -4192,6 +4562,59 @@ struct llava_uhd { } }; +static std::vector> ds_build_target_ratios(const int min_num, const int max_num) { + std::vector> ratios; + for (int n = min_num; n <= max_num; ++n) { + for (int i = 1; i <= n; ++i) { + for (int j = 1; j <= n; ++j) { + if (const int blocks = i * j; blocks >= min_num && blocks <= max_num) { + ratios.emplace_back(i, j); // (cols, rows) + } + } + } + } + + // sort by total blocks like in Python (key=lambda x: x[0] * x[1]) + std::sort(ratios.begin(), ratios.end(), + [](const auto &a, const auto &b) { + return (a.first * a.second) < (b.first * b.second); + }); + + // optional: dedup + ratios.erase(std::unique(ratios.begin(), ratios.end()), ratios.end()); + return ratios; +} + +static std::pair ds_find_closest_aspect_ratio( + const float aspect_ratio, + const std::vector> &target_ratios, + const int width, + const int height, + const int image_size +) { + float best_diff = std::numeric_limits::infinity(); + std::pair best_ratio = {1, 1}; + const float area = static_cast(width) * static_cast(height); + + for (const auto &r : target_ratios) { + const float target_ar = static_cast(r.first) / static_cast(r.second); + + if (const float diff = std::fabs(aspect_ratio - target_ar); diff < best_diff) { + best_diff = diff; + best_ratio = r; + } else if (diff == best_diff) { + // same as python: prefer this ratio if the image area is “large enough” + if (const float needed_area = 0.5f * image_size * image_size * r.first * r.second; area > needed_area) { + best_ratio = r; + } + } + } + + return best_ratio; // (cols, rows) +} + + + // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { @@ -4406,6 +4829,69 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } } break; + case PROJECTOR_TYPE_DEEPSEEK_OCR: + { + // configurable, or read from params + const int min_num = 2; + const int max_num = 9; + const int image_size = params.image_size; // typically 640 + const bool use_thumbnail = true; // mimic python's use_thumbnail + + // original image size + const int orig_w = original_size.width; + const int orig_h = original_size.height; + + // 1) build candidate grids (cols, rows) + auto target_ratios = ds_build_target_ratios(min_num, max_num); + + // 2) pick the grid that best matches the original aspect ratio + const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); + auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); + const int grid_cols = best.first; // how many tiles horizontally + const int grid_rows = best.second; // how many tiles vertically + + // 3) compute the target (forced) size — python did: + // target_width = image_size * cols + // target_height = image_size * rows + const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows }; + + // 4) prepare slice instructions, same style as the idefics3 branch + llava_uhd::slice_instructions instructions; + instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global + instructions.refined_size = refined_size; + instructions.grid_size = clip_image_size{ grid_cols, grid_rows }; + + // in deepseek python they always produce *full* 640x640 blocks, + // so we can do a simple double loop over rows/cols: + for (int r = 0; r < grid_rows; ++r) { + for (int c = 0; c < grid_cols; ++c) { + const int x = c * image_size; + const int y = r * image_size; + + instructions.slices.push_back(llava_uhd::slice_coordinates{ + /* x */ x, + /* y */ y, + /* size */ clip_image_size{ image_size, image_size } + }); + } + } + + // 5) run the actual slicing (this should: resize to refined_size, then crop every slice) + auto imgs = llava_uhd::slice_image(img, instructions); + + // 7) cast & normalize like the idefics3 branch + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + // keep the grid info — the model may need to know how to reassemble / attend + res_imgs->grid_x = grid_cols; + res_imgs->grid_y = grid_rows; + } + break; + default: LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type()); From b6b9f02c8a2c8cb57ac07a5bb353c10f93220089 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Fri, 14 Nov 2025 20:51:48 +0100 Subject: [PATCH 02/37] loading sam tensors --- tools/mtmd/clip-impl.h | 17 +++++---- tools/mtmd/clip.cpp | 81 ++++++++++++++++++++++++++++-------------- 2 files changed, 63 insertions(+), 35 deletions(-) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 8d1c7d0dff1..88535df55f6 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -131,18 +131,17 @@ // deepseek-ocr #define TN_SAM_POS_EMBD "sam.pos_embd" -#define TN_SAM_PATCH_EMBD "sam.patch_embd" -#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln" +#define TN_SAM_PATCH_EMBD "sam.patch_embd.%s" +#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln.%s" #define TN_SAM_POST_NORM "sam.blk.%d.post_ln" #define TN_SAM_ATTN_POS_H "sam.blk.%d.attn.pos_h" #define TN_SAM_ATTN_POS_W "sam.blk.%d.attn.pos_w" -#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv" -#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out" -#define TN_SAM_MLP_LIN_1 "sam.blk.%d.mlp.lin1" -#define TN_SAM_MLP_LIN_2 "sam.blk.%d.mlp.lin2" -#define TN_SAM_NECK "sam.neck.%d" -#define TN_SAM_NET_2 "sam.net_2" -#define TN_SAM_NET_3 "sam.net_3" +#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv.%s" +#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out.%s" +#define TN_SAM_FFN_UP "sam.blk.%d.mlp.lin1.%s" +#define TN_SAM_FFN_DOWN "sam.blk.%d.mlp.lin2.%s" +#define TN_SAM_NECK "sam.neck.%d.%s" +#define TN_SAM_NET "sam.net_%d.%s" #define TN_SAM_ATTN_OUT "sam.blk.%d.attn_out" diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 0961b96fd6f..039644b6884 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -446,14 +446,18 @@ struct clip_model { return proj_type == PROJECTOR_TYPE_ULTRAVOX || proj_type == PROJECTOR_TYPE_VOXTRAL; } - ggml_tensor * neck_conv_0; - ggml_tensor * neck_norm_0_w; - ggml_tensor * neck_norm_0_b; - ggml_tensor * neck_conv_1; - ggml_tensor * neck_norm_1_w; - ggml_tensor * neck_norm_1_b; + ggml_tensor * neck_0_w; + ggml_tensor * neck_1_w; + ggml_tensor * neck_1_b; + ggml_tensor * neck_2_w; + ggml_tensor * neck_3_w; + ggml_tensor * neck_3_b; + ggml_tensor * net_2; + ggml_tensor * net_3; - std::vector enc_layers; + int32_t n_sam_layers = 0; // used by deepseek-ocr sam encoder + + std::vector sam_layers; }; @@ -683,7 +687,7 @@ struct clip_graph { // loop over layers for (int il = 0; il < _depth; il++) { - auto & layer = model.enc_layers[il]; + auto & layer = model.sam_layers[il]; // layernorm1 cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); @@ -770,33 +774,27 @@ struct clip_graph { cur = ggml_win_unpart(ctx0, cur, w0, h0, 14); } - if (layer.ls_1_w) { - cur = ggml_mul(ctx0, cur, layer.ls_1_w); - cb(cur, "attn_out_scaled", il); - } - // re-add the layer input, e.g., residual cur = ggml_add(ctx0, cur, inpL); - cb(cur, "ffn_inp", il); + ggml_tensor * inpFF = cur; + + + cb(inpFF, "ffn_inp", il); // layernorm2 - cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); + cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); cb(cur, "ffn_inp_normed", il); // ffn - cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, layer.ff_gate_w, layer.ff_gate_b, layer.ff_down_w, + cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, layer.ff_down_b, hparams.ffn_op, il); cb(cur, "ffn_out", il); - if (layer.ls_2_w) { - cur = ggml_mul(ctx0, cur, layer.ls_2_w); - cb(cur, "ffn_out_scaled", il); - } // residual 2 - cur = ggml_add(ctx0, inpL, cur); + cur = ggml_add(ctx0, cur, inpFF); cb(cur, "layer_out", il); return cur; // B, 1024, 16, 16 @@ -804,15 +802,17 @@ struct clip_graph { cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); - cur = ggml_conv_2d_sk_p0(ctx0, model.neck_conv_0, cur); + cur = ggml_conv_2d_sk_p0(ctx0, model.neck_0_w, cur); - cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_0_w, model.neck_norm_0_b, hparams.eps); + cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_1_w, model.neck_1_b, hparams.eps); - cur = ggml_conv_2d_s1_ph(ctx0, model.neck_conv_1, cur); + cur = ggml_conv_2d_s1_ph(ctx0, model.neck_2_w, cur); - cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_norm_1_w, model.neck_norm_1_b, hparams.eps); + cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_3_w, model.neck_3_b, hparams.eps); - //cur = ggml_cpy(ctx0, cur, state.embd_img); + //TODO : check conv padding + cur = ggml_conv_2d_s1_ph(ctx0, model.net_2, cur); + cur = ggml_conv_2d_s1_ph(ctx0, model.net_3, cur); ggml_build_forward_expand(gf, cur); return cur; @@ -3604,6 +3604,35 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_DEEPSEEK_OCR: { + model.pos_embed = get_tensor(TN_SAM_POS_EMBD); + model.patch_embed_proj_w = get_tensor(string_format(TN_SAM_PATCH_EMBD, "weight")); + model.patch_embed_proj_b = get_tensor(string_format(TN_SAM_PATCH_EMBD, "bias")); + model.sam_layers.resize(model.n_sam_layers); + for (int il = 0; il < model.n_sam_layers; ++il) { + auto & layer = model.sam_layers[il]; + layer.qkv_w = get_tensor(string_format(TN_SAM_ATTN_QKV, il, "weight")); + layer.qkv_b = get_tensor(string_format(TN_SAM_ATTN_QKV, il, "bias")); + layer.o_w = get_tensor(string_format(TN_SAM_ATTN_OUT, il, "weight")); + layer.o_b = get_tensor(string_format(TN_SAM_ATTN_OUT, il, "bias")); + layer.ln_1_w = get_tensor(string_format(TN_SAM_PRE_NORM, il, "weight")); + layer.ln_1_b = get_tensor(string_format(TN_SAM_PRE_NORM, il, "bias")); + layer.ln_2_w = get_tensor(string_format(TN_SAM_POST_NORM, il, "weight")); + layer.ln_2_b = get_tensor(string_format(TN_SAM_POST_NORM, il, "bias")); + layer.rel_pos_h = get_tensor(string_format(TN_SAM_ATTN_POS_H, il)); + layer.rel_pos_w = get_tensor(string_format(TN_SAM_ATTN_POS_W, il)); + layer.ff_up_w = get_tensor(string_format(TN_SAM_FFN_UP, il, "weight")); + layer.ff_up_b = get_tensor(string_format(TN_SAM_FFN_UP, il, "bias")); + layer.ff_down_w = get_tensor(string_format(TN_SAM_FFN_DOWN, il, "weight")); + layer.ff_down_b = get_tensor(string_format(TN_SAM_FFN_DOWN, il, "bias")); + } + model.neck_0_w = get_tensor(string_format(TN_SAM_NECK, 0, "weight")); + model.neck_1_b = get_tensor(string_format(TN_SAM_NECK, 1, "bias")); + model.neck_1_w = get_tensor(string_format(TN_SAM_NECK, 1, "weight")); + model.neck_2_w = get_tensor(string_format(TN_SAM_NECK, 2, "weight")); + model.neck_3_b = get_tensor(string_format(TN_SAM_NECK, 3, "bias")); + model.neck_3_w = get_tensor(string_format(TN_SAM_NECK, 3, "weight")); + model.net_2 = get_tensor(string_format(TN_SAM_NET, 2, "weight")); + model.net_3 = get_tensor(string_format(TN_SAM_NET, 3, "weight")); } break; default: From 85c7cda8eb10fcbf3f6c941fb73d81df6bc14b0a Mon Sep 17 00:00:00 2001 From: bluebread Date: Sat, 15 Nov 2025 04:20:01 +0000 Subject: [PATCH 03/37] mtmd: fix vision model processing --- convert_hf_to_gguf.py | 63 +++++++++++++++++++++------------- gguf-py/gguf/constants.py | 50 +++++++++++++++++++-------- gguf-py/gguf/gguf_writer.py | 6 ++++ gguf-py/gguf/tensor_mapping.py | 39 ++++++++++++++++----- 4 files changed, 111 insertions(+), 47 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 82a6c95bdd0..77fc77e8234 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1445,7 +1445,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "width.clip-l-14-224.layers", "sam_vit_b.layers"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1494,8 +1494,8 @@ def __init__(self, *args, **kwargs): # FIXME: DeepseekOCRVisionModel specific hack if self.block_count is None: if isinstance(self, DeepseekOCRVisionModel): - clip_block_count = self.hparams['width']['clip-l-14-224']['layers'] - sam_block_count = self.hparams['width']['sam_vit_b']['layers'] + print(self.hparams) + clip_block_count = self.hparams['layers'] if clip_block_count is not None: self.block_count = clip_block_count if sam_block_count is not None: @@ -5793,6 +5793,16 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("DeepseekOCRForCausalLM") class DeepseekOCRVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + proc_fname = self.dir_model / "processor_config.json" + + if proc_fname.is_file(): + with open(proc_fname, "r") as f: + self.preprocessor_config = json.load(f) + + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -5811,10 +5821,25 @@ def set_gguf_parameters(self): # in this case, we are converting a test model self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor) + # SAM configuration + sam_hparams = hparams['sam'] + self.gguf_writer.add_vision_sam_layers_count(sam_hparams['layers']) + self.gguf_writer.add_vision_sam_embedding_length(sam_hparams['width']) + def get_vision_config(self) -> dict[str, Any]: - orig_vision_config = self.global_config.get("vision_config") + vision_config: dict[str, Any] | None = self.global_config.get("vision_config") + + if not vision_config: + raise ValueError("DeepseekOCR model requires 'vision_config' in the model configuration, but it was not found") + + vision_config['sam'] = vision_config['width']['sam_vit_b'] + vision_config.update(vision_config['width']['clip-l-14-224']) + vision_config['hidden_size'] = vision_config['width'] + vision_config['num_heads'] = vision_config['heads'] + vision_config['intermediate_size'] = vision_config['heads'] * 4 + + return vision_config - super().get_vision_config() def tensor_force_quant(self, name, new_name, bid, n_dims): # related to https://github.com/ggml-org/llama.cpp/issues/13025 @@ -5825,27 +5850,17 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return super().tensor_force_quant(name, new_name, bid, n_dims) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - - if "vision_model.head." in name: - return [] # skip redundant tensors for tinygemma3 - - if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ - or name.startswith("multimodal_projector.") or name.startswith("vision_model."): - # process vision tensors - name = name.replace("_weight", ".weight") - - # correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector - # the other norm values are part of SigLIP model, and they are already correct - # ref code: Gemma3RMSNorm - if "soft_emb_norm.weight" in name: - logger.info(f"Correcting norm value for '{name}'") - data_torch = data_torch + 1 - - return [(self.map_tensor_name(name), data_torch)] + # Only process vision-related tensors, skip language model tensors + # Vision components: sam_model, vision_model, projector, image_newline, view_seperator + # Language model components to skip: lm_head, embed_tokens, layers, norm + if name.startswith(("lm_head.", "model.embed_tokens.", "model.layers.", "model.norm.")): + return [] - return [] # skip other tensors + if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name: + return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)] + return [(self.map_tensor_name(name), data_torch)] + @ModelBase.register("Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index dfd947083a2..d3f51645ea9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -287,6 +287,10 @@ class Attention: class Projector: SCALE_FACTOR = "clip.vision.projector.scale_factor" + class SAM: + BLOCK_COUNT = "clip.vision.sam.block_count" + EMBEDDING_LENGTH = "clip.vision.sam.embedding_length" + class ClipAudio: NUM_MEL_BINS = "clip.audio.num_mel_bins" EMBEDDING_LENGTH = "clip.audio.embedding_length" @@ -664,20 +668,21 @@ class MODEL_TENSOR(IntEnum): V_MM_GATE = auto() # cogvlm V_TOK_BOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm - # DeepSeek-OCR sam_model - V_SAM_POS_EMBD = auto() - V_SAM_PATCH_EMBD = auto() - V_SAM_PRE_NORM = auto() - V_SAM_POST_NORM = auto() - V_SAM_ATTN_POS_H = auto() - V_SAM_ATTN_POS_W = auto() - V_SAM_ATTN_QKV = auto() - V_SAM_ATTN_OUT = auto() - V_SAM_MLP_LIN_1 = auto() - V_SAM_MLP_LIN_2 = auto() - V_SAM_NECK = auto() - V_SAM_NET_2 = auto() - V_SAM_NET_3 = auto() + V_SAM_POS_EMBD = auto() # Deepseek-OCR + V_SAM_PATCH_EMBD = auto() # Deepseek-OCR + V_SAM_PRE_NORM = auto() # Deepseek-OCR + V_SAM_POST_NORM = auto() # Deepseek-OCR + V_SAM_ATTN_POS_H = auto() # Deepseek-OCR + V_SAM_ATTN_POS_W = auto() # Deepseek-OCR + V_SAM_ATTN_QKV = auto() # Deepseek-OCR + V_SAM_ATTN_OUT = auto() # Deepseek-OCR + V_SAM_MLP_LIN_1 = auto() # Deepseek-OCR + V_SAM_MLP_LIN_2 = auto() # Deepseek-OCR + V_SAM_NECK = auto() # Deepseek-OCR + V_SAM_NET_2 = auto() # Deepseek-OCR + V_SAM_NET_3 = auto() # Deepseek-OCR + V_ENC_EMBD_IMGNL = auto() # Deepseek-OCR + V_ENC_EMBD_VSEP = auto() # Deepseek-OCR # audio (mtmd) A_ENC_EMBD_POS = auto() @@ -1059,6 +1064,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}", MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2", MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3", + MODEL_TENSOR.V_ENC_EMBD_IMGNL: "v.image_newline_embd", # Deepseek-OCR + MODEL_TENSOR.V_ENC_EMBD_VSEP: "v.view_separator_embd", # Deepseek-OCR # audio (mtmd) MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", @@ -1095,6 +1102,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_EMBD_CLS, MODEL_TENSOR.V_ENC_EMBD_PATCH, MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_EMBD_IMGNL, + MODEL_TENSOR.V_ENC_EMBD_VSEP, MODEL_TENSOR.V_ENC_INPUT_NORM, MODEL_TENSOR.V_ENC_ATTN_QKV, MODEL_TENSOR.V_ENC_ATTN_Q, @@ -1137,6 +1146,19 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_GATE, MODEL_TENSOR.V_TOK_BOI, MODEL_TENSOR.V_TOK_EOI, + MODEL_TENSOR.V_SAM_POS_EMBD, + MODEL_TENSOR.V_SAM_PATCH_EMBD, + MODEL_TENSOR.V_SAM_PRE_NORM, + MODEL_TENSOR.V_SAM_POST_NORM, + MODEL_TENSOR.V_SAM_ATTN_POS_H, + MODEL_TENSOR.V_SAM_ATTN_POS_W, + MODEL_TENSOR.V_SAM_ATTN_QKV, + MODEL_TENSOR.V_SAM_ATTN_OUT, + MODEL_TENSOR.V_SAM_MLP_LIN_1, + MODEL_TENSOR.V_SAM_MLP_LIN_2, + MODEL_TENSOR.V_SAM_NECK, + MODEL_TENSOR.V_SAM_NET_2, + MODEL_TENSOR.V_SAM_NET_3, # audio MODEL_TENSOR.A_ENC_EMBD_POS, MODEL_TENSOR.A_ENC_CONV1D, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a051daeeb13..fca498a859d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1077,6 +1077,12 @@ def add_vision_n_wa_pattern(self, value: int) -> None: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + + def add_vision_sam_layers_count(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value) + + def add_vision_sam_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value) # audio models def add_audio_projection_dim(self, value: int) -> None: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index f15ea0a02a0..2cf8110d293 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1179,6 +1179,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ_FC: ( "model.connector.modality_projection.proj", # SmolVLM "model.vision.linear_proj.linear_proj", # cogvlm + "model.projector.layers", # Deepseek-OCR ), MODEL_TENSOR.V_MMPROJ_MLP: ( @@ -1197,6 +1198,7 @@ class TensorNameMap: "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 "model.vision.patch_embedding.cls_embedding", # cogvlm + "model.vision_model.embeddings.class_embedding", # Deepseek-OCR ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -1210,6 +1212,7 @@ class TensorNameMap: "visual.patch_embed.proj", # qwen2vl "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm + "model.vision_model.embeddings.patch_embedding", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_EMBD_POS: ( @@ -1222,10 +1225,19 @@ class TensorNameMap: "visual.pos_embed", # qwen3vl "model.vision.patch_embedding.position_embedding", # cogvlm ), + + MODEL_TENSOR.V_ENC_EMBD_IMGNL: ( + "model.image_newline", # Deepseek-OCR + ), + + MODEL_TENSOR.V_ENC_EMBD_VSEP: ( + "model.view_seperator", # Deepseek-OCR + ), MODEL_TENSOR.V_ENC_ATTN_QKV: ( "visual.blocks.{bid}.attn.qkv", # qwen3vl "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm + "model.vision_model.transformer.layers.{bid}.self_attn.qkv_proj", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_ATTN_Q: ( @@ -1238,6 +1250,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated + "model.vision_model.transformer.layers.{bid}.self_attn.q_proj", # Deepseek-OCR CLIP, generated ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1255,6 +1268,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated + "model.vision_model.transformer.layers.{bid}.self_attn.k_proj", # Deepseek-OCR CLIP, generated ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1272,6 +1286,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated + "model.vision_model.transformer.layers.{bid}.self_attn.v_proj", # Deepseek-OCR CLIP, generated ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1286,6 +1301,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm1", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm + "model.vision_model.transformer.layers.{bid}.layer_norm1", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1301,6 +1317,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.proj", # qwen2vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm + "model.vision_model.transformer.layers.{bid}.self_attn.out_proj", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1315,6 +1332,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm + "model.vision_model.transformer.layers.{bid}.layer_norm2", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1329,6 +1347,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) + "model.vision_model.transformer.layers.{bid}.mlp.fc1", # Deepseek-OCR CLIP "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm ), @@ -1351,6 +1370,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm + "model.vision_model.transformer.layers.{bid}.mlp.fc2", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1368,6 +1388,7 @@ class TensorNameMap: "vision_tower.ln_pre", # pixtral-hf "vision_encoder.ln_pre", # pixtral "vision_model.layernorm_pre", # llama4 + "model.vision_model.pre_layrnorm", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_POST_NORM: ( @@ -1460,11 +1481,11 @@ class TensorNameMap: ), MODEL_TENSOR.V_SAM_POS_EMBD: ( - "model.sam_model.pos_embed" + "model.sam_model.pos_embed", ), MODEL_TENSOR.V_SAM_PATCH_EMBD: ( - "model.sam_model.patch_embed.proj" + "model.sam_model.patch_embed.proj", ), MODEL_TENSOR.V_SAM_PRE_NORM: ( @@ -1476,19 +1497,19 @@ class TensorNameMap: ), MODEL_TENSOR.V_SAM_ATTN_POS_H: ( - "model.sam_model.blocks.{bid}.attn.rel_pos_h" + "model.sam_model.blocks.{bid}.attn.rel_pos_h", ), MODEL_TENSOR.V_SAM_ATTN_POS_W: ( - "model.sam_model.blocks.{bid}.attn.rel_pos_w" + "model.sam_model.blocks.{bid}.attn.rel_pos_w", ), MODEL_TENSOR.V_SAM_ATTN_QKV: ( - "model.sam_model.blocks.{bid}.attn.qkv" + "model.sam_model.blocks.{bid}.attn.qkv", ), MODEL_TENSOR.V_SAM_ATTN_OUT: ( - "model.sam_model.blocks.{bid}.attn.proj" + "model.sam_model.blocks.{bid}.attn.proj", ), MODEL_TENSOR.V_SAM_MLP_LIN_1: ( @@ -1500,15 +1521,15 @@ class TensorNameMap: ), MODEL_TENSOR.V_SAM_NECK: ( - "model.sam_model.neck.{bid}" + "model.sam_model.neck.{bid}", ), MODEL_TENSOR.V_SAM_NET_2: ( - "model.sam_model.net_2" + "model.sam_model.net_2", ), MODEL_TENSOR.V_SAM_NET_3: ( - "model.sam_model.net_3" + "model.sam_model.net_3", ), MODEL_TENSOR.V_MM_POST_FC_NORM: ( From 2aab52e2c43a886e4b231c13e1ad6d27b0ae7fc0 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Sat, 15 Nov 2025 15:30:07 +0100 Subject: [PATCH 04/37] deepseek-ocr clip-vit model impl --- tools/mtmd/clip-impl.h | 5 +---- tools/mtmd/clip.cpp | 48 +++++++++--------------------------------- 2 files changed, 11 insertions(+), 42 deletions(-) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 88535df55f6..4cb2808c262 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -141,10 +141,7 @@ #define TN_SAM_FFN_UP "sam.blk.%d.mlp.lin1.%s" #define TN_SAM_FFN_DOWN "sam.blk.%d.mlp.lin2.%s" #define TN_SAM_NECK "sam.neck.%d.%s" -#define TN_SAM_NET "sam.net_%d.%s" - - -#define TN_SAM_ATTN_OUT "sam.blk.%d.attn_out" +#define TN_SAM_NET "sam.net_%d.%s" // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 039644b6884..d94d05b2f28 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1558,48 +1558,20 @@ struct clip_graph { // add CLS token inp = ggml_concat(ctx0, inp, model.class_embedding, 1); - // The larger models use a different ViT, which uses RMS norm instead of layer norm - // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188 - norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45) ? - NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B) - : - NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models) + //TODO : check norm type for dp-ocr-clip + norm_type norm_t = NORM_TYPE_NORMAL; - ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, model.position_embeddings, - nullptr); // shape [1024, 16, 16] + // for selecting learned pos embd, used by ViT + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions); - // remove CLS token - cur = ggml_view_2d(ctx0, cur, n_embd, n_patches, ggml_row_size(cur->type, n_embd), 0); - // pixel shuffle - { - const int scale_factor = model.hparams.n_merge; - const int bsz = 1; // batch size, always 1 for now since we don't support batching - const int height = n_patches_y; - const int width = n_patches_x; - GGML_ASSERT(scale_factor > 0); - cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, - width / scale_factor, bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - // flatten to 2D - cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, cur->ne[1] * cur->ne[2]); - } - - // projector (always using GELU activation) - { - // projector LayerNorm uses pytorch's default eps = 1e-5 - // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79 - cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1); - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_add(ctx0, cur, model.mm_1_b); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_3_w, cur); - cur = ggml_add(ctx0, cur, model.mm_3_b); - } + ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, learned_pos_embd, + nullptr); // shape [1024, 16, 16] - // build the graph + ggml_build_forward_expand(gf, cur); return cur; } From eab28ed318bc16cd8f967422758fd8ed7c7d50ae Mon Sep 17 00:00:00 2001 From: bluebread Date: Sat, 15 Nov 2025 17:28:18 +0000 Subject: [PATCH 05/37] mtmd: add DeepSeek-OCR LM support with standard attention --- convert_hf_to_gguf.py | 14 ++++++++------ gguf-py/gguf/gguf_writer.py | 2 +- src/llama-arch.cpp | 2 ++ src/llama-model.cpp | 30 ++++++++++++++++++++++++++++++ src/models/deepseek2.cpp | 29 ++++++++++++++++++++++++++++- 5 files changed, 69 insertions(+), 8 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 77fc77e8234..c8a48c01bfb 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1494,12 +1494,9 @@ def __init__(self, *args, **kwargs): # FIXME: DeepseekOCRVisionModel specific hack if self.block_count is None: if isinstance(self, DeepseekOCRVisionModel): - print(self.hparams) clip_block_count = self.hparams['layers'] if clip_block_count is not None: self.block_count = clip_block_count - if sam_block_count is not None: - self.block_count = sam_block_count if self.block_count is None else self.block_count + sam_block_count if self.block_count is None: raise KeyError(f"could not find block count using any of: {self.n_block_keys}") self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) @@ -7095,10 +7092,15 @@ def set_vocab(self): raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") def set_gguf_parameters(self): + is_ocr = (self.hparams["num_hidden_layers"] == 12) - # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) - self.hparams["num_key_value_heads"] = 1 - + if is_ocr: + self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0) + self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6) + else: + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + super().set_gguf_parameters() hparams = self.hparams kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index fca498a859d..34ecb5e396f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -813,7 +813,7 @@ def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) def add_layer_norm_rms_eps(self, value: float) -> None: - self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + self.add_float64(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) def add_group_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568df..ac3ab5cfa77 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1446,6 +1446,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 829f1e3c14f..a21a3ce619a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4550,6 +4550,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { const bool is_lite = (hparams.n_layer == 27); + const bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -4575,6 +4576,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + if (is_ocr) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + + continue; + } + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); if (!is_lite) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 68f72f72bb6..e649286cecf 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -5,6 +5,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (model.name.find("ocr") != std::string::npos || model.name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -44,7 +45,33 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head, n_tokens); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } + else { ggml_tensor * q = NULL; if (!is_lite) { q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); From 76305878d52cb20de142369d348212cee5eb5436 Mon Sep 17 00:00:00 2001 From: bluebread Date: Sun, 16 Nov 2025 08:45:08 +0000 Subject: [PATCH 06/37] mtmd: successfully runs DeepSeek-OCR LM in llama-cli --- convert_hf_to_gguf.py | 15 ++++++++------- gguf-py/gguf/gguf_writer.py | 2 +- src/llama-model.cpp | 17 +++++++++++------ src/models/deepseek2.cpp | 15 +++++++++------ 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c8a48c01bfb..6d07b9acdb7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7112,13 +7112,16 @@ def set_gguf_parameters(self): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) - self.gguf_writer.add_kv_lora_rank(kv_lora_rank) + if "kv_lora_rank" in hparams and hparams["kv_lora_rank"] is not None: + self.gguf_writer.add_kv_lora_rank(kv_lora_rank) # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(kv_lora_rank) - self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + if not is_ocr: + self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(kv_lora_rank) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) @@ -7133,8 +7136,6 @@ def set_gguf_parameters(self): else: raise ValueError(f"Unsupported scoring_func value: {scoring_func}") - self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) - rope_scaling = self.hparams.get("rope_scaling") or {} if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 34ecb5e396f..fca498a859d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -813,7 +813,7 @@ def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) def add_layer_norm_rms_eps(self, value: float) -> None: - self.add_float64(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) def add_group_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a21a3ce619a..79639c515eb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1562,12 +1562,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - if (!is_lite) { + if (!is_lite && !is_ocr) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + if (!is_ocr) { + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + } ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -1583,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; case 27: type = LLM_TYPE_16B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; @@ -4578,10 +4583,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (is_ocr) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); if (i < (int) hparams.n_layer_dense_lead) { diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index e649286cecf..375f3594541 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -46,6 +46,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // self_attention if (is_ocr) { + const int n_embed_head = hparams.n_embd / hparams.n_head(); + GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); + ggml_tensor * Qcur = NULL; ggml_tensor * Kcur = NULL; ggml_tensor * Vcur = NULL; @@ -57,13 +60,13 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(Kcur, "k", il); cb(Vcur, "v", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); + GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); cb(Qcur, "q_pe", il); cb(Kcur, "k_pe", il); From 2de3436705a853c815daf5c2bab5dcae18ee47c1 Mon Sep 17 00:00:00 2001 From: bluebread Date: Mon, 17 Nov 2025 08:44:29 +0000 Subject: [PATCH 07/37] mtmd: Fix RoPE type for DeepSeek-OCR LM. --- examples/eval-callback/eval-callback.cpp | 18 +++++++++--------- src/models/deepseek2.cpp | 5 +++-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index cefa39a57c8..ed181a1ab45 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -74,19 +74,19 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } } for (int64_t i3 = 0; i3 < ne[3]; i3++) { - LOG(" [\n"); + LOG(" [\n"); for (int64_t i2 = 0; i2 < ne[2]; i2++) { if (i2 == n && ne[2] > 2*n) { - LOG(" ..., \n"); + LOG(" ..., \n"); i2 = ne[2] - n; } - LOG(" [\n"); + LOG(" [\n"); for (int64_t i1 = 0; i1 < ne[1]; i1++) { if (i1 == n && ne[1] > 2*n) { - LOG(" ..., \n"); + LOG(" ..., \n"); i1 = ne[1] - n; } - LOG(" ["); + LOG(" ["); for (int64_t i0 = 0; i0 < ne[0]; i0++) { if (i0 == n && ne[0] > 2*n) { LOG("..., "); @@ -98,10 +98,10 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } LOG("],\n"); } - LOG(" ],\n"); + LOG(" ],\n"); } - LOG(" ]\n"); - LOG(" sum = %f\n", sum); + LOG(" ]\n"); + LOG(" sum = %f\n", sum); } // TODO: make this abort configurable/optional? @@ -136,7 +136,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); } - LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + LOG("%s: %16s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, t->name, ggml_type_name(t->type), ggml_op_desc(t), src0->name, ggml_ne_string(src0).c_str(), src1 ? src1_str : "", diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 375f3594541..bc1b2127acd 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -47,6 +47,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // self_attention if (is_ocr) { const int n_embed_head = hparams.n_embd / hparams.n_head(); + const int ocr_rope_type = GGML_ROPE_TYPE_NEOX; GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); ggml_tensor * Qcur = NULL; @@ -65,8 +66,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); cb(Qcur, "q_pe", il); cb(Kcur, "k_pe", il); From 97e0907c5b6a73d6f3e0614e4bb37e26e42ea17b Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Mon, 17 Nov 2025 11:07:33 +0100 Subject: [PATCH 08/37] loading LM testing Vision model loading --- convert_hf_to_gguf.py | 39 +++++++++++++++++++++------------------ src/llama-arch.cpp | 2 ++ src/llama-model.cpp | 39 +++++++++++++++++++++++++++++++++++++-- src/models/deepseek2.cpp | 32 +++++++++++++++++++++++++++++++- tools/mtmd/clip-impl.h | 28 ++++++++++++++-------------- tools/mtmd/clip.cpp | 19 ++++++++++++++----- 6 files changed, 119 insertions(+), 40 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 77fc77e8234..385864dd11d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1494,12 +1494,9 @@ def __init__(self, *args, **kwargs): # FIXME: DeepseekOCRVisionModel specific hack if self.block_count is None: if isinstance(self, DeepseekOCRVisionModel): - print(self.hparams) clip_block_count = self.hparams['layers'] if clip_block_count is not None: self.block_count = clip_block_count - if sam_block_count is not None: - self.block_count = sam_block_count if self.block_count is None else self.block_count + sam_block_count if self.block_count is None: raise KeyError(f"could not find block count using any of: {self.n_block_keys}") self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) @@ -5793,16 +5790,16 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("DeepseekOCRForCausalLM") class DeepseekOCRVisionModel(MmprojModel): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + proc_fname = self.dir_model / "processor_config.json" - + if proc_fname.is_file(): with open(proc_fname, "r") as f: self.preprocessor_config = json.load(f) - - + + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -5860,7 +5857,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)] return [(self.map_tensor_name(name), data_torch)] - + @ModelBase.register("Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): @@ -7095,9 +7092,14 @@ def set_vocab(self): raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") def set_gguf_parameters(self): + is_ocr = (self.hparams["num_hidden_layers"] == 12) - # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) - self.hparams["num_key_value_heads"] = 1 + if is_ocr: + self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0) + self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6) + else: + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 super().set_gguf_parameters() hparams = self.hparams @@ -7110,13 +7112,16 @@ def set_gguf_parameters(self): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) - self.gguf_writer.add_kv_lora_rank(kv_lora_rank) + if "kv_lora_rank" in hparams and hparams["kv_lora_rank"] is not None: + self.gguf_writer.add_kv_lora_rank(kv_lora_rank) # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(kv_lora_rank) - self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + if not is_ocr: + self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(kv_lora_rank) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) @@ -7131,8 +7136,6 @@ def set_gguf_parameters(self): else: raise ValueError(f"Unsupported scoring_func value: {scoring_func}") - self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) - rope_scaling = self.hparams.get("rope_scaling") or {} if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568df..ac3ab5cfa77 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1446,6 +1446,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 829f1e3c14f..79639c515eb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1562,12 +1562,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - if (!is_lite) { + if (!is_lite && !is_ocr) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + if (!is_ocr) { + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + } ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -1583,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; case 27: type = LLM_TYPE_16B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; @@ -4550,6 +4555,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { const bool is_lite = (hparams.n_layer == 27); + const bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -4575,6 +4581,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + if (is_ocr) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + + continue; + } + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); if (!is_lite) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 68f72f72bb6..375f3594541 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -5,6 +5,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (model.name.find("ocr") != std::string::npos || model.name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -44,7 +45,36 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + const int n_embed_head = hparams.n_embd / hparams.n_head(); + GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); + + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); + + GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } + else { ggml_tensor * q = NULL; if (!is_lite) { q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 4cb2808c262..520e0cf5083 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -130,18 +130,18 @@ #define TN_TOK_EOI "v.eoi" // deepseek-ocr -#define TN_SAM_POS_EMBD "sam.pos_embd" -#define TN_SAM_PATCH_EMBD "sam.patch_embd.%s" -#define TN_SAM_PRE_NORM "sam.blk.%d.pre_ln.%s" -#define TN_SAM_POST_NORM "sam.blk.%d.post_ln" -#define TN_SAM_ATTN_POS_H "sam.blk.%d.attn.pos_h" -#define TN_SAM_ATTN_POS_W "sam.blk.%d.attn.pos_w" -#define TN_SAM_ATTN_QKV "sam.blk.%d.attn.qkv.%s" -#define TN_SAM_ATTN_OUT "sam.blk.%d.attn.out.%s" -#define TN_SAM_FFN_UP "sam.blk.%d.mlp.lin1.%s" -#define TN_SAM_FFN_DOWN "sam.blk.%d.mlp.lin2.%s" -#define TN_SAM_NECK "sam.neck.%d.%s" -#define TN_SAM_NET "sam.net_%d.%s" +#define TN_SAM_POS_EMBD "v.sam.pos_embd" +#define TN_SAM_PATCH_EMBD "v.sam.patch_embd.%s" +#define TN_SAM_PRE_NORM "v.sam.blk.%d.pre_ln.%s" +#define TN_SAM_POST_NORM "v.sam.blk.%d.post_ln" +#define TN_SAM_ATTN_POS_H "v.sam.blk.%d.attn.pos_h" +#define TN_SAM_ATTN_POS_W "v.sam.blk.%d.attn.pos_w" +#define TN_SAM_ATTN_QKV "v.sam.blk.%d.attn.qkv.%s" +#define TN_SAM_ATTN_OUT "v.sam.blk.%d.attn.out.%s" +#define TN_SAM_FFN_UP "v.sam.blk.%d.mlp.lin1.%s" +#define TN_SAM_FFN_DOWN "v.sam.blk.%d.mlp.lin2.%s" +#define TN_SAM_NECK "v.sam.neck.%d.%s" +#define TN_SAM_NET "v.sam.net_%d.%s" // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -170,7 +170,7 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, - PROJECTOR_TYPE_DEEPSEEK_OCR, + PROJECTOR_TYPE_DEEPSEEKOCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -197,7 +197,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, - { PROJECTOR_TYPE_DEEPSEEK_OCR,"deepseek_orc"}, + { PROJECTOR_TYPE_DEEPSEEKOCR,"deepseekocr"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d94d05b2f28..5d4257ac841 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -682,8 +682,8 @@ struct clip_graph { const int enc_n_patches = enc_image_size / enc_patch_size; // 64 - ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_image_size, enc_n_embd); - ggml_tensor * cur = ggml_add(ctx0, inpL, model.position_embeddings); + ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd); + ggml_tensor * cur = ggml_add(ctx0, inpL, model.pos_embed); // loop over layers for (int il = 0; il < _depth; il++) { @@ -842,7 +842,7 @@ struct clip_graph { ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * global_features_1 = build_sam_enc(inp_raw); + ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny)); ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1); @@ -2862,6 +2862,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_cogvlm(); } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + res = graph.build_deepseek_ocr(); + } break; default: { res = graph.build_llava(); @@ -3187,6 +3191,11 @@ struct clip_model_loader { hparams.ffn_op = FFN_GELU_ERF; log_ffn_op = "gelu_erf"; // temporary solution for logging } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + hparams.set_limit_image_tokens(8, 1024); + hparams.set_warmup_n_tokens(256); // avoid OOM on warmup + } break; default: break; } @@ -3574,7 +3583,7 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; - case PROJECTOR_TYPE_DEEPSEEK_OCR: + case PROJECTOR_TYPE_DEEPSEEKOCR: { model.pos_embed = get_tensor(TN_SAM_POS_EMBD); model.patch_embed_proj_w = get_tensor(string_format(TN_SAM_PATCH_EMBD, "weight")); @@ -4830,7 +4839,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } } break; - case PROJECTOR_TYPE_DEEPSEEK_OCR: + case PROJECTOR_TYPE_DEEPSEEKOCR: { // configurable, or read from params const int min_num = 2; From 790bbb97d89858331f048748cd4a08f4a801b3b1 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Mon, 17 Nov 2025 15:27:00 +0100 Subject: [PATCH 09/37] sam warmup working --- tools/mtmd/clip-impl.h | 2 +- tools/mtmd/clip.cpp | 31 ++++++++++++------------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 520e0cf5083..fcaae246c73 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -133,7 +133,7 @@ #define TN_SAM_POS_EMBD "v.sam.pos_embd" #define TN_SAM_PATCH_EMBD "v.sam.patch_embd.%s" #define TN_SAM_PRE_NORM "v.sam.blk.%d.pre_ln.%s" -#define TN_SAM_POST_NORM "v.sam.blk.%d.post_ln" +#define TN_SAM_POST_NORM "v.sam.blk.%d.post_ln.%s" #define TN_SAM_ATTN_POS_H "v.sam.blk.%d.attn.pos_h" #define TN_SAM_ATTN_POS_W "v.sam.blk.%d.attn.pos_w" #define TN_SAM_ATTN_QKV "v.sam.blk.%d.attn.qkv.%s" diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 5d4257ac841..f4dc48e442c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -225,17 +225,7 @@ struct clip_hparams { // sam vit deepseek-ocr std::vector global_attn_indices() const { - switch (n_embd) { - case 768: return { 2, 5, 8, 11 }; - case 1024: return { 5, 11, 17, 23 }; - case 1280: return { 7, 15, 23, 31 }; - default: - { - fprintf(stderr, "%s: unsupported n_enc_state = %d\n", __func__, n_embd); - } break; - }; - - return {}; + return { 2, 5, 8, 11 }; } bool is_global_attn(int32_t layer) const { @@ -455,7 +445,7 @@ struct clip_model { ggml_tensor * net_2; ggml_tensor * net_3; - int32_t n_sam_layers = 0; // used by deepseek-ocr sam encoder + int32_t n_sam_layers = 12; // used by deepseek-ocr sam encoder std::vector sam_layers; @@ -721,7 +711,7 @@ struct clip_graph { Qcur = ggml_reshape_3d(ctx0, Qcur, enc_d_heads, W * H, B * enc_n_heads); ggml_tensor * Kcur = - ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 1 * cur->nb[3]); + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], cur->nb[3]); Kcur = ggml_reshape_4d(ctx0, Kcur, enc_d_heads, enc_n_heads, W * H, B); Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); Kcur = ggml_reshape_3d(ctx0, Kcur, enc_d_heads, W * H, B * enc_n_heads); @@ -740,12 +730,12 @@ struct clip_graph { struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcur, Qcur); - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_n_heads)); + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads)); struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W); struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H); - struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_n_heads, W, H, B * enc_n_embd); + struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); struct ggml_tensor * rel_w = ggml_cont( ctx0, @@ -763,7 +753,7 @@ struct clip_graph { ctx0, ggml_cont(ctx0, ggml_permute(ctx0, ggml_reshape_4d(ctx0, KQV, enc_d_heads, W * H, enc_n_heads, B), 0, 2, 1, 3)), - n_embd, W, H, B); + enc_n_embd, W, H, B); cur = ggml_mul_mat(ctx0, layer.o_w, cur); cur = ggml_add_inplace(ctx0, cur, layer.o_b); @@ -2492,9 +2482,11 @@ struct clip_graph { // patch_embed_proj_w shape = [768, 3, 16, 16] ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embed_proj_w, inp_raw, enc_patch_size, enc_patch_size, 0, 0, 1, 1); // [64, 64, 768] - inp = ggml_reshape_2d(ctx0, inp, enc_n_patches, enc_n_embd); // [4096, 768] + inp = ggml_reshape_2d(ctx0, inp, enc_n_patches * enc_n_patches, enc_n_embd); // [4096, 768] inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [768, 4096] inp = ggml_add(ctx0, inp, model.patch_embed_proj_b); + inp = ggml_cont(ctx0, inp); + inp = ggml_reshape_4d(ctx0, inp, enc_n_embd, enc_n_patches, enc_n_patches, 1); cb(inp, "enc_patch_bias", -1); return inp; } @@ -3193,8 +3185,9 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_DEEPSEEKOCR: { - hparams.set_limit_image_tokens(8, 1024); - hparams.set_warmup_n_tokens(256); // avoid OOM on warmup + hparams.patch_size = 16; + hparams.image_size = 1024; + hparams.warmup_image_size = 1024; } break; default: break; From cec9a5c6e0d0fc949ecc92e0eadccb2195174f3b Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Mon, 17 Nov 2025 18:59:40 +0100 Subject: [PATCH 10/37] sam erroneous return corrected --- tools/mtmd/clip.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f4dc48e442c..1d29bc8afe5 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -786,8 +786,6 @@ struct clip_graph { // residual 2 cur = ggml_add(ctx0, cur, inpFF); cb(cur, "layer_out", il); - - return cur; // B, 1024, 16, 16 } cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); @@ -1538,12 +1536,17 @@ struct clip_graph { ggml_tensor * build_dp_ocr_clip(ggml_tensor * inpL, ggml_tensor * patch_embeds) { GGML_ASSERT(model.class_embedding != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); - auto n_embd_vit_clip = 1024; const int n_pos = n_patches + 1; ggml_tensor * inp = ggml_cont_3d(ctx0, ggml_dup_tensor(ctx0, patch_embeds), patch_embeds->ne[0], n_patches_x, n_patches_y); - //ggml_tensor * inp = ggml_cpy(ctx0, inpL, ggml_dup_tensor(ctx0, inpL)); + + auto inp_n_elems = ggml_nelements(inp); + GGML_ASSERT(inp_n_elems == inp->ne[0] * inp->ne[1] * inp->ne[2]); + inp = ggml_permute(ctx0, inp, 2, 1,0,3); // [n_patches, n_embd] + inp = ggml_cont(ctx0, inp); + GGML_ASSERT(ggml_nelements(inp) == n_patches_x*patch_size*4*768); + inp= ggml_reshape_2d(ctx0,inp,n_patches_x*patch_size, 4*768); // add CLS token inp = ggml_concat(ctx0, inp, model.class_embedding, 1); From 8b3d319c032971e8c46af19a77c0be5ad96c860a Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Mon, 17 Nov 2025 20:57:51 +0100 Subject: [PATCH 11/37] clip-vit: corrected cls_embd concat --- tools/mtmd/clip.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 1d29bc8afe5..ecbc4fb04a4 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1538,15 +1538,11 @@ struct clip_graph { GGML_ASSERT(model.position_embeddings != nullptr); const int n_pos = n_patches + 1; - ggml_tensor * inp = - ggml_cont_3d(ctx0, ggml_dup_tensor(ctx0, patch_embeds), patch_embeds->ne[0], n_patches_x, n_patches_y); - - auto inp_n_elems = ggml_nelements(inp); - GGML_ASSERT(inp_n_elems == inp->ne[0] * inp->ne[1] * inp->ne[2]); - inp = ggml_permute(ctx0, inp, 2, 1,0,3); // [n_patches, n_embd] + ggml_tensor * inp = ggml_permute(ctx0, patch_embeds,2,1,0,3); inp = ggml_cont(ctx0, inp); - GGML_ASSERT(ggml_nelements(inp) == n_patches_x*patch_size*4*768); - inp= ggml_reshape_2d(ctx0,inp,n_patches_x*patch_size, 4*768); + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + + // add CLS token inp = ggml_concat(ctx0, inp, model.class_embedding, 1); From 1e08157134ee0c13ba44cbbe821aab680f04b395 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Mon, 17 Nov 2025 21:19:51 +0100 Subject: [PATCH 12/37] clip-vit: model convert qkv_proj split --- convert_hf_to_gguf.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 390edfe864c..12baa198fef 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5856,6 +5856,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name: return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)] + if name.startswith("model.vision_model.transformer.layers."): + # process visual tensors + # split QKV tensors if needed + if ".qkv_proj." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("qkv", "q")), wq), + (self.map_tensor_name(name.replace("qkv", "k")), wk), + (self.map_tensor_name(name.replace("qkv", "v")), wv), + ] + else: + return [(self.map_tensor_name(name), data_torch)] + return [(self.map_tensor_name(name), data_torch)] @@ -7100,7 +7121,7 @@ def set_gguf_parameters(self): else: # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) self.hparams["num_key_value_heads"] = 1 - + super().set_gguf_parameters() hparams = self.hparams kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 From 331cea8f8e8fe6d5521696cf7f60123f7ed52527 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Tue, 18 Nov 2025 05:59:37 +0100 Subject: [PATCH 13/37] corrected combining of image encoders' results --- gguf-py/gguf/constants.py | 6 +++--- tools/mtmd/clip-impl.h | 1 + tools/mtmd/clip.cpp | 21 ++++++++++++++++----- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d3f51645ea9..c4294574631 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -290,7 +290,7 @@ class Projector: class SAM: BLOCK_COUNT = "clip.vision.sam.block_count" EMBEDDING_LENGTH = "clip.vision.sam.embedding_length" - + class ClipAudio: NUM_MEL_BINS = "clip.audio.num_mel_bins" EMBEDDING_LENGTH = "clip.audio.embedding_length" @@ -1064,8 +1064,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}", MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2", MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3", - MODEL_TENSOR.V_ENC_EMBD_IMGNL: "v.image_newline_embd", # Deepseek-OCR - MODEL_TENSOR.V_ENC_EMBD_VSEP: "v.view_separator_embd", # Deepseek-OCR + MODEL_TENSOR.V_ENC_EMBD_IMGNL: "model.image_newline", # Deepseek-OCR + MODEL_TENSOR.V_ENC_EMBD_VSEP: "model.view_seperator", # Deepseek-OCR # audio (mtmd) MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index fcaae246c73..ba094cc25bb 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -86,6 +86,7 @@ #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_IMAGE_NEWLINE "model.image_newline" +#define TN_IMAGE_SEPERATOR "model.view_seperator" #define TN_MM_INP_NORM "mm.input_norm.weight" #define TN_MM_INP_NORM_B "mm.input_norm.bias" #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index ecbc4fb04a4..8bd1eef4bfc 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -835,7 +835,15 @@ struct clip_graph { ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1); // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) - ggml_tensor * global_features = ggml_concat(ctx0, global_features_1, global_features_2, 0); + global_features_1 = ggml_permute(ctx0, global_features_1,2,1,0,3); + global_features_1 = ggml_cont(ctx0, global_features_1); + global_features_1 = ggml_reshape_2d(ctx0, global_features_1, n_embd, n_patches); + // remove CLS token + global_features_2 = ggml_view_2d(ctx0, global_features_2, + n_embd, n_patches, + ggml_row_size(global_features_2->type, n_embd), 0); + + ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1); global_features = build_global_local_features( ctx0, global_features, @@ -843,6 +851,7 @@ struct clip_graph { n_patches_x, n_embd ); + ggml_build_forward_expand(gf, global_features); return gf; } @@ -858,8 +867,8 @@ struct clip_graph { int n_dim) { GGML_ASSERT(model.image_newline != nullptr); GGML_ASSERT(model.view_seperator != nullptr); - GGML_ASSERT(global_features->ne[0] == (int64_t) n_dim); - GGML_ASSERT(global_features->ne[1] == (int64_t) (h * w)); + GGML_ASSERT(global_features->ne[0] == static_cast(n_dim)); + GGML_ASSERT(global_features->ne[1] == static_cast(2 * (h * w))); // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] ggml_tensor * t = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); // (n_dim, w, h) @@ -1552,8 +1561,7 @@ struct clip_graph { // for selecting learned pos embd, used by ViT struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); - ggml_set_name(positions, "positions"); - ggml_set_input(positions); + cb(positions, "positions", -1); ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions); @@ -3607,6 +3615,9 @@ struct clip_model_loader { model.net_2 = get_tensor(string_format(TN_SAM_NET, 2, "weight")); model.net_3 = get_tensor(string_format(TN_SAM_NET, 3, "weight")); } + model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); + model.view_seperator = get_tensor(TN_IMAGE_SEPERATOR, false); + break; default: GGML_ASSERT(false && "unknown projector type"); From 6c0715befcab53eab7fb03cd82437c715729297e Mon Sep 17 00:00:00 2001 From: bluebread Date: Tue, 18 Nov 2025 06:19:38 +0000 Subject: [PATCH 14/37] fix: update callback for ffn_moe_weighted and add callback for attn_out in deepseek2 model --- src/llama-graph.cpp | 2 +- src/models/deepseek2.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b199e94628f..4daf3f230b5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1106,7 +1106,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); - cb(cur, "ffn_moe_weighted", il); + cb(experts, "ffn_moe_weighted", il); } ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index bc1b2127acd..f4a40d7d6e8 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -74,6 +74,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); } else { ggml_tensor * q = NULL; From 63a042f21e19c90d51645741fbd18f2d14c9864f Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Tue, 18 Nov 2025 09:43:11 +0100 Subject: [PATCH 15/37] concat image_newline and image_seperator tokens --- tools/mtmd/clip-impl.h | 2 +- tools/mtmd/clip.cpp | 67 ++++++++++++++++++++---------------------- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index ba094cc25bb..63d59055668 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -91,7 +91,7 @@ #define TN_MM_INP_NORM_B "mm.input_norm.bias" #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 -#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 +#define TN_MM_PROJECTOR "mm.model.fc.%s" // idefics3, deepseekocr #define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1 #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 8bd1eef4bfc..99b5ab45d95 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -316,7 +316,8 @@ struct clip_model { ggml_tensor * post_ln_w; ggml_tensor * post_ln_b; - ggml_tensor * projection; // TODO: rename it to fc (fully connected layer) + ggml_tensor * fc_w; + ggml_tensor * fc_b; ggml_tensor * mm_fc_w; ggml_tensor * mm_fc_b; @@ -623,7 +624,7 @@ struct clip_graph { // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 const int scale_factor = model.hparams.n_merge; cur = build_patch_merge_permute(cur, scale_factor); - cur = ggml_mul_mat(ctx0, model.projection, cur); + cur = ggml_mul_mat(ctx0, model.fc_w, cur); } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) { // pixel unshuffle block @@ -844,15 +845,12 @@ struct clip_graph { ggml_row_size(global_features_2->type, n_embd), 0); ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1); - global_features = build_global_local_features( - ctx0, - global_features, - n_patches_y, - n_patches_x, - n_embd - ); + global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd, n_patches); + global_features = ggml_cont(ctx0, global_features); + global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); + global_features = ggml_add(ctx0, global_features, model.fc_b); + global_features = build_global_local_features(ctx0,global_features); ggml_build_forward_expand(gf, global_features); - return gf; } @@ -861,41 +859,31 @@ struct clip_graph { // view_separator: [n_dim] ggml_tensor * build_global_local_features(ggml_context * ctx0, - ggml_tensor * global_features, - int h, - int w, - int n_dim) { + ggml_tensor * global_features) { GGML_ASSERT(model.image_newline != nullptr); GGML_ASSERT(model.view_seperator != nullptr); - GGML_ASSERT(global_features->ne[0] == static_cast(n_dim)); - GGML_ASSERT(global_features->ne[1] == static_cast(2 * (h * w))); // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] - ggml_tensor * t = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); // (n_dim, w, h) - t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim) - - // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] - ggml_tensor * nl = ggml_reshape_3d(ctx0, model.image_newline, 1, 1, n_dim); // (1, 1, n_dim) + ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h) + t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim) + ggml_tensor * nl = ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3); + nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows - ggml_tensor * nl_target_shape = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, h, n_dim); // (1, h, n_dim) - nl = ggml_repeat(ctx0, nl, nl_target_shape); // (1, h, n_dim) - nl = ggml_permute(ctx0, nl, 1, 0, 2, 3); // (h, 1, n_dim) - // 3) concat along width dimension (dim=1): (h, w, n_dim) + (h, 1, n_dim) -> (h, w+1, n_dim) + // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim) - // 4) flatten back to token axis: (h, w+1, n_dim) -> (n_dim, h*(w+1)) - t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (n_dim, w+1, h) - t = ggml_cont_2d(ctx0, t, n_dim, (w + 1) * h); // (n_dim, h*(w+1)) + t = ggml_reshape_2d(ctx0, t, 1280, 64 * (64 + 1)); // (n_dim, h*(w+1)) + // 5) append view_separator as an extra "token": // view_separator: [n_dim] -> [n_dim, 1] - ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, 1280, 1); // (n_dim, 1) // concat along token dimension (dim=1): - ggml_tensor * global_local_features = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) + t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) - return global_local_features; + return t; } @@ -3488,7 +3476,7 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_IDEFICS3: { - model.projection = get_tensor(TN_MM_PROJECTOR); + model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: @@ -3561,13 +3549,13 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_LLAMA4: { - model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); } break; case PROJECTOR_TYPE_COGVLM: { - model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight")); model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias")); model.mm_h_to_4h_w = get_tensor(string_format(TN_MM_H_TO_4H, "weight")); @@ -3617,6 +3605,9 @@ struct clip_model_loader { } model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); model.view_seperator = get_tensor(TN_IMAGE_SEPERATOR, false); + model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); + model.fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias")); + break; default: @@ -5086,6 +5077,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { n_patches += 2; // for BOI and EOI token embeddings } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + n_patches += 2; + } break; default: GGML_ABORT("unsupported projector type"); } @@ -5512,7 +5507,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_GEMMA3: return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: - return ctx->model.projection->ne[1]; + return ctx->model.fc_w->ne[1]; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: return ctx->model.mm_2_w->ne[1]; @@ -5527,6 +5522,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; + case PROJECTOR_TYPE_DEEPSEEKOCR: + return ctx->model.fc_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } From 89afda8da90024aaf908448a2bb8dafee739934c Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Tue, 18 Nov 2025 10:26:32 +0100 Subject: [PATCH 16/37] visual_model warmup (technically) works --- tools/mtmd/clip.cpp | 5 +++++ tools/mtmd/clip.h | 2 ++ tools/mtmd/mtmd.cpp | 3 ++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 99b5ab45d95..797f921f509 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -5412,6 +5412,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_JANUS_PRO: case PROJECTOR_TYPE_COGVLM: + case PROJECTOR_TYPE_DEEPSEEKOCR: { // do nothing } break; @@ -5554,6 +5555,10 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; } +bool clip_is_deepseekocr(const struct clip_ctx * ctx) { + return ctx->proj_type() == PROJECTOR_TYPE_DEEPSEEKOCR; +} + bool clip_has_vision_encoder(const struct clip_ctx * ctx) { return ctx->model.modality == CLIP_MODALITY_VISION; } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 3e4c985f117..458ee98fc78 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -105,6 +105,8 @@ bool clip_is_glm(const struct clip_ctx * ctx); bool clip_is_qwen2vl(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); +bool clip_is_deepseekocr(const struct clip_ctx * ctx); + bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index e5991377699..16349e8f406 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -810,7 +810,8 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) - || clip_is_glm(ctx_clip)) { + || clip_is_glm(ctx_clip) + || clip_is_deepseekocr(ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) { From 88032f46b1cf496670fb029dfcfa071ea2e31e02 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Thu, 20 Nov 2025 10:07:54 +0100 Subject: [PATCH 17/37] window partitioning using standard ggml ops --- tools/mtmd/clip.cpp | 50 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 797f921f509..40b60cbfd5d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -690,7 +690,8 @@ struct clip_graph { if (hparams.is_global_attn(il) == false) { // local attention layer - apply window partition // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 - cur = ggml_win_part(ctx0, cur, 14); + //cur = ggml_win_part(ctx0, cur, 14); + cur = window_partition(ctx0, cur, 14); } const int64_t W = cur->ne[1]; @@ -762,7 +763,7 @@ struct clip_graph { if (hparams.is_global_attn(il) == false) { // local attention layer - reverse window partition - cur = ggml_win_unpart(ctx0, cur, w0, h0, 14); + cur = window_unpartition(ctx0, cur, w0, h0, 14); } // re-add the layer input, e.g., residual @@ -865,9 +866,10 @@ struct clip_graph { // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h) - t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim) - ggml_tensor * nl = ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3); + t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) + ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows + nl = ggml_cont(ctx0, nl); // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] @@ -2464,6 +2466,46 @@ struct clip_graph { return inpL; } + static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) { + auto [c, w, h, b] = x->ne; + // same as + // x = ggml_win_part(m, x, window); + // x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]); + + int64_t px = (window - w % window) % window; + int64_t py = (window - h % window) % window; + int64_t npw = (w + px) / window; + int64_t nph = (h + py) / window; + + if (px > 0 || py > 0) { + x = ggml_pad(ctx, x, 0, int(px), int(py), 0); + } + x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b); + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b); + return x; + } + + static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) { + int64_t c = x->ne[0]; + // same as + // x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]); + // x = ggml_win_unpart(m, x, w, h, window); + + int64_t px = (window - w % window) % window; + int64_t py = (window - h % window) % window; + int64_t npw = (w + px) / window; + int64_t nph = (h + py) / window; + + int64_t b = x->ne[3] / (npw * nph); + x = ggml_reshape_4d(m, x, c * window, window, npw, nph * b); + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(m, x, c, w + px, h + py, b); + x = ggml_view_4d(m, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); + x = ggml_cont(m, x); + return x; + } + // build the input after conv2d (inp_raw --> patches) // returns tensor with shape [n_embd, n_patches] ggml_tensor * build_enc_inp(ggml_tensor * inp_raw, From 68b206b65c29c4a2116c593cc2ad135b7d9f1565 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Fri, 21 Nov 2025 15:29:39 +0100 Subject: [PATCH 18/37] sam implementation without using CPU only ops --- tools/mtmd/clip.cpp | 109 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 6 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 40b60cbfd5d..f8dbe39a25a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -734,8 +734,8 @@ struct clip_graph { struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads)); - struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W); - struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H); + struct ggml_tensor * rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); + struct ggml_tensor * rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); @@ -745,7 +745,7 @@ struct clip_graph { 2, 1, 3)); struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); - struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); + struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h, W); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); @@ -837,9 +837,9 @@ struct clip_graph { ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1); // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) - global_features_1 = ggml_permute(ctx0, global_features_1,2,1,0,3); - global_features_1 = ggml_cont(ctx0, global_features_1); + global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3)); global_features_1 = ggml_reshape_2d(ctx0, global_features_1, n_embd, n_patches); + // remove CLS token global_features_2 = ggml_view_2d(ctx0, global_features_2, n_embd, n_patches, @@ -850,6 +850,7 @@ struct clip_graph { global_features = ggml_cont(ctx0, global_features); global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); global_features = ggml_add(ctx0, global_features, model.fc_b); + global_features = build_global_local_features(ctx0,global_features); ggml_build_forward_expand(gf, global_features); return gf; @@ -869,7 +870,6 @@ struct clip_graph { t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows - nl = ggml_cont(ctx0, nl); // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] @@ -2466,6 +2466,103 @@ struct clip_graph { return inpL; } + // attn: [k_h*k_w, q_h*q_w] +// rel_h: [q_h, q_w, k_h] +// rel_w: [q_h, q_w, k_w] + +static ggml_tensor * add_rel_pos_inplace( + ggml_context * ctx, + ggml_tensor * attn, + ggml_tensor * rel_w, + ggml_tensor * rel_h, + int q_size +) { + + ggml_tensor *attn_4d = + ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_h_4d = + ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d + + ggml_tensor *rel_w_4d = + ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d + + ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep)); + result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); + + + return result; +} + + +static ggml_tensor * get_rel_pos( + ggml_context * ctx, + ggml_tensor * rel_pos, // [L, C] + int q_size, + int k_size +) { + + const auto dtype = rel_pos->type; + + const int64_t L = rel_pos->ne[0]; // length + const int64_t C = rel_pos->ne[1]; // channels + + // ------------------------------------------------- + // 1) q_idx ← arange(0..q_size-1) [q_size] + // 2) k_idx ← arange(0..k_size-1) [k_size] + // ------------------------------------------------- + + + ggml_tensor * q_coord = ggml_cast(ctx, + ggml_arange(ctx, 0.0f, static_cast(q_size), 1.0f), + GGML_TYPE_F32); // [q_size] + ggml_tensor * k_coord = ggml_cast(ctx, + ggml_arange(ctx, 0.0f, static_cast(k_size), 1.0f), + GGML_TYPE_F32); // [k_size] + + ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, q_size, k_size); + q_coord = ggml_cont(ctx,ggml_repeat(ctx, q_coord, rel)); // [q_size, k_size] + + // broadcast reshape: + k_coord = ggml_reshape_2d(ctx, k_coord, 1, k_size); // [1, k_size] + k_coord = ggml_cont(ctx,ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + + // ------------------------------------------------- + // relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling + // ------------------------------------------------- + rel = ggml_sub(ctx, k_coord, q_coord); // [q_size, k_size] + + rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast(k_size) - 1.0f); // [q_size, k_size] + + // ------------------------------------------------- + // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) + // ------------------------------------------------- + + ggml_tensor * rel_clamped = ggml_clamp(ctx, rel, 0, static_cast(L - 1)); + + ggml_tensor * idx_2d = ggml_cast(ctx, rel_clamped, GGML_TYPE_I32); // [q_size, k_size] + + // flatten to 1D for ggml_get_rows + const int64_t qk = static_cast(q_size) * static_cast(k_size); + ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] + + // ------------------------------------------------- + // Gather from rel_pos → [qk, C] + // ------------------------------------------------- + ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] + + // reshape to final output → [q_size, k_size, C] + ggml_tensor * out = ggml_reshape_3d(ctx, gathered,rel_pos->ne[0], + q_size, + k_size); + + return out; // [q_size, k_size, C] +} + static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) { auto [c, w, h, b] = x->ne; // same as From 8bce66d5f2a76e4e638f07b40769fdc8a248ad7d Mon Sep 17 00:00:00 2001 From: bluebread Date: Fri, 21 Nov 2025 15:28:37 +0000 Subject: [PATCH 19/37] clip: fixed warnings --- tools/mtmd/clip.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 40b60cbfd5d..eb3d461dac1 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -667,9 +667,9 @@ struct clip_graph { constexpr int _depth = 12; constexpr int enc_n_heads = 12; constexpr int enc_d_heads = enc_n_embd / enc_n_heads; - constexpr int _prompt_n_embd = 256; + // constexpr int _prompt_n_embd = 256; constexpr int enc_patch_size = 16; - constexpr int _window_size = 14; + // constexpr int _window_size = 14; const int enc_n_patches = enc_image_size / enc_patch_size; // 64 @@ -834,7 +834,7 @@ struct clip_graph { ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny)); - ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1); + ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features_1 = ggml_permute(ctx0, global_features_1,2,1,0,3); @@ -1532,7 +1532,7 @@ struct clip_graph { return gf; } - ggml_tensor * build_dp_ocr_clip(ggml_tensor * inpL, ggml_tensor * patch_embeds) { + ggml_tensor * build_dp_ocr_clip(ggml_tensor * patch_embeds) { GGML_ASSERT(model.class_embedding != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); @@ -2466,6 +2466,8 @@ struct clip_graph { return inpL; } + // Implementation based on approach suggested by Acly + // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) { auto [c, w, h, b] = x->ne; // same as @@ -2486,6 +2488,8 @@ struct clip_graph { return x; } + // Implementation based on approach suggested by Acly + // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) { int64_t c = x->ne[0]; // same as @@ -4881,7 +4885,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const int min_num = 2; const int max_num = 9; const int image_size = params.image_size; // typically 640 - const bool use_thumbnail = true; // mimic python's use_thumbnail + // const bool use_thumbnail = true; // mimic python's use_thumbnail // original image size const int orig_w = original_size.width; From 5e6cf3c6a838af32c3debce73425d199477b7669 Mon Sep 17 00:00:00 2001 From: bluebread Date: Fri, 21 Nov 2025 15:36:45 +0000 Subject: [PATCH 20/37] Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into sf/deepseek-ocr --- tools/mtmd/clip.cpp | 109 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 6 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index eb3d461dac1..45cc2328c8d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -734,8 +734,8 @@ struct clip_graph { struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads)); - struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W); - struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H); + struct ggml_tensor * rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); + struct ggml_tensor * rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); @@ -745,7 +745,7 @@ struct clip_graph { 2, 1, 3)); struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); - struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); + struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h, W); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); @@ -837,9 +837,9 @@ struct clip_graph { ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) - global_features_1 = ggml_permute(ctx0, global_features_1,2,1,0,3); - global_features_1 = ggml_cont(ctx0, global_features_1); + global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3)); global_features_1 = ggml_reshape_2d(ctx0, global_features_1, n_embd, n_patches); + // remove CLS token global_features_2 = ggml_view_2d(ctx0, global_features_2, n_embd, n_patches, @@ -850,6 +850,7 @@ struct clip_graph { global_features = ggml_cont(ctx0, global_features); global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); global_features = ggml_add(ctx0, global_features, model.fc_b); + global_features = build_global_local_features(ctx0,global_features); ggml_build_forward_expand(gf, global_features); return gf; @@ -869,7 +870,6 @@ struct clip_graph { t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows - nl = ggml_cont(ctx0, nl); // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] @@ -2466,6 +2466,103 @@ struct clip_graph { return inpL; } + // attn: [k_h*k_w, q_h*q_w] +// rel_h: [q_h, q_w, k_h] +// rel_w: [q_h, q_w, k_w] + +static ggml_tensor * add_rel_pos_inplace( + ggml_context * ctx, + ggml_tensor * attn, + ggml_tensor * rel_w, + ggml_tensor * rel_h, + int q_size +) { + + ggml_tensor *attn_4d = + ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_h_4d = + ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d + + ggml_tensor *rel_w_4d = + ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d + + ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep)); + result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); + + + return result; +} + + +static ggml_tensor * get_rel_pos( + ggml_context * ctx, + ggml_tensor * rel_pos, // [L, C] + int q_size, + int k_size +) { + + const auto dtype = rel_pos->type; + + const int64_t L = rel_pos->ne[0]; // length + const int64_t C = rel_pos->ne[1]; // channels + + // ------------------------------------------------- + // 1) q_idx ← arange(0..q_size-1) [q_size] + // 2) k_idx ← arange(0..k_size-1) [k_size] + // ------------------------------------------------- + + + ggml_tensor * q_coord = ggml_cast(ctx, + ggml_arange(ctx, 0.0f, static_cast(q_size), 1.0f), + GGML_TYPE_F32); // [q_size] + ggml_tensor * k_coord = ggml_cast(ctx, + ggml_arange(ctx, 0.0f, static_cast(k_size), 1.0f), + GGML_TYPE_F32); // [k_size] + + ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, q_size, k_size); + q_coord = ggml_cont(ctx,ggml_repeat(ctx, q_coord, rel)); // [q_size, k_size] + + // broadcast reshape: + k_coord = ggml_reshape_2d(ctx, k_coord, 1, k_size); // [1, k_size] + k_coord = ggml_cont(ctx,ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + + // ------------------------------------------------- + // relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling + // ------------------------------------------------- + rel = ggml_sub(ctx, k_coord, q_coord); // [q_size, k_size] + + rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast(k_size) - 1.0f); // [q_size, k_size] + + // ------------------------------------------------- + // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) + // ------------------------------------------------- + + ggml_tensor * rel_clamped = ggml_clamp(ctx, rel, 0, static_cast(L - 1)); + + ggml_tensor * idx_2d = ggml_cast(ctx, rel_clamped, GGML_TYPE_I32); // [q_size, k_size] + + // flatten to 1D for ggml_get_rows + const int64_t qk = static_cast(q_size) * static_cast(k_size); + ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] + + // ------------------------------------------------- + // Gather from rel_pos → [qk, C] + // ------------------------------------------------- + ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] + + // reshape to final output → [q_size, k_size, C] + ggml_tensor * out = ggml_reshape_3d(ctx, gathered,rel_pos->ne[0], + q_size, + k_size); + + return out; // [q_size, k_size, C] +} + // Implementation based on approach suggested by Acly // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) { From 7e9fbeccc5c28a8464ace5e4e22dfef213cb3c66 Mon Sep 17 00:00:00 2001 From: bluebread Date: Fri, 21 Nov 2025 17:12:12 +0000 Subject: [PATCH 21/37] mtmd: fix get_rel_pos --- tools/mtmd/clip.cpp | 174 +++++++++++++++++++++++--------------------- 1 file changed, 90 insertions(+), 84 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 45cc2328c8d..a4bf717d0b0 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2467,101 +2467,107 @@ struct clip_graph { } // attn: [k_h*k_w, q_h*q_w] -// rel_h: [q_h, q_w, k_h] -// rel_w: [q_h, q_w, k_w] - -static ggml_tensor * add_rel_pos_inplace( - ggml_context * ctx, - ggml_tensor * attn, - ggml_tensor * rel_w, - ggml_tensor * rel_h, - int q_size -) { - - ggml_tensor *attn_4d = - ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]); - - ggml_tensor *rel_h_4d = - ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]); - - ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d - - ggml_tensor *rel_w_4d = - ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]); - - ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d - - ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep)); - result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); - - - return result; -} - - -static ggml_tensor * get_rel_pos( - ggml_context * ctx, - ggml_tensor * rel_pos, // [L, C] - int q_size, - int k_size -) { - - const auto dtype = rel_pos->type; - - const int64_t L = rel_pos->ne[0]; // length - const int64_t C = rel_pos->ne[1]; // channels - - // ------------------------------------------------- - // 1) q_idx ← arange(0..q_size-1) [q_size] - // 2) k_idx ← arange(0..k_size-1) [k_size] - // ------------------------------------------------- - - - ggml_tensor * q_coord = ggml_cast(ctx, - ggml_arange(ctx, 0.0f, static_cast(q_size), 1.0f), - GGML_TYPE_F32); // [q_size] - ggml_tensor * k_coord = ggml_cast(ctx, - ggml_arange(ctx, 0.0f, static_cast(k_size), 1.0f), - GGML_TYPE_F32); // [k_size] + // rel_h: [q_h, q_w, k_h] + // rel_w: [q_h, q_w, k_w] + + static ggml_tensor * add_rel_pos_inplace( + ggml_context * ctx, + ggml_tensor * attn, + ggml_tensor * rel_w, + ggml_tensor * rel_h, + int q_size + ) { - ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, q_size, k_size); - q_coord = ggml_cont(ctx,ggml_repeat(ctx, q_coord, rel)); // [q_size, k_size] + ggml_tensor *attn_4d = + ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]); - // broadcast reshape: - k_coord = ggml_reshape_2d(ctx, k_coord, 1, k_size); // [1, k_size] - k_coord = ggml_cont(ctx,ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + ggml_tensor *rel_h_4d = + ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]); - // ------------------------------------------------- - // relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling - // ------------------------------------------------- - rel = ggml_sub(ctx, k_coord, q_coord); // [q_size, k_size] + ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d - rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast(k_size) - 1.0f); // [q_size, k_size] + ggml_tensor *rel_w_4d = + ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]); - // ------------------------------------------------- - // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) - // ------------------------------------------------- + ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d - ggml_tensor * rel_clamped = ggml_clamp(ctx, rel, 0, static_cast(L - 1)); + ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep)); + result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); - ggml_tensor * idx_2d = ggml_cast(ctx, rel_clamped, GGML_TYPE_I32); // [q_size, k_size] - // flatten to 1D for ggml_get_rows - const int64_t qk = static_cast(q_size) * static_cast(k_size); - ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] + return result; + } - // ------------------------------------------------- - // Gather from rel_pos → [qk, C] - // ------------------------------------------------- - ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] - // reshape to final output → [q_size, k_size, C] - ggml_tensor * out = ggml_reshape_3d(ctx, gathered,rel_pos->ne[0], - q_size, - k_size); + static ggml_tensor * get_rel_pos( + ggml_context * ctx, + ggml_tensor * rel_pos, // [L, C] + int q_size, + int k_size + ) { + const int64_t C = rel_pos->ne[0]; // channels + const int64_t L = rel_pos->ne[1]; // length + + GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L); + + // ------------------------------------------------- + // 1) q_idx ← arange(0..q_size-1) [q_size] + // 2) k_idx ← arange(0..k_size-1) [k_size] + // ------------------------------------------------- + + // ggml_arange always returns FP32 tensor + ggml_tensor * q_coord = ggml_arange(ctx, 0.0f, static_cast(q_size), 1.0f); // [q_size] + ggml_tensor * k_coord = ggml_arange(ctx, 0.0f, static_cast(k_size), 1.0f); // [k_size] + ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, k_size, q_size); + + // broadcast reshape: + q_coord = ggml_cont(ctx, + ggml_repeat(ctx, + ggml_reshape_2d(ctx, q_coord, 1, q_size), // [q_size, 1] + rel + ) + ); // [q_size, k_size] + k_coord = ggml_cont(ctx, ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + + // This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with + // the original implementation. + if (q_size != k_size) { + q_coord = ggml_scale_inplace(ctx, q_coord, std::max((float)k_size/q_size, 1.0f)); + k_coord = ggml_scale_inplace(ctx, k_coord, std::max((float)q_size/k_size, 1.0f)); + } - return out; // [q_size, k_size, C] -} + // ------------------------------------------------- + // relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling + // ------------------------------------------------- + + rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size] + rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast(k_size) - 1.0f); // [q_size, k_size] + // Clamp to [0, L-1] range for valid indexing + rel = ggml_clamp(ctx, rel, 0.0f, static_cast(rel_pos->ne[1] - 1)); + + // ------------------------------------------------- + // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) + // ------------------------------------------------- + + ggml_tensor * idx_2d = ggml_cast(ctx, rel, GGML_TYPE_I32); // [q_size, k_size] + + // Gather from rel_pos → [qk, C] + // ------------------------------------------------- + + // flatten to 1D for ggml_get_rows + int qk = q_size * k_size; + ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] + ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] + + // ------------------------------------------------- + // Gather from rel_pos → [qk, C] + // ------------------------------------------------- + + ggml_tensor * out = ggml_reshape_3d(ctx, gathered, C, k_size, q_size); // [qk, C] + + + return out; // [q_size, k_size, C] + } // Implementation based on approach suggested by Acly // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 From 7b8d735c901666d91f211f380ca2edc625fd72c1 Mon Sep 17 00:00:00 2001 From: bluebread Date: Fri, 21 Nov 2025 18:04:01 +0000 Subject: [PATCH 22/37] mtmd: fixed the wrong scaler for get_rel_pos --- tools/mtmd/clip.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index a4bf717d0b0..f291894b6ea 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2529,11 +2529,14 @@ struct clip_graph { ); // [q_size, k_size] k_coord = ggml_cont(ctx, ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + float q_scale = std::max((float)k_size/q_size, 1.0f); + float k_scale = std::max((float)q_size/k_size, 1.0f); + // This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with // the original implementation. if (q_size != k_size) { - q_coord = ggml_scale_inplace(ctx, q_coord, std::max((float)k_size/q_size, 1.0f)); - k_coord = ggml_scale_inplace(ctx, k_coord, std::max((float)q_size/k_size, 1.0f)); + q_coord = ggml_scale_inplace(ctx, q_coord, q_scale); + k_coord = ggml_scale_inplace(ctx, k_coord, k_scale); } // ------------------------------------------------- @@ -2541,7 +2544,7 @@ struct clip_graph { // ------------------------------------------------- rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size] - rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast(k_size) - 1.0f); // [q_size, k_size] + rel = ggml_scale_bias(ctx, rel, 1.0f, (k_size - 1.0f)*k_scale); // [q_size, k_size] // Clamp to [0, L-1] range for valid indexing rel = ggml_clamp(ctx, rel, 0.0f, static_cast(rel_pos->ne[1] - 1)); From 86f111f8b76ec0b696f5a1597b86b434ee71a828 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Fri, 21 Nov 2025 20:42:14 +0100 Subject: [PATCH 23/37] image encoding technically works but the output can't be checked singe image decoding fails --- tools/mtmd/clip.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f8dbe39a25a..787f00acaaa 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -819,6 +819,7 @@ struct clip_graph { // TODO: better implementation layer = ggml_permute(ctx0, ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps), 2, 0, 1, 3); + layer = ggml_cont(ctx0, layer); layer = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer), layer), @@ -1537,8 +1538,7 @@ struct clip_graph { GGML_ASSERT(model.position_embeddings != nullptr); const int n_pos = n_patches + 1; - ggml_tensor * inp = ggml_permute(ctx0, patch_embeds,2,1,0,3); - inp = ggml_cont(ctx0, inp); + ggml_tensor * inp = ggml_cont(ctx0,ggml_permute(ctx0, patch_embeds,2,1,0,3)); inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); @@ -1550,7 +1550,7 @@ struct clip_graph { norm_type norm_t = NORM_TYPE_NORMAL; // for selecting learned pos embd, used by ViT - struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); cb(positions, "positions", -1); ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions); @@ -5218,7 +5218,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_DEEPSEEKOCR: { - n_patches += 2; + int x_patch = img->nx / (params.patch_size); + + n_patches += x_patch + 1; + } break; default: GGML_ABORT("unsupported projector type"); From effe66958e25d860ffb12715e00ff313d821b248 Mon Sep 17 00:00:00 2001 From: bluebread Date: Sat, 22 Nov 2025 02:09:37 +0000 Subject: [PATCH 24/37] mtmd: minor changed --- tools/mtmd/clip.cpp | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f291894b6ea..23d86f95751 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -739,13 +739,14 @@ struct clip_graph { struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); - struct ggml_tensor * rel_w = ggml_cont( - ctx0, - ggml_permute(ctx0, ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), 0, - 2, 1, 3)); + struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0, + ggml_mul_mat(ctx0, + rw, + ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), + 0, 2, 1, 3)); struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); - struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h, W); + struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); @@ -2466,7 +2467,7 @@ struct clip_graph { return inpL; } - // attn: [k_h*k_w, q_h*q_w] + // attn: [q_h*q_w, k_h*k_w] // rel_h: [q_h, q_w, k_h] // rel_w: [q_h, q_w, k_w] @@ -2474,24 +2475,29 @@ struct clip_graph { ggml_context * ctx, ggml_tensor * attn, ggml_tensor * rel_w, - ggml_tensor * rel_h, - int q_size + ggml_tensor * rel_h ) { + const int k_w = rel_w->ne[0]; + const int k_h = rel_h->ne[0]; + const int q_w = rel_h->ne[1]; + const int q_h = rel_h->ne[2]; + + GGML_ASSERT(q_w == rel_w->ne[1]); + GGML_ASSERT(q_h == rel_w->ne[2]); + GGML_ASSERT(attn->ne[0] == k_h*k_w); + GGML_ASSERT(attn->ne[1] == q_h*q_w); - ggml_tensor *attn_4d = - ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]); + ggml_tensor *attn_4d = ggml_reshape_4d(ctx, attn, k_w, k_h, attn->ne[1], attn->ne[2]); - ggml_tensor *rel_h_4d = - ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]); + ggml_tensor *rel_h_4d = ggml_reshape_4d(ctx, rel_h, 1, k_h, attn->ne[1], attn->ne[2]); ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d - ggml_tensor *rel_w_4d = - ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]); + ggml_tensor *rel_w_4d = ggml_reshape_4d(ctx, rel_w, k_w, 1, attn->ne[1], attn->ne[2]); ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d - ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep)); + ggml_tensor * result = ggml_add_inplace(ctx, attn_4d, ggml_add_inplace(ctx, rel_h_rep, rel_w_rep)); result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); From ee8a1488f97c06fc13c7ee6e9233e78155780a06 Mon Sep 17 00:00:00 2001 From: bluebread Date: Sat, 22 Nov 2025 15:48:13 +0000 Subject: [PATCH 25/37] mtmd: add native resolution support --- tools/mtmd/clip.cpp | 115 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 6 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 22441d0f694..1399c8a30e5 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -194,6 +194,8 @@ struct clip_hparams { int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; + bool crop_mode = false; + // audio int32_t n_mel_bins = 0; // whisper preprocessor int32_t proj_stack_factor = 0; // ultravox @@ -3337,11 +3339,12 @@ struct clip_model_loader { log_ffn_op = "gelu_erf"; // temporary solution for logging } break; case PROJECTOR_TYPE_DEEPSEEKOCR: - { - hparams.patch_size = 16; - hparams.image_size = 1024; - hparams.warmup_image_size = 1024; - } break; + { + hparams.patch_size = 16; + hparams.image_size = 1024; + hparams.warmup_image_size = 1024; + hparams.crop_mode = false; + } break; default: break; } @@ -4992,7 +4995,107 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } break; case PROJECTOR_TYPE_DEEPSEEKOCR: - { + if (!params.crop_mode) { + /* Native Resolution (Tiny/Small/Base/Large) */ + + const int native_resolutions[] = { + 512 /* tiny */, 640 /* small */, 1024 /* base */, 1280 /* large */ + }; + // original image size + const int orig_w = original_size.width; + const int orig_h = original_size.height; + const int orig_area = orig_h * orig_w; + + // mode selection logic (find most suitable resolution) + int mode_i = 0; + int min_diff = orig_area; + + for (int i = 0; i < 4; i++) { + int r = native_resolutions[i]; + if (std::abs(orig_area - r*r) < min_diff) { + mode_i = i; + min_diff = std::abs(orig_area - r*r); + } + } + + const int image_size = native_resolutions[mode_i]; + + if (mode_i < 2) { + // TINY/SMALL MODE: Direct resize (no slicing) + // Just resize the image to image_size × image_size + + clip_image_u8_ptr resized_img(clip_image_u8_init()); + img_tool::resize(*img, *resized_img, + clip_image_size{image_size, image_size}, + img_tool::RESIZE_ALGO_BICUBIC); // Match PIL default + + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*resized_img, *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + + res_imgs->grid_x = 1; + res_imgs->grid_y = 1; + } + else { + // BASE/LARGE MODE: Resize with aspect ratio + padding + // Resize maintaining aspect ratio, then pad to square + + float scale = std::min( + static_cast(image_size) / orig_w, + static_cast(image_size) / orig_h + ); + int new_w = static_cast(orig_w * scale); + int new_h = static_cast(orig_h * scale); + + clip_image_u8_ptr scaled_img(clip_image_u8_init()); + img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h}, + img_tool::RESIZE_ALGO_BICUBIC); + + // Use mean color for padding + unsigned char pad_r = static_cast(params.image_mean[0] * 255.0f); + unsigned char pad_g = static_cast(params.image_mean[1] * 255.0f); + unsigned char pad_b = static_cast(params.image_mean[2] * 255.0f); + + // Step 2: Pad to image_size × image_size (center padding) + clip_image_u8_ptr padded_img(clip_image_u8_init()); + padded_img->nx = image_size; + padded_img->ny = image_size; + padded_img->buf.resize(image_size * image_size * 3); // black padding + + // Fill with mean color + for (int i = 0; i < image_size * image_size; ++i) { + padded_img->buf[i * 3 + 0] = pad_r; + padded_img->buf[i * 3 + 1] = pad_g; + padded_img->buf[i * 3 + 2] = pad_b; + } + + // Calculate padding offsets (center the image) + int pad_x = (image_size - new_w) / 2; + int pad_y = (image_size - new_h) / 2; + + // Copy scaled image into padded canvas + for (int y = 0; y < new_h; ++y) { + for (int x = 0; x < new_w; ++x) { + int src_idx = (y * new_w + x) * 3; + int dst_idx = ((y + pad_y) * image_size + (x + pad_x)) * 3; + padded_img->buf[dst_idx + 0] = scaled_img->buf[src_idx + 0]; + padded_img->buf[dst_idx + 1] = scaled_img->buf[src_idx + 1]; + padded_img->buf[dst_idx + 2] = scaled_img->buf[src_idx + 2]; + } + } + + // Step 3: Normalize and output + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*padded_img, *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + + res_imgs->grid_x = 1; + res_imgs->grid_y = 1; + } + } + else { + /* Dynamic Resolution (Gundam/Gundam-M) */ + // configurable, or read from params const int min_num = 2; const int max_num = 9; From 4cfa15fcd718700f7cee0c8c619238d5b50d0348 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Sat, 22 Nov 2025 16:57:34 +0100 Subject: [PATCH 26/37] - image encoding debugged - issues fixed mainly related wrong config like n_patches etc. - configs need to be corrected in the converter --- tools/mtmd/clip.cpp | 67 ++++++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 22441d0f694..37e6e2a1067 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -739,8 +739,8 @@ struct clip_graph { struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); - struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0, - ggml_mul_mat(ctx0, + struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0, + ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), 0, 2, 1, 3)); @@ -801,9 +801,8 @@ struct clip_graph { cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_3_w, model.neck_3_b, hparams.eps); - //TODO : check conv padding - cur = ggml_conv_2d_s1_ph(ctx0, model.net_2, cur); - cur = ggml_conv_2d_s1_ph(ctx0, model.net_3, cur); + cur = ggml_conv_2d(ctx0, model.net_2, cur, 2,2,1,1, 1,1); + cur = ggml_conv_2d(ctx0, model.net_3, cur, 2,2,1,1, 1,1); ggml_build_forward_expand(gf, cur); return cur; @@ -838,22 +837,27 @@ struct clip_graph { ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); + // FIXME remove n_patches is hardcoded + int clip_n_patches = 256; // FIXME hardcoded for sam 1024x1024 with 16x16 patches + // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3)); - global_features_1 = ggml_reshape_2d(ctx0, global_features_1, n_embd, n_patches); + // flatten 2nd and 3rd dims + global_features_1 = ggml_reshape_2d(ctx0, global_features_1, global_features_1->ne[0], clip_n_patches); // remove CLS token global_features_2 = ggml_view_2d(ctx0, global_features_2, - n_embd, n_patches, + n_embd, clip_n_patches, ggml_row_size(global_features_2->type, n_embd), 0); ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1); - global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd, n_patches); + global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd,clip_n_patches); global_features = ggml_cont(ctx0, global_features); global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); global_features = ggml_add(ctx0, global_features, model.fc_b); global_features = build_global_local_features(ctx0,global_features); + global_features = ggml_cont(ctx0, ggml_permute(ctx0, global_features, 1, 0, 2, 3)); ggml_build_forward_expand(gf, global_features); return gf; } @@ -868,16 +872,16 @@ struct clip_graph { GGML_ASSERT(model.view_seperator != nullptr); // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] - ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h) + ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 16, 16, 1); // (n_dim, w, h) t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); - nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows + nl = ggml_repeat_4d(ctx0, nl, 16, 1, 1280, 1); // n_pos rows // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim) - t = ggml_reshape_2d(ctx0, t, 1280, 64 * (64 + 1)); // (n_dim, h*(w+1)) + t = ggml_reshape_2d(ctx0, t, 1280, 16 * (16 + 1)); // (n_dim, h*(w+1)) // 5) append view_separator as an extra "token": @@ -1538,9 +1542,12 @@ struct clip_graph { GGML_ASSERT(model.class_embedding != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); - const int n_pos = n_patches + 1; - ggml_tensor * inp = ggml_cont(ctx0,ggml_permute(ctx0, patch_embeds,2,1,0,3)); - inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + ggml_tensor * inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds)); + + + const int n_pos = 257; // +1 for [CLS] + inp = ggml_cont(ctx0,ggml_permute(ctx0, inp,2,1,0,3)); + inp = ggml_reshape_2d(ctx0, inp, n_embd, inp->ne[1]*inp->ne[2]*inp->ne[3]); @@ -1552,7 +1559,9 @@ struct clip_graph { // for selecting learned pos embd, used by ViT ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); - cb(positions, "positions", -1); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions); @@ -2525,7 +2534,7 @@ struct clip_graph { ggml_tensor * q_coord = ggml_arange(ctx, 0.0f, static_cast(q_size), 1.0f); // [q_size] ggml_tensor * k_coord = ggml_arange(ctx, 0.0f, static_cast(k_size), 1.0f); // [k_size] ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, k_size, q_size); - + // broadcast reshape: q_coord = ggml_cont(ctx, ggml_repeat(ctx, @@ -2538,8 +2547,8 @@ struct clip_graph { float q_scale = std::max((float)k_size/q_size, 1.0f); float k_scale = std::max((float)q_size/k_size, 1.0f); - // This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with - // the original implementation. + // This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with + // the original implementation. if (q_size != k_size) { q_coord = ggml_scale_inplace(ctx, q_coord, q_scale); k_coord = ggml_scale_inplace(ctx, k_coord, k_scale); @@ -2548,7 +2557,7 @@ struct clip_graph { // ------------------------------------------------- // relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling // ------------------------------------------------- - + rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size] rel = ggml_scale_bias(ctx, rel, 1.0f, (k_size - 1.0f)*k_scale); // [q_size, k_size] // Clamp to [0, L-1] range for valid indexing @@ -2559,10 +2568,10 @@ struct clip_graph { // ------------------------------------------------- ggml_tensor * idx_2d = ggml_cast(ctx, rel, GGML_TYPE_I32); // [q_size, k_size] - + // Gather from rel_pos → [qk, C] // ------------------------------------------------- - + // flatten to 1D for ggml_get_rows int qk = q_size * k_size; ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] @@ -5237,9 +5246,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_DEEPSEEKOCR: { - int x_patch = img->nx / (params.patch_size); - - n_patches += x_patch + 1; + n_patches = 1280; } break; default: @@ -5573,10 +5580,20 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_JANUS_PRO: case PROJECTOR_TYPE_COGVLM: - case PROJECTOR_TYPE_DEEPSEEKOCR: { // do nothing } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + //FIXME we need correct this when all model configs are set correctly + //n_patch is not correct right now + int32_t n_pos = 16 * 16 + 1; //hardcode for now + std::vector positions(n_pos); + for (int i = 0; i < n_pos; i++) { + positions[i] = i; + } + set_input_i32("positions", positions); + } break; case PROJECTOR_TYPE_LLAMA4: { // set the 2D positions From 3f71188303d9bdab9b1b51b786a7b3ecf55ee944 Mon Sep 17 00:00:00 2001 From: bluebread Date: Sun, 23 Nov 2025 09:22:00 +0000 Subject: [PATCH 27/37] mtmd: correct token order --- src/llama-vocab.cpp | 1 + tools/mtmd/mtmd-cli.cpp | 15 ++++++++++++--- tools/mtmd/mtmd.cpp | 4 ++++ tools/mtmd/mtmd.h | 3 +++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 735c5d547f9..2634ab7c5ec 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2347,6 +2347,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "_" || t.first == "<|end_of_text|>" || t.first == "" // smoldocling + || t.first == "<|end▁of▁sentence|>" // deepseek-ocr ) { special_eog_ids.insert(t.second); if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 3e19e95958a..8ff93f08b9d 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -222,14 +222,18 @@ static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { bool add_bos = ctx.chat_history.empty(); - auto formatted_chat = chat_add_and_format(ctx, msg); - LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); mtmd_input_text text; - text.text = formatted_chat.c_str(); + text.text = msg.content.c_str(); text.add_special = add_bos; text.parse_special = true; + if (!mtmd_is_deepseekocr(ctx.ctx_vision.get())) { + auto formatted_chat = chat_add_and_format(ctx, msg); + LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); + text.text = formatted_chat.c_str(); + } + if (g_is_interrupted) return 0; mtmd::input_chunks chunks(mtmd_input_chunks_init()); @@ -332,6 +336,11 @@ int main(int argc, char ** argv) { } } else { + if (mtmd_is_deepseekocr(ctx.ctx_vision.get())) { + LOG_ERR("\n DeepSeek-OCR doesn't support chat mode."); + return 1; + } + LOG("\n Running in chat mode, available commands:"); if (mtmd_support_vision(ctx.ctx_vision.get())) { LOG("\n /image load an image"); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 16349e8f406..994013bea91 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -864,6 +864,10 @@ int mtmd_get_audio_bitrate(mtmd_context * ctx) { return 16000; // 16kHz } +bool mtmd_is_deepseekocr(mtmd_context * ctx) { + return ctx->ctx_v && clip_is_deepseekocr(ctx->ctx_v); +} + // // public API functions // diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 775fba6215c..99fdcd46501 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -117,6 +117,9 @@ MTMD_API bool mtmd_support_audio(mtmd_context * ctx); // return -1 if audio is not supported MTMD_API int mtmd_get_audio_bitrate(mtmd_context * ctx); +// whether the current model is DeepSeek-OCR +MTMD_API bool mtmd_is_deepseekocr(mtmd_context * ctx); + // mtmd_bitmap // // if bitmap is image: From 206f8abc3c5d15736dde0779dfbb020b354632c3 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Sun, 23 Nov 2025 20:27:02 +0100 Subject: [PATCH 28/37] - dynamic resizing - changes are concerning PR https://github.com/sfallah/llama.cpp/pull/4 --- tools/mtmd/clip.cpp | 104 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 81 insertions(+), 23 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index cc8a4461892..82d0b46a47b 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -676,7 +676,25 @@ struct clip_graph { const int enc_n_patches = enc_image_size / enc_patch_size; // 64 ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd); - ggml_tensor * cur = ggml_add(ctx0, inpL, model.pos_embed); + ggml_tensor * cur = nullptr; + + const auto tgt_size = inpL->ne[1]; + const auto str_size = model.pos_embed->ne[1]; + if (str_size != tgt_size) { + ggml_tensor * new_pos_embed = ggml_interpolate( + ctx0, + model.pos_embed, + tgt_size, + tgt_size, + enc_n_embd, + 1, + ggml_scale_mode::GGML_SCALE_MODE_BICUBIC + ); + new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 2,1,0,3)); + cur = ggml_add(ctx0, inpL, new_pos_embed); + } else { + cur = ggml_add(ctx0, inpL, model.pos_embed); + } // loop over layers for (int il = 0; il < _depth; il++) { @@ -840,10 +858,11 @@ struct clip_graph { ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); // FIXME remove n_patches is hardcoded - int clip_n_patches = 256; // FIXME hardcoded for sam 1024x1024 with 16x16 patches // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3)); + int clip_n_patches = global_features_1->ne[1] * global_features_1->ne[2]; + // flatten 2nd and 3rd dims global_features_1 = ggml_reshape_2d(ctx0, global_features_1, global_features_1->ne[0], clip_n_patches); @@ -874,21 +893,24 @@ struct clip_graph { GGML_ASSERT(model.view_seperator != nullptr); // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] - ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 16, 16, 1); // (n_dim, w, h) + const auto h = static_cast(std::sqrt(static_cast(global_features->ne[1]))); + const auto w = h; + const auto n_dim = global_features->ne[0]; + ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, n_dim, h, w, 1); // (n_dim, w, h) t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); - nl = ggml_repeat_4d(ctx0, nl, 16, 1, 1280, 1); // n_pos rows + nl = ggml_repeat_4d(ctx0, nl, h, 1, n_dim, 1); // n_pos rows // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim) - t = ggml_reshape_2d(ctx0, t, 1280, 16 * (16 + 1)); // (n_dim, h*(w+1)) + t = ggml_reshape_2d(ctx0, t, n_dim, h* (h + 1)); // (n_dim, h*(w+1)) // 5) append view_separator as an extra "token": // view_separator: [n_dim] -> [n_dim, 1] - ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, 1280, 1); // (n_dim, 1) + ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) // concat along token dimension (dim=1): t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) @@ -1547,10 +1569,35 @@ struct clip_graph { ggml_tensor * inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds)); - const int n_pos = 257; // +1 for [CLS] inp = ggml_cont(ctx0,ggml_permute(ctx0, inp,2,1,0,3)); inp = ggml_reshape_2d(ctx0, inp, n_embd, inp->ne[1]*inp->ne[2]*inp->ne[3]); + ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings)); + + int n_pos = new_pos_embd->ne[1]; // +1 for [CLS] + const auto tgt_size = static_cast(std::sqrt(inp->ne[1])); + const auto src_size = static_cast(std::sqrt(n_pos - 1)); + + + if (tgt_size != src_size) { + //ggml_tensor * old_pos_embd = ggml_new_tensor_2d(ctx0, model.position_embeddings->type, model.position_embeddings->ne[0], str_size * str_size); + ggml_tensor * old_pos_embd = ggml_view_2d(ctx0, new_pos_embd, + new_pos_embd->ne[0], src_size * src_size, + ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0); + ggml_tensor * cls_tok = ggml_view_2d(ctx0, new_pos_embd, + new_pos_embd->ne[0], 1, + ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size); + new_pos_embd = ggml_interpolate(ctx0, + old_pos_embd, + tgt_size, + tgt_size, + new_pos_embd->ne[0], 1, GGML_SCALE_MODE_BICUBIC); + new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1); + //new_pos_embd = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embd, 2,1,0,3)); + new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1); + n_pos = tgt_size * tgt_size + 1; + } + // add CLS token @@ -1560,11 +1607,8 @@ struct clip_graph { norm_type norm_t = NORM_TYPE_NORMAL; // for selecting learned pos embd, used by ViT - ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); - ggml_set_name(positions, "positions"); - ggml_set_input(positions); - - ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions); + ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); + ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, learned_pos_embd, @@ -2525,7 +2569,27 @@ struct clip_graph { const int64_t C = rel_pos->ne[0]; // channels const int64_t L = rel_pos->ne[1]; // length - GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L); + //GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L); + + const auto max_rel_dist = 2*std::max(q_size, k_size) - 1; + ggml_tensor * rel_pos_resized = rel_pos; + + if (max_rel_dist != L) { + // Linear interpolation + const auto scale = L / static_cast(max_rel_dist); + ggml_tensor * indices = ggml_arange(ctx, 0.0f, static_cast(max_rel_dist), 1.0f); + indices = ggml_scale_inplace(ctx, indices, scale); + ggml_tensor * indices_floor= ggml_cast(ctx, ggml_floor(ctx, indices), GGML_TYPE_I32); + ggml_tensor * indices_ceil = ggml_cast(ctx, ggml_ceil(ctx, indices), GGML_TYPE_I32); + ggml_tensor * weights = ggml_sub(ctx, indices, indices_floor); + ggml_tensor * ws1 = ggml_scale_bias(ctx, weights, -1.0f, 1.0f); + rel_pos_resized = ggml_cont(ctx , ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)); // [C, L] for ggml_get_rows + ggml_tensor * rs1 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_floor), 1, 0, 2, 3)); // lower rows + rs1 = ggml_mul(ctx, rs1, ws1); // lower rows + ggml_tensor * rs2 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_ceil), 1, 0, 2, 3)); // upper rows + rs2 = ggml_mul(ctx, rs2, weights); // upper rows + rel_pos_resized = ggml_add(ctx,rs1, rs2); + } // ------------------------------------------------- // 1) q_idx ← arange(0..q_size-1) [q_size] @@ -5007,7 +5071,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str if (!params.crop_mode) { /* Native Resolution (Tiny/Small/Base/Large) */ - const int native_resolutions[] = { + const int native_resolutions[] = { 512 /* tiny */, 640 /* small */, 1024 /* base */, 1280 /* large */ }; // original image size @@ -5060,7 +5124,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h}, img_tool::RESIZE_ALGO_BICUBIC); - // Use mean color for padding + // Use mean color for padding unsigned char pad_r = static_cast(params.image_mean[0] * 255.0f); unsigned char pad_g = static_cast(params.image_mean[1] * 255.0f); unsigned char pad_b = static_cast(params.image_mean[2] * 255.0f); @@ -5352,6 +5416,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int x_patch = img->nx / (params.patch_size); n_patches += x_patch + 1; + n_patches = 1280; + } break; default: @@ -5690,14 +5756,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_DEEPSEEKOCR: { - //FIXME we need correct this when all model configs are set correctly - //n_patch is not correct right now - int32_t n_pos = 16 * 16 + 1; //hardcode for now - std::vector positions(n_pos); - for (int i = 0; i < n_pos; i++) { - positions[i] = i; - } - set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_LLAMA4: { From 40e7e6e706644f36136f3c701189058391196ae5 Mon Sep 17 00:00:00 2001 From: bluebread Date: Mon, 24 Nov 2025 08:16:32 +0000 Subject: [PATCH 29/37] mtmd: quick fix token order --- tools/mtmd/mtmd-cli.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 8ff93f08b9d..70a17dba39e 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -316,8 +316,18 @@ int main(int argc, char ** argv) { if (is_single_turn) { g_is_generating = true; if (params.prompt.find(mtmd_default_marker()) == std::string::npos) { - for (size_t i = 0; i < params.image.size(); i++) { - params.prompt += mtmd_default_marker(); + if (mtmd_is_deepseekocr(ctx.ctx_vision.get())) { + std::string image_tokens = ""; + for (size_t i = 0; i < params.image.size(); i++) { + image_tokens += mtmd_default_marker(); + image_tokens += '\n'; + } + params.prompt = image_tokens + params.prompt; + } + else { + for (size_t i = 0; i < params.image.size(); i++) { + params.prompt += mtmd_default_marker(); + } } } common_chat_msg msg; From 81533e494e859432c861e0abb8561fad1d28ab7a Mon Sep 17 00:00:00 2001 From: bluebread Date: Mon, 24 Nov 2025 09:02:03 +0000 Subject: [PATCH 30/37] mtmd: fix danling pointer --- tools/mtmd/mtmd-cli.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 70a17dba39e..5e6cc79f379 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -228,10 +228,12 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { text.add_special = add_bos; text.parse_special = true; + std::string formatted_chat; + if (!mtmd_is_deepseekocr(ctx.ctx_vision.get())) { - auto formatted_chat = chat_add_and_format(ctx, msg); + formatted_chat = chat_add_and_format(ctx, msg); LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); - text.text = formatted_chat.c_str(); + text.text = formatted_chat.c_str(); } if (g_is_interrupted) return 0; From a488b495f7d8efccd8fc9ee84a2c7ca8c61428df Mon Sep 17 00:00:00 2001 From: bluebread Date: Sat, 29 Nov 2025 02:17:49 +0000 Subject: [PATCH 31/37] mtmd: SAM numerically works --- convert_hf_to_gguf.py | 12 ++- ggml/src/ggml.c | 1 + tools/mtmd/clip-impl.h | 182 +++++++++++++++++++++++++++++++++- tools/mtmd/clip.cpp | 217 ++++++++++++++++++++++------------------- 4 files changed, 307 insertions(+), 105 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 12baa198fef..2ef5430c618 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5839,12 +5839,14 @@ def get_vision_config(self) -> dict[str, Any]: def tensor_force_quant(self, name, new_name, bid, n_dims): + # TODO: increase numercial stability. maybe delete later. + return gguf.GGMLQuantizationType.F32 # related to https://github.com/ggml-org/llama.cpp/issues/13025 - if "input_projection" in name: - return gguf.GGMLQuantizationType.F16 - if ".embeddings." in name: - return gguf.GGMLQuantizationType.F32 - return super().tensor_force_quant(name, new_name, bid, n_dims) + # if "input_projection" in name: + # return gguf.GGMLQuantizationType.F16 + # if ".embeddings." in name: + # return gguf.GGMLQuantizationType.F32 + # return super().tensor_force_quant(name, new_name, bid, n_dims) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Only process vision-related tensors, skip language model tensors diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9be35c1be84..837b1751e37 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5081,6 +5081,7 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(q->ne[3] == v->ne[3]); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 63d59055668..c35d7a08d42 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -449,6 +450,33 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { // debugging // + +static std::string to_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static void print_tensor_info(ggml_tensor * t) { + const struct ggml_tensor * src0 = t->src[0]; + const struct ggml_tensor * src1 = t->src[1]; + + char src1_str[128] = {0}; + if (src1) { + snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, to_ne_string(src1).c_str()); + } + + printf("%s: %s = %s(%s{%s}, %s)\n", + t->name, ggml_type_name(t->type), ggml_op_desc(t), + src0->name, to_ne_string(src0).c_str(), + src1 ? src1_str : ""); +} + static void print_tensor_shape(ggml_tensor * t) { printf("%s.shape = [", t->name); for (int i = 0; i < ggml_n_dims(t); ++i) { @@ -460,12 +488,50 @@ static void print_tensor_shape(ggml_tensor * t) { printf("]\n"); } +static void print_tensor_sum(ggml_tensor * t, uint8_t * data, int64_t n) { + (void) n; // unused parameter + ggml_type type = t->type; + int64_t * ne = t->ne; + size_t * nb = t->nb; + double sum = 0.0; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + float v; + if (type == GGML_TYPE_F16) { + v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + } else if (type == GGML_TYPE_F32) { + v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I32) { + v = (float) *(int32_t *) &data[i]; + } else if (type == GGML_TYPE_I16) { + v = (float) *(int16_t *) &data[i]; + } else if (type == GGML_TYPE_I8) { + v = (float) *(int8_t *) &data[i]; + } else { + GGML_ABORT("fatal error"); + } + sum += v; + } + } + } + } + printf("%s.sum = %.6f\n", t->name, sum); +} + static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { ggml_type type = t->type; int64_t * ne = t->ne; size_t * nb = t->nb; + printf("%s.data: [\n", t->name); for (int64_t i3 = 0; i3 < ne[3]; i3++) { - printf("%s.data: [\n", t->name); + if (i3 == n && ne[3] > 2*n) { + printf(" ..., \n"); + i3 = ne[3] - n; + } + printf(" [\n"); for (int64_t i2 = 0; i2 < ne[2]; i2++) { if (i2 == n && ne[2] > 2*n) { printf(" ..., \n"); @@ -507,6 +573,120 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { } printf(" ]\n"); } + printf(" ]\n"); +} + +static void save_tensor_to_file(const struct ggml_tensor * tensor) { + char filename[512]; + snprintf(filename, sizeof(filename), "%s_cpp.txt", tensor->name); + + FILE * f = fopen(filename, "w"); + if (!f) { + fprintf(stderr, "Failed to open %s\n", filename); + return; + } + + // Check tensor size and warn if too large + int64_t total_elements = ggml_nelements(tensor); + fprintf(stderr, "Saving tensor %s (%lld elements) to %s\n", + tensor->name, (long long)total_elements, filename); + + if (total_elements > 10000000) { // 10M elements + fprintf(stderr, "Warning: tensor is very large (%lld elements), this may take time\n", + (long long)total_elements); + } + + uint8_t * data = (uint8_t *) tensor->data; + ggml_type type = tensor->type; + const int64_t * ne = tensor->ne; + const size_t * nb = tensor->nb; + + // Use a buffer to reduce I/O calls + const size_t BUF_SIZE = 8192; + char * buf = (char *) malloc(BUF_SIZE); + if (!buf) { + fprintf(stderr, "Failed to allocate buffer\n"); + fclose(f); + return; + } + size_t buf_pos = 0; + + // Helper lambda to flush buffer + auto flush_buf = [&]() { + if (buf_pos > 0) { + fwrite(buf, 1, buf_pos, f); + buf_pos = 0; + } + }; + + // Helper to append to buffer + auto append = [&](const char * str, size_t len) { + if (buf_pos + len >= BUF_SIZE) { + flush_buf(); + } + if (len >= BUF_SIZE) { + // String too large for buffer, write directly + fwrite(str, 1, len, f); + } else { + memcpy(buf + buf_pos, str, len); + buf_pos += len; + } + }; + + auto append_str = [&](const char * str) { + append(str, strlen(str)); + }; + + char num_buf[32]; + + // Write header once for all batches + append_str(tensor->name); + append_str(".data: [\n"); + + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + append_str(" [\n"); // Start of batch + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + append_str(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + append_str(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + float v; + if (type == GGML_TYPE_F16) { + v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + } else if (type == GGML_TYPE_F32) { + v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I32) { + v = (float) *(int32_t *) &data[i]; + } else if (type == GGML_TYPE_I16) { + v = (float) *(int16_t *) &data[i]; + } else if (type == GGML_TYPE_I8) { + v = (float) *(int8_t *) &data[i]; + } else { + GGML_ABORT("fatal error"); + } + int len = snprintf(num_buf, sizeof(num_buf), "%8.4f", v); + append(num_buf, len); + if (i0 < ne[0] - 1) append_str(", "); + } + append_str("],\n"); + } + append_str(" ],\n"); + } + append_str(" ]"); // End of batch + if (i3 < ne[3] - 1) { + append_str(",\n"); // Comma between batches + } else { + append_str("\n"); + } + } + + append_str("]\n"); // Close the top-level array + + flush_buf(); + free(buf); + fclose(f); + fprintf(stderr, "Tensor saved successfully\n"); } // diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 82d0b46a47b..4b7a4a563f8 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -681,16 +681,19 @@ struct clip_graph { const auto tgt_size = inpL->ne[1]; const auto str_size = model.pos_embed->ne[1]; if (str_size != tgt_size) { + ggml_tensor * old_pos_embed = nullptr; + old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3)); + // TODO: ggml_interpolate doesn't support bicubic model for CUDA backend ggml_tensor * new_pos_embed = ggml_interpolate( ctx0, - model.pos_embed, + old_pos_embed, tgt_size, tgt_size, enc_n_embd, 1, ggml_scale_mode::GGML_SCALE_MODE_BICUBIC ); - new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 2,1,0,3)); + new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3)); cur = ggml_add(ctx0, inpL, new_pos_embed); } else { cur = ggml_add(ctx0, inpL, model.pos_embed); @@ -699,10 +702,10 @@ struct clip_graph { // loop over layers for (int il = 0; il < _depth; il++) { auto & layer = model.sam_layers[il]; + ggml_tensor * shortcut = cur; // layernorm1 cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); - cb(cur, "enc_layer_inp_normed", il); const int64_t w0 = cur->ne[1]; const int64_t h0 = cur->ne[2]; @@ -711,7 +714,7 @@ struct clip_graph { // local attention layer - apply window partition // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 //cur = ggml_win_part(ctx0, cur, 14); - cur = window_partition(ctx0, cur, 14); + cur = window_partition(ctx0, cur, 14); // TODO: make this configurable } const int64_t W = cur->ne[1]; @@ -719,110 +722,93 @@ struct clip_graph { // self-attention { + const int B = cur->ne[3]; + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); cur = ggml_add(ctx0, cur, layer.qkv_b); - const int B = cur->ne[3]; - - cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W * H, B); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2)); - - ggml_tensor * Qcur = - ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 0); - Qcur = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, enc_n_heads, W * H, B); - Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); - Qcur = ggml_reshape_3d(ctx0, Qcur, enc_d_heads, W * H, B * enc_n_heads); - - ggml_tensor * Kcur = - ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], cur->nb[3]); - Kcur = ggml_reshape_4d(ctx0, Kcur, enc_d_heads, enc_n_heads, W * H, B); - Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); - Kcur = ggml_reshape_3d(ctx0, Kcur, enc_d_heads, W * H, B * enc_n_heads); - - ggml_tensor * Vcur = - ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 2 * cur->nb[3]); - Vcur = ggml_reshape_4d(ctx0, Vcur, enc_d_heads, enc_n_heads, W * H, B); - Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); // transposed - Vcur = ggml_reshape_3d(ctx0, Vcur, W * H, enc_d_heads, B * enc_n_heads); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcur, Qcur); - - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads)); - - struct ggml_tensor * rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); - struct ggml_tensor * rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); - - struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); - - struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0, - ggml_mul_mat(ctx0, - rw, - ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), - 0, 2, 1, 3)); - struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); - - struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcur, KQ_soft_max); - - cur = ggml_reshape_4d( - ctx0, - ggml_cont(ctx0, ggml_permute(ctx0, ggml_reshape_4d(ctx0, KQV, enc_d_heads, W * H, enc_n_heads, B), - 0, 2, 1, 3)), - enc_n_embd, W, H, B); - + cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape + cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W*H, B); + + ggml_tensor * Q; + ggml_tensor * K; + ggml_tensor * V; + + Q = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]); + Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), enc_d_heads, enc_n_heads, W*H, B); + Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads] + + K = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]); + K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), enc_d_heads, enc_n_heads, W*H, B); + K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads] + + V = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]); + V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), enc_d_heads, enc_n_heads, W*H, B); + V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads] + + ggml_tensor * mask; + ggml_tensor * rw; + ggml_tensor * rh; + ggml_tensor * qr; + + rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C] + rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C] + qr = ggml_reshape_4d(ctx0, Q, enc_d_heads, W, H, B*enc_n_heads); + + const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H; + + rw = ggml_mul_mat (ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*enc_n_heads, W, H, W] + rw = ggml_cont (ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*enc_n_heads, H, W, W] + rw = ggml_reshape_4d(ctx0, rw, W, 1, W*H, enc_n_heads*B); + rw = ggml_repeat_4d (ctx0, rw, W, H, W*H, enc_n_heads*B); + rh = ggml_mul_mat (ctx0, rh, qr); // [B*enc_n_heads, H, W, H] + rh = ggml_reshape_4d(ctx0, rh, 1, H, W*H, enc_n_heads*B); + mask = ggml_add (ctx0, rw, rh); // [B*enc_n_heads, H*W, H, W] + mask = ggml_reshape_4d(ctx0, mask, W*H, W*H, enc_n_heads, B); + mask = ggml_pad (ctx0, mask, 0, WH_pad, 0, 0); + mask = ggml_cast (ctx0, mask, GGML_TYPE_F16); + + float scale = 1.0f / sqrtf((float)enc_d_heads); + cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, enc_n_heads, enc_d_heads] + + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), enc_n_embd, W, H, B); cur = ggml_mul_mat(ctx0, layer.o_w, cur); cur = ggml_add_inplace(ctx0, cur, layer.o_b); } if (hparams.is_global_attn(il) == false) { // local attention layer - reverse window partition - cur = window_unpartition(ctx0, cur, w0, h0, 14); + cur = window_unpartition(ctx0, cur, w0, h0, 14); // TODO: make window size configurable } // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, shortcut); ggml_tensor * inpFF = cur; - - cb(inpFF, "ffn_inp", il); - // layernorm2 cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); - cb(cur, "ffn_inp_normed", il); // ffn cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, layer.ff_down_b, hparams.ffn_op, il); - cb(cur, "ffn_out", il); - - // residual 2 cur = ggml_add(ctx0, cur, inpFF); - cb(cur, "layer_out", il); + cb(cur, "sam_layer_out", il); } - cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); - - cur = ggml_conv_2d_sk_p0(ctx0, model.neck_0_w, cur); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_1_w, model.neck_1_b, hparams.eps); + const int out_chans = model.neck_0_w->ne[3]; - cur = ggml_conv_2d_s1_ph(ctx0, model.neck_2_w, cur); + cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1); + cur = sam_layer_norm_2d(ctx0, cur, out_chans, model.neck_1_w, model.neck_1_b, hparams.eps); + cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1); + cur = sam_layer_norm_2d(ctx0, cur, out_chans, model.neck_3_w, model.neck_3_b, hparams.eps); - cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_3_w, model.neck_3_b, hparams.eps); - - cur = ggml_conv_2d(ctx0, model.net_2, cur, 2,2,1,1, 1,1); - cur = ggml_conv_2d(ctx0, model.net_3, cur, 2,2,1,1, 1,1); + cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1); + cb(cur, "sam_output", -1); ggml_build_forward_expand(gf, cur); return cur; @@ -2576,19 +2562,27 @@ struct clip_graph { if (max_rel_dist != L) { // Linear interpolation - const auto scale = L / static_cast(max_rel_dist); - ggml_tensor * indices = ggml_arange(ctx, 0.0f, static_cast(max_rel_dist), 1.0f); - indices = ggml_scale_inplace(ctx, indices, scale); - ggml_tensor * indices_floor= ggml_cast(ctx, ggml_floor(ctx, indices), GGML_TYPE_I32); - ggml_tensor * indices_ceil = ggml_cast(ctx, ggml_ceil(ctx, indices), GGML_TYPE_I32); - ggml_tensor * weights = ggml_sub(ctx, indices, indices_floor); - ggml_tensor * ws1 = ggml_scale_bias(ctx, weights, -1.0f, 1.0f); - rel_pos_resized = ggml_cont(ctx , ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)); // [C, L] for ggml_get_rows - ggml_tensor * rs1 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_floor), 1, 0, 2, 3)); // lower rows - rs1 = ggml_mul(ctx, rs1, ws1); // lower rows - ggml_tensor * rs2 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_ceil), 1, 0, 2, 3)); // upper rows - rs2 = ggml_mul(ctx, rs2, weights); // upper rows - rel_pos_resized = ggml_add(ctx,rs1, rs2); + int64_t ne0 = rel_pos_resized->ne[0]; + int64_t ne1 = rel_pos_resized->ne[1]; + int64_t ne2 = rel_pos_resized->ne[2]; + int64_t ne3 = rel_pos_resized->ne[3]; + + rel_pos_resized = ggml_reshape_3d( + ctx, + ggml_cont(ctx, ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)), + ne1, 1, ne0*ne2*ne3 + ); + rel_pos_resized = ggml_reshape_4d( + ctx, + ggml_interpolate( + ctx, + rel_pos_resized, + max_rel_dist, 1, ne0*ne2*ne3, 1, + ggml_scale_mode::GGML_SCALE_MODE_BILINEAR + ), + max_rel_dist, ne0, ne2, ne3 + ); + rel_pos_resized = ggml_cont(ctx, ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)); } // ------------------------------------------------- @@ -2627,7 +2621,7 @@ struct clip_graph { rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size] rel = ggml_scale_bias(ctx, rel, 1.0f, (k_size - 1.0f)*k_scale); // [q_size, k_size] // Clamp to [0, L-1] range for valid indexing - rel = ggml_clamp(ctx, rel, 0.0f, static_cast(rel_pos->ne[1] - 1)); + rel = ggml_clamp(ctx, rel, 0.0f, static_cast(rel_pos_resized->ne[1] - 1)); // ------------------------------------------------- // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) @@ -2641,7 +2635,7 @@ struct clip_graph { // flatten to 1D for ggml_get_rows int qk = q_size * k_size; ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] - ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] + ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos_resized, idx_flat); // [qk, C] // ------------------------------------------------- // Gather from rel_pos → [qk, C] @@ -2671,7 +2665,7 @@ struct clip_graph { } x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b); x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); - x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b); + x = ggml_reshape_4d(ctx, x, c, window, window, npw * nph * b); return x; } @@ -5078,6 +5072,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const int orig_w = original_size.width; const int orig_h = original_size.height; const int orig_area = orig_h * orig_w; + std::array color; + + for (int i = 0; i < 3; i++) { + color[i] = (int)(255 * params.image_mean[i]); + } // mode selection logic (find most suitable resolution) int mode_i = 0; @@ -5100,7 +5099,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str clip_image_u8_ptr resized_img(clip_image_u8_init()); img_tool::resize(*img, *resized_img, clip_image_size{image_size, image_size}, - img_tool::RESIZE_ALGO_BICUBIC); // Match PIL default + img_tool::RESIZE_ALGO_BICUBIC, true, color); // Match PIL default clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(*resized_img, *res, params.image_mean, params.image_std); @@ -5122,7 +5121,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str clip_image_u8_ptr scaled_img(clip_image_u8_init()); img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h}, - img_tool::RESIZE_ALGO_BICUBIC); + img_tool::RESIZE_ALGO_BICUBIC, true, color); // Use mean color for padding unsigned char pad_r = static_cast(params.image_mean[0] * 255.0f); @@ -5801,8 +5800,28 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (ggml_tensor * t : ctx->debug_print_tensors) { std::vector data(ggml_nbytes(t)); ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); + print_tensor_info(t); print_tensor_shape(t); - print_tensor_data(t, data.data(), 3); + print_tensor_sum(t, data.data(), 3); + std::string tname_s = std::string(t->name); + + bool is_stored = false; + std::vector patterns = { + /* Add tensor names here to dump (e.g. "sam_output") */ + "sam_output" + }; + + for (auto & p : patterns) { + if (tname_s == p) { + save_tensor_to_file(t); + is_stored = true; + break; + } + } + + if (!is_stored) { + print_tensor_data(t, data.data(), 3); + } } } From ccb2f2385ec9b8b688eb0a85d01bc0432514f753 Mon Sep 17 00:00:00 2001 From: bluebread Date: Sat, 29 Nov 2025 07:04:14 +0000 Subject: [PATCH 32/37] mtmd: debug CLIP-L (vit_pre_ln) --- tools/mtmd/clip.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 4b7a4a563f8..57ab543b858 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1555,8 +1555,8 @@ struct clip_graph { ggml_tensor * inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds)); - inp = ggml_cont(ctx0,ggml_permute(ctx0, inp,2,1,0,3)); - inp = ggml_reshape_2d(ctx0, inp, n_embd, inp->ne[1]*inp->ne[2]*inp->ne[3]); + inp = ggml_reshape_2d(ctx0, inp, inp->ne[0]*inp->ne[1], inp->ne[2]); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings)); @@ -1587,7 +1587,7 @@ struct clip_graph { // add CLS token - inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + inp = ggml_concat(ctx0, model.class_embedding, inp, 1); //TODO : check norm type for dp-ocr-clip norm_type norm_t = NORM_TYPE_NORMAL; @@ -1596,7 +1596,6 @@ struct clip_graph { ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); - ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, learned_pos_embd, nullptr); // shape [1024, 16, 16] @@ -2395,7 +2394,7 @@ struct clip_graph { // pre-layernorm if (model.pre_ln_w) { inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); - cb(inpL, "pre_ln", -1); + cb(inpL, "vit_pre_ln", -1); } // loop over layers @@ -5808,7 +5807,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima bool is_stored = false; std::vector patterns = { /* Add tensor names here to dump (e.g. "sam_output") */ - "sam_output" + "vit_pre_ln" }; for (auto & p : patterns) { From 841a4a88df610a2e3df1711499974546bc7eb45e Mon Sep 17 00:00:00 2001 From: bluebread Date: Sat, 29 Nov 2025 16:40:50 +0000 Subject: [PATCH 33/37] mtmd: debug CLIP-L & first working DeepSeek-OCR model --- tools/mtmd/clip.cpp | 77 +++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 42 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 57ab543b858..807f22f4b1b 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -837,34 +837,32 @@ struct clip_graph { ggml_cgraph * build_deepseek_ocr() { //patch embedding ggml_tensor * inp_raw = build_inp_raw(); - - ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny)); - ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); - + // FIXME remove n_patches is hardcoded - + // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) - global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3)); + global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1, 1, 2, 0, 3)); int clip_n_patches = global_features_1->ne[1] * global_features_1->ne[2]; - + // flatten 2nd and 3rd dims global_features_1 = ggml_reshape_2d(ctx0, global_features_1, global_features_1->ne[0], clip_n_patches); - + // remove CLS token - global_features_2 = ggml_view_2d(ctx0, global_features_2, - n_embd, clip_n_patches, - ggml_row_size(global_features_2->type, n_embd), 0); - - ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1); + global_features_2 = ggml_view_2d(ctx0, global_features_2, n_embd, clip_n_patches, + global_features_2->nb[1], global_features_2->nb[1]); + + ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 0); global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd,clip_n_patches); global_features = ggml_cont(ctx0, global_features); global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); global_features = ggml_add(ctx0, global_features, model.fc_b); global_features = build_global_local_features(ctx0,global_features); - global_features = ggml_cont(ctx0, ggml_permute(ctx0, global_features, 1, 0, 2, 3)); + + cb(global_features, "dsocr_output", -1); + ggml_build_forward_expand(gf, global_features); return gf; } @@ -878,30 +876,23 @@ struct clip_graph { GGML_ASSERT(model.image_newline != nullptr); GGML_ASSERT(model.view_seperator != nullptr); - // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] const auto h = static_cast(std::sqrt(static_cast(global_features->ne[1]))); const auto w = h; const auto n_dim = global_features->ne[0]; - ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, n_dim, h, w, 1); // (n_dim, w, h) - t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) - ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); - nl = ggml_repeat_4d(ctx0, nl, h, 1, n_dim, 1); // n_pos rows + ggml_tensor * cur; + ggml_tensor * imgnl; + ggml_tensor * vs; - // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] - t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim) - - t = ggml_reshape_2d(ctx0, t, n_dim, h* (h + 1)); // (n_dim, h*(w+1)) - - - // 5) append view_separator as an extra "token": - // view_separator: [n_dim] -> [n_dim, 1] - ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) - - // concat along token dimension (dim=1): - t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) + cur = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); + imgnl = ggml_repeat_4d(ctx0, model.image_newline, n_dim, 1, h, 1); + cur = ggml_reshape_2d(ctx0, ggml_concat(ctx0, cur, imgnl, 1), n_dim, (w+1)*h); + cb(cur, "insert_imgnl", -1); + vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + cur = ggml_concat(ctx0, cur, vs, 1); // (n_dim, h*(w+1) + 1) + cb(cur, "insert_vs", -1); - return t; + return cur; } @@ -1596,8 +1587,8 @@ struct clip_graph { ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); - ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, learned_pos_embd, - nullptr); // shape [1024, 16, 16] + ggml_tensor * cur = build_vit(inp, n_pos, norm_t, ffn_op_type::FFN_GELU_QUICK, + learned_pos_embd, nullptr); // shape [1024, 16, 16] ggml_build_forward_expand(gf, cur); @@ -2394,7 +2385,7 @@ struct clip_graph { // pre-layernorm if (model.pre_ln_w) { inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); - cb(inpL, "vit_pre_ln", -1); + cb(inpL, "pre_ln", -1); } // loop over layers @@ -5411,12 +5402,15 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_DEEPSEEKOCR: { - int x_patch = img->nx / (params.patch_size); - - n_patches += x_patch + 1; - n_patches = 1280; - - + // SAM encoder applies two stride-2 convolutions (net_2 and net_3) + // which reduces spatial dimensions by 4x in each direction (16x total) + // E.g., 64x64 -> 16x16 patches + n_patches /= 16; + + // build_global_local_features adds image newlines and view separator + // Formula: h*(w+1) + 1 where h = w = sqrt(n_patches) + int h = static_cast(std::sqrt(static_cast(n_patches))); + n_patches = h * (h + 1) + 1; } break; default: GGML_ABORT("unsupported projector type"); @@ -5807,7 +5801,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima bool is_stored = false; std::vector patterns = { /* Add tensor names here to dump (e.g. "sam_output") */ - "vit_pre_ln" }; for (auto & p : patterns) { From c5f4c64fe4814cf1746cdcf3443432cdcbfc94e6 Mon Sep 17 00:00:00 2001 From: bluebread Date: Sun, 30 Nov 2025 16:57:19 +0000 Subject: [PATCH 34/37] mtmd : add --dsocr-mode CLI argument for DeepSeek-OCR resolution control & all native resolution modes work --- common/arg.cpp | 15 +++ common/common.h | 1 + ggml/src/ggml-cuda/upscale.cu | 2 + tools/mtmd/clip-impl.h | 4 +- tools/mtmd/clip.cpp | 188 ++++++++++++++++++---------------- tools/mtmd/clip.h | 11 ++ tools/mtmd/mtmd-cli.cpp | 1 + tools/mtmd/mtmd.cpp | 22 ++++ tools/mtmd/mtmd.h | 3 + 9 files changed, 159 insertions(+), 88 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index d2b81c331ca..458ca407952 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1824,6 +1824,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.image_max_tokens = value; } ).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MAX_TOKENS")); + add_opt(common_arg( + {"--dsocr-mode"}, "MODE", + "DeepSeek-OCR resolution mode, one of:\n" + "- auto (default): automatically select resolution\n" + "- tiny, small, base, large: native resolution\n" + "- gundam, gundam-master: dynamic resolution", + [](common_params & params, const std::string & value) { + if (value == "auto" || value == "tiny" || value == "small" || value == "base" || + value == "large" || value == "gundam" || value == "gundam-master") { + params.dsocr_mode = value; + } else { + throw std::invalid_argument("invalid value"); + } + } + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_DSOCR_MODE")); if (llama_supports_rpc()) { add_opt(common_arg( {"--rpc"}, "SERVERS", diff --git a/common/common.h b/common/common.h index 2f23d0baa83..82d3989f10a 100644 --- a/common/common.h +++ b/common/common.h @@ -433,6 +433,7 @@ struct common_params { std::vector image; // path to image file(s) int image_min_tokens = -1; int image_max_tokens = -1; + std::string dsocr_mode = "auto"; // DeepSeek-OCR resolution mode: auto, tiny, small, base, large, gundam, gundam-master // finetune struct lr_opt lr; diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index 687c669304d..944d00a2adc 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -214,5 +214,7 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, pixel_offset, stream); + } else { + GGML_ABORT("fatal error"); } } diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index b3ddf773f86..a486ee13840 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -569,7 +569,7 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { printf(" ]\n"); } -static void save_tensor_to_file(const struct ggml_tensor * tensor) { +static void save_tensor_to_file(const struct ggml_tensor * tensor, const uint8_t * data_ptr) { char filename[512]; snprintf(filename, sizeof(filename), "%s_cpp.txt", tensor->name); @@ -589,7 +589,7 @@ static void save_tensor_to_file(const struct ggml_tensor * tensor) { (long long)total_elements); } - uint8_t * data = (uint8_t *) tensor->data; + const uint8_t * data = (data_ptr) ? data_ptr : (uint8_t *) tensor->data; ggml_type type = tensor->type; const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 730e1318a92..a590c067269 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -193,8 +193,6 @@ struct clip_hparams { int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; - bool crop_mode = false; - // audio int32_t n_mel_bins = 0; // whisper preprocessor int32_t proj_stack_factor = 0; // ultravox @@ -208,6 +206,9 @@ struct clip_hparams { int32_t custom_image_min_tokens = -1; int32_t custom_image_max_tokens = -1; + // DeepSeek-OCR resolution mode + enum clip_dsocr_mode dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_AUTO; + void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) { const int cur_merge = n_merge == 0 ? 1 : n_merge; const int patch_area = patch_size * patch_size * cur_merge * cur_merge; @@ -512,6 +513,7 @@ struct clip_ctx { if (ctx_params.image_max_tokens > 0) { model.hparams.custom_image_max_tokens = ctx_params.image_max_tokens; } + model.hparams.dsocr_mode = ctx_params.dsocr_mode; backend_ptrs.push_back(backend_cpu); backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu)); @@ -3403,7 +3405,6 @@ struct clip_model_loader { hparams.patch_size = 16; hparams.image_size = 1024; hparams.warmup_image_size = 1024; - hparams.crop_mode = false; } break; default: break; @@ -5054,9 +5055,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } break; case PROJECTOR_TYPE_DEEPSEEKOCR: - if (!params.crop_mode) { - /* Native Resolution (Tiny/Small/Base/Large) */ - + { const int native_resolutions[] = { 512 /* tiny */, 640 /* small */, 1024 /* base */, 1280 /* large */ }; @@ -5065,29 +5064,44 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const int orig_h = original_size.height; const int orig_area = orig_h * orig_w; std::array color; - + for (int i = 0; i < 3; i++) { color[i] = (int)(255 * params.image_mean[i]); } - - // mode selection logic (find most suitable resolution) + int mode_i = 0; - int min_diff = orig_area; - - for (int i = 0; i < 4; i++) { - int r = native_resolutions[i]; - if (std::abs(orig_area - r*r) < min_diff) { - mode_i = i; - min_diff = std::abs(orig_area - r*r); + + if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_TINY) { + mode_i = 0; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_SMALL) { + mode_i = 1; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_BASE) { + mode_i = 2; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_LARGE) { + mode_i = 3; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM) { + mode_i = 4; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM_MASTER) { + mode_i = 5; + } else { + if (params.dsocr_mode != clip_dsocr_mode::CLIP_DSOCR_MODE_AUTO) { + LOG_WRN("%s: unknown dsocr_mode, using auto mode\n", __func__); + } + int min_diff = orig_area; + for (int i = 0; i < 4; i++) { + int r = native_resolutions[i]; + if (std::abs(orig_area - r*r) < min_diff) { + mode_i = i; + min_diff = std::abs(orig_area - r*r); + } } } - const int image_size = native_resolutions[mode_i]; - if (mode_i < 2) { - // TINY/SMALL MODE: Direct resize (no slicing) + /* Native Resolution (Tiny/Small) */ + const int image_size = native_resolutions[mode_i]; + // Just resize the image to image_size × image_size - clip_image_u8_ptr resized_img(clip_image_u8_init()); img_tool::resize(*img, *resized_img, clip_image_size{image_size, image_size}, @@ -5100,10 +5114,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_x = 1; res_imgs->grid_y = 1; } - else { - // BASE/LARGE MODE: Resize with aspect ratio + padding + else if (mode_i < 4) { + /* Native Resolution (Base/Large) */ + const int image_size = native_resolutions[mode_i]; + // Resize maintaining aspect ratio, then pad to square - float scale = std::min( static_cast(image_size) / orig_w, static_cast(image_size) / orig_h @@ -5120,7 +5135,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str unsigned char pad_g = static_cast(params.image_mean[1] * 255.0f); unsigned char pad_b = static_cast(params.image_mean[2] * 255.0f); - // Step 2: Pad to image_size × image_size (center padding) + // Pad to image_size × image_size (center padding) clip_image_u8_ptr padded_img(clip_image_u8_init()); padded_img->nx = image_size; padded_img->ny = image_size; @@ -5148,7 +5163,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } - // Step 3: Normalize and output + // Normalize and output clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(*padded_img, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); @@ -5156,68 +5171,69 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_x = 1; res_imgs->grid_y = 1; } - } - else { - /* Dynamic Resolution (Gundam/Gundam-M) */ - - // configurable, or read from params - const int min_num = 2; - const int max_num = 9; - const int image_size = params.image_size; // typically 640 - // const bool use_thumbnail = true; // mimic python's use_thumbnail - - // original image size - const int orig_w = original_size.width; - const int orig_h = original_size.height; - - // 1) build candidate grids (cols, rows) - auto target_ratios = ds_build_target_ratios(min_num, max_num); - - // 2) pick the grid that best matches the original aspect ratio - const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); - auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); - const int grid_cols = best.first; // how many tiles horizontally - const int grid_rows = best.second; // how many tiles vertically - - // 3) compute the target (forced) size — python did: - // target_width = image_size * cols - // target_height = image_size * rows - const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows }; - - // 4) prepare slice instructions, same style as the idefics3 branch - llava_uhd::slice_instructions instructions; - instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global - instructions.refined_size = refined_size; - instructions.grid_size = clip_image_size{ grid_cols, grid_rows }; - - // in deepseek python they always produce *full* 640x640 blocks, - // so we can do a simple double loop over rows/cols: - for (int r = 0; r < grid_rows; ++r) { - for (int c = 0; c < grid_cols; ++c) { - const int x = c * image_size; - const int y = r * image_size; - - instructions.slices.push_back(llava_uhd::slice_coordinates{ - /* x */ x, - /* y */ y, - /* size */ clip_image_size{ image_size, image_size } - }); + else { + GGML_ABORT("DeepSeek-OCR: Gundam/Gundam-Master haven't been tested yet.\n"); + /* Dynamic Resolution (Gundam/Gundam-Master) */ + + // configurable, or read from params + const int min_num = 2; + const int max_num = 9; + const int image_size = params.image_size; // typically 640 + // const bool use_thumbnail = true; // mimic python's use_thumbnail + + // original image size + const int orig_w = original_size.width; + const int orig_h = original_size.height; + + // 1) build candidate grids (cols, rows) + auto target_ratios = ds_build_target_ratios(min_num, max_num); + + // 2) pick the grid that best matches the original aspect ratio + const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); + auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); + const int grid_cols = best.first; // how many tiles horizontally + const int grid_rows = best.second; // how many tiles vertically + + // 3) compute the target (forced) size — python did: + // target_width = image_size * cols + // target_height = image_size * rows + const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows }; + + // 4) prepare slice instructions, same style as the idefics3 branch + llava_uhd::slice_instructions instructions; + instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global + instructions.refined_size = refined_size; + instructions.grid_size = clip_image_size{ grid_cols, grid_rows }; + + // in deepseek python they always produce *full* 640x640 blocks, + // so we can do a simple double loop over rows/cols: + for (int r = 0; r < grid_rows; ++r) { + for (int c = 0; c < grid_cols; ++c) { + const int x = c * image_size; + const int y = r * image_size; + + instructions.slices.push_back(llava_uhd::slice_coordinates{ + /* x */ x, + /* y */ y, + /* size */ clip_image_size{ image_size, image_size } + }); + } } + + // 5) run the actual slicing (this should: resize to refined_size, then crop every slice) + auto imgs = llava_uhd::slice_image(img, instructions); + + // 7) cast & normalize like the idefics3 branch + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + // keep the grid info — the model may need to know how to reassemble / attend + res_imgs->grid_x = grid_cols; + res_imgs->grid_y = grid_rows; } - - // 5) run the actual slicing (this should: resize to refined_size, then crop every slice) - auto imgs = llava_uhd::slice_image(img, instructions); - - // 7) cast & normalize like the idefics3 branch - for (size_t i = 0; i < imgs.size(); ++i) { - clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); - res_imgs->entries.push_back(std::move(res)); - } - - // keep the grid info — the model may need to know how to reassemble / attend - res_imgs->grid_x = grid_cols; - res_imgs->grid_y = grid_rows; } break; @@ -5807,7 +5823,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (auto & p : patterns) { if (tname_s == p) { - save_tensor_to_file(t); + save_tensor_to_file(t, data.data()); is_stored = true; break; } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index eb96b389cfb..c0b191dcf30 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -29,11 +29,22 @@ enum clip_flash_attn_type { CLIP_FLASH_ATTN_TYPE_ENABLED = 1, }; +enum clip_dsocr_mode { + CLIP_DSOCR_MODE_AUTO, + CLIP_DSOCR_MODE_TINY, + CLIP_DSOCR_MODE_SMALL, + CLIP_DSOCR_MODE_BASE, + CLIP_DSOCR_MODE_LARGE, + CLIP_DSOCR_MODE_GUNDAM, + CLIP_DSOCR_MODE_GUNDAM_MASTER, +}; + struct clip_context_params { bool use_gpu; enum clip_flash_attn_type flash_attn_type; int image_min_tokens; int image_max_tokens; + enum clip_dsocr_mode dsocr_mode; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index bd52341e357..f30ec1bcbf4 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -138,6 +138,7 @@ struct mtmd_cli_context { mparams.flash_attn_type = params.flash_attn_type; mparams.image_min_tokens = params.image_min_tokens; mparams.image_max_tokens = params.image_max_tokens; + mparams.dsocr_mode = params.dsocr_mode.c_str(); ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index f6ac40ba911..0c360f13741 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -110,6 +110,7 @@ mtmd_context_params mtmd_context_params_default() { /* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ -1, /* image_max_tokens */ -1, + /* dsocr_mode */ "auto", }; return params; } @@ -172,11 +173,32 @@ struct mtmd_context { throw std::runtime_error("media_marker must not be empty"); } + enum clip_dsocr_mode dsocr_mode; + + if (std::string(ctx_params.dsocr_mode) == "auto") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_AUTO; + } else if (std::string(ctx_params.dsocr_mode) == "tiny") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_TINY; + } else if (std::string(ctx_params.dsocr_mode) == "small") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_SMALL; + } else if (std::string(ctx_params.dsocr_mode) == "base") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_BASE; + } else if (std::string(ctx_params.dsocr_mode) == "large") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_LARGE; + } else if (std::string(ctx_params.dsocr_mode) == "gundam") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM; + } else if (std::string(ctx_params.dsocr_mode) == "gundam-master") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM_MASTER; + } else { + throw std::invalid_argument("invalid value"); + } + clip_context_params ctx_clip_params { /* use_gpu */ ctx_params.use_gpu, /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, + /* dsocr_mode */ dsocr_mode, }; auto res = clip_init(mmproj_fname, ctx_clip_params); diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index bc4c9a57bda..3dc34ae3b77 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -86,6 +86,9 @@ struct mtmd_context_params { // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) int image_max_tokens; // maximum number of tokens for image input (default: read from metadata) + + // DeepSeek-OCR resolution mode + const char * dsocr_mode; // one of: auto, tiny, small, base, large, gundam, gundam-master }; MTMD_API const char * mtmd_default_marker(void); From 95239f92b985ab1d7ceb8c4e07de48a2cbb98007 Mon Sep 17 00:00:00 2001 From: bluebread Date: Mon, 1 Dec 2025 07:31:24 +0000 Subject: [PATCH 35/37] mtmd: simplify SAM patch embedding --- tools/mtmd/clip.cpp | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index a590c067269..f46ea33678a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -663,28 +663,24 @@ struct clip_graph { return gf; } - ggml_tensor * build_sam_enc(ggml_tensor * inp_raw, - const int enc_image_size = 1024 - ) { + ggml_tensor * build_sam_enc(ggml_tensor * inp_raw) { constexpr int enc_n_embd = 768; constexpr int _depth = 12; constexpr int enc_n_heads = 12; constexpr int enc_d_heads = enc_n_embd / enc_n_heads; - // constexpr int _prompt_n_embd = 256; - constexpr int enc_patch_size = 16; - // constexpr int _window_size = 14; - - const int enc_n_patches = enc_image_size / enc_patch_size; // 64 - - ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd); - ggml_tensor * cur = nullptr; + ggml_tensor * inpL; + + inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw); + inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, enc_n_embd)); + inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3)); + + ggml_tensor * cur; const auto tgt_size = inpL->ne[1]; const auto str_size = model.pos_embed->ne[1]; if (str_size != tgt_size) { ggml_tensor * old_pos_embed = nullptr; old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3)); - // TODO: ggml_interpolate doesn't support bicubic model for CUDA backend ggml_tensor * new_pos_embed = ggml_interpolate( ctx0, old_pos_embed, @@ -838,7 +834,7 @@ struct clip_graph { ggml_cgraph * build_deepseek_ocr() { //patch embedding ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny)); + ggml_tensor * global_features_1 = build_sam_enc(inp_raw); ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); // FIXME remove n_patches is hardcoded @@ -5819,6 +5815,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima bool is_stored = false; std::vector patterns = { /* Add tensor names here to dump (e.g. "sam_output") */ + "inpL", "inp_raw_cpy" }; for (auto & p : patterns) { From c914e0540549a699e8d6b0a7667e7af90d1c3d92 Mon Sep 17 00:00:00 2001 From: bluebread Date: Wed, 3 Dec 2025 05:18:39 +0000 Subject: [PATCH 36/37] mtmd: adapt Pillow image resizing function --- tools/mtmd/clip.cpp | 215 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 212 insertions(+), 3 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f46ea33678a..3fea05cdc5d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -4218,6 +4218,7 @@ struct img_tool { enum resize_algo { RESIZE_ALGO_BILINEAR, RESIZE_ALGO_BICUBIC, + RESIZE_ALGO_BICUBIC_PILLOW, // RESIZE_ALGO_LANCZOS, // TODO }; @@ -4247,6 +4248,9 @@ struct img_tool { case RESIZE_ALGO_BICUBIC: resize_bicubic(src, dst, target_resolution.width, target_resolution.height); break; + case RESIZE_ALGO_BICUBIC_PILLOW: + resize_bicubic_pillow(src, dst, target_resolution.width, target_resolution.height); + break; default: throw std::runtime_error("Unsupported resize algorithm"); } @@ -4266,6 +4270,9 @@ struct img_tool { case RESIZE_ALGO_BICUBIC: resize_bicubic(src, resized_image, new_width, new_height); break; + case RESIZE_ALGO_BICUBIC_PILLOW: + resize_bicubic_pillow(src, resized_image, new_width, new_height); + break; default: throw std::runtime_error("Unsupported resize algorithm"); } @@ -4475,6 +4482,209 @@ struct img_tool { return true; } + // Bicubic resize function using Pillow's ImagingResample algorithm + // Adapted from https://github.com/python-pillow/Pillow/blob/main/src/libImaging/Resample.c + static bool resize_bicubic_pillow(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { + const int PRECISION_BITS = 32 - 8 - 2; + + // Bicubic filter function + auto bicubic_filter = [](double x) -> double { + constexpr double a = -0.5; + if (x < 0.0) { + x = -x; + } + if (x < 1.0) { + return ((a + 2.0) * x - (a + 3.0)) * x * x + 1; + } + if (x < 2.0) { + return (((x - 5) * x + 8) * x - 4) * a; + } + return 0.0; + }; + + constexpr double filter_support = 2.0; + + // Clipping function for 8-bit values + auto clip8 = [](int val) -> uint8_t { + if (val < 0) return 0; + if (val > 255) return 255; + return static_cast(val); + }; + + // Precompute coefficients + auto precompute_coeffs = [&](int inSize, double in0, double in1, int outSize, + std::vector & bounds, std::vector & kk) -> int { + double support, scale, filterscale; + double center, ww, ss; + int xx, x, ksize, xmin, xmax; + + filterscale = scale = (in1 - in0) / outSize; + if (filterscale < 1.0) { + filterscale = 1.0; + } + + support = filter_support * filterscale; + ksize = static_cast(std::ceil(support)) * 2 + 1; + + std::vector prekk(outSize * ksize); + bounds.resize(outSize * 2); + + for (xx = 0; xx < outSize; xx++) { + center = in0 + (xx + 0.5) * scale; + ww = 0.0; + ss = 1.0 / filterscale; + + xmin = static_cast(center - support + 0.5); + if (xmin < 0) { + xmin = 0; + } + + xmax = static_cast(center + support + 0.5); + if (xmax > inSize) { + xmax = inSize; + } + xmax -= xmin; + + double * k = &prekk[xx * ksize]; + for (x = 0; x < xmax; x++) { + double w = bicubic_filter((x + xmin - center + 0.5) * ss); + k[x] = w; + ww += w; + } + + for (x = 0; x < xmax; x++) { + if (ww != 0.0) { + k[x] /= ww; + } + } + + for (; x < ksize; x++) { + k[x] = 0; + } + + bounds[xx * 2 + 0] = xmin; + bounds[xx * 2 + 1] = xmax; + } + + // Normalize coefficients to fixed-point + kk.resize(outSize * ksize); + for (int i = 0; i < outSize * ksize; i++) { + if (prekk[i] < 0) { + kk[i] = static_cast(-0.5 + prekk[i] * (1 << PRECISION_BITS)); + } else { + kk[i] = static_cast(0.5 + prekk[i] * (1 << PRECISION_BITS)); + } + } + + return ksize; + }; + + // Horizontal resampling + auto resample_horizontal = [&](const clip_image_u8 & imIn, clip_image_u8 & imOut, + int ksize, const std::vector & bounds, const std::vector & kk) { + imOut.ny = imIn.ny; + imOut.buf.resize(3 * imOut.nx * imOut.ny); + + for (int yy = 0; yy < imOut.ny; yy++) { + for (int xx = 0; xx < imOut.nx; xx++) { + int xmin = bounds[xx * 2 + 0]; + int xmax = bounds[xx * 2 + 1]; + const int32_t * k = &kk[xx * ksize]; + + int32_t ss0 = 1 << (PRECISION_BITS - 1); + int32_t ss1 = 1 << (PRECISION_BITS - 1); + int32_t ss2 = 1 << (PRECISION_BITS - 1); + + for (int x = 0; x < xmax; x++) { + int src_idx = ((yy * imIn.nx) + (x + xmin)) * 3; + ss0 += static_cast(imIn.buf[src_idx + 0]) * k[x]; + ss1 += static_cast(imIn.buf[src_idx + 1]) * k[x]; + ss2 += static_cast(imIn.buf[src_idx + 2]) * k[x]; + } + + int dst_idx = (yy * imOut.nx + xx) * 3; + imOut.buf[dst_idx + 0] = clip8(ss0 >> PRECISION_BITS); + imOut.buf[dst_idx + 1] = clip8(ss1 >> PRECISION_BITS); + imOut.buf[dst_idx + 2] = clip8(ss2 >> PRECISION_BITS); + } + } + }; + + // Vertical resampling + auto resample_vertical = [&](const clip_image_u8 & imIn, clip_image_u8 & imOut, + int ksize, const std::vector & bounds, const std::vector & kk) { + imOut.nx = imIn.nx; + imOut.buf.resize(3 * imOut.nx * imOut.ny); + + for (int yy = 0; yy < imOut.ny; yy++) { + int ymin = bounds[yy * 2 + 0]; + int ymax = bounds[yy * 2 + 1]; + const int32_t * k = &kk[yy * ksize]; + + for (int xx = 0; xx < imOut.nx; xx++) { + int32_t ss0 = 1 << (PRECISION_BITS - 1); + int32_t ss1 = 1 << (PRECISION_BITS - 1); + int32_t ss2 = 1 << (PRECISION_BITS - 1); + + for (int y = 0; y < ymax; y++) { + int src_idx = ((y + ymin) * imIn.nx + xx) * 3; + ss0 += static_cast(imIn.buf[src_idx + 0]) * k[y]; + ss1 += static_cast(imIn.buf[src_idx + 1]) * k[y]; + ss2 += static_cast(imIn.buf[src_idx + 2]) * k[y]; + } + + int dst_idx = (yy * imOut.nx + xx) * 3; + imOut.buf[dst_idx + 0] = clip8(ss0 >> PRECISION_BITS); + imOut.buf[dst_idx + 1] = clip8(ss1 >> PRECISION_BITS); + imOut.buf[dst_idx + 2] = clip8(ss2 >> PRECISION_BITS); + } + } + }; + + // Main resampling logic + const int src_width = img.nx; + const int src_height = img.ny; + + dst.nx = target_width; + dst.ny = target_height; + + bool need_horizontal = (target_width != src_width); + bool need_vertical = (target_height != src_height); + + // Precompute coefficients for both passes + std::vector bounds_horiz, bounds_vert; + std::vector kk_horiz, kk_vert; + int ksize_horiz = 0, ksize_vert = 0; + + if (need_horizontal) { + ksize_horiz = precompute_coeffs(src_width, 0.0, src_width, target_width, bounds_horiz, kk_horiz); + } + + if (need_vertical) { + ksize_vert = precompute_coeffs(src_height, 0.0, src_height, target_height, bounds_vert, kk_vert); + } + + // Perform two-pass resampling + if (need_horizontal && need_vertical) { + // Both horizontal and vertical + clip_image_u8 temp; + temp.nx = target_width; + resample_horizontal(img, temp, ksize_horiz, bounds_horiz, kk_horiz); + resample_vertical(temp, dst, ksize_vert, bounds_vert, kk_vert); + } else if (need_horizontal) { + // Only horizontal + resample_horizontal(img, dst, ksize_horiz, bounds_horiz, kk_horiz); + } else if (need_vertical) { + // Only vertical + resample_vertical(img, dst, ksize_vert, bounds_vert, kk_vert); + } else { + // No resampling needed + dst.buf = img.buf; + } + + return true; + } + static inline int clip(int x, int lower, int upper) { return std::max(lower, std::min(x, upper)); } @@ -5101,7 +5311,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str clip_image_u8_ptr resized_img(clip_image_u8_init()); img_tool::resize(*img, *resized_img, clip_image_size{image_size, image_size}, - img_tool::RESIZE_ALGO_BICUBIC, true, color); // Match PIL default + img_tool::RESIZE_ALGO_BICUBIC_PILLOW, false, color); // Match PIL default clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(*resized_img, *res, params.image_mean, params.image_std); @@ -5124,7 +5334,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str clip_image_u8_ptr scaled_img(clip_image_u8_init()); img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h}, - img_tool::RESIZE_ALGO_BICUBIC, true, color); + img_tool::RESIZE_ALGO_BICUBIC_PILLOW, true, color); // Use mean color for padding unsigned char pad_r = static_cast(params.image_mean[0] * 255.0f); @@ -5815,7 +6025,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima bool is_stored = false; std::vector patterns = { /* Add tensor names here to dump (e.g. "sam_output") */ - "inpL", "inp_raw_cpy" }; for (auto & p : patterns) { From e20857ba59d6dd4d4be9fd363dbc10cacbc339c0 Mon Sep 17 00:00:00 2001 From: bluebread Date: Wed, 3 Dec 2025 07:51:12 +0000 Subject: [PATCH 37/37] mtmd: simplify DeepSeek-OCR dynamic resolution preprocessing --- tools/mtmd/clip.cpp | 81 +++++++++++++++++++++------------------------ 1 file changed, 37 insertions(+), 44 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3fea05cdc5d..a1b0d914663 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -5016,7 +5016,7 @@ static std::vector> ds_build_target_ratios(const int min_num return ratios; } -static std::pair ds_find_closest_aspect_ratio( +static std::pair ds_find_closest_ratio( const float aspect_ratio, const std::vector> &target_ratios, const int width, @@ -5382,60 +5382,53 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str /* Dynamic Resolution (Gundam/Gundam-Master) */ // configurable, or read from params - const int min_num = 2; - const int max_num = 9; - const int image_size = params.image_size; // typically 640 - // const bool use_thumbnail = true; // mimic python's use_thumbnail - + const int min_num = 2; + const int max_num = 9; + const int image_size = (mode_i == 4) ? 640 : 1024; + // original image size - const int orig_w = original_size.width; - const int orig_h = original_size.height; + const int orig_w = original_size.width; + const int orig_h = original_size.height; - // 1) build candidate grids (cols, rows) + // create overview image (thumbnail) + clip_image_u8_ptr overview_img(clip_image_u8_init()); + img_tool::resize(*img, *overview_img, { image_size, image_size }, + img_tool::RESIZE_ALGO_BICUBIC_PILLOW, true, color); + clip_image_f32_ptr overview_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(*overview_img, *overview_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(overview_f32)); + + // build candidate grids (cols, rows) auto target_ratios = ds_build_target_ratios(min_num, max_num); - // 2) pick the grid that best matches the original aspect ratio + // pick the grid that best matches the original aspect ratio const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); - auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); + auto best = ds_find_closest_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); const int grid_cols = best.first; // how many tiles horizontally const int grid_rows = best.second; // how many tiles vertically - - // 3) compute the target (forced) size — python did: - // target_width = image_size * cols - // target_height = image_size * rows - const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows }; - - // 4) prepare slice instructions, same style as the idefics3 branch - llava_uhd::slice_instructions instructions; - instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global - instructions.refined_size = refined_size; - instructions.grid_size = clip_image_size{ grid_cols, grid_rows }; - - // in deepseek python they always produce *full* 640x640 blocks, - // so we can do a simple double loop over rows/cols: + + // resize to refined size (no padding, direct resize) + clip_image_u8_ptr refined_img(clip_image_u8_init()); + img_tool::resize(*img, *refined_img, { image_size * grid_cols, image_size * grid_rows }, + img_tool::RESIZE_ALGO_BICUBIC_PILLOW, false); + + // crop slices from the refined image for (int r = 0; r < grid_rows; ++r) { for (int c = 0; c < grid_cols; ++c) { const int x = c * image_size; const int y = r * image_size; - - instructions.slices.push_back(llava_uhd::slice_coordinates{ - /* x */ x, - /* y */ y, - /* size */ clip_image_size{ image_size, image_size } - }); + + // crop the slice + clip_image_u8_ptr slice_img(clip_image_u8_init()); + img_tool::crop(*refined_img, *slice_img, x, y, image_size, image_size); + + // normalize and add to results + clip_image_f32_ptr slice_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(*slice_img, *slice_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(slice_f32)); } } - - // 5) run the actual slicing (this should: resize to refined_size, then crop every slice) - auto imgs = llava_uhd::slice_image(img, instructions); - - // 7) cast & normalize like the idefics3 branch - for (size_t i = 0; i < imgs.size(); ++i) { - clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); - res_imgs->entries.push_back(std::move(res)); - } - + // keep the grid info — the model may need to know how to reassemble / attend res_imgs->grid_x = grid_cols; res_imgs->grid_y = grid_rows; @@ -5971,8 +5964,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // do nothing } break; case PROJECTOR_TYPE_DEEPSEEKOCR: - { - } break; + { + } break; case PROJECTOR_TYPE_LLAMA4: { // set the 2D positions