From 1e4fd1944634e0ab48f06c37a3a623d2ea593c60 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sun, 26 Oct 2025 19:18:15 +0800 Subject: [PATCH 1/4] support qwen3vl series. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Thireus ☠ Co-authored-by: yairpatch Co-authored-by: LETS-BEE --- convert_hf_to_gguf.py | 233 +++++++++++++++++++++- ggml/include/ggml.h | 1 + ggml/src/ggml-cpu/ops.cpp | 34 ++-- ggml/src/ggml-cuda/rope.cu | 45 +++-- gguf-py/gguf/constants.py | 51 +++++ gguf-py/gguf/gguf_writer.py | 6 + gguf-py/gguf/tensor_mapping.py | 15 ++ include/llama.h | 1 + src/llama-arch.cpp | 42 ++++ src/llama-arch.h | 3 + src/llama-hparams.cpp | 2 +- src/llama-hparams.h | 3 + src/llama-kv-cache.cpp | 2 +- src/llama-model.cpp | 352 ++++++++++++++++++++++++++++++++- tools/mtmd/clip-impl.h | 3 + tools/mtmd/clip.cpp | 261 +++++++++++++++++++++++- tools/mtmd/mtmd.cpp | 2 +- 17 files changed, 1019 insertions(+), 37 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 05d791806df1e..596677661b751 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3843,7 +3843,43 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately name = name.replace("language_model.", "") # InternVL - if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"): + + # handle aggregated expert tensors + # GGUF stores dimensions reversed from PyTorch, so: + # PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A} + # Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp) + # Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down + if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"): + mapped = f"{name}.weight" if not name.endswith(".weight") else name + # Input: (n_expert=128, n_ff_exp=768, n_embd=2048) + # Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128} + # Need PyTorch: (128, 2048, 768) [reversed of GGML] + # So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768) + permuted = data_torch.permute(0, 2, 1).contiguous() + return [(self.map_tensor_name(mapped), permuted)] + + if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"): + if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0: + raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}") + split_dim = data_torch.shape[-1] // 2 + gate = data_torch[..., :split_dim].contiguous() + up = data_torch[..., split_dim:].contiguous() + # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768) + # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128} + # Need PyTorch: (128, 768, 2048) [reversed of GGML] + # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048) + base_name = name.removesuffix(".weight") + base = base_name.rsplit('.', 1)[0] + mapped_gate = f"{base}.gate_proj.weight" + mapped_up = f"{base}.up_proj.weight" + perm_gate = gate.permute(0, 2, 1).contiguous() + perm_up = up.permute(0, 2, 1).contiguous() + return [ + (self.map_tensor_name(mapped_gate), perm_gate), + (self.map_tensor_name(mapped_up), perm_up), + ] + + if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"): # skip visual tensors return [] if name.find("experts") != -1: @@ -3991,6 +4027,201 @@ def set_vocab(self): super().set_vocab() +@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration") +class Qwen3VLVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + # Compute image_size if not present + if "image_size" not in self.hparams_vision: + # For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings + num_pos = self.hparams_vision.get("num_position_embeddings", 2304) + patch_size = self.hparams_vision.get("patch_size", 16) + # num_position_embeddings = (image_size / patch_size) ** 2 + # So image_size = sqrt(num_position_embeddings) * patch_size + image_size = int(num_pos**0.5 * patch_size) + self.hparams_vision["image_size"] = image_size + + # Rename config values for compatibility + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") + self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") + + self.deepstack_layers: list[int] = list(self.hparams_vision.get("deepstack_visual_indexes", [])) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL) + self.gguf_writer.add_vision_use_gelu(True) + + if self.hparams_vision is not None: + merge_size = self.hparams_vision.get("spatial_merge_size") + if merge_size is not None: + self.gguf_writer.add_vision_spatial_merge_size(int(merge_size)) + + # Use text config's rms_norm_eps for vision attention layernorm eps + rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6) + self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps) + + if self.deepstack_layers: + self.gguf_writer.add_vision_deepstack_layers(self.deepstack_layers) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip text model tensors - they go in the text model file + if name.startswith("model.language_model.") or name.startswith("lm_head."): + return [] + + if name.startswith("model.visual."): + name = name.replace("model.visual.", "visual.", 1) + + if name.startswith("visual.deepstack_merger_list."): + prefix, rest = name.split(".", maxsplit=3)[2:] + idx = int(prefix) + target = rest + + tensor_type: gguf.MODEL_TENSOR + if target.startswith("norm."): + tensor_type = gguf.MODEL_TENSOR.V_DS_NORM + suffix = target.split(".", 1)[1] + elif target.startswith("linear_fc1."): + tensor_type = gguf.MODEL_TENSOR.V_DS_FC1 + suffix = target.split(".", 1)[1] + elif target.startswith("linear_fc2."): + tensor_type = gguf.MODEL_TENSOR.V_DS_FC2 + suffix = target.split(".", 1)[1] + else: + raise ValueError(f"Unexpected deepstack tensor: {name}") + + new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}") + return [(new_name, data_torch)] + + if name.startswith("visual.merger."): + suffix = name.split(".", 2)[2] + if suffix.startswith("linear_fc"): + fc_idx_str, tail = suffix.split(".", 1) + fc_num = int(fc_idx_str.replace("linear_fc", "")) + # Qwen3VL has linear_fc1 and linear_fc2 + # Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2) + if fc_num == 1: + fc_idx = 0 + elif fc_num == 2: + fc_idx = 2 + else: + raise ValueError(f"unexpected fc index {fc_num} in {name}") + new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, fc_idx, suffix=f".{tail}") + elif suffix.startswith("norm."): + new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}") + else: + raise ValueError(f"Unexpected merger tensor: {name}") + return [(new_name, data_torch)] + + if name == "visual.patch_embed.proj.weight": + # split Conv3D into Conv2Ds along temporal dimension + c1, c2, kt, _, _ = data_torch.shape + del c1, c2 + if kt != 2: + raise ValueError("Current implementation only supports temporal_patch_size of 2") + return [ + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]), + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]), + ] + + if name == "visual.patch_embed.proj.bias": + # Include the bias - it's used by the C++ code + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)] + + if name.startswith("visual."): + if ".qkv." in name: + if data_torch.ndim == 2: + c3, _ = data_torch.shape + else: + c3 = data_torch.shape[0] + if c3 % 3 != 0: + raise ValueError(f"Unexpected QKV shape for {name}: {data_torch.shape}") + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + base = name.replace("qkv", "{placeholder}") + return [ + (self.map_tensor_name(base.format(placeholder="q")), wq), + (self.map_tensor_name(base.format(placeholder="k")), wk), + (self.map_tensor_name(base.format(placeholder="v")), wv), + ] + + return [(self.map_tensor_name(name), data_torch)] + + # Fall back to parent class for other tensors + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Qwen3VLForConditionalGeneration") +class Qwen3VLTextModel(Qwen3Model): + model_arch = gguf.MODEL_ARCH.QWEN3VL + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL + text_config = self.hparams.get("text_config", {}) + # rope_scaling is deprecated in V5, use rope_parameters instead + rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {} + + if rope_scaling.get("mrope_section"): + # mrope_section contains [time, height, width] dimensions + mrope_section = rope_scaling["mrope_section"] + # Pad to 4 dimensions [time, height, width, extra] + while len(mrope_section) < 4: + mrope_section.append(0) + self.gguf_writer.add_rope_dimension_sections(mrope_section[:4]) + + logger.info(f"MRoPE sections: {mrope_section[:4]}") + + vision_config = self.hparams.get("vision_config", {}) + deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", [])) + self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip vision tensors - they go in the mmproj file + if name.startswith("model.visual."): + return [] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Qwen3VLMoeForConditionalGeneration") +class Qwen3VLMoeTextModel(Qwen3MoeModel): + model_arch = gguf.MODEL_ARCH.QWEN3VLMOE + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL + text_config = self.hparams.get("text_config", {}) + # rope_scaling is deprecated in V5, use rope_parameters instead + rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {} + + if rope_scaling.get("mrope_section"): + # mrope_section contains [time, height, width] dimensions + mrope_section = rope_scaling["mrope_section"] + # Pad to 4 dimensions [time, height, width, extra] + while len(mrope_section) < 4: + mrope_section.append(0) + self.gguf_writer.add_rope_dimension_sections(mrope_section[:4]) + + logger.info(f"MRoPE sections: {mrope_section[:4]}") + + vision_config = self.hparams.get("vision_config", {}) + deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", [])) + self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip vision tensors - they go in the mmproj file + if name.startswith("model.visual."): + return [] + + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("GPT2LMHeadModel") class GPT2Model(TextModel): model_arch = gguf.MODEL_ARCH.GPT2 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d948b00cc7f30..2311cdabe3ba4 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -242,6 +242,7 @@ #define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 +#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000 #define GGML_MROPE_SECTIONS 4 diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b52f0f8472cfe..52d4b85877583 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5474,7 +5474,7 @@ static void ggml_rope_cache_init( } static void ggml_mrope_cache_init( - float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects, + float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float * cache, float sin_sign, float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py @@ -5509,14 +5509,24 @@ static void ggml_mrope_cache_init( } float theta = theta_t; - if (sector >= sections[0] && sector < sec_w) { - theta = theta_h; - } - else if (sector >= sec_w && sector < sec_w + sections[2]) { - theta = theta_w; - } - else if (sector >= sec_w + sections[2]) { - theta = theta_e; + if (is_imrope) { // qwen3vl apply interleaved mrope + if (sector % 3 == 1 && sector < 3 * sections[1]) { + theta = theta_h; + } else if (sector % 3 == 2 && sector < 3 * sections[2]) { + theta = theta_w; + } else { + theta = theta_e; + } + } else { + if (sector >= sections[0] && sector < sec_w) { + theta = theta_h; + } + else if (sector >= sec_w && sector < sec_w + sections[2]) { + theta = theta_w; + } + else if (sector >= sec_w + sections[2]) { + theta = theta_e; + } } rope_yarn( @@ -5589,6 +5599,7 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding + const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -5627,7 +5638,7 @@ static void ggml_compute_forward_rope_f32( const int64_t p_w = pos[i2 + ne2 * 2]; const int64_t p_e = pos[i2 + ne2 * 3]; ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_vision, + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); } @@ -5775,6 +5786,7 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -5813,7 +5825,7 @@ static void ggml_compute_forward_rope_f16( const int64_t p_w = pos[i2 + ne2 * 2]; const int64_t p_e = pos[i2 + ne2 * 3]; ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_vision, + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); } diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index d058504cd6cc0..6ace1776b5e70 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -125,7 +125,7 @@ template static __global__ void rope_multi( const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -152,17 +152,27 @@ static __global__ void rope_multi( const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; - if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } else { // t + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + } + } else { + if (sector < sections.v[0]) { + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + } } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -276,7 +286,7 @@ template static void rope_multi_cuda( const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -287,11 +297,11 @@ static void rope_multi_cuda( if (freq_factors == nullptr) { rope_multi<<>>( x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, sections); + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { rope_multi<<>>( x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, sections); + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } } @@ -369,6 +379,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { @@ -406,11 +417,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) if (src0->type == GGML_TYPE_F32) { rope_multi_cuda( (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { rope_multi_cuda( (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1b71fb3749aaa..a5c36d7648c9e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -111,6 +111,7 @@ class LLM: EXPERTS_PER_GROUP = "{arch}.experts_per_group" MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers" NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers" + NUM_DEEPSTACK_LAYERS = "{arch}.n_deepstack_layers" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -277,6 +278,7 @@ class ClipVision: USE_GELU = "clip.use_gelu" USE_SILU = "clip.use_silu" N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + DEEPSTACK_LAYERS = "clip.vision.deepstack_layers" class Attention: HEAD_COUNT = "clip.vision.attention.head_count" @@ -350,6 +352,8 @@ class MODEL_ARCH(IntEnum): QWEN2VL = auto() QWEN3 = auto() QWEN3MOE = auto() + QWEN3VL = auto() + QWEN3VLMOE = auto() PHI2 = auto() PHI3 = auto() PHIMOE = auto() @@ -430,6 +434,7 @@ class VISION_PROJECTOR_TYPE(IntEnum): GLM_EDGE = auto() MERGER = auto() GEMMA3 = auto() + QWEN3VL = auto() class MODEL_TENSOR(IntEnum): @@ -640,6 +645,9 @@ class MODEL_TENSOR(IntEnum): V_RESMPL_QUERY = auto() # minicpmv V_TOK_EMBD_IMG_BREAK = auto() # pixtral V_MM_PATCH_MERGER = auto() # mistral small 3.1 + V_DS_NORM = auto() # qwen3vl + V_DS_FC1 = auto() # qwen3vl + V_DS_FC2 = auto() # qwen3vl # audio (mtmd) A_ENC_EMBD_POS = auto() A_ENC_CONV1D = auto() @@ -695,6 +703,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.QWEN2VL: "qwen2vl", MODEL_ARCH.QWEN3: "qwen3", MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.QWEN3VL: "qwen3vl", + MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PHIMOE: "phimoe", @@ -986,6 +996,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query", MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1 + MODEL_TENSOR.V_DS_NORM: "v.deepstack.{bid}.norm", + MODEL_TENSOR.V_DS_FC1: "v.deepstack.{bid}.fc1", + MODEL_TENSOR.V_DS_FC2: "v.deepstack.{bid}.fc2", # audio (mtmd) MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", @@ -1054,6 +1067,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.V_RESMPL_QUERY, MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK, MODEL_TENSOR.V_MM_PATCH_MERGER, + MODEL_TENSOR.V_DS_NORM, + MODEL_TENSOR.V_DS_FC1, + MODEL_TENSOR.V_DS_FC2, # audio MODEL_TENSOR.A_ENC_EMBD_POS, MODEL_TENSOR.A_ENC_CONV1D, @@ -1495,6 +1511,40 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.QWEN3VL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN3VLMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3055,6 +3105,7 @@ class VisionProjectorType: LLAMA4 = "llama4" QWEN2VL = "qwen2vl_merger" QWEN25VL = "qwen2.5vl_merger" + QWEN3VL = "qwen3vl_merger" ULTRAVOX = "ultravox" INTERNVL = "internvl" QWEN2A = "qwen2a" # audio diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d52d4f40f7884..3897c83227711 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -860,6 +860,9 @@ def add_attn_temperature_length(self, value: int) -> None: def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) + def add_num_deepstack_layers(self, count: int) -> None: + self.add_uint32(Keys.LLM.NUM_DEEPSTACK_LAYERS.format(arch=self.arch), count) + def add_rope_dimension_count(self, count: int) -> None: self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) @@ -1070,6 +1073,9 @@ def add_vision_projector_scale_factor(self, value: int) -> None: def add_vision_n_wa_pattern(self, value: int) -> None: self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + + def add_vision_deepstack_layers(self, layers: Sequence[int]) -> None: + self.add_array(Keys.ClipVision.DEEPSTACK_LAYERS, layers) # audio models diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index d7dcd8efb8426..1447c65bbc989 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1185,6 +1185,7 @@ class TensorNameMap: "model.vision_model.embeddings.position_embedding", # SmolVLM "vision_model.positional_embedding_vlm", # llama 4 "vision_tower.patch_embed.pos_emb", # kimi-vl + "visual.pos_embed", # qwen3vl ), MODEL_TENSOR.V_ENC_ATTN_Q: ( @@ -1282,6 +1283,7 @@ class TensorNameMap: "vision_model.model.layers.{bid}.mlp.fc1", # llama4 "visual.blocks.{bid}.mlp.fc1", # qwen2vl "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl + "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) ), @@ -1301,6 +1303,7 @@ class TensorNameMap: "vision_model.model.layers.{bid}.mlp.fc2", # llama4 "visual.blocks.{bid}.mlp.fc2", # qwen2vl "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl + "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) ), @@ -1397,6 +1400,18 @@ class TensorNameMap: "patch_merger.merging_layer", # mistral ), + MODEL_TENSOR.V_DS_NORM: ( + "model.visual.deepstack_merger_list.{bid}.norm", # deepstack in qwen3vl + ), + + MODEL_TENSOR.V_DS_FC1: ( + "model.visual.deepstack_merger_list.{bid}.linear_fc1", # deepstack in qwen3vl + ), + + MODEL_TENSOR.V_DS_FC2: ( + "model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl + ), + # audio (mtmd) MODEL_TENSOR.A_ENC_EMBD_POS: ( diff --git a/include/llama.h b/include/llama.h index a0a660bff88da..08e2ffa34c9f6 100644 --- a/include/llama.h +++ b/include/llama.h @@ -83,6 +83,7 @@ extern "C" { LLAMA_ROPE_TYPE_NORM = 0, LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX, LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE, + LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE, LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION, }; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8ca769c5fd2ef..4d92e7ff860dd 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -32,6 +32,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN3, "qwen3" }, { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_QWEN3VL, "qwen3vl" }, + { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -145,6 +147,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, + { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -779,6 +782,45 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_QWEN3VL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN3VLMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PHI2, { diff --git a/src/llama-arch.h b/src/llama-arch.h index dea725c1a753a..4889120d70c67 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -36,6 +36,8 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN3, LLM_ARCH_QWEN3MOE, + LLM_ARCH_QWEN3VL, + LLM_ARCH_QWEN3VLMOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, @@ -149,6 +151,7 @@ enum llm_kv { LLM_KV_EXPERTS_PER_GROUP, LLM_KV_MOE_EVERY_N_LAYERS, LLM_KV_NEXTN_PREDICT_LAYERS, + LLM_KV_NUM_DEEPSTACK_LAYERS, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index db65d69eabdcb..514d653844c40 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -148,7 +148,7 @@ bool llama_hparams::is_recurrent(uint32_t il) const { } uint32_t llama_hparams::n_pos_per_embd() const { - return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; + return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1; } bool llama_hparams::is_swa(uint32_t il) const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 6fcf91b7daa47..539fecb3f7817 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -183,6 +183,9 @@ struct llama_hparams { std::array xielu_beta; std::array xielu_eps; + // qwen3vl deepstack + uint32_t n_deepstack_layers = 0; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 736693e174527..63c59bf20da66 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1340,7 +1340,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( const auto & yarn_beta_slow = cparams.yarn_beta_slow; const auto & n_rot = hparams.n_rot; - const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE + const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE // @ngxson : this is a workaround // for M-RoPE, we want to rotate the whole vector when doing KV shift // a normal RoPE should work, we just need to use the correct ordering diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2a83d66279b79..8450fae9b241f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1026,6 +1026,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN3VL: + { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, 0); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_1_7B; break; + case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + // for deepstack patch, we consider the embd to be [main_embd, deepstack_embd_1, deepstack_embd_2, ...] + hparams.n_embd = hparams.n_embd * (hparams.n_deepstack_layers + 1); + } break; case LLM_ARCH_QWEN3MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -1037,6 +1051,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN3VLMOE: + { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, 0); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + case 94: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + // for deepstack patch, we consider the embd to be [main_embd, deepstack_embd_1, deepstack_embd_2, ...] + hparams.n_embd = hparams.n_embd * (hparams.n_deepstack_layers + 1); + } break; case LLM_ARCH_PHI2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3278,7 +3306,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_QWEN3: + case LLM_ARCH_QWEN3VL: { + int64_t n_embd = hparams.n_embd; + // for deepstack features, we consider the embd to be [main_embd, deepstack_embd_1, deepstack_embd_2, ...] + if (arch == LLM_ARCH_QWEN3VL) { + n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1); + } tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3312,7 +3346,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_QWEN3VLMOE: { + // for deepstack features, we consider the embd to be [main_embd, deepstack_embd_1, deepstack_embd_2, ...] + int64_t n_embd = hparams.n_embd; + if (arch == LLM_ARCH_QWEN3VL) { + n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1); + } tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -6377,6 +6417,10 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + // MRoPE (Multi-axis Rotary Position Embedding) sections + if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { + LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); + } if (!classifier_labels.empty()) { LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); @@ -6442,7 +6486,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE) { + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -9606,6 +9650,301 @@ struct llm_build_qwen3moe : public llm_graph_context { } }; +struct llm_build_qwen3vl : public llm_graph_context { + llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + + const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds + const size_t n_deepstack_layers = hparams.n_deepstack_layers; + const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1); + const int64_t n_embd_head = hparams.n_embd_head_v; + + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + std::vector deepstack_features(n_deepstack_layers, nullptr); + + if (ubatch.embd) { + // Image input: split main embd and deepstack embds + ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); + for (size_t i = 0; i < n_deepstack_layers; i++) { + deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); + } + inpL = inpL_main; + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + if (ubatch.embd && (size_t)il < n_deepstack_layers) { + cur = ggml_add(ctx0, cur, deepstack_features[il]); + cb(cur, "deepstack_out", il); + } + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_qwen3vlmoe : public llm_graph_context { + llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds + const size_t n_deepstack_layers = hparams.n_deepstack_layers; + const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1); + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + std::vector deepstack_features(n_deepstack_layers, nullptr); + + if (ubatch.embd) { + // Image input: split main embd and deepstack embds + ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); + for (size_t i = 0; i < n_deepstack_layers; i++) { + deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); + } + inpL = inpL_main; + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + if (ubatch.embd && (size_t)il < n_deepstack_layers) { + cur = ggml_add(ctx0, cur, deepstack_features[il]); + cb(cur, "deepstack_out", il); + } + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_phi2 : public llm_graph_context { llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -19881,6 +20220,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_QWEN3VL: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_QWEN3VLMOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_PHI2: { llm = std::make_unique(*this, params); @@ -20394,6 +20741,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN2VL: return LLAMA_ROPE_TYPE_MROPE; + case LLM_ARCH_QWEN3VL: + case LLM_ARCH_QWEN3VLMOE: + return LLAMA_ROPE_TYPE_IMROPE; // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1669fad99b36b..77f3cb71e4612 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -39,6 +39,7 @@ #define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" +#define KEY_DEEPSTACK_LAYERS "clip.vision.deepstack_layers" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -127,6 +128,7 @@ enum projector_type { PROJECTOR_TYPE_MINICPMV, PROJECTOR_TYPE_GLM_EDGE, PROJECTOR_TYPE_QWEN2VL, + PROJECTOR_TYPE_QWEN3VL, PROJECTOR_TYPE_GEMMA3, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, @@ -150,6 +152,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, { PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"}, { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"}, + { PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f2abf88523843..6e32d44cb2eae 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -196,6 +196,8 @@ struct clip_hparams { int32_t n_wa_pattern = 0; int32_t spatial_merge_size = 0; + std::vector deepstack_layers; // qwen3vl deepstack layers + // audio int32_t n_mel_bins = 0; // whisper preprocessor int32_t proj_stack_factor = 0; // ultravox @@ -359,6 +361,17 @@ struct clip_model { ggml_tensor * mm_norm_pre_w = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; + // qwen3vl deepstack + struct deepstack_merger { + ggml_tensor * norm_w = nullptr; + ggml_tensor * norm_b = nullptr; + ggml_tensor * fc1_w = nullptr; + ggml_tensor * fc1_b = nullptr; + ggml_tensor * fc2_w = nullptr; + ggml_tensor * fc2_b = nullptr; + }; + std::vector deepstack_mergers; + bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A || proj_type == PROJECTOR_TYPE_VOXTRAL; @@ -831,6 +844,201 @@ struct clip_graph { return gf; } + // Qwen3VL + ggml_cgraph * build_qwen3vl() { + GGML_ASSERT(model.patch_bias != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + GGML_ASSERT(!hparams.deepstack_layers.empty()); + + const int batch_size = 1; + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position + + norm_type norm_t = NORM_TYPE_NORMAL; + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + // second conv dimension + { + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + n_embd, n_patches_x * n_patches_y, batch_size); + } + + // add patch bias + if (model.patch_bias != nullptr) { + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + } + + // calculate absolute position embedding and apply + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + learned_pos_embd = ggml_cont_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + learned_pos_embd = ggml_reshape_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3); + learned_pos_embd = ggml_cont_3d( + ctx0, learned_pos_embd, + n_embd, n_patches_x * n_patches_y, batch_size); + inp = ggml_add(ctx0, inp, learned_pos_embd); + cb(inp, "inp_pos_emb", -1); + + ggml_tensor * inpL = inp; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + } + + // deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size] + ggml_tensor * deepstack_features = nullptr; + const int merge_factor = hparams.spatial_merge_size > 0 ? hparams.spatial_merge_size * hparams.spatial_merge_size : 4; // default 2x2=4 for qwen3vl + + // loop over layers + for (int il = 0; il < n_layer; il++) { + auto & layer = model.layers[il]; + + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + cb(cur, "ln1", il); + + // self-attention + { + ggml_tensor * Qcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); + ggml_tensor * Kcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); + ggml_tensor * Vcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // apply M-RoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + cb(cur, "ffn_inp", il); + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + cb(cur, "ffn_inp_normed", il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + if (std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) != hparams.deepstack_layers.end()) { + const int deepstack_idx = std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) - hparams.deepstack_layers.begin(); + auto & merger = model.deepstack_mergers[deepstack_idx]; + ggml_tensor * feat = ggml_dup(ctx0, cur); + feat = ggml_reshape_3d(ctx0, feat, n_embd * merge_factor, n_pos / merge_factor, batch_size); + + feat = build_norm(feat, merger.norm_w, merger.norm_b, norm_t, eps, il); + feat = ggml_mul_mat(ctx0, merger.fc1_w, feat); + feat = ggml_add(ctx0, feat, merger.fc1_b); + + feat = ggml_gelu(ctx0, feat); + + feat = ggml_mul_mat(ctx0, merger.fc2_w, feat); + feat = ggml_add(ctx0, feat, merger.fc2_b); + + if(!deepstack_features) { + deepstack_features = feat; + } else { + // concat along the feature dimension + deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0); + } + } + + inpL = cur; + } + + // post-layernorm + if (model.post_ln_w) { + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); + } + + // multimodal projection + ggml_tensor * embeddings = inpL; + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + + // GELU activation + embeddings = ggml_gelu(ctx0, embeddings); + + // Second linear layer + embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + + embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + return gf; + } + ggml_cgraph * build_minicpmv() { const int batch_size = 1; @@ -2103,6 +2311,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_qwen2vl(); } break; + case PROJECTOR_TYPE_QWEN3VL: + { + res = graph.build_qwen3vl(); + } break; case PROJECTOR_TYPE_MINICPMV: { res = graph.build_minicpmv(); @@ -2421,6 +2633,13 @@ struct clip_model_loader { hparams.warmup_image_size = hparams.patch_size * 8; get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern); } break; + case PROJECTOR_TYPE_QWEN3VL: + { + hparams.image_size = 1024; // still need this? + hparams.warmup_image_size = hparams.patch_size * 8; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); + get_arr_int(KEY_DEEPSTACK_LAYERS, hparams.deepstack_layers, false); + } break; case PROJECTOR_TYPE_LLAMA4: { hparams.rope_theta = 10000.0f; @@ -2459,6 +2678,15 @@ struct clip_model_loader { LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor); LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + if (hparams.spatial_merge_size > 0) { + LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size); + } + if (!hparams.deepstack_layers.empty()) { + LOG_INF("%s: deepstack_layers: ", __func__); + for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) { + LOG_CNT("%d%s", hparams.deepstack_layers[i], i < hparams.deepstack_layers.size() - 1 ? ", " : "\n"); + } + } } else if (is_audio) { LOG_INF("\n--- audio hparams ---\n"); LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins); @@ -2691,6 +2919,26 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_QWEN3VL: + { + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + + if (!hparams.deepstack_layers.empty()) { + model.deepstack_mergers.resize(hparams.deepstack_layers.size()); + for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) { + auto & merger = model.deepstack_mergers[i]; + merger.norm_w = get_tensor(string_format("v.deepstack.%d.norm.weight", (int)i), false); + merger.norm_b = get_tensor(string_format("v.deepstack.%d.norm.bias", (int)i), false); + merger.fc1_w = get_tensor(string_format("v.deepstack.%d.fc1.weight", (int)i), false); + merger.fc1_b = get_tensor(string_format("v.deepstack.%d.fc1.bias", (int)i), false); + merger.fc2_w = get_tensor(string_format("v.deepstack.%d.fc2.weight", (int)i), false); + merger.fc2_b = get_tensor(string_format("v.deepstack.%d.fc2.bias", (int)i), false); + } + } + } break; case PROJECTOR_TYPE_GEMMA3: { model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); @@ -3554,7 +3802,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_y = inst.grid_size.height; return true; - } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { clip_image_u8 resized; auto patch_size = params.patch_size * 2; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size); @@ -3774,7 +4022,7 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) { int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; const int n_total = clip_n_output_tokens(ctx, img); - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0); } return n_total; @@ -3782,7 +4030,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; - if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) { + if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) { return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0); } return 1; @@ -3838,6 +4086,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_QWEN3VL: { // dynamic size (2 conv, so double patch size) int patch_size = params.patch_size * 2; @@ -4142,6 +4391,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_f32("pos_embed", pos_embed); } break; case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN3VL: { const int merge_ratio = 2; const int pw = image_size_width / patch_size; @@ -4387,6 +4637,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: return ctx->model.mm_1_b->ne[0]; + case PROJECTOR_TYPE_QWEN3VL: + return ctx->model.mm_1_b->ne[0] * ((int)ctx->model.hparams.deepstack_layers.size() + 1); // main path + deepstack paths case PROJECTOR_TYPE_GEMMA3: return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: @@ -4421,7 +4673,8 @@ bool clip_is_glm(const struct clip_ctx * ctx) { bool clip_is_qwen2vl(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL - || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL; + || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL + || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL; } bool clip_is_llava(const struct clip_ctx * ctx) { diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 4d487581ae0a0..66c3d1945fec0 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -258,7 +258,7 @@ struct mtmd_context { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md img_end = "[IMG_END]"; - } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL) { + } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; img_end = "<|vision_end|>"; From f84bd67c804d8c5923d21cb7aa21dbb50fba9fe6 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Sun, 26 Oct 2025 20:00:33 +0800 Subject: [PATCH 2/4] bugfix: fix the arch check for qwen3vl-moe. --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8450fae9b241f..d97c97737be89 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3350,7 +3350,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { // for deepstack features, we consider the embd to be [main_embd, deepstack_embd_1, deepstack_embd_2, ...] int64_t n_embd = hparams.n_embd; - if (arch == LLM_ARCH_QWEN3VL) { + if (arch == LLM_ARCH_QWEN3VLMOE) { n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1); } tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); From 0443a098f3370b2bbef01e95c9da313f5074a7a6 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 28 Oct 2025 20:45:43 +0800 Subject: [PATCH 3/4] use build_ffn --- tools/mtmd/clip.cpp | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 6e32d44cb2eae..a3cb7c3479627 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -993,13 +993,11 @@ struct clip_graph { feat = ggml_reshape_3d(ctx0, feat, n_embd * merge_factor, n_pos / merge_factor, batch_size); feat = build_norm(feat, merger.norm_w, merger.norm_b, norm_t, eps, il); - feat = ggml_mul_mat(ctx0, merger.fc1_w, feat); - feat = ggml_add(ctx0, feat, merger.fc1_b); - - feat = ggml_gelu(ctx0, feat); - - feat = ggml_mul_mat(ctx0, merger.fc2_w, feat); - feat = ggml_add(ctx0, feat, merger.fc2_b); + feat = build_ffn(feat, + merger.fc1_w, merger.fc1_b, + nullptr, nullptr, + merger.fc2_w, merger.fc2_b, + ffn_op_type::FFN_GELU, il); if(!deepstack_features) { deepstack_features = feat; @@ -1021,15 +1019,11 @@ struct clip_graph { ggml_tensor * embeddings = inpL; embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - - // GELU activation - embeddings = ggml_gelu(ctx0, embeddings); - - // Second linear layer - embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + ffn_op_type::FFN_GELU, -1); embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension From 3271877207a4db6f584d0b18c3d1921f3101e815 Mon Sep 17 00:00:00 2001 From: JJJYmmm <1650675829@qq.com> Date: Tue, 28 Oct 2025 21:58:46 +0800 Subject: [PATCH 4/4] optimize deepstack structure --- convert_hf_to_gguf.py | 12 +++-- gguf-py/gguf/constants.py | 2 +- gguf-py/gguf/gguf_writer.py | 4 +- tools/mtmd/clip-impl.h | 5 +- tools/mtmd/clip.cpp | 99 ++++++++++++++++++++++--------------- 5 files changed, 73 insertions(+), 49 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 596677661b751..28cf4dcdb9f33 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4046,7 +4046,9 @@ def __init__(self, *args, **kwargs): self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") - self.deepstack_layers: list[int] = list(self.hparams_vision.get("deepstack_visual_indexes", [])) + self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0) + for idx in self.hparams_vision.get("deepstack_visual_indexes", []): + self.is_deepstack_layers[idx] = True def set_gguf_parameters(self): super().set_gguf_parameters() @@ -4062,10 +4064,11 @@ def set_gguf_parameters(self): rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6) self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps) - if self.deepstack_layers: - self.gguf_writer.add_vision_deepstack_layers(self.deepstack_layers) + if self.is_deepstack_layers: + self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + assert self.hparams_vision is not None # Skip text model tensors - they go in the text model file if name.startswith("model.language_model.") or name.startswith("lm_head."): return [] @@ -4075,7 +4078,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.startswith("visual.deepstack_merger_list."): prefix, rest = name.split(".", maxsplit=3)[2:] - idx = int(prefix) + # prefix is the layer index, convert to absolute clip layer index! + idx = self.hparams_vision.get("deepstack_visual_indexes", [])[int(prefix)] target = rest tensor_type: gguf.MODEL_TENSOR diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a5c36d7648c9e..884549c1b20f8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -278,7 +278,7 @@ class ClipVision: USE_GELU = "clip.use_gelu" USE_SILU = "clip.use_silu" N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl - DEEPSTACK_LAYERS = "clip.vision.deepstack_layers" + IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" class Attention: HEAD_COUNT = "clip.vision.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3897c83227711..e33409402e56b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1074,8 +1074,8 @@ def add_vision_projector_scale_factor(self, value: int) -> None: def add_vision_n_wa_pattern(self, value: int) -> None: self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) - def add_vision_deepstack_layers(self, layers: Sequence[int]) -> None: - self.add_array(Keys.ClipVision.DEEPSTACK_LAYERS, layers) + def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: + self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) # audio models diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 77f3cb71e4612..9416aba3dfde9 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -39,7 +39,7 @@ #define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" -#define KEY_DEEPSTACK_LAYERS "clip.vision.deepstack_layers" +#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -94,6 +94,9 @@ #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model) #define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model) +#define TN_DEEPSTACK_NORM "v.deepstack.%d.norm.%s" // qwen3vl deepstack +#define TN_DEEPSTACK_FC1 "v.deepstack.%d.fc1.%s" // qwen3vl deepstack +#define TN_DEEPSTACK_FC2 "v.deepstack.%d.fc2.%s" // qwen3vl deepstack // mimicpmv #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index a3cb7c3479627..dd5c22469f453 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -196,7 +196,7 @@ struct clip_hparams { int32_t n_wa_pattern = 0; int32_t spatial_merge_size = 0; - std::vector deepstack_layers; // qwen3vl deepstack layers + std::vector is_deepstack_layers; // qwen3vl: whether the layer is a deepstack layer // audio int32_t n_mel_bins = 0; // whisper preprocessor @@ -241,6 +241,14 @@ struct clip_layer { // layer scale (no bias) ggml_tensor * ls_1_w = nullptr; ggml_tensor * ls_2_w = nullptr; + + // qwen3vl deepstack merger + ggml_tensor * deepstack_norm_w = nullptr; + ggml_tensor * deepstack_norm_b = nullptr; + ggml_tensor * deepstack_fc1_w = nullptr; + ggml_tensor * deepstack_fc1_b = nullptr; + ggml_tensor * deepstack_fc2_w = nullptr; + ggml_tensor * deepstack_fc2_b = nullptr; }; struct clip_model { @@ -361,17 +369,6 @@ struct clip_model { ggml_tensor * mm_norm_pre_w = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; - // qwen3vl deepstack - struct deepstack_merger { - ggml_tensor * norm_w = nullptr; - ggml_tensor * norm_b = nullptr; - ggml_tensor * fc1_w = nullptr; - ggml_tensor * fc1_b = nullptr; - ggml_tensor * fc2_w = nullptr; - ggml_tensor * fc2_b = nullptr; - }; - std::vector deepstack_mergers; - bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A || proj_type == PROJECTOR_TYPE_VOXTRAL; @@ -849,7 +846,6 @@ struct clip_graph { GGML_ASSERT(model.patch_bias != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); GGML_ASSERT(model.class_embedding == nullptr); - GGML_ASSERT(!hparams.deepstack_layers.empty()); const int batch_size = 1; const int n_pos = n_patches; @@ -986,17 +982,13 @@ struct clip_graph { cur = ggml_add(ctx0, inpL, cur); cb(cur, "layer_out", il); - if (std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) != hparams.deepstack_layers.end()) { - const int deepstack_idx = std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) - hparams.deepstack_layers.begin(); - auto & merger = model.deepstack_mergers[deepstack_idx]; - ggml_tensor * feat = ggml_dup(ctx0, cur); - feat = ggml_reshape_3d(ctx0, feat, n_embd * merge_factor, n_pos / merge_factor, batch_size); - - feat = build_norm(feat, merger.norm_w, merger.norm_b, norm_t, eps, il); + if (hparams.is_deepstack_layers[il]) { + ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size); + feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il); feat = build_ffn(feat, - merger.fc1_w, merger.fc1_b, + layer.deepstack_fc1_w, layer.deepstack_fc1_b, nullptr, nullptr, - merger.fc2_w, merger.fc2_b, + layer.deepstack_fc2_w, layer.deepstack_fc2_b, ffn_op_type::FFN_GELU, il); if(!deepstack_features) { @@ -2571,6 +2563,9 @@ struct clip_model_loader { hparams.vision_feature_layer.insert(layer); } + // set default deepstack layers to false + hparams.is_deepstack_layers.resize(hparams.n_layer, false); + // model-specific params switch (model.proj_type) { case PROJECTOR_TYPE_MINICPMV: @@ -2632,7 +2627,7 @@ struct clip_model_loader { hparams.image_size = 1024; // still need this? hparams.warmup_image_size = hparams.patch_size * 8; get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); - get_arr_int(KEY_DEEPSTACK_LAYERS, hparams.deepstack_layers, false); + get_arr_bool(KEY_IS_DEEPSTACK_LAYERS, hparams.is_deepstack_layers, false); } break; case PROJECTOR_TYPE_LLAMA4: { @@ -2675,10 +2670,19 @@ struct clip_model_loader { if (hparams.spatial_merge_size > 0) { LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size); } - if (!hparams.deepstack_layers.empty()) { - LOG_INF("%s: deepstack_layers: ", __func__); - for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) { - LOG_CNT("%d%s", hparams.deepstack_layers[i], i < hparams.deepstack_layers.size() - 1 ? ", " : "\n"); + if (!hparams.is_deepstack_layers.empty()) { + LOG_INF("%s: deepstack enabled layers: ", __func__); + bool first = true; + for (size_t i = 0; i < hparams.is_deepstack_layers.size(); ++i) { + if (hparams.is_deepstack_layers[i]) { + LOG_CNT("%s%zu", first ? "" : ", ", i); + first = false; + } + } + if (first) { + LOG_CNT("none\n"); + } else { + LOG_CNT("\n"); } } } else if (is_audio) { @@ -2778,6 +2782,17 @@ struct clip_model_loader { layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); + + // qwen3vl deepstack layer + if (hparams.is_deepstack_layers[il]) { + layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false); + layer.deepstack_norm_b = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "bias"), false); + layer.deepstack_fc1_w = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "weight"), false); + layer.deepstack_fc1_b = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "bias"), false); + layer.deepstack_fc2_w = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "weight"), false); + layer.deepstack_fc2_b = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "bias"), false); + } + // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check! bool is_ffn_swapped = ( @@ -2919,19 +2934,6 @@ struct clip_model_loader { model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); - - if (!hparams.deepstack_layers.empty()) { - model.deepstack_mergers.resize(hparams.deepstack_layers.size()); - for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) { - auto & merger = model.deepstack_mergers[i]; - merger.norm_w = get_tensor(string_format("v.deepstack.%d.norm.weight", (int)i), false); - merger.norm_b = get_tensor(string_format("v.deepstack.%d.norm.bias", (int)i), false); - merger.fc1_w = get_tensor(string_format("v.deepstack.%d.fc1.weight", (int)i), false); - merger.fc1_b = get_tensor(string_format("v.deepstack.%d.fc1.bias", (int)i), false); - merger.fc2_w = get_tensor(string_format("v.deepstack.%d.fc2.weight", (int)i), false); - merger.fc2_b = get_tensor(string_format("v.deepstack.%d.fc2.bias", (int)i), false); - } - } } break; case PROJECTOR_TYPE_GEMMA3: { @@ -3139,6 +3141,21 @@ struct clip_model_loader { } } + void get_arr_bool(const std::string & key, std::vector & output, bool required = true) { + const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (i < 0) { + if (required) throw std::runtime_error("Key not found: " + key); + return; + } + + const int n = gguf_get_arr_n(ctx_gguf.get(), i); + output.resize(n); + const bool * values = (const bool *)gguf_get_arr_data(ctx_gguf.get(), i); + for (int i = 0; i < n; ++i) { + output[i] = values[i]; + } + } + void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) { auto & hparams = model.hparams; for (int x = 1; x <= max_patches_per_side; x++) { @@ -4632,7 +4649,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN25VL: return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_QWEN3VL: - return ctx->model.mm_1_b->ne[0] * ((int)ctx->model.hparams.deepstack_layers.size() + 1); // main path + deepstack paths + return ctx->model.mm_1_b->ne[0] * (1 + std::count(ctx->model.hparams.is_deepstack_layers.begin(), ctx->model.hparams.is_deepstack_layers.end(), true)); // main path + deepstack paths case PROJECTOR_TYPE_GEMMA3: return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: