diff --git a/swift/llm/template/template/minicpm.py b/swift/llm/template/template/minicpm.py index 6a0d69d024..5b658f9058 100644 --- a/swift/llm/template/template/minicpm.py +++ b/swift/llm/template/template/minicpm.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from dataclasses import dataclass, field from functools import partial -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional import torch from torch import nn @@ -12,7 +12,7 @@ from ..register import TemplateMeta, register_template from ..template_inputs import StdTemplateInputs from ..utils import Context, Prompt, findall -from ..vision_utils import load_video_minicpmv_mplug_owl3 +from ..vision_utils import load_video_minicpmv_4_5, load_video_minicpmv_mplug_owl3 from .llama import Llama3TemplateMeta from .qwen import Qwen2_5TemplateMeta, Qwen3Template, QwenTemplateMeta from .utils import ChatmlTemplateMeta @@ -244,6 +244,41 @@ def _get_new_tokens(i): class MiniCPMV4_5Template(MiniCPMV2_6Template, Qwen3Template): + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, + inputs: StdTemplateInputs) -> List[Context]: + assert media_type in {'image', 'video'} + max_num_frames = get_env_args('max_num_frames', int, 64) + max_num_packing = get_env_args('max_num_packing', int, 3) + choose_fps = get_env_args('choose_fps', int, 5) + time_scale = get_env_args('time_scale', float, 0.1) + load_video = partial( + load_video_minicpmv_4_5, + max_num_frames=max_num_frames, + max_num_packing=max_num_packing, + choose_fps=choose_fps, + time_scale=time_scale) + image_context = super().replace_tag('image', index, inputs) + if media_type == 'image': + return image_context + elif media_type == 'video': + return self.replace_video2image(load_video, inputs, lambda i: image_context) + + def replace_video2image(self, load_video_func, inputs, replace_tag: Callable) -> List[Context]: + context_list = [] + if self.mode in {'vllm', 'lmdeploy'}: + video = inputs.videos.pop(inputs.video_idx) + inputs.video_idx -= 1 + else: + video = inputs.videos[inputs.video_idx] + images = inputs.images + new_images, temporal_ids = load_video_func(video) # change here + inputs.images = images[:inputs.image_idx] + new_images + images[inputs.image_idx:] + for i in range(len(new_images)): + context_list += replace_tag(i) + inputs.image_idx += len(new_images) + inputs.extra_kwargs['temporal_ids'] = temporal_ids # change here + return context_list + def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded = Template._encode(self, inputs) images = inputs.images @@ -259,8 +294,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: idx_list = findall(input_ids, -100) image_processor = self.processor.image_processor - image_inputs = image_processor([images], return_tensors='pt', - max_slice_nums=max_slice_nums).to(self.model_info.torch_dtype) + processor_kwargs = {} + if 'temporal_ids' in inputs.extra_kwargs: + processor_kwargs['temporal_ids'] = inputs.extra_kwargs['temporal_ids'] + image_inputs = image_processor([images], return_tensors='pt', max_slice_nums=max_slice_nums, + **processor_kwargs).to(self.model_info.torch_dtype) def _get_new_tokens(i): placeholder = image_processor.get_slice_image_placeholder( diff --git a/swift/llm/template/vision_utils.py b/swift/llm/template/vision_utils.py index e081a5c315..7284270933 100644 --- a/swift/llm/template/vision_utils.py +++ b/swift/llm/template/vision_utils.py @@ -266,6 +266,68 @@ def uniform_sample(_l, _n): return frames +def load_video_minicpmv_4_5(video: Union[str, bytes], + max_num_frames=64, + max_num_packing=3, + choose_fps=3, + time_scale=0.1, + force_packing=None): + + from decord import VideoReader, cpu # pip install decord + from scipy.spatial import cKDTree + + def uniform_sample(seq, n): + gap = len(seq) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [seq[i] for i in idxs] + + def map_to_nearest_scale(values, scale): + tree = cKDTree(np.asarray(scale)[:, None]) + _, indices = tree.query(np.asarray(values)[:, None]) + return np.asarray(scale)[indices] + + def group_array(arr, size): + return [arr[i:i + size] for i in range(0, len(arr), size)] + + video_io = load_file(video) + vr = VideoReader(video_io, ctx=cpu(0)) + fps = vr.get_avg_fps() + video_duration = len(vr) / fps + + if choose_fps * int(video_duration) <= max_num_frames: + packing_nums = 1 + choose_frames = round(min(choose_fps, round(fps)) * min(max_num_frames, video_duration)) + + else: + packing_nums = math.ceil(video_duration * choose_fps / max_num_frames) + if packing_nums <= max_num_packing: + choose_frames = round(video_duration * choose_fps) + else: + choose_frames = round(max_num_frames * max_num_packing) + packing_nums = max_num_packing + + frame_idx = [i for i in range(0, len(vr))] + frame_idx = np.array(uniform_sample(frame_idx, choose_frames)) + + if force_packing: + packing_nums = min(force_packing, max_num_packing) + + frames = vr.get_batch(frame_idx).asnumpy() + + frame_idx_ts = frame_idx / fps + scale = np.arange(0, video_duration, time_scale) + + frame_ts_id = map_to_nearest_scale(frame_idx_ts, scale) / time_scale + frame_ts_id = frame_ts_id.astype(np.int32) + + assert len(frames) == len(frame_ts_id) + + frames = [Image.fromarray(v.astype('uint8')).convert('RGB') for v in frames] + frame_ts_id_group = group_array(frame_ts_id, packing_nums) + + return frames, frame_ts_id_group + + def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False): import librosa audio_io = load_file(audio)