From a51c6b13db9036085ff49b08a056c2578e77b522 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 25 Oct 2025 01:08:01 +0200 Subject: [PATCH 1/2] model : add LightOnOCR-1B model --- convert_hf_to_gguf.py | 26 ++++++++++++++++++++++++-- gguf-py/gguf/constants.py | 1 + tools/mtmd/clip-impl.h | 2 ++ tools/mtmd/clip.cpp | 26 +++++++++++++++++++++++--- tools/mtmd/mtmd.cpp | 5 +++++ 5 files changed, 55 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ed99dc8477231..19db2f600d8bf 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2289,18 +2289,21 @@ def set_gguf_parameters(self): ) class LlavaVisionModel(MmprojModel): img_break_tok_id = -1 + use_break_tok = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.hparams.get("model_type") == "pixtral": # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) - self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") + if self.use_break_tok: + self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") elif self.is_mistral_format: # hparams is already vision config here so norm_eps is only defined in global_config. self.hparams["norm_eps"] = self.global_config.get("norm_eps", None) assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json" - self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) + if self.use_break_tok: + self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) else: raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") logger.info(f"Image break token id: {self.img_break_tok_id}") @@ -3791,6 +3794,10 @@ def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor: return torch.stack([true_row, false_row], dim=0) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "model.vision_" in name: + # skip multimodal tensors + return [] + if self.is_rerank: is_tied_head = self.is_tied_embeddings and "embed_tokens" in name is_real_head = not self.is_tied_embeddings and "lm_head" in name @@ -9280,6 +9287,21 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", " return super().map_tensor_name(name, try_suffixes) +@ModelBase.register("LightOnOCRForConditionalGeneration") +class LightOnOCRVisionModel(LlavaVisionModel): + is_mistral_format = False + use_break_tok = False + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + name = name.replace("model.vision_encoder.", "vision_tower.") + name = name.replace("model.vision_projection.", "multi_modal_projector.") + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("KimiVLForConditionalGeneration") class KimiVLModel(MmprojModel): def __init__(self, *args, **kwargs): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1b71fb3749aaa..94fcfaf69cf09 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3062,6 +3062,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + LIGHTONOCR = "lightonocr" # Items here are (block size, type size) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1669fad99b36b..ad2108d1798ae 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -139,6 +139,7 @@ enum projector_type { PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_LFM2, PROJECTOR_TYPE_KIMIVL, + PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -161,6 +162,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"}, + { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f2abf88523843..08167625302c6 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -621,7 +621,7 @@ struct clip_graph { } // arrangement of the [IMG_BREAK] token - { + if (model.token_embd_img_break) { // not efficient, but works // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows] // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension @@ -2095,6 +2095,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 res = graph.build_siglip(); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { res = graph.build_pixtral(); } break; @@ -2380,6 +2381,7 @@ struct clip_model_loader { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { hparams.rope_theta = 10000.0f; hparams.warmup_image_size = hparams.patch_size * 8; @@ -2722,6 +2724,15 @@ struct clip_model_loader { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); } break; + case PROJECTOR_TYPE_LIGHTONOCR: + { + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + } break; case PROJECTOR_TYPE_ULTRAVOX: { model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); @@ -3622,7 +3633,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); return true; - } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL + || ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR + ) { clip_image_u8 resized_image; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size); image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height); @@ -3865,12 +3878,17 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = x_patch * y_patch; } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { // dynamic size int n_merge = params.spatial_merge_size; int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1); int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1); - n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + if (ctx->model.token_embd_img_break) { + n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + } else { + n_patches = n_patches_y * n_patches_x; + } } break; case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: @@ -4247,6 +4265,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_LIGHTONOCR: { // set the 2D positions int n_patches_per_col = image_size_width / patch_size; @@ -4377,6 +4396,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_peg_0_b->ne[0]; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_MLP_NORM: return ctx->model.mm_3_b->ne[0]; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 4d487581ae0a0..3b901bfac8215 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -275,6 +275,11 @@ struct mtmd_context { img_beg = ""; img_end = ""; + } else if (proj == PROJECTOR_TYPE_LIGHTONOCR) { + // <|im_start|> ... (image embeddings) ... <|im_end|> + img_beg = "<|im_start|>"; + img_end = "<|im_end|>"; + } } From 62cc684bed5048c703b9c3b90a815bde2694c924 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 27 Oct 2025 15:36:45 +0100 Subject: [PATCH 2/2] add test --- tools/mtmd/tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index dbdf7656a66d9..5e33d127649a0 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -70,6 +70,7 @@ add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0" add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0" +add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"