diff --git a/common/chat.cpp b/common/chat.cpp index 8587140e1ff0a..9dcb820d08072 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -579,6 +579,11 @@ common_chat_templates_ptr common_chat_templates_init( "{%- if false %}"); } + // TODO @ngxson : hot fix for PaddleOCR + if (default_template_src.find("<|IMAGE_PLACEHOLDER|>") != std::string::npos) { + string_replace_all(default_template_src, "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>", ""); + } + std::string token_bos = bos_token_override; std::string token_eos = eos_token_override; bool add_bos = false; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ed99dc8477231..2d8c4c6099dbe 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3234,7 +3234,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM") +@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM", "PaddleOCRVLForConditionalGeneration") class Ernie4_5Model(TextModel): model_arch = gguf.MODEL_ARCH.ERNIE4_5 @@ -3250,6 +3250,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if (head_dim := self.hparams.get("head_dim")) is None: head_dim = self.hparams["hidden_size"] // num_heads + if "mlp_AR" in name or "vision_model" in name: + # skip vision model and projector tensors + return [] + if "ernie." in name: name = name.replace("ernie.", "model.") # split the qkv weights @@ -3368,6 +3372,44 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("SiglipVisionModel") +class PaddleOCRVisionModel(MmprojModel): + # PaddleOCR-VL uses a modified version of Siglip + min_pixels: int = 0 + max_pixels: int = 0 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.min_pixels = self.preprocessor_config["size"]["min_pixels"] + self.max_pixels = self.preprocessor_config["size"]["max_pixels"] + self.hparams_vision["image_size"] = int(math.sqrt(self.max_pixels)) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_vision is not None + hparams = self.hparams_vision + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PADDLEOCR) + self.gguf_writer.add_vision_max_pixels(self.max_pixels) + self.gguf_writer.add_vision_min_pixels(self.min_pixels) + self.gguf_writer.add_vision_use_gelu(True) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-6)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + name = name.replace("visual.", "model.") + + if "vision_model" in name or "mlp_AR" in name: + if "packing_position_embedding" in name: + return [] # unused + elif "vision_model.head" in name: + # we don't yet support image embeddings for this model + return [] + else: + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + @ModelBase.register( "Qwen2VLModel", "Qwen2VLForConditionalGeneration", diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1b71fb3749aaa..ff947503afc17 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -265,6 +265,8 @@ class Clip: class ClipVision: IMAGE_SIZE = "clip.vision.image_size" + MAX_PIXELS = "clip.vision.max_pixels" + MIN_PIXELS = "clip.vision.min_pixels" PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" @@ -3062,6 +3064,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + PADDLEOCR = "paddleocr" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d52d4f40f7884..602df3cb7b56f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1029,6 +1029,12 @@ def add_vision_projection_dim(self, value: int) -> None: def add_vision_patch_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.PATCH_SIZE, value) + def add_vision_max_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.MAX_PIXELS, value) + + def add_vision_min_pixels(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.MIN_PIXELS, value) + def add_vision_embedding_length(self, value: int) -> None: self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index d7dcd8efb8426..b0591ddb276fc 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1144,6 +1144,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl + "mlp_AR.linear_{bid}", # PaddleOCR-VL ), MODEL_TENSOR.V_MMPROJ_FC: ( @@ -1338,6 +1339,7 @@ class TensorNameMap: "multi_modal_projector.layer_norm", "multi_modal_projector.pre_norm", "pre_mm_projector_norm", + "mlp_AR.pre_norm", # PaddleOCR-VL ), MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( @@ -1362,6 +1364,7 @@ class TensorNameMap: MODEL_TENSOR.V_RESMPL_ATTN_OUT: ( "resampler.attn.out_proj", + "model.vision_model.head.attention.out_proj", ), MODEL_TENSOR.V_RESMPL_KV: ( diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1669fad99b36b..534b806f88be4 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -95,12 +95,14 @@ #define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model) // mimicpmv -#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" -#define TN_MINICPMV_QUERY "resampler.query" -#define TN_MINICPMV_PROJ "resampler.proj.weight" -#define TN_MINICPMV_KV_PROJ "resampler.kv.weight" -#define TN_MINICPMV_ATTN "resampler.attn.%s.%s" -#define TN_MINICPMV_LN "resampler.ln_%s.%s" +#define TN_RESAMPL_POS_EMBD_K "resampler.pos_embed_k" +#define TN_RESAMPL_QUERY "resampler.query" +#define TN_RESAMPL_PROJ "resampler.proj.weight" +#define TN_RESAMPL_KV_PROJ "resampler.kv.weight" +#define TN_RESAMPL_ATTN "resampler.attn.%s.%s" +#define TN_RESAMPL_LN "resampler.ln_%s.%s" +#define TN_RESAMPL_FFN_UP "resampler.ffn_up.%s" +#define TN_RESAMPL_FFN_DOWN "resampler.ffn_down.%s" #define TN_GLM_ADAPER_CONV "adapter.conv.%s" #define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s" @@ -139,6 +141,7 @@ enum projector_type { PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_LFM2, PROJECTOR_TYPE_KIMIVL, + PROJECTOR_TYPE_PADDLEOCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -161,6 +164,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"}, + { PROJECTOR_TYPE_PADDLEOCR, "paddleocr"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f2abf88523843..ad62ff4340912 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1136,6 +1136,72 @@ struct clip_graph { return gf; } + ggml_cgraph * build_paddleocr() { + // 2D input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + + // build ViT with 2D position embeddings + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + // first half is X axis and second half is Y axis + return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + }; + + ggml_tensor * inp = build_inp(); + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + learned_pos_embd, + add_pos); + + cb(cur, "vit_out", -1); + + { + // mlp_AR + float proj_norm_eps = 1e-5; // PaddleOCR uses hard-coded value eps=1e-5 for Projector + cur = build_norm(cur, + model.mm_input_norm_w, model.mm_input_norm_b, + NORM_TYPE_NORMAL, proj_norm_eps, -1); + //cur = build_patch_merge_permute(cur, hparams.proj_scale_factor); + + // stack and padding + int64_t stride = hparams.proj_scale_factor * hparams.proj_scale_factor; + int64_t n_embd = cur->ne[0]; + int64_t n_tokens = cur->ne[1]; + int64_t n_tokens_padded = CLIP_ALIGN(n_tokens, stride); + int64_t n_pad = n_tokens_padded - n_tokens; + if (n_pad > 0) { + cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); + cur = ggml_pad(ctx0, cur, n_pad * n_embd, 0, 0, 0); + } + cur = ggml_view_2d(ctx0, cur, + n_embd * stride, + n_tokens_padded / stride, + ggml_row_size(cur->type, n_embd * stride), 0); + cb(cur, "after_stacked", -1); + + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + hparams.ffn_op, -1); + cb(cur, "mlp_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; + } + // this graph is used by llava, granite and glm // due to having embedding_stack (used by granite), we cannot reuse build_vit ggml_cgraph * build_llava() { @@ -2125,6 +2191,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_kimivl(); } break; + case PROJECTOR_TYPE_PADDLEOCR: + { + res = graph.build_paddleocr(); + } break; default: { res = graph.build_llava(); @@ -2440,6 +2510,10 @@ struct clip_model_loader { hparams.ffn_op = FFN_GELU_ERF; log_ffn_op = "gelu_erf"; // temporary solution for logging } break; + case PROJECTOR_TYPE_PADDLEOCR: + { + hparams.proj_scale_factor = 2; + } break; default: break; } @@ -2650,25 +2724,25 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_MINICPMV: { - // model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); - model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K); - model.mm_model_query = get_tensor(TN_MINICPMV_QUERY); - model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ); - model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ); - model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight")); - model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight")); - model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight")); - model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias")); - model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias")); - model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias")); - model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight")); - model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias")); - model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight")); - model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias")); - model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight")); - model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias")); - model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight")); - model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias")); + // model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_RESAMPL_POS_EMBD); + model.mm_model_pos_embed_k = get_tensor(TN_RESAMPL_POS_EMBD_K); + model.mm_model_query = get_tensor(TN_RESAMPL_QUERY); + model.mm_model_proj = get_tensor(TN_RESAMPL_PROJ); + model.mm_model_kv_proj = get_tensor(TN_RESAMPL_KV_PROJ); + model.mm_model_attn_q_w = get_tensor(string_format(TN_RESAMPL_ATTN, "q", "weight")); + model.mm_model_attn_k_w = get_tensor(string_format(TN_RESAMPL_ATTN, "k", "weight")); + model.mm_model_attn_v_w = get_tensor(string_format(TN_RESAMPL_ATTN, "v", "weight")); + model.mm_model_attn_q_b = get_tensor(string_format(TN_RESAMPL_ATTN, "q", "bias")); + model.mm_model_attn_k_b = get_tensor(string_format(TN_RESAMPL_ATTN, "k", "bias")); + model.mm_model_attn_v_b = get_tensor(string_format(TN_RESAMPL_ATTN, "v", "bias")); + model.mm_model_attn_o_w = get_tensor(string_format(TN_RESAMPL_ATTN, "out", "weight")); + model.mm_model_attn_o_b = get_tensor(string_format(TN_RESAMPL_ATTN, "out", "bias")); + model.mm_model_ln_q_w = get_tensor(string_format(TN_RESAMPL_LN, "q", "weight")); + model.mm_model_ln_q_b = get_tensor(string_format(TN_RESAMPL_LN, "q", "bias")); + model.mm_model_ln_kv_w = get_tensor(string_format(TN_RESAMPL_LN, "kv", "weight")); + model.mm_model_ln_kv_b = get_tensor(string_format(TN_RESAMPL_LN, "kv", "bias")); + model.mm_model_ln_post_w = get_tensor(string_format(TN_RESAMPL_LN, "post", "weight")); + model.mm_model_ln_post_b = get_tensor(string_format(TN_RESAMPL_LN, "post", "bias")); } break; case PROJECTOR_TYPE_GLM_EDGE: { @@ -2702,6 +2776,7 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_PADDLEOCR: { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); @@ -3622,7 +3697,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); return true; - } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL + || ctx->proj_type() == PROJECTOR_TYPE_PADDLEOCR + ) { clip_image_u8 resized_image; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size); image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height); @@ -3864,6 +3941,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size; n_patches = x_patch * y_patch; } break; + case PROJECTOR_TYPE_PADDLEOCR: + { + // dynamic size + int scale_factor = ctx->model.hparams.proj_scale_factor; + int stride = scale_factor * scale_factor; + n_patches = CLIP_ALIGN(n_patches, stride) / stride; + } break; case PROJECTOR_TYPE_PIXTRAL: { // dynamic size @@ -4247,6 +4331,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_PADDLEOCR: { // set the 2D positions int n_patches_per_col = image_size_width / patch_size; @@ -4402,6 +4487,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_fc_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_PADDLEOCR: return ctx->model.mm_2_w->ne[1]; default: GGML_ABORT("Unknown projector type"); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 4d487581ae0a0..0c461ad12070d 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -275,6 +275,10 @@ struct mtmd_context { img_beg = ""; img_end = ""; + } else if (proj == PROJECTOR_TYPE_PADDLEOCR) { + // <|IMAGE_START|> ... (image embeddings) ... <|IMAGE_END|> + img_beg = "<|IMAGE_START|>"; + img_end = "<|IMAGE_END|>"; } }