Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions swift/llm/template/template/minicpm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional

import torch
from torch import nn
Expand All @@ -12,7 +12,7 @@
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, Prompt, findall
from ..vision_utils import load_video_minicpmv_mplug_owl3
from ..vision_utils import load_video_minicpmv_4_5, load_video_minicpmv_mplug_owl3
from .llama import Llama3TemplateMeta
from .qwen import Qwen2_5TemplateMeta, Qwen3Template, QwenTemplateMeta
from .utils import ChatmlTemplateMeta
Expand Down Expand Up @@ -244,6 +244,41 @@ def _get_new_tokens(i):

class MiniCPMV4_5Template(MiniCPMV2_6Template, Qwen3Template):

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
inputs: StdTemplateInputs) -> List[Context]:
assert media_type in {'image', 'video'}
max_num_frames = get_env_args('max_num_frames', int, 64)
max_num_packing = get_env_args('max_num_packing', int, 3)
choose_fps = get_env_args('choose_fps', int, 5)
time_scale = get_env_args('time_scale', float, 0.1)
load_video = partial(
load_video_minicpmv_4_5,
max_num_frames=max_num_frames,
max_num_packing=max_num_packing,
choose_fps=choose_fps,
time_scale=time_scale)
image_context = super().replace_tag('image', index, inputs)
if media_type == 'image':
return image_context
elif media_type == 'video':
return self.replace_video2image(load_video, inputs, lambda i: image_context)

def replace_video2image(self, load_video_func, inputs, replace_tag: Callable) -> List[Context]:
context_list = []
if self.mode in {'vllm', 'lmdeploy'}:
video = inputs.videos.pop(inputs.video_idx)
inputs.video_idx -= 1
else:
video = inputs.videos[inputs.video_idx]
images = inputs.images
new_images, temporal_ids = load_video_func(video) # change here
inputs.images = images[:inputs.image_idx] + new_images + images[inputs.image_idx:]
for i in range(len(new_images)):
context_list += replace_tag(i)
inputs.image_idx += len(new_images)
inputs.extra_kwargs['temporal_ids'] = temporal_ids # change here
return context_list

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded = Template._encode(self, inputs)
images = inputs.images
Expand All @@ -259,8 +294,11 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
idx_list = findall(input_ids, -100)

image_processor = self.processor.image_processor
image_inputs = image_processor([images], return_tensors='pt',
max_slice_nums=max_slice_nums).to(self.model_info.torch_dtype)
processor_kwargs = {}
if 'temporal_ids' in inputs.extra_kwargs:
processor_kwargs['temporal_ids'] = inputs.extra_kwargs['temporal_ids']
image_inputs = image_processor([images], return_tensors='pt', max_slice_nums=max_slice_nums,
**processor_kwargs).to(self.model_info.torch_dtype)

def _get_new_tokens(i):
placeholder = image_processor.get_slice_image_placeholder(
Expand Down
62 changes: 62 additions & 0 deletions swift/llm/template/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,68 @@ def uniform_sample(_l, _n):
return frames


def load_video_minicpmv_4_5(video: Union[str, bytes],
max_num_frames=64,
max_num_packing=3,
choose_fps=3,
time_scale=0.1,
force_packing=None):

from decord import VideoReader, cpu # pip install decord
from scipy.spatial import cKDTree

def uniform_sample(seq, n):
gap = len(seq) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [seq[i] for i in idxs]

def map_to_nearest_scale(values, scale):
tree = cKDTree(np.asarray(scale)[:, None])
_, indices = tree.query(np.asarray(values)[:, None])
return np.asarray(scale)[indices]

def group_array(arr, size):
return [arr[i:i + size] for i in range(0, len(arr), size)]

video_io = load_file(video)
vr = VideoReader(video_io, ctx=cpu(0))
fps = vr.get_avg_fps()
video_duration = len(vr) / fps

if choose_fps * int(video_duration) <= max_num_frames:
packing_nums = 1
choose_frames = round(min(choose_fps, round(fps)) * min(max_num_frames, video_duration))

else:
packing_nums = math.ceil(video_duration * choose_fps / max_num_frames)
if packing_nums <= max_num_packing:
choose_frames = round(video_duration * choose_fps)
else:
choose_frames = round(max_num_frames * max_num_packing)
packing_nums = max_num_packing

frame_idx = [i for i in range(0, len(vr))]
frame_idx = np.array(uniform_sample(frame_idx, choose_frames))

if force_packing:
packing_nums = min(force_packing, max_num_packing)

frames = vr.get_batch(frame_idx).asnumpy()

frame_idx_ts = frame_idx / fps
scale = np.arange(0, video_duration, time_scale)

frame_ts_id = map_to_nearest_scale(frame_idx_ts, scale) / time_scale
frame_ts_id = frame_ts_id.astype(np.int32)

assert len(frames) == len(frame_ts_id)

frames = [Image.fromarray(v.astype('uint8')).convert('RGB') for v in frames]
frame_ts_id_group = group_array(frame_ts_id, packing_nums)

return frames, frame_ts_id_group


def load_audio(audio: Union[str, bytes], sampling_rate: int, return_sr: bool = False):
import librosa
audio_io = load_file(audio)
Expand Down
Loading