Skip to content

[Feature] Add Molmo2 support (image + video inference, LoRA SFT)#9063

Merged
Jintao-Huang merged 8 commits into
modelscope:mainfrom
Kagura-0001:codex/add-molmo2-support-pr
Apr 19, 2026
Merged

[Feature] Add Molmo2 support (image + video inference, LoRA SFT)#9063
Jintao-Huang merged 8 commits into
modelscope:mainfrom
Kagura-0001:codex/add-molmo2-support-pr

Conversation

@Kagura-0001
Copy link
Copy Markdown
Contributor

@Kagura-0001 Kagura-0001 commented Apr 10, 2026

Summary

Add support for Molmo2 (4B/8B/O-7B) multimodal models with image and video understanding capabilities.

Changes (6 files, +153/-395)

File Change
swift/template/templates/molmo.py +117 — Add Molmo2Template class
swift/model/models/mllm.py +36 — Add Molmo2Loader class + register_model
swift/template/templates/__init__.py Remove molmo2 import (merged into molmo)
swift/model/models/__init__.py Remove molmo2 import (merged into mllm)
swift/template/templates/molmo2.py DELETED — Merged into molmo.py
swift/model/models/molmo2.py DELETED — Merged into mllm.py

Architecture

Template (Molmo2Template):

  • ChatML format via ChatmlTemplateMeta (auto_add_bos=True)
  • replace_tag<|image|> / <|video|> placeholders
  • _prepare_mm_inputsprocessor.image_processor for images, processor.video_processor for videos
  • _encode → standard super()._encode() + _extend_tokens pattern
  • _build_token_type_ids → marks image tokens as 1
  • _data_collator_mm_dataconcat_tensor for image_grids, video_grids, image_token_pooling

Model (Molmo2Loader):

  • AutoModelForImageTextToText as auto_model_cls
  • get_class_from_dynamic_module for trust_remote_code (Molmo2 not yet upstream in transformers)
  • _no_split_modules += ["MolmoSequentialBlock"]
  • patch_output_clone on wte embedding
  • Requires: transformers>=4.57.1,<5

Supported Models:

  • allenai/Molmo2-4B (LLM-Research/Molmo2-4B)
  • allenai/Molmo2-8B (LLM-Research/Molmo2-8B)
  • allenai/Molmo2-O-7B (LLM-Research/Molmo2-O-7B)

Tests

All on H100 80GB, transformers==4.57.6, torch==2.8.0+cu128:

Test Status Details
Image inference (TransformersEngine) Correct image description
Video inference (TransformersEngine) Correct video description
LoRA SFT training 4 steps, loss 2.0→1.76, 14.3M params (0.29%), 22.96 GiB VRAM

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Molmo2 model family, including registration of 4B, 8B, and O-7B variants, along with a dedicated template for image and video understanding. Key additions include the Molmo2Loader with compatibility patches for transformers and vision attention, and the Molmo2Template for handling multi-modal inputs. Feedback focuses on correcting a version requirement typo, improving the robustness of module splitting logic, preventing potential division-by-zero errors in FPS calculation, and replacing assertions with explicit value errors for input validation.

Comment thread swift/model/models/molmo2.py Outdated
model_arch=ModelArch.molmo,
architectures=['Molmo2ForConditionalGeneration'],
tags=['vision', 'video'],
requires=['transformers>=4.57.1', 'decord'],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The version requirement transformers>=4.57.1 appears to be a typo, as this version does not exist yet (the current stable version is around 4.48). Molmo models typically require transformers>=4.45.0.

Suggested change
requires=['transformers>=4.57.1', 'decord'],
requires=['transformers>=4.45.0', 'decord'],

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept transformers>=4.57.1 here. 4.57.1 is a real released version now, and the local Molmo2 smoke validation for this PR succeeded with transformers==4.57.3. I would prefer to keep the newer minimum for the current Molmo2 processor/runtime path rather than relax it to 4.45.0 without additional compatibility coverage.

Comment thread swift/model/models/molmo2.py Outdated
def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
from transformers import AutoModelForImageTextToText
model_cls = get_class_from_dynamic_module('modeling_molmo2.Molmo2ForConditionalGeneration', model_dir)
model_cls._no_split_modules = getattr(model_cls, '_no_split_modules', []) or ['MolmoSequentialBlock']
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for setting _no_split_modules will overwrite the default list if it's empty, but it won't append to it if it already contains other modules. It's safer to ensure MolmoSequentialBlock is included in the list without discarding existing entries.

Suggested change
model_cls._no_split_modules = getattr(model_cls, '_no_split_modules', []) or ['MolmoSequentialBlock']
no_split_modules = getattr(model_cls, '_no_split_modules', []) or []
if 'MolmoSequentialBlock' not in no_split_modules:
model_cls._no_split_modules = no_split_modules + ['MolmoSequentialBlock']

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in commit 6eaf502: I now preserve existing _no_split_modules entries and only append MolmoSequentialBlock when it is missing.

Comment thread swift/template/templates/molmo2.py Outdated
if len(timestamps) <= 1:
fps = 1.0
else:
fps = 1.0 / float(np.median(np.diff(timestamps)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calculating fps by dividing by the median of timestamp differences can lead to a ZeroDivisionError if the input timestamps are identical or if the median difference is zero. It's safer to add a check or a fallback value.

Suggested change
fps = 1.0 / float(np.median(np.diff(timestamps)))
median_diff = np.median(np.diff(timestamps))
fps = 1.0 / float(median_diff) if median_diff > 0 else 1.0

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in commit 6eaf502: I added a guard for zero/non-positive median timestamp deltas and fall back to fps = 1.0 in that case.

Comment thread swift/template/templates/molmo2.py Outdated
media_inputs.update(image_inputs)

if inputs.videos:
assert len(inputs.videos) == 1, 'Molmo2 currently only supports single-video samples.'
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert for input validation is discouraged as it can be optimized away in production (when running with -O). It is better to raise a ValueError to provide a clear error message to the user.

Suggested change
assert len(inputs.videos) == 1, 'Molmo2 currently only supports single-video samples.'
if len(inputs.videos) != 1:
raise ValueError('Molmo2 currently only supports single-video samples.')

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in commit 6eaf502: I replaced the assert with an explicit ValueError so the validation is preserved in optimized runs as well.

@Tohrusky
Copy link
Copy Markdown
Contributor

\gemini review

@liandanlao
Copy link
Copy Markdown

I cannot wait to try it.

@Kagura-0001 Kagura-0001 force-pushed the codex/add-molmo2-support-pr branch from 6eaf502 to af896cb Compare April 16, 2026 04:05
…ransformers<5

Changes:
- Merge Molmo2Template into swift/template/templates/molmo.py (was standalone molmo2.py)
- Merge Molmo2Loader into swift/model/models/mllm.py (was standalone molmo2.py)
- Delete standalone swift/template/templates/molmo2.py and swift/model/models/molmo2.py
- Remove __init__.py imports of deleted molmo2 modules
- Remove all transformers>=5.0 compatibility patches:
  - ProcessorMixin.__init__ monkey-patch for unknown kwargs
  - ROPE_INIT_FUNCTIONS['default'] injection
  - prepare_inputs_for_generation cache_position fix
- Pin version to transformers>=4.57.1,<5 (verified working)
- Simplify Molmo2Loader from ~80 lines to ~10 lines

Template (Molmo2Template):
- ChatML format via ChatmlTemplateMeta with auto_add_bos=True
- replace_tag returns '<|image|>' / '<|video|>' placeholders
- _prepare_mm_inputs: processor.image_processor for images,
  processor.video_processor for videos (with return_metadata=True)
- _encode: super()._encode() + _extend_tokens to expand placeholders
- _build_token_type_ids: marks image_token_ids as 1
- _data_collator_mm_data: handles image_grids, video_grids, etc.

Model (Molmo2Loader):
- AutoModelForImageTextToText as auto_model_cls
- get_class_from_dynamic_module for trust_remote_code
- _no_split_modules += MolmoSequentialBlock
- patch_output_clone on wte embedding

Tested on 4x H100 with transformers==4.57.6, torch==2.8.0:
- Image inference (TransformersEngine): PASS
- Video inference (TransformersEngine): PASS
- LoRA SFT training (4 steps, rank=8): PASS
@Kagura-0001 Kagura-0001 changed the title [Feature] Add Molmo2 model and template support [Feature] Add Molmo2 support (image + video inference, LoRA SFT) Apr 16, 2026
@maxLWS
Copy link
Copy Markdown

maxLWS commented Apr 16, 2026

@gemini-code-assist review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Molmo2 model, including model registration, template implementation for image and video understanding, and unit tests. The changes involve adding molmo2 constants, a new Molmo2Loader, and a Molmo2Template class. Feedback highlights critical issues with placeholder expansion where tokenizer.encode is used instead of convert_tokens_to_ids, potential shape mismatches for token_type_ids during truncation, and redundant image loading logic in the template.

Comment thread swift/template/templates/molmo.py Outdated
Comment on lines +146 to +147
image_placeholder = self.tokenizer.encode('<|image|>', add_special_tokens=False)
idx_list = findall(input_ids, image_placeholder)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tokenizer.encode method returns a list of token IDs. Since findall expects a single integer as the second argument (as seen in the MolmoTemplate implementation at line 29), passing a list will cause idx_list to be empty, and the image placeholders will not be expanded. You should use convert_tokens_to_ids to get the single token ID for <|image|>. This is a critical correctness issue for multimodal inference.

Suggested change
image_placeholder = self.tokenizer.encode('<|image|>', add_special_tokens=False)
idx_list = findall(input_ids, image_placeholder)
image_placeholder = self.tokenizer.convert_tokens_to_ids('<|image|>')
idx_list = findall(input_ids, image_placeholder)

Comment thread swift/template/templates/molmo.py Outdated
Comment on lines +153 to +154
video_placeholder = self.tokenizer.encode('<|video|>', add_special_tokens=False)
idx_list = findall(input_ids, video_placeholder)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the image placeholder expansion, tokenizer.encode returns a list which is incompatible with findall expecting an integer. Use convert_tokens_to_ids for the <|video|> token to ensure the placeholder indices are correctly identified.

Suggested change
video_placeholder = self.tokenizer.encode('<|video|>', add_special_tokens=False)
idx_list = findall(input_ids, video_placeholder)
video_placeholder = self.tokenizer.convert_tokens_to_ids('<|video|>')
idx_list = findall(input_ids, video_placeholder)

encoded['input_ids'] = input_ids
encoded['labels'] = labels
encoded['loss_scale'] = loss_scale
encoded['token_type_ids'] = self._build_token_type_ids(input_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Adding token_type_ids here may lead to a shape mismatch if truncation occurs. The Template._encode_truncated method in base.py only truncates input_ids, labels, and loss_scale. Any additional sequence-length keys added in _encode will retain their original length in the encoded dictionary, causing errors during batch collation or model forward pass when the sequence exceeds max_length. Consider calculating token_type_ids in the data collator or ensuring it is also truncated.

Comment thread swift/template/templates/molmo.py Outdated
Comment on lines +106 to +107
images = [img if isinstance(img, Image.Image) else Image.open(img).convert('RGB')
for img in inputs.images]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The inputs.images list is already pre-processed into PIL.Image objects by Template._preprocess_inputs (which calls _load_image) before _encode is executed. Since load_images defaults to True in the Template class, inputs.images should already contain loaded images. This explicit check and Image.open call are redundant.

            images = inputs.images

@Kagura-0001 Kagura-0001 marked this pull request as ready for review April 16, 2026 11:37
@Kagura-0001
Copy link
Copy Markdown
Contributor Author

@Tohrusky
Copy link
Copy Markdown
Contributor

Tohrusky commented Apr 16, 2026

@hjh0119

Comment thread swift/template/templates/molmo.py Outdated
return ['<|image|>']
if media_type == 'video':
return ['<|video|>']
return []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if / elif / else: raise ValueError

@hjh0119
Copy link
Copy Markdown
Collaborator

hjh0119 commented Apr 16, 2026

thanks for your pr )

plz pass the lint test

@hjh0119 hjh0119 requested a review from Jintao-Huang April 16, 2026 14:30
Comment thread swift/model/models/mllm.py Outdated
MLLMModelType.molmo2,
[
ModelGroup([
Model('LLM-Research/Molmo2-4B', 'allenai/Molmo2-4B'),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLM-Research/Molmo2-4B This model cannot be found.

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the Molmo2 model family, including the 4B, 8B, and O-7B variants. It implements the Molmo2Loader and Molmo2Template to handle image and video processing within the ChatML format. Feedback points out that video_metadata is missing from the media inputs and suggests including video tokens in the token_type_ids generation.

Comment on lines +121 to +128
video_metadata = video_inputs.pop('video_metadata')
for video_grid, metadata in zip(video_inputs['video_grids'], video_metadata):
video_string = self.processor.get_video_string(
video_grid.cpu().numpy(),
np.asarray(metadata.timestamps, dtype=np.float32),
)
video_expansions.append(tokenizer.encode(video_string, add_special_tokens=False))
media_inputs.update(video_inputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The video_metadata is popped from video_inputs but never added to media_inputs. This will cause video_metadata to be missing from the encoded batch, which might lead to issues if the model or collator expects it. Although the collator attempts to gather it, it won't find anything if it's not in the batch.

Suggested change
video_metadata = video_inputs.pop('video_metadata')
for video_grid, metadata in zip(video_inputs['video_grids'], video_metadata):
video_string = self.processor.get_video_string(
video_grid.cpu().numpy(),
np.asarray(metadata.timestamps, dtype=np.float32),
)
video_expansions.append(tokenizer.encode(video_string, add_special_tokens=False))
media_inputs.update(video_inputs)
video_metadata = video_inputs.pop('video_metadata')
for video_grid, metadata in zip(video_inputs['video_grids'], video_metadata):
video_string = self.processor.get_video_string(
video_grid.cpu().numpy(),
np.asarray(metadata.timestamps, dtype=np.float32),
)
video_expansions.append(tokenizer.encode(video_string, add_special_tokens=False))
media_inputs.update(video_inputs)
media_inputs['video_metadata'] = video_metadata

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Investigated the official Molmo2 source: processing_molmo2.py consumes video_metadata inside the processor to build the timestamped video_string, but modeling_molmo2.py never accepts or uses video_metadata as a model input. So video_metadata is processor-side metadata, not something forwarded into the model. Our template follows the same behavior. I removed the dead collator code that was trying to gather/pass it (commit 2cbe18382).

Comment on lines +132 to +134
def _build_token_type_ids(self, input_ids: List[int]) -> List[int]:
image_token_ids = {int(token_id) for token_id in getattr(self.processor, 'image_token_ids', [])}
return [1 if token_id in image_token_ids else 0 for token_id in input_ids]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _build_token_type_ids method currently only marks image tokens as 1. For multimodal models supporting video, it is often necessary to also mark video tokens (e.g., as 1 or 2 depending on the model's requirements). Please check if video tokens should also be included in the token_type_ids mask.

Suggested change
def _build_token_type_ids(self, input_ids: List[int]) -> List[int]:
image_token_ids = {int(token_id) for token_id in getattr(self.processor, 'image_token_ids', [])}
return [1 if token_id in image_token_ids else 0 for token_id in input_ids]
def _build_token_type_ids(self, input_ids: List[int]) -> List[int]:
image_token_ids = {int(token_id) for token_id in getattr(self.processor, 'image_token_ids', [])}
video_token_ids = {int(token_id) for token_id in getattr(self.processor, 'video_token_ids', [])}
media_token_ids = image_token_ids | video_token_ids
return [1 if token_id in media_token_ids else 0 for token_id in input_ids]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Investigated the official Molmo2 source: the processor does not define a separate video_token_ids. In processing_molmo2.py, token_type_ids is built by checking membership in self.image_token_ids, and IMAGE_TOKENS already includes the video-related frame markers (FRAME_START_TOKEN, FRAME_END_TOKEN) as well as the patch/column tokens used in video strings. So our _build_token_type_ids is correct: checking image_token_ids already covers both image and video multimodal tokens.

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

Thanks, LGTM. Are there any other commits that need to be submitted?

@Kagura-0001
Copy link
Copy Markdown
Contributor Author

No further commits planned. All review feedback has been addressed. Ready to merge. Thanks!

@Jintao-Huang Jintao-Huang merged commit fb021df into modelscope:main Apr 19, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants