diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 222f6ed6dc4..12baa198fef 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -620,6 +620,9 @@ def load_hparams(dir_model: Path, is_mistral_format: bool): if "thinker_config" in config: # rename for Qwen2.5-Omni config["text_config"] = config["thinker_config"]["text_config"] + if "language_config" in config: + # rename for DeepSeekOCR + config["text_config"] = config["language_config"] return config @classmethod @@ -1442,7 +1445,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1488,13 +1491,28 @@ def __init__(self, *args, **kwargs): # TODO @ngxson : this is a hack to support both vision and audio encoders have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) + # FIXME: DeepseekOCRVisionModel specific hack + if self.block_count is None: + if isinstance(self, DeepseekOCRVisionModel): + clip_block_count = self.hparams['layers'] + if clip_block_count is not None: + self.block_count = clip_block_count + if self.block_count is None: + raise KeyError(f"could not find block count using any of: {self.n_block_keys}") self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) # load preprocessor config self.preprocessor_config = {} if not self.is_mistral_format: - with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: - self.preprocessor_config = json.load(f) + # check if preprocessor_config.json exists + if (self.dir_model / "preprocessor_config.json").is_file(): + with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: + self.preprocessor_config = json.load(f) + else: + # try "processing_config" file if exists + if (self.dir_model / "processing_config.json").is_file(): + with open(self.dir_model / "processing_config.json", "r", encoding="utf-8") as f: + self.preprocessor_config = json.load(f) def get_vision_config(self) -> dict[str, Any] | None: config_name = "vision_config" if not self.is_mistral_format else "vision_encoder" @@ -5770,6 +5788,97 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # skip other tensors +@ModelBase.register("DeepseekOCRForCausalLM") +class DeepseekOCRVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + proc_fname = self.dir_model / "processor_config.json" + + if proc_fname.is_file(): + with open(proc_fname, "r") as f: + self.preprocessor_config = json.load(f) + + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DEEPSEEKOCR) + # default values below are taken from HF tranformers code + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) + self.gguf_writer.add_vision_use_gelu(True) + # calculate proj_scale_factor (used by tinygemma3 test model) + image_seq_length = self.preprocessor_config.get("image_seq_length", 256) + n_per_side = int(image_seq_length ** 0.5) + image_size = self.hparams["image_size"] + patch_size = self.hparams["patch_size"] + proj_scale_factor = (image_size // patch_size) // n_per_side + if proj_scale_factor > 0 and proj_scale_factor != 4: + # we only need to write this if it's not the default value + # in this case, we are converting a test model + self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor) + + # SAM configuration + sam_hparams = hparams['sam'] + self.gguf_writer.add_vision_sam_layers_count(sam_hparams['layers']) + self.gguf_writer.add_vision_sam_embedding_length(sam_hparams['width']) + + def get_vision_config(self) -> dict[str, Any]: + vision_config: dict[str, Any] | None = self.global_config.get("vision_config") + + if not vision_config: + raise ValueError("DeepseekOCR model requires 'vision_config' in the model configuration, but it was not found") + + vision_config['sam'] = vision_config['width']['sam_vit_b'] + vision_config.update(vision_config['width']['clip-l-14-224']) + vision_config['hidden_size'] = vision_config['width'] + vision_config['num_heads'] = vision_config['heads'] + vision_config['intermediate_size'] = vision_config['heads'] * 4 + + return vision_config + + + def tensor_force_quant(self, name, new_name, bid, n_dims): + # related to https://github.com/ggml-org/llama.cpp/issues/13025 + if "input_projection" in name: + return gguf.GGMLQuantizationType.F16 + if ".embeddings." in name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Only process vision-related tensors, skip language model tensors + # Vision components: sam_model, vision_model, projector, image_newline, view_seperator + # Language model components to skip: lm_head, embed_tokens, layers, norm + if name.startswith(("lm_head.", "model.embed_tokens.", "model.layers.", "model.norm.")): + return [] + + if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name: + return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)] + + if name.startswith("model.vision_model.transformer.layers."): + # process visual tensors + # split QKV tensors if needed + if ".qkv_proj." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("qkv", "q")), wq), + (self.map_tensor_name(name.replace("qkv", "k")), wk), + (self.map_tensor_name(name.replace("qkv", "v")), wv), + ] + else: + return [(self.map_tensor_name(name), data_torch)] + + return [(self.map_tensor_name(name), data_torch)] + @ModelBase.register("Gemma3nForConditionalGeneration") class Gemma3NModel(Gemma3Model): @@ -6943,6 +7052,7 @@ def prepare_tensors(self): @ModelBase.register( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", + "DeepseekOCRForCausalLM", "KimiVLForConditionalGeneration", ) class DeepseekV2Model(TextModel): @@ -7003,39 +7113,49 @@ def set_vocab(self): raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") def set_gguf_parameters(self): + is_ocr = (self.hparams["num_hidden_layers"] == 12) - # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) - self.hparams["num_key_value_heads"] = 1 + if is_ocr: + self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0) + self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6) + else: + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 super().set_gguf_parameters() hparams = self.hparams + kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 + routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0) + norm_topk_prob = hparams.get("norm_topk_prob", False) + scoring_func = hparams.get("scoring_func", "softmax") self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) - self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + if "kv_lora_rank" in hparams and hparams["kv_lora_rank"] is not None: + self.gguf_writer.add_kv_lora_rank(kv_lora_rank) # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) - self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + if not is_ocr: + self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(kv_lora_rank) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) - self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) - if hparams["scoring_func"] == "sigmoid": + if scoring_func == "sigmoid": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) - elif hparams["scoring_func"] == "softmax": + elif scoring_func == "softmax": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) else: - raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") - - self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + raise ValueError(f"Unsupported scoring_func value: {scoring_func}") rope_scaling = self.hparams.get("rope_scaling") or {} if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: @@ -7043,12 +7163,14 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # skip vision tensors and remove "language_model." for Kimi-VL - if "vision_tower" in name or "multi_modal_projector" in name: + if "vision_" in name or "multi_modal_projector" in name \ + or "image_newline" in name or "model.projector" in name or "sam_model" in name or "view_seperator" in name: return [] if name.startswith("language_model."): diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index cefa39a57c8..ed181a1ab45 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -74,19 +74,19 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } } for (int64_t i3 = 0; i3 < ne[3]; i3++) { - LOG(" [\n"); + LOG(" [\n"); for (int64_t i2 = 0; i2 < ne[2]; i2++) { if (i2 == n && ne[2] > 2*n) { - LOG(" ..., \n"); + LOG(" ..., \n"); i2 = ne[2] - n; } - LOG(" [\n"); + LOG(" [\n"); for (int64_t i1 = 0; i1 < ne[1]; i1++) { if (i1 == n && ne[1] > 2*n) { - LOG(" ..., \n"); + LOG(" ..., \n"); i1 = ne[1] - n; } - LOG(" ["); + LOG(" ["); for (int64_t i0 = 0; i0 < ne[0]; i0++) { if (i0 == n && ne[0] > 2*n) { LOG("..., "); @@ -98,10 +98,10 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } LOG("],\n"); } - LOG(" ],\n"); + LOG(" ],\n"); } - LOG(" ]\n"); - LOG(" sum = %f\n", sum); + LOG(" ]\n"); + LOG(" sum = %f\n", sum); } // TODO: make this abort configurable/optional? @@ -136,7 +136,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); } - LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + LOG("%s: %16s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, t->name, ggml_type_name(t->type), ggml_op_desc(t), src0->name, ggml_ne_string(src0).c_str(), src1 ? src1_str : "", diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6b4b6c5ab07..c4294574631 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -287,6 +287,10 @@ class Attention: class Projector: SCALE_FACTOR = "clip.vision.projector.scale_factor" + class SAM: + BLOCK_COUNT = "clip.vision.sam.block_count" + EMBEDDING_LENGTH = "clip.vision.sam.embedding_length" + class ClipAudio: NUM_MEL_BINS = "clip.audio.num_mel_bins" EMBEDDING_LENGTH = "clip.audio.embedding_length" @@ -664,6 +668,22 @@ class MODEL_TENSOR(IntEnum): V_MM_GATE = auto() # cogvlm V_TOK_BOI = auto() # cogvlm V_TOK_EOI = auto() # cogvlm + V_SAM_POS_EMBD = auto() # Deepseek-OCR + V_SAM_PATCH_EMBD = auto() # Deepseek-OCR + V_SAM_PRE_NORM = auto() # Deepseek-OCR + V_SAM_POST_NORM = auto() # Deepseek-OCR + V_SAM_ATTN_POS_H = auto() # Deepseek-OCR + V_SAM_ATTN_POS_W = auto() # Deepseek-OCR + V_SAM_ATTN_QKV = auto() # Deepseek-OCR + V_SAM_ATTN_OUT = auto() # Deepseek-OCR + V_SAM_MLP_LIN_1 = auto() # Deepseek-OCR + V_SAM_MLP_LIN_2 = auto() # Deepseek-OCR + V_SAM_NECK = auto() # Deepseek-OCR + V_SAM_NET_2 = auto() # Deepseek-OCR + V_SAM_NET_3 = auto() # Deepseek-OCR + V_ENC_EMBD_IMGNL = auto() # Deepseek-OCR + V_ENC_EMBD_VSEP = auto() # Deepseek-OCR + # audio (mtmd) A_ENC_EMBD_POS = auto() A_ENC_CONV1D = auto() @@ -1030,6 +1050,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_GATE: "mm.gate", MODEL_TENSOR.V_TOK_BOI: "v.boi", MODEL_TENSOR.V_TOK_EOI: "v.eoi", + # DeepSeek-OCR sam_model + MODEL_TENSOR.V_SAM_POS_EMBD: "v.sam.pos_embd", + MODEL_TENSOR.V_SAM_PATCH_EMBD: "v.sam.patch_embd", + MODEL_TENSOR.V_SAM_PRE_NORM: "v.sam.blk.{bid}.pre_ln", + MODEL_TENSOR.V_SAM_POST_NORM: "v.sam.blk.{bid}.post_ln", + MODEL_TENSOR.V_SAM_ATTN_POS_H: "v.sam.blk.{bid}.attn.pos_h", + MODEL_TENSOR.V_SAM_ATTN_POS_W: "v.sam.blk.{bid}.attn.pos_w", + MODEL_TENSOR.V_SAM_ATTN_QKV: "v.sam.blk.{bid}.attn.qkv", + MODEL_TENSOR.V_SAM_ATTN_OUT: "v.sam.blk.{bid}.attn.out", + MODEL_TENSOR.V_SAM_MLP_LIN_1: "v.sam.blk.{bid}.mlp.lin1", + MODEL_TENSOR.V_SAM_MLP_LIN_2: "v.sam.blk.{bid}.mlp.lin2", + MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}", + MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2", + MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3", + MODEL_TENSOR.V_ENC_EMBD_IMGNL: "model.image_newline", # Deepseek-OCR + MODEL_TENSOR.V_ENC_EMBD_VSEP: "model.view_seperator", # Deepseek-OCR # audio (mtmd) MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", @@ -1066,6 +1102,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_ENC_EMBD_CLS, MODEL_TENSOR.V_ENC_EMBD_PATCH, MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_EMBD_IMGNL, + MODEL_TENSOR.V_ENC_EMBD_VSEP, MODEL_TENSOR.V_ENC_INPUT_NORM, MODEL_TENSOR.V_ENC_ATTN_QKV, MODEL_TENSOR.V_ENC_ATTN_Q, @@ -1108,6 +1146,19 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_MM_GATE, MODEL_TENSOR.V_TOK_BOI, MODEL_TENSOR.V_TOK_EOI, + MODEL_TENSOR.V_SAM_POS_EMBD, + MODEL_TENSOR.V_SAM_PATCH_EMBD, + MODEL_TENSOR.V_SAM_PRE_NORM, + MODEL_TENSOR.V_SAM_POST_NORM, + MODEL_TENSOR.V_SAM_ATTN_POS_H, + MODEL_TENSOR.V_SAM_ATTN_POS_W, + MODEL_TENSOR.V_SAM_ATTN_QKV, + MODEL_TENSOR.V_SAM_ATTN_OUT, + MODEL_TENSOR.V_SAM_MLP_LIN_1, + MODEL_TENSOR.V_SAM_MLP_LIN_2, + MODEL_TENSOR.V_SAM_NECK, + MODEL_TENSOR.V_SAM_NET_2, + MODEL_TENSOR.V_SAM_NET_3, # audio MODEL_TENSOR.A_ENC_EMBD_POS, MODEL_TENSOR.A_ENC_CONV1D, @@ -2247,7 +2298,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, @@ -3207,6 +3260,7 @@ class VisionProjectorType: LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" + DEEPSEEKOCR = "deepseekocr" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a051daeeb13..fca498a859d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1077,6 +1077,12 @@ def add_vision_n_wa_pattern(self, value: int) -> None: def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + + def add_vision_sam_layers_count(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.SAM.BLOCK_COUNT, value) + + def add_vision_sam_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.SAM.EMBEDDING_LENGTH, value) # audio models def add_audio_projection_dim(self, value: int) -> None: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 92940668761..2cf8110d293 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -2,6 +2,8 @@ from typing import Sequence +from numpy.f2py.auxfuncs import throw_error + from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES @@ -1177,6 +1179,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ_FC: ( "model.connector.modality_projection.proj", # SmolVLM "model.vision.linear_proj.linear_proj", # cogvlm + "model.projector.layers", # Deepseek-OCR ), MODEL_TENSOR.V_MMPROJ_MLP: ( @@ -1195,6 +1198,7 @@ class TensorNameMap: "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 "model.vision.patch_embedding.cls_embedding", # cogvlm + "model.vision_model.embeddings.class_embedding", # Deepseek-OCR ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -1208,6 +1212,7 @@ class TensorNameMap: "visual.patch_embed.proj", # qwen2vl "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm + "model.vision_model.embeddings.patch_embedding", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_EMBD_POS: ( @@ -1220,10 +1225,19 @@ class TensorNameMap: "visual.pos_embed", # qwen3vl "model.vision.patch_embedding.position_embedding", # cogvlm ), + + MODEL_TENSOR.V_ENC_EMBD_IMGNL: ( + "model.image_newline", # Deepseek-OCR + ), + + MODEL_TENSOR.V_ENC_EMBD_VSEP: ( + "model.view_seperator", # Deepseek-OCR + ), MODEL_TENSOR.V_ENC_ATTN_QKV: ( "visual.blocks.{bid}.attn.qkv", # qwen3vl "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm + "model.vision_model.transformer.layers.{bid}.self_attn.qkv_proj", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_ATTN_Q: ( @@ -1236,6 +1250,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated + "model.vision_model.transformer.layers.{bid}.self_attn.q_proj", # Deepseek-OCR CLIP, generated ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1253,6 +1268,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated + "model.vision_model.transformer.layers.{bid}.self_attn.k_proj", # Deepseek-OCR CLIP, generated ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1270,6 +1286,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated + "model.vision_model.transformer.layers.{bid}.self_attn.v_proj", # Deepseek-OCR CLIP, generated ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1284,6 +1301,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm1", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm + "model.vision_model.transformer.layers.{bid}.layer_norm1", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1299,6 +1317,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.proj", # qwen2vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm + "model.vision_model.transformer.layers.{bid}.self_attn.out_proj", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1313,6 +1332,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm + "model.vision_model.transformer.layers.{bid}.layer_norm2", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1327,6 +1347,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) + "model.vision_model.transformer.layers.{bid}.mlp.fc1", # Deepseek-OCR CLIP "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm ), @@ -1349,6 +1370,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm + "model.vision_model.transformer.layers.{bid}.mlp.fc2", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1366,6 +1388,7 @@ class TensorNameMap: "vision_tower.ln_pre", # pixtral-hf "vision_encoder.ln_pre", # pixtral "vision_model.layernorm_pre", # llama4 + "model.vision_model.pre_layrnorm", # Deepseek-OCR CLIP ), MODEL_TENSOR.V_POST_NORM: ( @@ -1457,6 +1480,58 @@ class TensorNameMap: "model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl ), + MODEL_TENSOR.V_SAM_POS_EMBD: ( + "model.sam_model.pos_embed", + ), + + MODEL_TENSOR.V_SAM_PATCH_EMBD: ( + "model.sam_model.patch_embed.proj", + ), + + MODEL_TENSOR.V_SAM_PRE_NORM: ( + "model.sam_model.blocks.{bid}.norm1", # deepstack in qwen3vl + ), + + MODEL_TENSOR.V_SAM_POST_NORM: ( + "model.sam_model.blocks.{bid}.norm2", # deepstack in qwen3vl + ), + + MODEL_TENSOR.V_SAM_ATTN_POS_H: ( + "model.sam_model.blocks.{bid}.attn.rel_pos_h", + ), + + MODEL_TENSOR.V_SAM_ATTN_POS_W: ( + "model.sam_model.blocks.{bid}.attn.rel_pos_w", + ), + + MODEL_TENSOR.V_SAM_ATTN_QKV: ( + "model.sam_model.blocks.{bid}.attn.qkv", + ), + + MODEL_TENSOR.V_SAM_ATTN_OUT: ( + "model.sam_model.blocks.{bid}.attn.proj", + ), + + MODEL_TENSOR.V_SAM_MLP_LIN_1: ( + "model.sam_model.blocks.{bid}.mlp.lin1", + ), + + MODEL_TENSOR.V_SAM_MLP_LIN_2: ( + "model.sam_model.blocks.{bid}.mlp.lin2", + ), + + MODEL_TENSOR.V_SAM_NECK: ( + "model.sam_model.neck.{bid}", + ), + + MODEL_TENSOR.V_SAM_NET_2: ( + "model.sam_model.net_2", + ), + + MODEL_TENSOR.V_SAM_NET_3: ( + "model.sam_model.net_3", + ), + MODEL_TENSOR.V_MM_POST_FC_NORM: ( "model.vision.linear_proj.norm1", # cogvlm ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568df..ac3ab5cfa77 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1446,6 +1446,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b199e94628f..4daf3f230b5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1106,7 +1106,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); - cb(cur, "ffn_moe_weighted", il); + cb(experts, "ffn_moe_weighted", il); } ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 829f1e3c14f..79639c515eb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1562,12 +1562,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - if (!is_lite) { + if (!is_lite && !is_ocr) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + if (!is_ocr) { + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + } ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -1583,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; case 27: type = LLM_TYPE_16B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; @@ -4550,6 +4555,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { const bool is_lite = (hparams.n_layer == 27); + const bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -4575,6 +4581,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + if (is_ocr) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + + continue; + } + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); if (!is_lite) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 735c5d547f9..2634ab7c5ec 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2347,6 +2347,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "_" || t.first == "<|end_of_text|>" || t.first == "" // smoldocling + || t.first == "<|end▁of▁sentence|>" // deepseek-ocr ) { special_eog_ids.insert(t.second); if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 68f72f72bb6..f4a40d7d6e8 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -5,6 +5,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (model.name.find("ocr") != std::string::npos || model.name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -44,7 +45,38 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + const int n_embed_head = hparams.n_embd / hparams.n_head(); + const int ocr_rope_type = GGML_ROPE_TYPE_NEOX; + GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); + + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); + + GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + else { ggml_tensor * q = NULL; if (!is_lite) { q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 722b1a4948d..63d59055668 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -86,11 +86,12 @@ #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_IMAGE_NEWLINE "model.image_newline" +#define TN_IMAGE_SEPERATOR "model.view_seperator" #define TN_MM_INP_NORM "mm.input_norm.weight" #define TN_MM_INP_NORM_B "mm.input_norm.bias" #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 -#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 +#define TN_MM_PROJECTOR "mm.model.fc.%s" // idefics3, deepseekocr #define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1 #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) @@ -129,6 +130,20 @@ #define TN_TOK_BOI "v.boi" #define TN_TOK_EOI "v.eoi" +// deepseek-ocr +#define TN_SAM_POS_EMBD "v.sam.pos_embd" +#define TN_SAM_PATCH_EMBD "v.sam.patch_embd.%s" +#define TN_SAM_PRE_NORM "v.sam.blk.%d.pre_ln.%s" +#define TN_SAM_POST_NORM "v.sam.blk.%d.post_ln.%s" +#define TN_SAM_ATTN_POS_H "v.sam.blk.%d.attn.pos_h" +#define TN_SAM_ATTN_POS_W "v.sam.blk.%d.attn.pos_w" +#define TN_SAM_ATTN_QKV "v.sam.blk.%d.attn.qkv.%s" +#define TN_SAM_ATTN_OUT "v.sam.blk.%d.attn.out.%s" +#define TN_SAM_FFN_UP "v.sam.blk.%d.mlp.lin1.%s" +#define TN_SAM_FFN_DOWN "v.sam.blk.%d.mlp.lin2.%s" +#define TN_SAM_NECK "v.sam.neck.%d.%s" +#define TN_SAM_NET "v.sam.net_%d.%s" + // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -156,6 +171,7 @@ enum projector_type { PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, + PROJECTOR_TYPE_DEEPSEEKOCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -182,6 +198,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, + { PROJECTOR_TYPE_DEEPSEEKOCR,"deepseekocr"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d1423b67f98..82d0b46a47b 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -194,6 +194,8 @@ struct clip_hparams { int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; + bool crop_mode = false; + // audio int32_t n_mel_bins = 0; // whisper preprocessor int32_t proj_stack_factor = 0; // ultravox @@ -222,6 +224,23 @@ struct clip_hparams { warmup_image_size = n_tok_per_side * patch_size * cur_merge; // TODO: support warmup size for custom token numbers } + + // sam vit deepseek-ocr + std::vector global_attn_indices() const { + return { 2, 5, 8, 11 }; + } + + bool is_global_attn(int32_t layer) const { + const auto indices = global_attn_indices(); + + for (const auto & idx : indices) { + if (layer == idx) { + return true; + } + } + + return false; + } }; struct clip_layer { @@ -271,6 +290,10 @@ struct clip_layer { bool has_deepstack() const { return deepstack_fc1_w != nullptr; } + + // sam rel_pos + ggml_tensor * rel_pos_w = nullptr; + ggml_tensor * rel_pos_h = nullptr; }; struct clip_model { @@ -295,7 +318,8 @@ struct clip_model { ggml_tensor * post_ln_w; ggml_tensor * post_ln_b; - ggml_tensor * projection; // TODO: rename it to fc (fully connected layer) + ggml_tensor * fc_w; + ggml_tensor * fc_b; ggml_tensor * mm_fc_w; ggml_tensor * mm_fc_b; @@ -308,6 +332,7 @@ struct clip_model { ggml_tensor * mm_2_b = nullptr; ggml_tensor * image_newline = nullptr; + ggml_tensor * view_seperator = nullptr; // Yi type models with mlp+normalization projection ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 @@ -400,6 +425,11 @@ struct clip_model { ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_eoi = nullptr; + // deepseek ocr sam + ggml_tensor * patch_embed_proj_w = nullptr; + ggml_tensor * patch_embed_proj_b = nullptr; + ggml_tensor * pos_embed = nullptr; + bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A || proj_type == PROJECTOR_TYPE_VOXTRAL; @@ -409,6 +439,19 @@ struct clip_model { return proj_type == PROJECTOR_TYPE_ULTRAVOX || proj_type == PROJECTOR_TYPE_VOXTRAL; } + ggml_tensor * neck_0_w; + ggml_tensor * neck_1_w; + ggml_tensor * neck_1_b; + ggml_tensor * neck_2_w; + ggml_tensor * neck_3_w; + ggml_tensor * neck_3_b; + ggml_tensor * net_2; + ggml_tensor * net_3; + + int32_t n_sam_layers = 12; // used by deepseek-ocr sam encoder + + std::vector sam_layers; + }; struct clip_ctx { @@ -521,9 +564,9 @@ struct clip_graph { hparams(model.hparams), img(img), patch_size(hparams.patch_size), - n_patches_x(img.nx / patch_size), - n_patches_y(img.ny / patch_size), - n_patches(n_patches_x * n_patches_y), + n_patches_x(img.nx / patch_size), // sam 1024 / 16 = 64 + n_patches_y(img.ny / patch_size), // sam 1024 / 16 = 64 + n_patches(n_patches_x * n_patches_y), // sam 64 * 64 = 4096 n_embd(hparams.n_embd), n_head(hparams.n_head), d_head(n_embd / n_head), @@ -583,7 +626,7 @@ struct clip_graph { // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 const int scale_factor = model.hparams.n_merge; cur = build_patch_merge_permute(cur, scale_factor); - cur = ggml_mul_mat(ctx0, model.projection, cur); + cur = ggml_mul_mat(ctx0, model.fc_w, cur); } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) { // pixel unshuffle block @@ -619,6 +662,264 @@ struct clip_graph { return gf; } + ggml_tensor * build_sam_enc(ggml_tensor * inp_raw, + const int enc_image_size = 1024 + ) { + constexpr int enc_n_embd = 768; + constexpr int _depth = 12; + constexpr int enc_n_heads = 12; + constexpr int enc_d_heads = enc_n_embd / enc_n_heads; + // constexpr int _prompt_n_embd = 256; + constexpr int enc_patch_size = 16; + // constexpr int _window_size = 14; + + const int enc_n_patches = enc_image_size / enc_patch_size; // 64 + + ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd); + ggml_tensor * cur = nullptr; + + const auto tgt_size = inpL->ne[1]; + const auto str_size = model.pos_embed->ne[1]; + if (str_size != tgt_size) { + ggml_tensor * new_pos_embed = ggml_interpolate( + ctx0, + model.pos_embed, + tgt_size, + tgt_size, + enc_n_embd, + 1, + ggml_scale_mode::GGML_SCALE_MODE_BICUBIC + ); + new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 2,1,0,3)); + cur = ggml_add(ctx0, inpL, new_pos_embed); + } else { + cur = ggml_add(ctx0, inpL, model.pos_embed); + } + + // loop over layers + for (int il = 0; il < _depth; il++) { + auto & layer = model.sam_layers[il]; + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "enc_layer_inp_normed", il); + + const int64_t w0 = cur->ne[1]; + const int64_t h0 = cur->ne[2]; + + if (hparams.is_global_attn(il) == false) { + // local attention layer - apply window partition + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 + //cur = ggml_win_part(ctx0, cur, 14); + cur = window_partition(ctx0, cur, 14); + } + + const int64_t W = cur->ne[1]; + const int64_t H = cur->ne[2]; + + // self-attention + { + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + cur = ggml_add(ctx0, cur, layer.qkv_b); + const int B = cur->ne[3]; + + cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W * H, B); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2)); + + ggml_tensor * Qcur = + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 0); + Qcur = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, enc_n_heads, W * H, B); + Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); + Qcur = ggml_reshape_3d(ctx0, Qcur, enc_d_heads, W * H, B * enc_n_heads); + + ggml_tensor * Kcur = + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], cur->nb[3]); + Kcur = ggml_reshape_4d(ctx0, Kcur, enc_d_heads, enc_n_heads, W * H, B); + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + Kcur = ggml_reshape_3d(ctx0, Kcur, enc_d_heads, W * H, B * enc_n_heads); + + ggml_tensor * Vcur = + ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 2 * cur->nb[3]); + Vcur = ggml_reshape_4d(ctx0, Vcur, enc_d_heads, enc_n_heads, W * H, B); + Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); // transposed + Vcur = ggml_reshape_3d(ctx0, Vcur, W * H, enc_d_heads, B * enc_n_heads); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcur, Qcur); + + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads)); + + struct ggml_tensor * rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); + struct ggml_tensor * rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); + + struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); + + struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0, + ggml_mul_mat(ctx0, + rw, + ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), + 0, 2, 1, 3)); + struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); + + struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcur, KQ_soft_max); + + cur = ggml_reshape_4d( + ctx0, + ggml_cont(ctx0, ggml_permute(ctx0, ggml_reshape_4d(ctx0, KQV, enc_d_heads, W * H, enc_n_heads, B), + 0, 2, 1, 3)), + enc_n_embd, W, H, B); + + cur = ggml_mul_mat(ctx0, layer.o_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.o_b); + } + + if (hparams.is_global_attn(il) == false) { + // local attention layer - reverse window partition + cur = window_unpartition(ctx0, cur, w0, h0, 14); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + ggml_tensor * inpFF = cur; + + + cb(inpFF, "ffn_inp", il); + + // layernorm2 + cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "ffn_inp_normed", il); + + // ffn + cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, + layer.ff_down_b, hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + + // residual 2 + cur = ggml_add(ctx0, cur, inpFF); + cb(cur, "layer_out", il); + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); + + cur = ggml_conv_2d_sk_p0(ctx0, model.neck_0_w, cur); + + cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_1_w, model.neck_1_b, hparams.eps); + + cur = ggml_conv_2d_s1_ph(ctx0, model.neck_2_w, cur); + + cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_3_w, model.neck_3_b, hparams.eps); + + cur = ggml_conv_2d(ctx0, model.net_2, cur, 2,2,1,1, 1,1); + cur = ggml_conv_2d(ctx0, model.net_3, cur, 2,2,1,1, 1,1); + + ggml_build_forward_expand(gf, cur); + return cur; + } + + ggml_tensor * sam_layer_norm_2d(ggml_context * ctx0, + ggml_tensor * layer, + int n_channels, + ggml_tensor * w, + ggml_tensor * b, + float eps) { + // LayerNorm2d + // normalize along channel dimmension + // TODO: better implementation + layer = ggml_permute(ctx0, ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps), 2, 0, + 1, 3); + layer = ggml_cont(ctx0, layer); + + layer = + ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer), layer), + ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer)); + + return layer; + } + + ggml_cgraph * build_deepseek_ocr() { + //patch embedding + ggml_tensor * inp_raw = build_inp_raw(); + + + ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny)); + + ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); + + // FIXME remove n_patches is hardcoded + + // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) + global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3)); + int clip_n_patches = global_features_1->ne[1] * global_features_1->ne[2]; + + // flatten 2nd and 3rd dims + global_features_1 = ggml_reshape_2d(ctx0, global_features_1, global_features_1->ne[0], clip_n_patches); + + // remove CLS token + global_features_2 = ggml_view_2d(ctx0, global_features_2, + n_embd, clip_n_patches, + ggml_row_size(global_features_2->type, n_embd), 0); + + ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1); + global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd,clip_n_patches); + global_features = ggml_cont(ctx0, global_features); + global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); + global_features = ggml_add(ctx0, global_features, model.fc_b); + + global_features = build_global_local_features(ctx0,global_features); + global_features = ggml_cont(ctx0, ggml_permute(ctx0, global_features, 1, 0, 2, 3)); + ggml_build_forward_expand(gf, global_features); + return gf; + } + + // global_features: [n_dim, h*w] + // image_newline: [n_dim] + // view_separator: [n_dim] + + ggml_tensor * build_global_local_features(ggml_context * ctx0, + ggml_tensor * global_features) { + GGML_ASSERT(model.image_newline != nullptr); + GGML_ASSERT(model.view_seperator != nullptr); + + // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] + const auto h = static_cast(std::sqrt(static_cast(global_features->ne[1]))); + const auto w = h; + const auto n_dim = global_features->ne[0]; + ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, n_dim, h, w, 1); // (n_dim, w, h) + t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) + ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); + nl = ggml_repeat_4d(ctx0, nl, h, 1, n_dim, 1); // n_pos rows + + + // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] + t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim) + + t = ggml_reshape_2d(ctx0, t, n_dim, h* (h + 1)); // (n_dim, h*(w+1)) + + + // 5) append view_separator as an extra "token": + // view_separator: [n_dim] -> [n_dim, 1] + ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + + // concat along token dimension (dim=1): + t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) + + return t; + } + + + ggml_cgraph * build_pixtral() { const int n_merge = hparams.n_merge; @@ -1215,7 +1516,7 @@ struct clip_graph { norm_t, hparams.ffn_op, model.position_embeddings, - nullptr); + nullptr); // shape [1024, 16, 16] // remove CLS token cur = ggml_view_2d(ctx0, cur, @@ -1261,6 +1562,63 @@ struct clip_graph { return gf; } + ggml_tensor * build_dp_ocr_clip(ggml_tensor * patch_embeds) { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + ggml_tensor * inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds)); + + + inp = ggml_cont(ctx0,ggml_permute(ctx0, inp,2,1,0,3)); + inp = ggml_reshape_2d(ctx0, inp, n_embd, inp->ne[1]*inp->ne[2]*inp->ne[3]); + + ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings)); + + int n_pos = new_pos_embd->ne[1]; // +1 for [CLS] + const auto tgt_size = static_cast(std::sqrt(inp->ne[1])); + const auto src_size = static_cast(std::sqrt(n_pos - 1)); + + + if (tgt_size != src_size) { + //ggml_tensor * old_pos_embd = ggml_new_tensor_2d(ctx0, model.position_embeddings->type, model.position_embeddings->ne[0], str_size * str_size); + ggml_tensor * old_pos_embd = ggml_view_2d(ctx0, new_pos_embd, + new_pos_embd->ne[0], src_size * src_size, + ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0); + ggml_tensor * cls_tok = ggml_view_2d(ctx0, new_pos_embd, + new_pos_embd->ne[0], 1, + ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size); + new_pos_embd = ggml_interpolate(ctx0, + old_pos_embd, + tgt_size, + tgt_size, + new_pos_embd->ne[0], 1, GGML_SCALE_MODE_BICUBIC); + new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1); + //new_pos_embd = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embd, 2,1,0,3)); + new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1); + n_pos = tgt_size * tgt_size + 1; + } + + + + // add CLS token + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + //TODO : check norm type for dp-ocr-clip + norm_type norm_t = NORM_TYPE_NORMAL; + + // for selecting learned pos embd, used by ViT + ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); + ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); + + + ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, learned_pos_embd, + nullptr); // shape [1024, 16, 16] + + ggml_build_forward_expand(gf, cur); + + return cur; + } + ggml_cgraph * build_llama4() { GGML_ASSERT(model.class_embedding != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); @@ -2164,18 +2522,218 @@ struct clip_graph { return inpL; } + // attn: [q_h*q_w, k_h*k_w] + // rel_h: [q_h, q_w, k_h] + // rel_w: [q_h, q_w, k_w] + + static ggml_tensor * add_rel_pos_inplace( + ggml_context * ctx, + ggml_tensor * attn, + ggml_tensor * rel_w, + ggml_tensor * rel_h + ) { + const int k_w = rel_w->ne[0]; + const int k_h = rel_h->ne[0]; + const int q_w = rel_h->ne[1]; + const int q_h = rel_h->ne[2]; + + GGML_ASSERT(q_w == rel_w->ne[1]); + GGML_ASSERT(q_h == rel_w->ne[2]); + GGML_ASSERT(attn->ne[0] == k_h*k_w); + GGML_ASSERT(attn->ne[1] == q_h*q_w); + + ggml_tensor *attn_4d = ggml_reshape_4d(ctx, attn, k_w, k_h, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_h_4d = ggml_reshape_4d(ctx, rel_h, 1, k_h, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d + + ggml_tensor *rel_w_4d = ggml_reshape_4d(ctx, rel_w, k_w, 1, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d + + ggml_tensor * result = ggml_add_inplace(ctx, attn_4d, ggml_add_inplace(ctx, rel_h_rep, rel_w_rep)); + result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); + + + return result; + } + + + static ggml_tensor * get_rel_pos( + ggml_context * ctx, + ggml_tensor * rel_pos, // [L, C] + int q_size, + int k_size + ) { + const int64_t C = rel_pos->ne[0]; // channels + const int64_t L = rel_pos->ne[1]; // length + + //GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L); + + const auto max_rel_dist = 2*std::max(q_size, k_size) - 1; + ggml_tensor * rel_pos_resized = rel_pos; + + if (max_rel_dist != L) { + // Linear interpolation + const auto scale = L / static_cast(max_rel_dist); + ggml_tensor * indices = ggml_arange(ctx, 0.0f, static_cast(max_rel_dist), 1.0f); + indices = ggml_scale_inplace(ctx, indices, scale); + ggml_tensor * indices_floor= ggml_cast(ctx, ggml_floor(ctx, indices), GGML_TYPE_I32); + ggml_tensor * indices_ceil = ggml_cast(ctx, ggml_ceil(ctx, indices), GGML_TYPE_I32); + ggml_tensor * weights = ggml_sub(ctx, indices, indices_floor); + ggml_tensor * ws1 = ggml_scale_bias(ctx, weights, -1.0f, 1.0f); + rel_pos_resized = ggml_cont(ctx , ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)); // [C, L] for ggml_get_rows + ggml_tensor * rs1 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_floor), 1, 0, 2, 3)); // lower rows + rs1 = ggml_mul(ctx, rs1, ws1); // lower rows + ggml_tensor * rs2 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_ceil), 1, 0, 2, 3)); // upper rows + rs2 = ggml_mul(ctx, rs2, weights); // upper rows + rel_pos_resized = ggml_add(ctx,rs1, rs2); + } + + // ------------------------------------------------- + // 1) q_idx ← arange(0..q_size-1) [q_size] + // 2) k_idx ← arange(0..k_size-1) [k_size] + // ------------------------------------------------- + + // ggml_arange always returns FP32 tensor + ggml_tensor * q_coord = ggml_arange(ctx, 0.0f, static_cast(q_size), 1.0f); // [q_size] + ggml_tensor * k_coord = ggml_arange(ctx, 0.0f, static_cast(k_size), 1.0f); // [k_size] + ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, k_size, q_size); + + // broadcast reshape: + q_coord = ggml_cont(ctx, + ggml_repeat(ctx, + ggml_reshape_2d(ctx, q_coord, 1, q_size), // [q_size, 1] + rel + ) + ); // [q_size, k_size] + k_coord = ggml_cont(ctx, ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + + float q_scale = std::max((float)k_size/q_size, 1.0f); + float k_scale = std::max((float)q_size/k_size, 1.0f); + + // This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with + // the original implementation. + if (q_size != k_size) { + q_coord = ggml_scale_inplace(ctx, q_coord, q_scale); + k_coord = ggml_scale_inplace(ctx, k_coord, k_scale); + } + + // ------------------------------------------------- + // relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling + // ------------------------------------------------- + + rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size] + rel = ggml_scale_bias(ctx, rel, 1.0f, (k_size - 1.0f)*k_scale); // [q_size, k_size] + // Clamp to [0, L-1] range for valid indexing + rel = ggml_clamp(ctx, rel, 0.0f, static_cast(rel_pos->ne[1] - 1)); + + // ------------------------------------------------- + // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) + // ------------------------------------------------- + + ggml_tensor * idx_2d = ggml_cast(ctx, rel, GGML_TYPE_I32); // [q_size, k_size] + + // Gather from rel_pos → [qk, C] + // ------------------------------------------------- + + // flatten to 1D for ggml_get_rows + int qk = q_size * k_size; + ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] + ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] + + // ------------------------------------------------- + // Gather from rel_pos → [qk, C] + // ------------------------------------------------- + + ggml_tensor * out = ggml_reshape_3d(ctx, gathered, C, k_size, q_size); // [qk, C] + + + return out; // [q_size, k_size, C] + } + + // Implementation based on approach suggested by Acly + // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 + static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) { + auto [c, w, h, b] = x->ne; + // same as + // x = ggml_win_part(m, x, window); + // x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]); + + int64_t px = (window - w % window) % window; + int64_t py = (window - h % window) % window; + int64_t npw = (w + px) / window; + int64_t nph = (h + py) / window; + + if (px > 0 || py > 0) { + x = ggml_pad(ctx, x, 0, int(px), int(py), 0); + } + x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b); + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b); + return x; + } + + // Implementation based on approach suggested by Acly + // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 + static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) { + int64_t c = x->ne[0]; + // same as + // x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]); + // x = ggml_win_unpart(m, x, w, h, window); + + int64_t px = (window - w % window) % window; + int64_t py = (window - h % window) % window; + int64_t npw = (w + px) / window; + int64_t nph = (h + py) / window; + + int64_t b = x->ne[3] / (npw * nph); + x = ggml_reshape_4d(m, x, c * window, window, npw, nph * b); + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(m, x, c, w + px, h + py, b); + x = ggml_view_4d(m, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); + x = ggml_cont(m, x); + return x; + } + + // build the input after conv2d (inp_raw --> patches) + // returns tensor with shape [n_embd, n_patches] + ggml_tensor * build_enc_inp(ggml_tensor * inp_raw, + const int enc_patch_size, + const int enc_n_patches, + const int enc_n_embd) { + GGML_ASSERT(model.patch_embed_proj_w != nullptr); + GGML_ASSERT(model.patch_embed_proj_b != nullptr); + // Image to Patch Embedding. + // ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3] + // patch_embed_proj_w shape = [768, 3, 16, 16] + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embed_proj_w, inp_raw, enc_patch_size, enc_patch_size, 0, 0, + 1, 1); // [64, 64, 768] + inp = ggml_reshape_2d(ctx0, inp, enc_n_patches * enc_n_patches, enc_n_embd); // [4096, 768] + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [768, 4096] + inp = ggml_add(ctx0, inp, model.patch_embed_proj_b); + inp = ggml_cont(ctx0, inp); + inp = ggml_reshape_4d(ctx0, inp, enc_n_embd, enc_n_patches, enc_n_patches, 1); + cb(inp, "enc_patch_bias", -1); + return inp; + } + // build the input after conv2d (inp_raw --> patches) // returns tensor with shape [n_embd, n_patches] ggml_tensor * build_inp() { - ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + // Image to Patch Embedding. + ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3] + // sam patch_embeddings_0 shape = [768, 3, 16, 16] + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); // sam shape = [64, 64, 768] + inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); // sam shape = [4096, 768] + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // sam shape = [768, 4096] if (model.patch_bias) { + // sam patch_bias shape = [768] inp = ggml_add(ctx0, inp, model.patch_bias); cb(inp, "patch_bias", -1); } - return inp; + return inp; // shape = [n_embd, n_patches] same as [768, 4096] } ggml_tensor * build_inp_raw(int channels = 3) { @@ -2524,6 +3082,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_cogvlm(); } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + res = graph.build_deepseek_ocr(); + } break; default: { res = graph.build_llava(); @@ -2849,6 +3411,13 @@ struct clip_model_loader { hparams.ffn_op = FFN_GELU_ERF; log_ffn_op = "gelu_erf"; // temporary solution for logging } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + hparams.patch_size = 16; + hparams.image_size = 1024; + hparams.warmup_image_size = 1024; + hparams.crop_mode = false; + } break; default: break; } @@ -3141,7 +3710,7 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_IDEFICS3: { - model.projection = get_tensor(TN_MM_PROJECTOR); + model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: @@ -3214,13 +3783,13 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_LLAMA4: { - model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); } break; case PROJECTOR_TYPE_COGVLM: { - model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + model.mm_model_proj = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); model.mm_post_fc_norm_w = get_tensor(string_format(TN_MM_POST_FC_NORM, "weight")); model.mm_post_fc_norm_b = get_tensor(string_format(TN_MM_POST_FC_NORM, "bias")); model.mm_h_to_4h_w = get_tensor(string_format(TN_MM_H_TO_4H, "weight")); @@ -3236,6 +3805,45 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + model.pos_embed = get_tensor(TN_SAM_POS_EMBD); + model.patch_embed_proj_w = get_tensor(string_format(TN_SAM_PATCH_EMBD, "weight")); + model.patch_embed_proj_b = get_tensor(string_format(TN_SAM_PATCH_EMBD, "bias")); + model.sam_layers.resize(model.n_sam_layers); + for (int il = 0; il < model.n_sam_layers; ++il) { + auto & layer = model.sam_layers[il]; + layer.qkv_w = get_tensor(string_format(TN_SAM_ATTN_QKV, il, "weight")); + layer.qkv_b = get_tensor(string_format(TN_SAM_ATTN_QKV, il, "bias")); + layer.o_w = get_tensor(string_format(TN_SAM_ATTN_OUT, il, "weight")); + layer.o_b = get_tensor(string_format(TN_SAM_ATTN_OUT, il, "bias")); + layer.ln_1_w = get_tensor(string_format(TN_SAM_PRE_NORM, il, "weight")); + layer.ln_1_b = get_tensor(string_format(TN_SAM_PRE_NORM, il, "bias")); + layer.ln_2_w = get_tensor(string_format(TN_SAM_POST_NORM, il, "weight")); + layer.ln_2_b = get_tensor(string_format(TN_SAM_POST_NORM, il, "bias")); + layer.rel_pos_h = get_tensor(string_format(TN_SAM_ATTN_POS_H, il)); + layer.rel_pos_w = get_tensor(string_format(TN_SAM_ATTN_POS_W, il)); + layer.ff_up_w = get_tensor(string_format(TN_SAM_FFN_UP, il, "weight")); + layer.ff_up_b = get_tensor(string_format(TN_SAM_FFN_UP, il, "bias")); + layer.ff_down_w = get_tensor(string_format(TN_SAM_FFN_DOWN, il, "weight")); + layer.ff_down_b = get_tensor(string_format(TN_SAM_FFN_DOWN, il, "bias")); + } + model.neck_0_w = get_tensor(string_format(TN_SAM_NECK, 0, "weight")); + model.neck_1_b = get_tensor(string_format(TN_SAM_NECK, 1, "bias")); + model.neck_1_w = get_tensor(string_format(TN_SAM_NECK, 1, "weight")); + model.neck_2_w = get_tensor(string_format(TN_SAM_NECK, 2, "weight")); + model.neck_3_b = get_tensor(string_format(TN_SAM_NECK, 3, "bias")); + model.neck_3_w = get_tensor(string_format(TN_SAM_NECK, 3, "weight")); + model.net_2 = get_tensor(string_format(TN_SAM_NET, 2, "weight")); + model.net_3 = get_tensor(string_format(TN_SAM_NET, 3, "weight")); + } + model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); + model.view_seperator = get_tensor(TN_IMAGE_SEPERATOR, false); + model.fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); + model.fc_b = get_tensor(string_format(TN_MM_PROJECTOR, "bias")); + + + break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -4192,6 +4800,59 @@ struct llava_uhd { } }; +static std::vector> ds_build_target_ratios(const int min_num, const int max_num) { + std::vector> ratios; + for (int n = min_num; n <= max_num; ++n) { + for (int i = 1; i <= n; ++i) { + for (int j = 1; j <= n; ++j) { + if (const int blocks = i * j; blocks >= min_num && blocks <= max_num) { + ratios.emplace_back(i, j); // (cols, rows) + } + } + } + } + + // sort by total blocks like in Python (key=lambda x: x[0] * x[1]) + std::sort(ratios.begin(), ratios.end(), + [](const auto &a, const auto &b) { + return (a.first * a.second) < (b.first * b.second); + }); + + // optional: dedup + ratios.erase(std::unique(ratios.begin(), ratios.end()), ratios.end()); + return ratios; +} + +static std::pair ds_find_closest_aspect_ratio( + const float aspect_ratio, + const std::vector> &target_ratios, + const int width, + const int height, + const int image_size +) { + float best_diff = std::numeric_limits::infinity(); + std::pair best_ratio = {1, 1}; + const float area = static_cast(width) * static_cast(height); + + for (const auto &r : target_ratios) { + const float target_ar = static_cast(r.first) / static_cast(r.second); + + if (const float diff = std::fabs(aspect_ratio - target_ar); diff < best_diff) { + best_diff = diff; + best_ratio = r; + } else if (diff == best_diff) { + // same as python: prefer this ratio if the image area is “large enough” + if (const float needed_area = 0.5f * image_size * image_size * r.first * r.second; area > needed_area) { + best_ratio = r; + } + } + } + + return best_ratio; // (cols, rows) +} + + + // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { @@ -4406,6 +5067,169 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + if (!params.crop_mode) { + /* Native Resolution (Tiny/Small/Base/Large) */ + + const int native_resolutions[] = { + 512 /* tiny */, 640 /* small */, 1024 /* base */, 1280 /* large */ + }; + // original image size + const int orig_w = original_size.width; + const int orig_h = original_size.height; + const int orig_area = orig_h * orig_w; + + // mode selection logic (find most suitable resolution) + int mode_i = 0; + int min_diff = orig_area; + + for (int i = 0; i < 4; i++) { + int r = native_resolutions[i]; + if (std::abs(orig_area - r*r) < min_diff) { + mode_i = i; + min_diff = std::abs(orig_area - r*r); + } + } + + const int image_size = native_resolutions[mode_i]; + + if (mode_i < 2) { + // TINY/SMALL MODE: Direct resize (no slicing) + // Just resize the image to image_size × image_size + + clip_image_u8_ptr resized_img(clip_image_u8_init()); + img_tool::resize(*img, *resized_img, + clip_image_size{image_size, image_size}, + img_tool::RESIZE_ALGO_BICUBIC); // Match PIL default + + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*resized_img, *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + + res_imgs->grid_x = 1; + res_imgs->grid_y = 1; + } + else { + // BASE/LARGE MODE: Resize with aspect ratio + padding + // Resize maintaining aspect ratio, then pad to square + + float scale = std::min( + static_cast(image_size) / orig_w, + static_cast(image_size) / orig_h + ); + int new_w = static_cast(orig_w * scale); + int new_h = static_cast(orig_h * scale); + + clip_image_u8_ptr scaled_img(clip_image_u8_init()); + img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h}, + img_tool::RESIZE_ALGO_BICUBIC); + + // Use mean color for padding + unsigned char pad_r = static_cast(params.image_mean[0] * 255.0f); + unsigned char pad_g = static_cast(params.image_mean[1] * 255.0f); + unsigned char pad_b = static_cast(params.image_mean[2] * 255.0f); + + // Step 2: Pad to image_size × image_size (center padding) + clip_image_u8_ptr padded_img(clip_image_u8_init()); + padded_img->nx = image_size; + padded_img->ny = image_size; + padded_img->buf.resize(image_size * image_size * 3); // black padding + + // Fill with mean color + for (int i = 0; i < image_size * image_size; ++i) { + padded_img->buf[i * 3 + 0] = pad_r; + padded_img->buf[i * 3 + 1] = pad_g; + padded_img->buf[i * 3 + 2] = pad_b; + } + + // Calculate padding offsets (center the image) + int pad_x = (image_size - new_w) / 2; + int pad_y = (image_size - new_h) / 2; + + // Copy scaled image into padded canvas + for (int y = 0; y < new_h; ++y) { + for (int x = 0; x < new_w; ++x) { + int src_idx = (y * new_w + x) * 3; + int dst_idx = ((y + pad_y) * image_size + (x + pad_x)) * 3; + padded_img->buf[dst_idx + 0] = scaled_img->buf[src_idx + 0]; + padded_img->buf[dst_idx + 1] = scaled_img->buf[src_idx + 1]; + padded_img->buf[dst_idx + 2] = scaled_img->buf[src_idx + 2]; + } + } + + // Step 3: Normalize and output + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*padded_img, *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + + res_imgs->grid_x = 1; + res_imgs->grid_y = 1; + } + } + else { + /* Dynamic Resolution (Gundam/Gundam-M) */ + + // configurable, or read from params + const int min_num = 2; + const int max_num = 9; + const int image_size = params.image_size; // typically 640 + // const bool use_thumbnail = true; // mimic python's use_thumbnail + + // original image size + const int orig_w = original_size.width; + const int orig_h = original_size.height; + + // 1) build candidate grids (cols, rows) + auto target_ratios = ds_build_target_ratios(min_num, max_num); + + // 2) pick the grid that best matches the original aspect ratio + const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); + auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); + const int grid_cols = best.first; // how many tiles horizontally + const int grid_rows = best.second; // how many tiles vertically + + // 3) compute the target (forced) size — python did: + // target_width = image_size * cols + // target_height = image_size * rows + const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows }; + + // 4) prepare slice instructions, same style as the idefics3 branch + llava_uhd::slice_instructions instructions; + instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global + instructions.refined_size = refined_size; + instructions.grid_size = clip_image_size{ grid_cols, grid_rows }; + + // in deepseek python they always produce *full* 640x640 blocks, + // so we can do a simple double loop over rows/cols: + for (int r = 0; r < grid_rows; ++r) { + for (int c = 0; c < grid_cols; ++c) { + const int x = c * image_size; + const int y = r * image_size; + + instructions.slices.push_back(llava_uhd::slice_coordinates{ + /* x */ x, + /* y */ y, + /* size */ clip_image_size{ image_size, image_size } + }); + } + } + + // 5) run the actual slicing (this should: resize to refined_size, then crop every slice) + auto imgs = llava_uhd::slice_image(img, instructions); + + // 7) cast & normalize like the idefics3 branch + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + // keep the grid info — the model may need to know how to reassemble / attend + res_imgs->grid_x = grid_cols; + res_imgs->grid_y = grid_rows; + } + break; + default: LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type()); @@ -4587,6 +5411,15 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { n_patches += 2; // for BOI and EOI token embeddings } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + int x_patch = img->nx / (params.patch_size); + + n_patches += x_patch + 1; + n_patches = 1280; + + + } break; default: GGML_ABORT("unsupported projector type"); } @@ -4921,6 +5754,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima { // do nothing } break; + case PROJECTOR_TYPE_DEEPSEEKOCR: + { + } break; case PROJECTOR_TYPE_LLAMA4: { // set the 2D positions @@ -5013,7 +5849,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_GEMMA3: return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: - return ctx->model.projection->ne[1]; + return ctx->model.fc_w->ne[1]; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: return ctx->model.mm_2_w->ne[1]; @@ -5028,6 +5864,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; + case PROJECTOR_TYPE_DEEPSEEKOCR: + return ctx->model.fc_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } @@ -5058,6 +5896,10 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; } +bool clip_is_deepseekocr(const struct clip_ctx * ctx) { + return ctx->proj_type() == PROJECTOR_TYPE_DEEPSEEKOCR; +} + bool clip_has_vision_encoder(const struct clip_ctx * ctx) { return ctx->model.modality == CLIP_MODALITY_VISION; } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 3e4c985f117..458ee98fc78 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -105,6 +105,8 @@ bool clip_is_glm(const struct clip_ctx * ctx); bool clip_is_qwen2vl(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); +bool clip_is_deepseekocr(const struct clip_ctx * ctx); + bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 3e19e95958a..5e6cc79f379 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -222,14 +222,20 @@ static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { bool add_bos = ctx.chat_history.empty(); - auto formatted_chat = chat_add_and_format(ctx, msg); - LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); mtmd_input_text text; - text.text = formatted_chat.c_str(); + text.text = msg.content.c_str(); text.add_special = add_bos; text.parse_special = true; + std::string formatted_chat; + + if (!mtmd_is_deepseekocr(ctx.ctx_vision.get())) { + formatted_chat = chat_add_and_format(ctx, msg); + LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); + text.text = formatted_chat.c_str(); + } + if (g_is_interrupted) return 0; mtmd::input_chunks chunks(mtmd_input_chunks_init()); @@ -312,8 +318,18 @@ int main(int argc, char ** argv) { if (is_single_turn) { g_is_generating = true; if (params.prompt.find(mtmd_default_marker()) == std::string::npos) { - for (size_t i = 0; i < params.image.size(); i++) { - params.prompt += mtmd_default_marker(); + if (mtmd_is_deepseekocr(ctx.ctx_vision.get())) { + std::string image_tokens = ""; + for (size_t i = 0; i < params.image.size(); i++) { + image_tokens += mtmd_default_marker(); + image_tokens += '\n'; + } + params.prompt = image_tokens + params.prompt; + } + else { + for (size_t i = 0; i < params.image.size(); i++) { + params.prompt += mtmd_default_marker(); + } } } common_chat_msg msg; @@ -332,6 +348,11 @@ int main(int argc, char ** argv) { } } else { + if (mtmd_is_deepseekocr(ctx.ctx_vision.get())) { + LOG_ERR("\n DeepSeek-OCR doesn't support chat mode."); + return 1; + } + LOG("\n Running in chat mode, available commands:"); if (mtmd_support_vision(ctx.ctx_vision.get())) { LOG("\n /image load an image"); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index e5991377699..994013bea91 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -810,7 +810,8 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) - || clip_is_glm(ctx_clip)) { + || clip_is_glm(ctx_clip) + || clip_is_deepseekocr(ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) { @@ -863,6 +864,10 @@ int mtmd_get_audio_bitrate(mtmd_context * ctx) { return 16000; // 16kHz } +bool mtmd_is_deepseekocr(mtmd_context * ctx) { + return ctx->ctx_v && clip_is_deepseekocr(ctx->ctx_v); +} + // // public API functions // diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 775fba6215c..99fdcd46501 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -117,6 +117,9 @@ MTMD_API bool mtmd_support_audio(mtmd_context * ctx); // return -1 if audio is not supported MTMD_API int mtmd_get_audio_bitrate(mtmd_context * ctx); +// whether the current model is DeepSeek-OCR +MTMD_API bool mtmd_is_deepseekocr(mtmd_context * ctx); + // mtmd_bitmap // // if bitmap is image: