In [None]:
# SPDX-License-Identifier: Apache-2.0
"""
此示例展示了如何使用 vLLM 运行 Qwen/Qwen2.5-VL-3B-Instruct 模型进行离线推断，
并使用正确的提示格式进行文本生成。
"""
import os
from contextlib import contextmanager
from dataclasses import asdict
from typing import NamedTuple, Optional

# from transformers import AutoTokenizer # Qwen2.5-VL 的 run_qwen2_5_vl 中不直接使用

from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
# from vllm.assets.video import VideoAsset # 如果要测试视频，取消注释
from vllm.lora.request import LoRARequest # ModelRequestData 中需要，但此模型默认为 None
from vllm.utils import FlexibleArgumentParser


class ModelRequestData(NamedTuple):
    engine_args: EngineArgs
    prompts: list[str]
    stop_token_ids: Optional[list[int]] = None
    lora_requests: Optional[list[LoRARequest]] = None


# Qwen2.5-VL
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:

    model_name = "Qwen/Qwen2.5-VL-3B-Instruct"

    engine_args = EngineArgs(
        model=model_name,
        max_model_len=4096,
        max_num_seqs=5,  # 根据你的 GPU 调整
        mm_processor_kwargs={
            "min_pixels": 28 * 28,
            "max_pixels": 1280 * 28 * 28,
            "fps": 1, # 对于图像，fps 通常不直接使用，但保持与原脚本一致
        },
        limit_mm_per_prompt={modality: 1},
        # trust_remote_code=True, # Qwen 模型通常需要，vLLM 会自动处理
    )

    if modality == "image":
        placeholder = "<|image_pad|>"
    elif modality == "video":
        placeholder = "<|video_pad|>"
    else:
        raise ValueError(f"Qwen2.5-VL 不支持的模态: {modality}")


    prompts = [
        ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
         f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
         f"{question}<|im_end|>\n"
         "<|im_start|>assistant\n") for question in questions
    ]
    
    # Qwen2.5-VL 通常使用其 tokenizer.eos_token_id 作为停止符，
    # 但 vLLM 的 SamplingParams 也可以通过 stop=['<|im_end|>'] 等字符串设置
    # 这里我们不显式设置 stop_token_ids，让模型自然停止或依赖 max_tokens
    # 如果需要特定停止符，可以像原始脚本中其他模型那样加载 tokenizer 并获取 ids
    stop_token_ids = None # 例如： [tokenizer.eos_token_id] 或特定 ids

    return ModelRequestData(
        engine_args=engine_args,
        prompts=prompts,
        stop_token_ids=stop_token_ids,
    )


def get_multi_modal_input(modality: str, num_frames: int = 16):
    """
    返回:
    {
        "data": image or video,
        "questions": question list,
    }
    """
    if modality == "image":
        # 输入图像和问题
        image = ImageAsset("cherry_blossom") \
            .pil_image.convert("RGB")
        img_questions = [
            "What is the content of this image?",
            "Describe the content of this image in detail.",
            "这张图片里有什么？请用中文回答。",
        ]

        return {
            "data": image,
            "questions": img_questions,
        }

    # if modality == "video": # 如果要测试视频，取消注释并提供视频资源
    #     # 输入视频和问题
    #     video = VideoAsset(name="baby_reading",
    #                        num_frames=num_frames).np_ndarrays
    #     vid_questions = ["Why is this video funny?"]
    #
    #     return {
    #         "data": video,
    #         "questions": vid_questions,
    #     }

    msg = f"当前测试脚本不支持的模态 {modality} 或未实现。"
    raise ValueError(msg)


@contextmanager
def time_counter(enable: bool):
    if enable:
        import time
        start_time = time.time()
        yield
        elapsed_time = time.time() - start_time
        print("-" * 50)
        print(f"-- generate time = {elapsed_time:.2f} 秒")
        print("-" * 50)
    else:
        yield


def main_test_qwen2_5_vl():
    # ---- 可配置参数 ----
    target_model_func = run_qwen2_5_vl
    current_modality = "image" # "image" 或 "video"
    num_prompts_to_run = 2     # 你想运行多少个提示（会从 get_multi_modal_input 的 questions 中选取）
    vllm_seed = 42             # vLLM 引擎的随机种子
    time_generation = True     # 是否打印生成时间
    # ---------------------

    print(f"开始测试模型: Qwen/Qwen2.5-VL-3B-Instruct, 模态: {current_modality}")

    mm_input = get_multi_modal_input(modality=current_modality)
    data = mm_input["data"]
    # 选择部分或全部问题
    questions = mm_input["questions"][:num_prompts_to_run] 
    if not questions:
        print("没有提供问题，测试中止。")
        return

    req_data = target_model_func(questions, current_modality)

    # 禁用其他模态以节省内存 (从原始脚本借鉴)
    default_limits = {"image": 0, "video": 0, "audio": 0}
    engine_args_dict = asdict(req_data.engine_args)
    engine_args_dict["limit_mm_per_prompt"] = default_limits | dict(
        req_data.engine_args.limit_mm_per_prompt or {}
    )
    
    # 合并其他参数
    engine_args_final = engine_args_dict | {
        "seed": vllm_seed,
        "disable_mm_preprocessor_cache": False, # 你可以设为 True 来测试禁用缓存的效果
        "trust_remote_code": True, # 对于很多 HuggingFace 模型是必需的
    }
    
    print(f"初始化 LLM 引擎，参数: {engine_args_final}")
    llm = LLM(**engine_args_final)

    # 我们设置 temperature 为 0.2 以便在批量推断时即使提示相同也能得到略微不同的输出。
    # 对于 Qwen2.5-VL，你可能也想通过 stop 参数指定 '<|im_end|>' 等
    sampling_params = SamplingParams(
        temperature=0.2,
        max_tokens=256, # 根据需要调整
        stop_token_ids=req_data.stop_token_ids,
        # stop=["<|im_end|>"] # 也可以用字符串形式的停止符
    )
    
    print(f"采样参数: {sampling_params}")

    inputs_for_llm = []
    for i in range(len(questions)):
        inputs_for_llm.append({
            "prompt": req_data.prompts[i],
            "multi_modal_data": {
                current_modality: data
            }
        })

    # LoRA 请求（如果适用）
    # Qwen2.5-VL 的 req_data.lora_requests 默认为 None
    lora_request_list = (req_data.lora_requests * len(inputs_for_llm)
                         if req_data.lora_requests else None)

    print(f"\n开始生成 {len(inputs_for_llm)} 个回复...")
    with time_counter(time_generation):
        outputs = llm.generate(
            inputs_for_llm,
            sampling_params=sampling_params,
            lora_request=lora_request_list,
        )

    print("-" * 50)
    print("生成结果:")
    for i, o in enumerate(outputs):
        original_prompt_question = questions[i]
        # 构建一个简化的输入提示用于显示
        display_prompt = req_data.prompts[i].split("<|im_start|>assistant")[0] + "<|im_start|>assistant\n"
        
        print(f"\n--- 问题 {i+1} ---")
        print(f"问题文本: {original_prompt_question}")
        # print(f"完整提示 (部分): {display_prompt}") # 完整提示可能很长
        generated_text = o.outputs[0].text
        print(f"模型回复: {generated_text.strip()}")
        print("-" * 20)


if __name__ == "__main__":
    main_test_qwen2_5_vl()