Skip to content

feat: support audio input for minicpm-o-4_5#9147

Merged
Jintao-Huang merged 3 commits into
modelscope:mainfrom
fanqiNO1:minicpmo45_audio
Apr 20, 2026
Merged

feat: support audio input for minicpm-o-4_5#9147
Jintao-Huang merged 3 commits into
modelscope:mainfrom
fanqiNO1:minicpmo45_audio

Conversation

@fanqiNO1
Copy link
Copy Markdown
Contributor

@fanqiNO1 fanqiNO1 commented Apr 19, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Support audio input for MiniCPM-O-4_5 with the help of Qoder, which reads modeling_minicpmo.py and processing_minicpmo.py and implement the MiniCPMO4_5Template.

Experiment results

I verified that the implementation is consistent with that of the official provided script. The script is as follows:

image
script
import json
from copy import deepcopy

import librosa
import numpy as np
import torch
from loguru import logger
from minicpmo.utils import get_video_frame_audio_segments
from PIL import Image
from swift.model import get_processor
from swift.template import get_template
from transformers import AutoProcessor


video_path = "MiniCPM-O-4_5/assets/Skiing.mp4"
audio_path = "MiniCPM-O-4_5/assets/bajie.wav"
user_prompt = "Please describe the video content and the audio content."


def minicpmo4_5_official_inputs():
    video_frames, _, _ = get_video_frame_audio_segments(video_path)
    audio_input, _ = librosa.load(audio_path, sr=16000, mono=True)
    msgs = [
        {"role": "user", "content": video_frames + [user_prompt, audio_input, user_prompt]}
    ]

    processor = AutoProcessor.from_pretrained("OpenBMB/MiniCPM-O-4_5", trust_remote_code=True)
    tokenizer = processor.tokenizer
    # modified from modeling_minicpmo.MiniCPMO.chat

    batched = isinstance(msgs[0], list)
    msgs_list = msgs
    images_list = None

    if not batched:
        images_list, msgs_list = [images_list], [msgs_list]
    else:
        assert images_list is None, "Please integrate image to msgs when using batch inference."
        images_list = [None] * len(msgs_list)
    assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same."

    prompts_lists = []
    input_images_list = []
    input_audios_list = []
    audio_parts_list = []

    for image, msgs in zip(images_list, msgs_list):
        if isinstance(msgs, str):
            msgs = json.loads(msgs)
        copy_msgs = deepcopy(msgs)

        assert len(msgs) > 0, "msgs is empty"

        if image is not None and isinstance(copy_msgs[0]["content"], str):
            copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]

        images = []
        audios = []
        audio_parts = []
        for i, msg in enumerate(copy_msgs):
            role = msg["role"]
            content = msg["content"]
            assert role in ["system", "user", "assistant"]
            if i == 0:
                assert role in ["user", "system"], "The role of first msg should be user"
            # Normalize structured content (OpenAI format) to native format
            # content = normalize_content(content)
            cur_msgs = []
            for c in content:
                if isinstance(c, Image.Image):
                    images.append(c)
                    cur_msgs.append("<image>./</image>")
                elif isinstance(c, np.ndarray):  # audio
                    audios.append(c)
                    audio_parts.append(i)
                    cur_msgs.append("<audio>./</audio>")
                    use_tts_template = False  # modified to False
                elif isinstance(c, str):
                    cur_msgs.append(c)

            msg["content"] = "\n".join(cur_msgs)

        prompts_lists.append(
            processor.tokenizer.apply_chat_template(
                copy_msgs,
                tokenize=False,
                add_generation_prompt=True,
                use_tts_template=False,  # modified to False
                enable_thinking=False,
            )
        )
        input_images_list.append(images)
        input_audios_list.append(audios)
        audio_parts_list.append(audio_parts)

    inputs = processor(
        prompts_lists,
        input_images_list,
        input_audios_list,
        audio_parts_list,
        max_slice_nums=1,
        use_image_id=False,
        stream_input=False,
        return_tensors="pt",
        max_length=8192,
    )
    inputs.pop("image_sizes")
    # input_string = processor.tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=False)

    def safe_decode(input_ids, tokenizer):
        tokens = []
        count = 0
        current_token_id = None
        for input_id in input_ids:
            result = tokenizer.decode(input_id, skip_special_tokens=False)
            if result == "<unk>":
                count += 1
                current_token_id = input_id
            else:
                if count > 0:
                    tokens.append(f"[{current_token_id} * {count}]")
                count = 0
                current_token_id = None
                tokens.append(result)
        if count > 0:
            tokens.append(f"[{current_token_id} * {count}]")
        return "".join(tokens)

    input_string = safe_decode(inputs["input_ids"][0], tokenizer)
    logger.info("Official input_ids length: {}", len(inputs["input_ids"][0]))
    logger.info("Official Decoded input string: {}", input_string)
    return inputs


def minicpmo4_5_swift_inputs():
    processor = get_processor("OpenBMB/MiniCPM-O-4_5")
    template = get_template(processor, enable_thinking=False)

    inputs = {
        "messages": [
            {"role": "user", "content": f"<video>{user_prompt}\n<audio>\n{user_prompt}"}
        ],
        "videos": [video_path],
        "audios": [audio_path],
    }

    inputs = template.encode(inputs)
    input_string = template.safe_decode(inputs["input_ids"])
    logger.info("Swift input_ids length: {}", len(inputs["input_ids"]))
    logger.info("Swift Decoded input string: {}", input_string)

    return inputs


def is_equal(value1, value2, name):
    if isinstance(value1, list) and isinstance(value2, list):
        if len(value1) != len(value2):
            return False
        for v1, v2 in zip(value1, value2):
            if not is_equal(v1, v2, name):
                return False
        return True
    elif isinstance(value1, torch.Tensor) and isinstance(value2, torch.Tensor):
        if value1.shape != value2.shape:
            logger.warning(f"Tensor {name} shapes differ: {value1.shape} vs {value2.shape}")
            return False
        if value1.dtype != value2.dtype:
            logger.warning(f"Tensor {name} dtypes differ: {value1.dtype} vs {value2.dtype}")
            return False
        if not torch.equal(value1, value2):
            logger.warning(f"Tensor {name} values differ at some positions.")
            return False
        return True
    else:
        return value1 == value2


if __name__ == "__main__":
    official_inputs = minicpmo4_5_official_inputs()
    swift_inputs = minicpmo4_5_swift_inputs()

    for k in official_inputs.keys():
        if k not in swift_inputs:
            logger.warning(f"Key {k} not in swift inputs")
            continue
        official_v = official_inputs[k]
        swift_v = swift_inputs[k]

        if k == "input_ids":
            official_v = official_v[0].tolist()
        elif k == "pixel_values":
            for i in range(len(official_v[0])):
                official_v[0][i] = official_v[0][i].to(torch.bfloat16)

        if not is_equal(official_v, swift_v, name=k):
            logger.warning(f"Value for key {k} is different between official and swift inputs")
        else:
            logger.info(f"Value for key {k} is the same between official and swift inputs")

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the MiniCPMO4_5Template class, which extends MiniCPMV4_5Template to support audio and video processing. Key additions include audio feature extraction using Whisper, interleaved image and audio placeholder handling for video inputs, and specialized encoding and data collation logic for multimodal data. The review feedback identifies that temporal_ids are missing from both the _encode output and the _data_collator gathering process, which are necessary for the model's vision processing.

'loss_scale': loss_scale,
'image_bound': image_bound,
'pixel_values': image_inputs['pixel_values'],
'tgt_sizes': image_inputs['tgt_sizes'],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The encoded dictionary is missing the temporal_ids key, which is present in the parent MiniCPMV4_5Template._encode implementation. This key is required for correct video and image slice processing in the model.

Suggested change
'tgt_sizes': image_inputs['tgt_sizes'],
'tgt_sizes': image_inputs['tgt_sizes'],
'temporal_ids': image_inputs.get('temporal_ids'),

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, there is no temporal_ids for MiniCPM-O-4_5.

image

Comment on lines +559 to +560
for k in ['pixel_values', 'image_bound', 'tgt_sizes']:
res[k] = self.gather_list(batch, k)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _data_collator should also gather temporal_ids from the batch, as they are part of the vision data for this model architecture. This ensures consistency with the parent MiniCPMV4_5Template implementation.

Suggested change
for k in ['pixel_values', 'image_bound', 'tgt_sizes']:
res[k] = self.gather_list(batch, k)
# Vision data
for k in ['pixel_values', 'image_bound', 'tgt_sizes', 'temporal_ids']:
res[k] = self.gather_list(batch, k)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

Thanks for the PR. Please take a look at Gemini's review suggestions first and see if any code changes are needed.
I will review it as soon as possible.

inputs.audios[index] = load_audio(inputs.audios[index], sampling_rate=self.SAMPLING_RATE)
return ['<|audio_start|><|audio_end|>']
elif media_type == 'video':
from minicpmo.utils import get_video_frame_audio_segments
Copy link
Copy Markdown
Contributor Author

@fanqiNO1 fanqiNO1 Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是直接 pip install minicpmo-utils 安装的,文档已补充

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

thanks! LGTM

@Jintao-Huang Jintao-Huang merged commit 2261023 into modelscope:main Apr 20, 2026
3 checks passed
@Jintao-Huang Jintao-Huang mentioned this pull request May 6, 2026
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants