Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3203,7 +3203,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
{"--parse-special"},
string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
string_format("parse special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
[](common_params & params) {
params.parse_special = true;
}
Expand Down
252 changes: 250 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))

# preprocessor config
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
Expand Down Expand Up @@ -3852,7 +3852,43 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
name = name.replace("language_model.", "") # InternVL
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):

# handle aggregated expert tensors
# GGUF stores dimensions reversed from PyTorch, so:
# PyTorch (A,B,C) -> GGUF writes [C,B,A] -> GGML reads ne={C,B,A}
# Input shapes from HF: (n_expert, n_ff_exp, n_embd) or (n_expert, n_embd, n_ff_exp)
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
mapped = f"{name}.weight" if not name.endswith(".weight") else name
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
permuted = data_torch.permute(0, 2, 1).contiguous()
return [(self.map_tensor_name(mapped), permuted)]

if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
split_dim = data_torch.shape[-1] // 2
gate = data_torch[..., :split_dim].contiguous()
up = data_torch[..., split_dim:].contiguous()
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
base_name = name.removesuffix(".weight")
base = base_name.rsplit('.', 1)[0]
mapped_gate = f"{base}.gate_proj.weight"
mapped_up = f"{base}.up_proj.weight"
perm_gate = gate.permute(0, 2, 1).contiguous()
perm_up = up.permute(0, 2, 1).contiguous()
return [
(self.map_tensor_name(mapped_gate), perm_gate),
(self.map_tensor_name(mapped_up), perm_up),
]

if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
# skip visual tensors
return []
if name.find("experts") != -1:
Expand Down Expand Up @@ -4004,6 +4040,187 @@ def set_vocab(self):
super().set_vocab()


@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
# Compute image_size if not present
if "image_size" not in self.hparams_vision:
# For Qwen3VL/Qwen3VLMoe, compute from num_position_embeddings
num_pos = self.hparams_vision.get("num_position_embeddings", 2304)
patch_size = self.hparams_vision.get("patch_size", 16)
# num_position_embeddings = (image_size / patch_size) ** 2
# So image_size = sqrt(num_position_embeddings) * patch_size
image_size = int(num_pos**0.5 * patch_size)
self.hparams_vision["image_size"] = image_size

# Rename config values for compatibility
self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads")
self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth")

self.is_deepstack_layers = [False] * int(self.hparams_vision["num_hidden_layers"] or 0)
for idx in self.hparams_vision.get("deepstack_visual_indexes", []):
self.is_deepstack_layers[idx] = True

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN3VL)
self.gguf_writer.add_vision_use_gelu(True)

if self.hparams_vision is not None:
merge_size = self.hparams_vision.get("spatial_merge_size")
if merge_size is not None:
self.gguf_writer.add_vision_spatial_merge_size(int(merge_size))

# Use text config's rms_norm_eps for vision attention layernorm eps
rms_norm_eps = self.global_config.get("text_config", {}).get("rms_norm_eps", 1e-6)
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)

if self.is_deepstack_layers:
self.gguf_writer.add_vision_is_deepstack_layers(self.is_deepstack_layers)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
assert self.hparams_vision is not None
# Skip text model tensors - they go in the text model file
if name.startswith("model.language_model.") or name.startswith("lm_head."):
return []

if name.startswith("model.visual."):
name = name.replace("model.visual.", "visual.", 1)

if name.startswith("visual.deepstack_merger_list."):
prefix, rest = name.split(".", maxsplit=3)[2:]
# prefix is the layer index, convert to absolute clip layer index!
idx = self.hparams_vision.get("deepstack_visual_indexes", [])[int(prefix)]
target = rest

tensor_type: gguf.MODEL_TENSOR
if target.startswith("norm."):
tensor_type = gguf.MODEL_TENSOR.V_DS_NORM
suffix = target.split(".", 1)[1]
elif target.startswith("linear_fc1."):
tensor_type = gguf.MODEL_TENSOR.V_DS_FC1
suffix = target.split(".", 1)[1]
elif target.startswith("linear_fc2."):
tensor_type = gguf.MODEL_TENSOR.V_DS_FC2
suffix = target.split(".", 1)[1]
else:
raise ValueError(f"Unexpected deepstack tensor: {name}")

new_name = self.format_tensor_name(tensor_type, idx, suffix=f".{suffix}")
return [(new_name, data_torch)]

if name.startswith("visual.merger."):
suffix = name.split(".", 2)[2]
if suffix.startswith("linear_fc"):
fc_idx_str, tail = suffix.split(".", 1)
fc_num = int(fc_idx_str.replace("linear_fc", ""))
# Qwen3VL has linear_fc1 and linear_fc2
# Map to indices 0 and 2 (matching Qwen2VL which uses indices 0 and 2)
if fc_num == 1:
fc_idx = 0
elif fc_num == 2:
fc_idx = 2
else:
raise ValueError(f"unexpected fc index {fc_num} in {name}")
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_MMPROJ, fc_idx, suffix=f".{tail}")
elif suffix.startswith("norm."):
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.V_POST_NORM, suffix=f".{suffix.split('.', 1)[1]}")
else:
raise ValueError(f"Unexpected merger tensor: {name}")
return [(new_name, data_torch)]

if name == "visual.patch_embed.proj.weight":
# split Conv3D into Conv2Ds along temporal dimension
c1, c2, kt, _, _ = data_torch.shape
del c1, c2
if kt != 2:
raise ValueError("Current implementation only supports temporal_patch_size of 2")
return [
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight", data_torch[:, :, 0, ...]),
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
]

if name == "visual.patch_embed.proj.bias":
# Include the bias - it's used by the C++ code
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".bias", data_torch)]

if name.startswith("visual."):
return [(self.map_tensor_name(name), data_torch)]

# Fall back to parent class for other tensors
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3VLForConditionalGeneration")
class Qwen3VLTextModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.QWEN3VL

def set_gguf_parameters(self):
super().set_gguf_parameters()

# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
text_config = self.hparams.get("text_config", {})
# rope_scaling is deprecated in V5, use rope_parameters instead
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}

if rope_scaling.get("mrope_section"):
# mrope_section contains [time, height, width] dimensions
mrope_section = rope_scaling["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])

logger.info(f"MRoPE sections: {mrope_section[:4]}")

vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors - they go in the mmproj file
if name.startswith("model.visual."):
return []

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Qwen3VLMoeForConditionalGeneration")
class Qwen3VLMoeTextModel(Qwen3MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3VLMOE

def set_gguf_parameters(self):
super().set_gguf_parameters()

# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
text_config = self.hparams.get("text_config", {})
# rope_scaling is deprecated in V5, use rope_parameters instead
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}

if rope_scaling.get("mrope_section"):
# mrope_section contains [time, height, width] dimensions
mrope_section = rope_scaling["mrope_section"]
# Pad to 4 dimensions [time, height, width, extra]
while len(mrope_section) < 4:
mrope_section.append(0)
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])

logger.info(f"MRoPE sections: {mrope_section[:4]}")

vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Skip vision tensors - they go in the mmproj file
if name.startswith("model.visual."):
return []

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2
Expand Down Expand Up @@ -9493,6 +9710,37 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [] # skip other tensors


@ModelBase.register("CogVLMForCausalLM")
class CogVLMVisionModel(MmprojModel):

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.COGVLM)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if not name.startswith("model.vision."):
return []

return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("CogVLMForCausalLM")
class CogVLMModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.COGVLM

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

# block vision tensors
if name.startswith("model.vision."):
return []

return [(self.map_tensor_name(name), data_torch)]

###### CONVERSION LOGIC ######


Expand Down
1 change: 1 addition & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@
#define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000

#define GGML_MROPE_SECTIONS 4

Expand Down
5 changes: 0 additions & 5 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id(
chunk_size = 64;
}

#if defined(__aarch64__)
// disable for ARM
const bool disable_chunking = true;
#else
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
#endif // defined(__aarch64__)

int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
Expand Down
Loading