Skip to content

add HyperClovaX Vision#44314

Open
jp1924 wants to merge 85 commits intohuggingface:mainfrom
jp1924:feat/hcx-seed-32b
Open

add HyperClovaX Vision#44314
jp1924 wants to merge 85 commits intohuggingface:mainfrom
jp1924:feat/hcx-seed-32b

Conversation

@jp1924
Copy link
Copy Markdown
Contributor

@jp1924 jp1924 commented Feb 27, 2026

What does this PR do?

Hello, Transformers team!

I submitted a PR to add naver-hyperclovax/HyperCLOVAX-SEED-Think-32B (hereafter HCX), developed by the Korean IT company Naver while executing the government's national AI model project.

The HCX code was written based on Transformer 4.52.4, leading to the following issues:

  1. Being based on an outdated Transformer model prevents the application of the latest training optimization techniques supported by Transformer 5.0.0 (e.g., sequence parallelism).
  2. The use of some deprecated code or features may cause unexpected bugs in the latest Transformer version.
  3. The modeling code was overly complex, reducing debugging and development convenience. Additionally, experimental code used during model creation remained untouched.

Moving to Transformer 5.0.0 significantly improved the readability and development convenience of the modeling code. We aim to leverage this to add the HCX model to transformers.

TODO list

  • Add docstrings

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp @yonigozlan @molbap

@jp1924 jp1924 changed the title add HyperCLOVAX Vision add HyperClovaX Vision Feb 27, 2026
@jp1924
Copy link
Copy Markdown
Contributor Author

jp1924 commented Feb 27, 2026

When I ran the test in my environment, hcx test ran normally without any failures.
The process for adding a model will likely proceed like this:

  • First, fix all the comments you guys added in the modeling part.
  • (Optional) Add modular code.
  • Once the modeling code work is nearly complete, write the docstrings.
  • Then apply the style.

I think the work will proceed in this manner?

@jp1924
Copy link
Copy Markdown
Contributor Author

jp1924 commented Feb 27, 2026

Hmm... In my environment, torch_compilable_check passes normally, but it fails in this test. I think I need to look into this issue further.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Yep don't worry we'll fix this one on main!

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jp1924, thanks a lot for the PR!

I think we need to move all the code to modular file first, since there is a lot of copied stuff from other models. I left you comments below about where exactly can copy from. Ping me for another review when the modular is ready

@jp1924
Copy link
Copy Markdown
Contributor Author

jp1924 commented Mar 19, 2026

@zucchini-nlp

modeling_auto.py

[#1] "text model?"
→ When loading naver-hyperclovax/HyperCLOVAX-SEED-Think-32B from the Hub, its config.json has "model_type": "vlm", which prevents the model from loading correctly (the expected model types are hyperclovax and hyperclovax_text). For the same reason, lookups via AutoModel, AutoModelForCausalLM, and AutoModelForImageTextToText also fail. I suspect similar issues have come up in other PRs adding models to Transformers — how have you handled this in those cases?

[#2] "this is weird. Usually vision_model matches to vision model, same for text. And just hyperclovax matches the generation model"
→ Done.


configuration_hyperclovax_vision.py

[#3] "ultra nit: 2026"
→ Done.

[#4] "we can't import from other models like this. If we want to inherit from Qwen, then config has to be defined entirely in modular"
→ Done.

[#5] "needs a rebase on main. Config classes are now dataclasses decorated with auto_doc and strict"
→ Done.

[#6] "let's use PreTrainedConfig with T"
→ Done.

[#7] "still relevant — yes, exactly. Inherit from similar configs and add/delete attributes if needed"
→ Done.


modular_hyperclovax_vision.py

[#8] "these all will be copied by modular as long as we copy from LlamaDecoderLayer in L70. So we can delete them"
HyperClovaXAttention and HyperClovaXDecoderLayer are kept since they are referenced in _can_record_outputs.

[#9] "cache position was removed :)"
→ Done.

[#10] "to delete Qwen2_5_VLVisionBlock, it will be added when recursing over children by PreTrainedModel"
→ Done.

[#11] "can be HCXVisionPreTrainedModel(LlamaPreTrainedModel) and allow us delete similar attributes"
→ Done.

[#12] "looks very much same as super(). Does it error out if we don't override _init_weights?"
→ Done.

[#13] "HyperClovaXTextModel(GraniteModel). Actually, the llm is same as Granite no? Especially since use_post_norm=False which means we just need to load Granite and Qwen as is via AutoModel"
→ Removed use_post_norm to further simplify the modular file.

[#14] "these then can be deleted as they are identical to granite"
→ Done.

[#15] "same here, forward looks identical and can be copied"
→ Done.

[#16] "also same as granite"
→ Done.

[#17] "we need to be sure that the attr exists. We have freedom to define it in config the correct way"
→ Changed to derive the hidden size from vision_config and text_config at the config level.

[#18] "for a single linear layer, we don't need a module. Can be inside ConditionalGeneration's __init__"
→ Removed the MultiModalProjector module and replaced it with a single nn.Linear layer.

[#19] "just AutoModel is enough"
→ Fixed so that Qwen2_5_VisionTransformerPretrainedModel is correctly resolved by AutoModel.from_config when using Qwen2_5_VLVisionConfig from qwen2_5_vl.

[#20] "prob this class and the cond generation can be copied from VideoLlama3, looks similar"
→ It was possible to replace HCXVisionModel with VideoLlama, but ConditionalGeneration could not be changed. Since hcx supports position_ids but VideoLlama does not, the test failed.

@jp1924 jp1924 requested a review from zucchini-nlp March 19, 2026 09:49
@zucchini-nlp
Copy link
Copy Markdown
Member

→ When loading naver-hyperclovax/HyperCLOVAX-SEED-Think-32B from the Hub, its config.json has "model_type": "vlm", which prevents the model from loading correctly (the expected model types are hyperclovax and hyperclovax_text). For the same reason, lookups via AutoModel, AutoModelForCausalLM, and AutoModelForImageTextToText also fail. I suspect similar issues have come up in other PRs adding models to Transformers — how have you handled this in those cases?

Can we ask repo owners to change the model_type? Even their config has correct model type so I don't know how they ended up with "vlm"

https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Think-32B/blob/main/configuration_hyperclovax.py

@jp1924
Copy link
Copy Markdown
Contributor Author

jp1924 commented Mar 19, 2026

This is because the model_type in config.json is set to vlm, so it gets overridden.
https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Think-32B/blob/main/config.json#L25

I’ll open a pull request and ask the repo owner to make the changes,
but judging by their recent activity, it looks like they’ve pretty much abandoned the repo, so it probably won’t be easy.

@jp1924
Copy link
Copy Markdown
Contributor Author

jp1924 commented Mar 19, 2026

@zucchini-nlp
Also, the config file you're looking at is for the text-only model; for the multi-modal model, you need to check https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Think-32B/blob/main/configuration_vlm.py. If you look at this, you'll see that the files are divided into text-only and text+image models.
Apparently, they were separated so they could be deployed on a VLLM-based server called NAVER-Cloud-HyperCLOVA-X/OmniServe....

The reason I’m working on this PR is that I need to use this, and I’m submitting it to resolve this inconvenience...
Please help me!🙇‍♂

@jp1924
Copy link
Copy Markdown
Contributor Author

jp1924 commented Mar 19, 2026

I suppose the cleanest solution would be to create a conversion file like convert_llava_weights_to_hf.py, but couldn't we just agree to add this one instead?

@zucchini-nlp
Copy link
Copy Markdown
Member

Okeee, I didn't scroll enough to see the multimodal config i guess 😓

So to wrap it up, this affects only when loading AutoModeFromImageTextToText and doesn't prevent us from loading HyperClovaxForConditionalGeneraton.from_pretrained. We still can use AutoModel in the code to load vision and lm backbones. Am I getting it right?

If yes, then I suggest to open a PR and add a warning on model doc page, saying that the model has to be loaded with HyperClovaxForConditionalGeneraton and not auto-class

@jp1924
Copy link
Copy Markdown
Contributor Author

jp1924 commented Mar 20, 2026

@zucchini-nlp

I opened a PR on the Model Hub repo for the model_type change: https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Think-32B/discussions/11#69bcccb6b66162302dba072e

Best case is the repo maintainer approves/merges quickly, but activity has been low so responses may be slow. Therefore, for now I think we should proceed with the transformers PR while keeping model_type="vlm" on the Hub to avoid delays.

Loading tests with the current code show:

  • AutoClass APIs (e.g. AutoModel, AutoModelForImageTextToText) still do not work correctly (dimension mismatches or silent failures).
  • However, loading with our custom classes works:
from transformers import (
    HCXVisionForConditionalGeneration,
    HyperClovaXForCausalLM,
    HyperClovaXConfig,
    HCXVisionConfig,
)

HCXVisionForConditionalGeneration.from_pretrained("naver-hyperclovax/HyperCLOVAX-SEED-Think-32B")
HyperClovaXForCausalLM.from_pretrained("naver-hyperclovax/HyperCLOVAX-SEED-Think-32B")
HyperClovaXConfig.from_pretrained("naver-hyperclovax/HyperCLOVAX-SEED-Think-32B")
HCXVisionConfig.from_pretrained("naver-hyperclovax/HyperCLOVAX-SEED-Think-32B")

I recommend adding a strong warning in the model docstring / model card to reduce user confusion. Suggested wording:

Important Note on Loading
The current Hub config sets model_type to "vlm". Using AutoClass loaders (e.g., AutoModel, AutoModelForImageTextToText) may load an incorrect architecture. To be safe, please call HCXVisionForConditionalGeneration.from_pretrained() or HyperClovaXForCausalLM.from_pretrained() directly. (Once the Hub model_type is updated to "hyperclovax_vlm", AutoClass loading should work and we will remove/update this notice.)

When the Hub change is applied we can remove or update this warning — this keeps user confusion to a minimum in the meantime.

Also, regarding issue #20: I updated modular_hyperclovax_vision.py.
I tried to inherit from the VideoLlama family and keep HCXVisionModel minimal, but differences in how vision features are handled required reimplementing forward, get_image_features, and get_video_features specifically for HCX.

Main differences:

  • HCX: immediately after mm_projection we use the plain tensor and directly concatenate it into the language model input. We do not use post-processing/meta-info like video_merge_sizes or video_grid_thw.
  • VideoLlama: after extracting vision features, it performs post-processing (merge/grid reshape, etc.) and often returns a tuple; forward then unpacks that tuple and performs concat/processing.

Consequences:

  • Using VideoLlama's get_*_features as-is causes tuple vs plain tensor mismatches and raises errors.
  • The VideoLlama forward concat logic expects video_merge_sizes; without it shape/assertion checks fail.

So I overrode feature extraction and forward logic to implement HCX's specific behavior.

Please review and let me know any feedback or requested changes

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oke, the model type thing affects only when users want AutoModel support and we can just add a warning in docs. Not a big deal

I think there are still opportunities to copy and delete copy-paste code, e.g. no need to create a whole new LM and TextConfig. Same goes for processor, as I didn't see difference in the code. We can instead define the correct auto-mapping. Should work, no?

("vits", "VitsConfig"),
("vivit", "VivitConfig"),
("vjepa2", "VJEPA2Config"),
# ("vlm", "HCXVisionConfig"),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ig "hyperclovax_vlm: HCXVisionConfig"

("qwen2", "Qwen2Config"),
("qwen2_5_omni", "Qwen2_5OmniConfig"),
("qwen2_5_vl", "Qwen2_5_VLConfig"),
("qwen2_5_vl_image", "Qwen2_5_VLVisionConfig"),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we do qwen2_5_vl_vision just for consistency

("hunyuan_v1_dense", "HunYuanDenseV1Config"),
("hunyuan_v1_moe", "HunYuanMoEV1Config"),
("hyperclovax", "HyperClovaXConfig"),
("hyperclovax_vlm", "HCXVisionConfig"),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here: hyperclovax_vision

("glmasr_encoder", "glmasr"),
("hyperclovax", "hyperclovax_vision"),
("hyperclovax_vlm", "hyperclovax_vision"),
# ("vlm", "hyperclovax_vision"),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to delete the commented line. There are more, so I won't comment on each one :)

Comment on lines +46 to +62
_no_split_modules = ["HyperClovaXDecoderLayer"]
input_modalities = ("image", "video", "text")
_can_record_outputs = {"hidden_states": HyperClovaXDecoderLayer, "attentions": HyperClovaXAttention}


@auto_docstring
class HyperClovaXModel(GraniteModel):
config_class = HyperClovaXConfig
input_modalities = ("text",)


@auto_docstring
class HyperClovaXForCausalLM(GraniteForCausalLM):
accepts_loss_kwargs = False
config_class = HyperClovaXConfig
input_modalities = ("text",)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hyperclovax LM is same as granite right? I don't think we need a new LM class for it then

Comment on lines +25 to +32
image_processor_class = "Qwen2VLImageProcessorFast"
video_processor_class = "Qwen2VLVideoProcessor"
tokenizer_class = (
"GPT2Tokenizer",
"GPT2TokenizerFast",
"PreTrainedTokenizer",
"PreTrainedTokenizerFast",
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't be needed, instead we only add it in auto mapping

Comment on lines +56 to +68
def apply_chat_template(self, conversation, chat_template=None, **kwargs):
conversation = copy.deepcopy(conversation)

tokenize = kwargs.get("tokenize", False)
if not tokenize:
template = chat_template
if template is None:
template = (
self.chat_template["default"] if isinstance(self.chat_template, dict) else self.chat_template
)
return self.tokenizer.apply_chat_template(conversation, chat_template=template, **kwargs)

return super().apply_chat_template(conversation, chat_template=chat_template, **kwargs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, i think we can work with nested templates no? If not, I can check because it should be supported

@@ -0,0 +1,225 @@
import copy
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the class is same as https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/processing_qwen2_vl.py. If yes, we can delete the whole file and just map hyperclovax to load a Qwen2VLProcessor in auto mappings

If no, we need to move this in modular and let it copy identical code blocks,

Comment on lines +417 to +421
"vlm": [
WeightRenaming("mm_projector", "projector"),
WeightRenaming("language_model.model", "language_model"),
WeightRenaming("language_model.lm_head", "lm_head"),
],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, no, we need to delete this one. Also we don't have a model_type = vlm so it wont be a problem

Comment on lines -523 to +542
VLMS = ["detr"]
VLMS = [
"aria",
"ayavision",
"colpali",
"emu3",
"fuyu",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thsi was deleted on main, no need to revert :)

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, hyperclovax_vision, qwen2_5_vl

@zucchini-nlp
Copy link
Copy Markdown
Member

I want to link here another very related PR. A contrib wants to add the LM backbone and I realized that the text-only LM actually applies norms unlike the VLM. So imo the two PRs will be merged/reviewed in sequence

#44957

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants