Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
43a130b
mtmd: llama.cpp DeepSeekOCR support
sfallah Nov 14, 2025
b6b9f02
loading sam tensors
sfallah Nov 14, 2025
85c7cda
mtmd: fix vision model processing
bluebread Nov 15, 2025
578c8d7
Merge pull request #1 from bluebread/sf/deepseek-ocr
sfallah Nov 15, 2025
2aab52e
deepseek-ocr clip-vit model impl
sfallah Nov 15, 2025
eab28ed
mtmd: add DeepSeek-OCR LM support with standard attention
bluebread Nov 15, 2025
7630587
mtmd: successfully runs DeepSeek-OCR LM in llama-cli
bluebread Nov 16, 2025
2de3436
mtmd: Fix RoPE type for DeepSeek-OCR LM.
bluebread Nov 17, 2025
e8b2610
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread Nov 17, 2025
97e0907
loading LM
sfallah Nov 17, 2025
13dc6fb
Merge branch 'sf/deepseek-ocr' into sf/deepseek-ocr
sfallah Nov 17, 2025
b32bb5e
Merge pull request #2 from bluebread/sf/deepseek-ocr
sfallah Nov 17, 2025
790bbb9
sam warmup working
sfallah Nov 17, 2025
cec9a5c
sam erroneous return corrected
sfallah Nov 17, 2025
8b3d319
clip-vit: corrected cls_embd concat
sfallah Nov 17, 2025
1e08157
clip-vit: model convert qkv_proj split
sfallah Nov 17, 2025
331cea8
corrected combining of image encoders' results
sfallah Nov 18, 2025
6c0715b
fix: update callback for ffn_moe_weighted and add callback for attn_o…
bluebread Nov 18, 2025
a65ddf5
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread Nov 18, 2025
63a042f
concat image_newline and image_seperator tokens
sfallah Nov 18, 2025
89afda8
visual_model warmup (technically) works
sfallah Nov 18, 2025
88032f4
window partitioning using standard ggml ops
sfallah Nov 20, 2025
1268dc3
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread Nov 20, 2025
68b206b
sam implementation without using CPU only ops
sfallah Nov 21, 2025
8bce66d
clip: fixed warnings
bluebread Nov 21, 2025
5e6cf3c
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread Nov 21, 2025
7e9fbec
mtmd: fix get_rel_pos
bluebread Nov 21, 2025
0f5587d
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread Nov 21, 2025
7b8d735
mtmd: fixed the wrong scaler for get_rel_pos
bluebread Nov 21, 2025
86f111f
image encoding technically works but the output can't be checked sing…
sfallah Nov 21, 2025
effe669
mtmd: minor changed
bluebread Nov 22, 2025
f8f66a1
Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into s…
bluebread Nov 22, 2025
3fcfc3a
Merge pull request #3 from bluebread/sf/deepseek-ocr
sfallah Nov 22, 2025
ee8a148
mtmd: add native resolution support
bluebread Nov 22, 2025
4cfa15f
- image encoding debugged
sfallah Nov 22, 2025
3f71188
mtmd: correct token order
bluebread Nov 23, 2025
a594990
Merge pull request #5 from bluebread/dsocr-debug
sfallah Nov 23, 2025
6dfda99
Merge branch 'sf/deepseek-ocr' into sf/deepseek-ocr
sfallah Nov 23, 2025
7941f5d
Merge pull request #4 from bluebread/sf/deepseek-ocr
sfallah Nov 23, 2025
206f8ab
- dynamic resizing
sfallah Nov 23, 2025
40e7e6e
mtmd: quick fix token order
bluebread Nov 24, 2025
81533e4
mtmd: fix danling pointer
bluebread Nov 24, 2025
8810940
Merge pull request #6 from bluebread/sf/deepseek-ocr
sfallah Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 140 additions & 18 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ def load_hparams(dir_model: Path, is_mistral_format: bool):
if "thinker_config" in config:
# rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"]
if "language_config" in config:
# rename for DeepSeekOCR
config["text_config"] = config["language_config"]
return config

@classmethod
Expand Down Expand Up @@ -1442,7 +1445,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]

n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "layers"]

has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
Expand Down Expand Up @@ -1488,13 +1491,28 @@ def __init__(self, *args, **kwargs):
# TODO @ngxson : this is a hack to support both vision and audio encoders
have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder
self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True)
# FIXME: DeepseekOCRVisionModel specific hack
if self.block_count is None:
if isinstance(self, DeepseekOCRVisionModel):
clip_block_count = self.hparams['layers']
if clip_block_count is not None:
self.block_count = clip_block_count
if self.block_count is None:
raise KeyError(f"could not find block count using any of: {self.n_block_keys}")
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)

# load preprocessor config
self.preprocessor_config = {}
if not self.is_mistral_format:
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
# check if preprocessor_config.json exists
if (self.dir_model / "preprocessor_config.json").is_file():
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
else:
# try "processing_config" file if exists
if (self.dir_model / "processing_config.json").is_file():
with open(self.dir_model / "processing_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)

def get_vision_config(self) -> dict[str, Any] | None:
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
Expand Down Expand Up @@ -5770,6 +5788,97 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [] # skip other tensors

@ModelBase.register("DeepseekOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

proc_fname = self.dir_model / "processor_config.json"

if proc_fname.is_file():
with open(proc_fname, "r") as f:
self.preprocessor_config = json.load(f)


def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.DEEPSEEKOCR)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
self.gguf_writer.add_vision_use_gelu(True)
# calculate proj_scale_factor (used by tinygemma3 test model)
image_seq_length = self.preprocessor_config.get("image_seq_length", 256)
n_per_side = int(image_seq_length ** 0.5)
image_size = self.hparams["image_size"]
patch_size = self.hparams["patch_size"]
proj_scale_factor = (image_size // patch_size) // n_per_side
if proj_scale_factor > 0 and proj_scale_factor != 4:
# we only need to write this if it's not the default value
# in this case, we are converting a test model
self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor)

# SAM configuration
sam_hparams = hparams['sam']
self.gguf_writer.add_vision_sam_layers_count(sam_hparams['layers'])
self.gguf_writer.add_vision_sam_embedding_length(sam_hparams['width'])

def get_vision_config(self) -> dict[str, Any]:
vision_config: dict[str, Any] | None = self.global_config.get("vision_config")

if not vision_config:
raise ValueError("DeepseekOCR model requires 'vision_config' in the model configuration, but it was not found")

vision_config['sam'] = vision_config['width']['sam_vit_b']
vision_config.update(vision_config['width']['clip-l-14-224'])
vision_config['hidden_size'] = vision_config['width']
vision_config['num_heads'] = vision_config['heads']
vision_config['intermediate_size'] = vision_config['heads'] * 4

return vision_config


def tensor_force_quant(self, name, new_name, bid, n_dims):
# related to https://github.com/ggml-org/llama.cpp/issues/13025
if "input_projection" in name:
return gguf.GGMLQuantizationType.F16
if ".embeddings." in name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Only process vision-related tensors, skip language model tensors
# Vision components: sam_model, vision_model, projector, image_newline, view_seperator
# Language model components to skip: lm_head, embed_tokens, layers, norm
if name.startswith(("lm_head.", "model.embed_tokens.", "model.layers.", "model.norm.")):
return []

if ".attn.rel_pos_h" in name or ".attn.rel_pos_w" in name:
return [(self.map_tensor_name(name, try_suffixes=("",)), data_torch)]

if name.startswith("model.vision_model.transformer.layers."):
# process visual tensors
# split QKV tensors if needed
if ".qkv_proj." in name:
if data_torch.ndim == 2: # weight
c3, _ = data_torch.shape
else: # bias
c3 = data_torch.shape[0]
assert c3 % 3 == 0
c = c3 // 3
wq = data_torch[:c]
wk = data_torch[c: c * 2]
wv = data_torch[c * 2:]
return [
(self.map_tensor_name(name.replace("qkv", "q")), wq),
(self.map_tensor_name(name.replace("qkv", "k")), wk),
(self.map_tensor_name(name.replace("qkv", "v")), wv),
]
else:
return [(self.map_tensor_name(name), data_torch)]

return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("Gemma3nForConditionalGeneration")
class Gemma3NModel(Gemma3Model):
Expand Down Expand Up @@ -6943,6 +7052,7 @@ def prepare_tensors(self):
@ModelBase.register(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekOCRForCausalLM",
"KimiVLForConditionalGeneration",
)
class DeepseekV2Model(TextModel):
Expand Down Expand Up @@ -7003,52 +7113,64 @@ def set_vocab(self):
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")

def set_gguf_parameters(self):
is_ocr = (self.hparams["num_hidden_layers"] == 12)

# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1
if is_ocr:
self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0)
self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6)
else:
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1

super().set_gguf_parameters()
hparams = self.hparams
kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512
routed_scaling_factor = hparams.get("routed_scaling_factor", 1.0)
norm_topk_prob = hparams.get("norm_topk_prob", False)
scoring_func = hparams.get("scoring_func", "softmax")

self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
if "kv_lora_rank" in hparams and hparams["kv_lora_rank"] is not None:
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)

# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
if not is_ocr:
self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(kv_lora_rank)
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)

if hparams["scoring_func"] == "sigmoid":
if scoring_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["scoring_func"] == "softmax":
elif scoring_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")

self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
raise ValueError(f"Unsupported scoring_func value: {scoring_func}")

rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
if "vision_" in name or "multi_modal_projector" in name \
or "image_newline" in name or "model.projector" in name or "sam_model" in name or "view_seperator" in name:
return []

if name.startswith("language_model."):
Expand Down
18 changes: 9 additions & 9 deletions examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,19 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
}
}
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
LOG(" [\n");
LOG(" [\n");
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
if (i2 == n && ne[2] > 2*n) {
LOG(" ..., \n");
LOG(" ..., \n");
i2 = ne[2] - n;
}
LOG(" [\n");
LOG(" [\n");
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
if (i1 == n && ne[1] > 2*n) {
LOG(" ..., \n");
LOG(" ..., \n");
i1 = ne[1] - n;
}
LOG(" [");
LOG(" [");
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
if (i0 == n && ne[0] > 2*n) {
LOG("..., ");
Expand All @@ -98,10 +98,10 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
}
LOG("],\n");
}
LOG(" ],\n");
LOG(" ],\n");
}
LOG(" ]\n");
LOG(" sum = %f\n", sum);
LOG(" ]\n");
LOG(" sum = %f\n", sum);
}

// TODO: make this abort configurable/optional?
Expand Down Expand Up @@ -136,7 +136,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
}

LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
LOG("%s: %16s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
t->name, ggml_type_name(t->type), ggml_op_desc(t),
src0->name, ggml_ne_string(src0).c_str(),
src1 ? src1_str : "",
Expand Down
54 changes: 54 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ class Attention:
class Projector:
SCALE_FACTOR = "clip.vision.projector.scale_factor"

class SAM:
BLOCK_COUNT = "clip.vision.sam.block_count"
EMBEDDING_LENGTH = "clip.vision.sam.embedding_length"

class ClipAudio:
NUM_MEL_BINS = "clip.audio.num_mel_bins"
EMBEDDING_LENGTH = "clip.audio.embedding_length"
Expand Down Expand Up @@ -664,6 +668,22 @@ class MODEL_TENSOR(IntEnum):
V_MM_GATE = auto() # cogvlm
V_TOK_BOI = auto() # cogvlm
V_TOK_EOI = auto() # cogvlm
V_SAM_POS_EMBD = auto() # Deepseek-OCR
V_SAM_PATCH_EMBD = auto() # Deepseek-OCR
V_SAM_PRE_NORM = auto() # Deepseek-OCR
V_SAM_POST_NORM = auto() # Deepseek-OCR
V_SAM_ATTN_POS_H = auto() # Deepseek-OCR
V_SAM_ATTN_POS_W = auto() # Deepseek-OCR
V_SAM_ATTN_QKV = auto() # Deepseek-OCR
V_SAM_ATTN_OUT = auto() # Deepseek-OCR
V_SAM_MLP_LIN_1 = auto() # Deepseek-OCR
V_SAM_MLP_LIN_2 = auto() # Deepseek-OCR
V_SAM_NECK = auto() # Deepseek-OCR
V_SAM_NET_2 = auto() # Deepseek-OCR
V_SAM_NET_3 = auto() # Deepseek-OCR
V_ENC_EMBD_IMGNL = auto() # Deepseek-OCR
V_ENC_EMBD_VSEP = auto() # Deepseek-OCR

# audio (mtmd)
A_ENC_EMBD_POS = auto()
A_ENC_CONV1D = auto()
Expand Down Expand Up @@ -1030,6 +1050,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_MM_GATE: "mm.gate",
MODEL_TENSOR.V_TOK_BOI: "v.boi",
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
# DeepSeek-OCR sam_model
MODEL_TENSOR.V_SAM_POS_EMBD: "v.sam.pos_embd",
MODEL_TENSOR.V_SAM_PATCH_EMBD: "v.sam.patch_embd",
MODEL_TENSOR.V_SAM_PRE_NORM: "v.sam.blk.{bid}.pre_ln",
MODEL_TENSOR.V_SAM_POST_NORM: "v.sam.blk.{bid}.post_ln",
MODEL_TENSOR.V_SAM_ATTN_POS_H: "v.sam.blk.{bid}.attn.pos_h",
MODEL_TENSOR.V_SAM_ATTN_POS_W: "v.sam.blk.{bid}.attn.pos_w",
MODEL_TENSOR.V_SAM_ATTN_QKV: "v.sam.blk.{bid}.attn.qkv",
MODEL_TENSOR.V_SAM_ATTN_OUT: "v.sam.blk.{bid}.attn.out",
MODEL_TENSOR.V_SAM_MLP_LIN_1: "v.sam.blk.{bid}.mlp.lin1",
MODEL_TENSOR.V_SAM_MLP_LIN_2: "v.sam.blk.{bid}.mlp.lin2",
MODEL_TENSOR.V_SAM_NECK: "v.sam.neck.{bid}",
MODEL_TENSOR.V_SAM_NET_2: "v.sam.net_2",
MODEL_TENSOR.V_SAM_NET_3: "v.sam.net_3",
MODEL_TENSOR.V_ENC_EMBD_IMGNL: "model.image_newline", # Deepseek-OCR
MODEL_TENSOR.V_ENC_EMBD_VSEP: "model.view_seperator", # Deepseek-OCR
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
Expand Down Expand Up @@ -1066,6 +1102,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_EMBD_IMGNL,
MODEL_TENSOR.V_ENC_EMBD_VSEP,
MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_ATTN_QKV,
MODEL_TENSOR.V_ENC_ATTN_Q,
Expand Down Expand Up @@ -1108,6 +1146,19 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_MM_GATE,
MODEL_TENSOR.V_TOK_BOI,
MODEL_TENSOR.V_TOK_EOI,
MODEL_TENSOR.V_SAM_POS_EMBD,
MODEL_TENSOR.V_SAM_PATCH_EMBD,
MODEL_TENSOR.V_SAM_PRE_NORM,
MODEL_TENSOR.V_SAM_POST_NORM,
MODEL_TENSOR.V_SAM_ATTN_POS_H,
MODEL_TENSOR.V_SAM_ATTN_POS_W,
MODEL_TENSOR.V_SAM_ATTN_QKV,
MODEL_TENSOR.V_SAM_ATTN_OUT,
MODEL_TENSOR.V_SAM_MLP_LIN_1,
MODEL_TENSOR.V_SAM_MLP_LIN_2,
MODEL_TENSOR.V_SAM_NECK,
MODEL_TENSOR.V_SAM_NET_2,
MODEL_TENSOR.V_SAM_NET_3,
# audio
MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_CONV1D,
Expand Down Expand Up @@ -2247,7 +2298,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
Expand Down Expand Up @@ -3207,6 +3260,7 @@ class VisionProjectorType:
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
DEEPSEEKOCR = "deepseekocr"


# Items here are (block size, type size)
Expand Down
Loading