Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 236 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3843,7 +3843,43 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
name = name.replace("language_model.", "") # InternVL
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):

# handle aggregated expert tensors
# GGUF stores dimensions reversed from PyTorch, so:
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
mapped = f"{name}.weight" if not name.endswith(".weight") else name
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
permuted = data_torch.permute(0, 2, 1).contiguous()
return [(self.map_tensor_name(mapped), permuted)]

if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
split_dim = data_torch.shape[-1] // 2
gate = data_torch[..., :split_dim].contiguous()
up = data_torch[..., split_dim:].contiguous()
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
base_name = name.removesuffix(".weight")
base = base_name.rsplit('.', 1)[0]
mapped_gate = f"{base}.gate_proj.weight"
mapped_up = f"{base}.up_proj.weight"
perm_gate = gate.permute(0, 2, 1).contiguous()
perm_up = up.permute(0, 2, 1).contiguous()
return [
(self.map_tensor_name(mapped_gate), perm_gate),
(self.map_tensor_name(mapped_up), perm_up),
]

if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
# skip visual tensors
return []
if name.find("experts") != -1:
Expand Down Expand Up @@ -3991,6 +4027,205 @@ def set_vocab(self):
super().set_vocab()


@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
# Compute image_size if not present
if "image_size" not in self.hparams_vision:
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
num_pos = self.hparams_vision.get("num_position_embeddings", 2304)
patch_size = self.hparams_vision.get("patch_size", 16)
# num_position_embeddings = (image_size / patch_size) ** 2
# So image_size = sqrt(num_position_embeddings) * patch_size
image_size = int(num_pos**0.5 * patch_size)
self.hparams_vision["image_size"] = image_size

# Rename config values for compatibility
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")

self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0)
for idx in self.hparams_vision.get("deepstack_visual_indexes", []):
self.is_deepstack_layers[idx] = True

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
self.gguf_writer.add_vision_use_gelu(True)

if self.hparams_vision is not None:
merge_size = self.hparams_vision.get("spatial_merge_size")
if merge_size is not None:
self.gguf_writer.add_vision_spatial_merge_size(int(merge_size))

# Use text config's rms_norm_eps for vision attention layernorm eps
rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6)
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)

if self.is_deepstack_layers:
self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
assert self.hparams_vision is not None
# Skip text model tensors - they go in the text model file
if name.startswith("model.language_model.") or name.startswith("lm_head."):
return []

if name.startswith("model.visual."):
name = name.replace("model.visual.", "visual.", 1)

if name.startswith("visual.deepstack_merger_list."):
prefix, rest = name.split(".", maxsplit=3)[2:]
# prefix is the layer index, convert to absolute clip layer index!
idx = self.hparams_vision.get("deepstack_visual_indexes", [])[int(prefix)]
target = rest

tensor_type: gguf.MODEL_TENSOR
if target.startswith("norm."):
tensor_type = gguf.MODEL_TENSOR.V_DS_NORM
suffix = target.split(".", 1)[1]
elif target.startswith("linear_fc1."):
tensor_type = gguf.MODEL_TENSOR.V_DS_FC1
suffix = target.split(".", 1)[1]
elif target.startswith("linear_fc2."):
tensor_type = gguf.MODEL_TENSOR.V_DS_FC2
suffix = target.split(".", 1)[1]
else:
raise ValueError(f"Unexpected deepstack tensor: {name}")

new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}")
return [(new_name, data_torch)]

if name.startswith("visual.merger."):
suffix = name.split(".", 2)[2]
if suffix.startswith("linear_fc"):
fc_idx_str, tail = suffix.split(".", 1)
fc_num = int(fc_idx_str.replace("linear_fc", ""))
# Qwen3VL has linear_fc1 and linear_fc2
# Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
if fc_num == 1:
fc_idx = 0
elif fc_num == 2:
fc_idx = 2
else:
raise ValueError(f"unexpected fc index {fc_num} in {name}")
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, fc_idx, suffix=f".{tail}")
elif suffix.startswith("norm."):
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}")
else:
raise ValueError(f"Unexpected merger tensor: {name}")
return [(new_name, data_torch)]

if name == "visual.patch_embed.proj.weight":
# split Conv3D into Conv2Ds along temporal dimension
c1, c2, kt, _, _ = data_torch.shape
del c1, c2
if kt != 2:
raise ValueError("Current implementation only supports temporal_patch_size of 2")
return [
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]),
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
]

if name == "visual.patch_embed.proj.bias":
# Include the bias - it's used by the C++ code
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]

if name.startswith("visual."):
if ".qkv." in name:
if data_torch.ndim == 2:
c3, _ = data_torch.shape
else:
c3 = data_torch.shape[0]
if c3 % 3 != 0:
raise ValueError(f"Unexpected QKV shape for {name}: {data_torch.shape}")
c = c3 // 3
wq = data_torch[:c]
wk = data_torch[c: c * 2]
wv = data_torch[c * 2:]
base = name.replace("qkv", "{placeholder}")
return [
(self.map_tensor_name(base.format(placeholder="q")), wq),
(self.map_tensor_name(base.format(placeholder="k")), wk),
(self.map_tensor_name(base.format(placeholder="v")), wv),
]

return [(self.map_tensor_name(name), data_torch)]

# Fall back to parent class for other tensors
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3VLForConditionalGeneration")
class Qwen3VLTextModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.QWEN3VL

def set_gguf_parameters(self):
super().set_gguf_parameters()

# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
text_config = self.hparams.get("text_config", {})
# rope_scaling is deprecated in V5, use rope_parameters instead
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}

if rope_scaling.get("mrope_section"):
# mrope_section contains [time, height, width] dimensions
mrope_section = rope_scaling["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])

logger.info(f"MRoPE sections: {mrope_section[:4]}")

vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors - they go in the mmproj file
if name.startswith("model.visual."):
return []

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
class Qwen3VLMoeTextModel(Qwen3MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE

def set_gguf_parameters(self):
super().set_gguf_parameters()

# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
text_config = self.hparams.get("text_config", {})
# rope_scaling is deprecated in V5, use rope_parameters instead
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}

if rope_scaling.get("mrope_section"):
# mrope_section contains [time, height, width] dimensions
mrope_section = rope_scaling["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])

logger.info(f"MRoPE sections: {mrope_section[:4]}")

vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors - they go in the mmproj file
if name.startswith("model.visual."):
return []

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2
Expand Down
1 change: 1 addition & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@
#define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000

#define GGML_MROPE_SECTIONS 4

Expand Down
34 changes: 23 additions & 11 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5474,7 +5474,7 @@ static void ggml_rope_cache_init(
}

static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
Expand Down Expand Up @@ -5509,14 +5509,24 @@ static void ggml_mrope_cache_init(
}

float theta = theta_t;
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
if (is_imrope) { // qwen3vl apply interleaved mrope
if (sector % 3 == 1 && sector < 3 * sections[1]) {
theta = theta_h;
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
theta = theta_w;
} else {
theta = theta_e;
}
Comment on lines +5512 to +5519
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this will make KV shifting more complicated, as my hack from #13870 won't work anymore

} else {
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
}

rope_yarn(
Expand Down Expand Up @@ -5589,6 +5599,7 @@ static void ggml_compute_forward_rope_f32(

const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;

if (is_mrope) {
Expand Down Expand Up @@ -5627,7 +5638,7 @@ static void ggml_compute_forward_rope_f32(
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}

Expand Down Expand Up @@ -5775,6 +5786,7 @@ static void ggml_compute_forward_rope_f16(

const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode & GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;

if (is_mrope) {
Expand Down Expand Up @@ -5813,7 +5825,7 @@ static void ggml_compute_forward_rope_f16(
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}

Expand Down
Loading