Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,11 @@ common_chat_templates_ptr common_chat_templates_init(
"{%- if false %}");
}

// TODO @ngxson : hot fix for PaddleOCR
if (default_template_src.find("<|IMAGE_PLACEHOLDER|>") != std::string::npos) {
string_replace_all(default_template_src, "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>", "");
}

std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
bool add_bos = false;
Expand Down
44 changes: 43 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3234,7 +3234,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM")
@ModelBase.register("Ernie4_5_ForCausalLM", "Ernie4_5ForCausalLM", "PaddleOCRVLForConditionalGeneration")
class Ernie4_5Model(TextModel):
model_arch = gguf.MODEL_ARCH.ERNIE4_5

Expand All @@ -3250,6 +3250,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if (head_dim := self.hparams.get("head_dim")) is None:
head_dim = self.hparams["hidden_size"] // num_heads

if "mlp_AR" in name or "vision_model" in name:
# skip vision model and projector tensors
return []

if "ernie." in name:
name = name.replace("ernie.", "model.")
# split the qkv weights
Expand Down Expand Up @@ -3368,6 +3372,44 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("SiglipVisionModel")
class PaddleOCRVisionModel(MmprojModel):
# PaddleOCR-VL uses a modified version of Siglip
min_pixels: int = 0
max_pixels: int = 0

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.min_pixels = self.preprocessor_config["size"]["min_pixels"]
self.max_pixels = self.preprocessor_config["size"]["max_pixels"]
self.hparams_vision["image_size"] = int(math.sqrt(self.max_pixels))

def set_gguf_parameters(self):
super().set_gguf_parameters()
assert self.hparams_vision is not None
hparams = self.hparams_vision
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PADDLEOCR)
self.gguf_writer.add_vision_max_pixels(self.max_pixels)
self.gguf_writer.add_vision_min_pixels(self.min_pixels)
self.gguf_writer.add_vision_use_gelu(True)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("rms_norm_eps", 1e-6))

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
name = name.replace("visual.", "model.")

if "vision_model" in name or "mlp_AR" in name:
if "packing_position_embedding" in name:
return [] # unused
elif "vision_model.head" in name:
# we don't yet support image embeddings for this model
return []
else:
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors


@ModelBase.register(
"Qwen2VLModel",
"Qwen2VLForConditionalGeneration",
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ class Clip:

class ClipVision:
IMAGE_SIZE = "clip.vision.image_size"
MAX_PIXELS = "clip.vision.max_pixels"
MIN_PIXELS = "clip.vision.min_pixels"
PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
Expand Down Expand Up @@ -3062,6 +3064,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
PADDLEOCR = "paddleocr"


# Items here are (block size, type size)
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,12 @@ def add_vision_projection_dim(self, value: int) -> None:
def add_vision_patch_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)

def add_vision_max_pixels(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.MAX_PIXELS, value)

def add_vision_min_pixels(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.MIN_PIXELS, value)

def add_vision_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)

Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
"visual.merger.mlp.{bid}", # qwen2vl
"mlp_AR.linear_{bid}", # PaddleOCR-VL
),

MODEL_TENSOR.V_MMPROJ_FC: (
Expand Down Expand Up @@ -1338,6 +1339,7 @@ class TensorNameMap:
"multi_modal_projector.layer_norm",
"multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
"mlp_AR.pre_norm", # PaddleOCR-VL
),

MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
Expand All @@ -1362,6 +1364,7 @@ class TensorNameMap:

MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
"resampler.attn.out_proj",
"model.vision_model.head.attention.out_proj",
),

MODEL_TENSOR.V_RESMPL_KV: (
Expand Down
16 changes: 10 additions & 6 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)

// mimicpmv
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
#define TN_MINICPMV_QUERY "resampler.query"
#define TN_MINICPMV_PROJ "resampler.proj.weight"
#define TN_MINICPMV_KV_PROJ "resampler.kv.weight"
#define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
#define TN_MINICPMV_LN "resampler.ln_%s.%s"
#define TN_RESAMPL_POS_EMBD_K "resampler.pos_embed_k"
#define TN_RESAMPL_QUERY "resampler.query"
#define TN_RESAMPL_PROJ "resampler.proj.weight"
#define TN_RESAMPL_KV_PROJ "resampler.kv.weight"
#define TN_RESAMPL_ATTN "resampler.attn.%s.%s"
#define TN_RESAMPL_LN "resampler.ln_%s.%s"
#define TN_RESAMPL_FFN_UP "resampler.ffn_up.%s"
#define TN_RESAMPL_FFN_DOWN "resampler.ffn_down.%s"

#define TN_GLM_ADAPER_CONV "adapter.conv.%s"
#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s"
Expand Down Expand Up @@ -139,6 +141,7 @@ enum projector_type {
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_PADDLEOCR,
PROJECTOR_TYPE_UNKNOWN,
};

Expand All @@ -161,6 +164,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_PADDLEOCR, "paddleocr"},
};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
126 changes: 106 additions & 20 deletions tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,72 @@ struct clip_graph {
return gf;
}

ggml_cgraph * build_paddleocr() {
// 2D input positions
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);

ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);

ggml_tensor * learned_pos_embd = resize_position_embeddings();

// build ViT with 2D position embeddings
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
// first half is X axis and second half is Y axis
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
};

ggml_tensor * inp = build_inp();
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_NORMAL,
hparams.ffn_op,
learned_pos_embd,
add_pos);

cb(cur, "vit_out", -1);

{
// mlp_AR
float proj_norm_eps = 1e-5; // PaddleOCR uses hard-coded value eps=1e-5 for Projector
cur = build_norm(cur,
model.mm_input_norm_w, model.mm_input_norm_b,
NORM_TYPE_NORMAL, proj_norm_eps, -1);
//cur = build_patch_merge_permute(cur, hparams.proj_scale_factor);

// stack and padding
int64_t stride = hparams.proj_scale_factor * hparams.proj_scale_factor;
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
int64_t n_tokens_padded = CLIP_ALIGN(n_tokens, stride);
int64_t n_pad = n_tokens_padded - n_tokens;
if (n_pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, n_pad * n_embd, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur,
n_embd * stride,
n_tokens_padded / stride,
ggml_row_size(cur->type, n_embd * stride), 0);
cb(cur, "after_stacked", -1);

cur = build_ffn(cur,
model.mm_1_w, model.mm_1_b,
nullptr, nullptr,
model.mm_2_w, model.mm_2_b,
hparams.ffn_op, -1);
cb(cur, "mlp_out", -1);
}

// build the graph
ggml_build_forward_expand(gf, cur);

return gf;
}

// this graph is used by llava, granite and glm
// due to having embedding_stack (used by granite), we cannot reuse build_vit
ggml_cgraph * build_llava() {
Expand Down Expand Up @@ -2125,6 +2191,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_kimivl();
} break;
case PROJECTOR_TYPE_PADDLEOCR:
{
res = graph.build_paddleocr();
} break;
default:
{
res = graph.build_llava();
Expand Down Expand Up @@ -2440,6 +2510,10 @@ struct clip_model_loader {
hparams.ffn_op = FFN_GELU_ERF;
log_ffn_op = "gelu_erf"; // temporary solution for logging
} break;
case PROJECTOR_TYPE_PADDLEOCR:
{
hparams.proj_scale_factor = 2;
} break;
default:
break;
}
Expand Down Expand Up @@ -2650,25 +2724,25 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_MINICPMV:
{
// model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
model.mm_model_query = get_tensor(TN_MINICPMV_QUERY);
model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ);
model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ);
model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight"));
model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight"));
model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight"));
model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias"));
model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias"));
model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias"));
model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight"));
model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias"));
model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight"));
model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias"));
model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight"));
model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias"));
model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
// model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_RESAMPL_POS_EMBD);
model.mm_model_pos_embed_k = get_tensor(TN_RESAMPL_POS_EMBD_K);
model.mm_model_query = get_tensor(TN_RESAMPL_QUERY);
model.mm_model_proj = get_tensor(TN_RESAMPL_PROJ);
model.mm_model_kv_proj = get_tensor(TN_RESAMPL_KV_PROJ);
model.mm_model_attn_q_w = get_tensor(string_format(TN_RESAMPL_ATTN, "q", "weight"));
model.mm_model_attn_k_w = get_tensor(string_format(TN_RESAMPL_ATTN, "k", "weight"));
model.mm_model_attn_v_w = get_tensor(string_format(TN_RESAMPL_ATTN, "v", "weight"));
model.mm_model_attn_q_b = get_tensor(string_format(TN_RESAMPL_ATTN, "q", "bias"));
model.mm_model_attn_k_b = get_tensor(string_format(TN_RESAMPL_ATTN, "k", "bias"));
model.mm_model_attn_v_b = get_tensor(string_format(TN_RESAMPL_ATTN, "v", "bias"));
model.mm_model_attn_o_w = get_tensor(string_format(TN_RESAMPL_ATTN, "out", "weight"));
model.mm_model_attn_o_b = get_tensor(string_format(TN_RESAMPL_ATTN, "out", "bias"));
model.mm_model_ln_q_w = get_tensor(string_format(TN_RESAMPL_LN, "q", "weight"));
model.mm_model_ln_q_b = get_tensor(string_format(TN_RESAMPL_LN, "q", "bias"));
model.mm_model_ln_kv_w = get_tensor(string_format(TN_RESAMPL_LN, "kv", "weight"));
model.mm_model_ln_kv_b = get_tensor(string_format(TN_RESAMPL_LN, "kv", "bias"));
model.mm_model_ln_post_w = get_tensor(string_format(TN_RESAMPL_LN, "post", "weight"));
model.mm_model_ln_post_b = get_tensor(string_format(TN_RESAMPL_LN, "post", "bias"));
} break;
case PROJECTOR_TYPE_GLM_EDGE:
{
Expand Down Expand Up @@ -2702,6 +2776,7 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_PADDLEOCR:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
Expand Down Expand Up @@ -3622,7 +3697,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
return true;

} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL
|| ctx->proj_type() == PROJECTOR_TYPE_PADDLEOCR
) {
clip_image_u8 resized_image;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
Expand Down Expand Up @@ -3864,6 +3941,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
n_patches = x_patch * y_patch;
} break;
case PROJECTOR_TYPE_PADDLEOCR:
{
// dynamic size
int scale_factor = ctx->model.hparams.proj_scale_factor;
int stride = scale_factor * scale_factor;
n_patches = CLIP_ALIGN(n_patches, stride) / stride;
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
// dynamic size
Expand Down Expand Up @@ -4247,6 +4331,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_PADDLEOCR:
{
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
Expand Down Expand Up @@ -4402,6 +4487,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_PADDLEOCR:
return ctx->model.mm_2_w->ne[1];
default:
GGML_ABORT("Unknown projector type");
Expand Down
4 changes: 4 additions & 0 deletions tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ struct mtmd_context {
img_beg = "<img>";
img_end = "</img>";

} else if (proj == PROJECTOR_TYPE_PADDLEOCR) {
// <|IMAGE_START|> ... (image embeddings) ... <|IMAGE_END|>
img_beg = "<|IMAGE_START|>";
img_end = "<|IMAGE_END|>";
}
}

Expand Down
Loading