diff --git a/multimodal/dashinfer_vlm/api_server/config.py b/multimodal/dashinfer_vlm/api_server/config.py index 09d2dd9c9..ba34ff821 100644 --- a/multimodal/dashinfer_vlm/api_server/config.py +++ b/multimodal/dashinfer_vlm/api_server/config.py @@ -58,7 +58,7 @@ def add_context_args(parser): choices=["cuda"], help="device (Default: cuda)", ) - group.add_argument("--max_length", type=int, default=32000, help="model max length") + group.add_argument("--max_length", type=int, default=128000, help="model max length") group.add_argument("--max_batch", type=int, default=128, help="max batch") group.add_argument( "--parallel_size", diff --git a/multimodal/dashinfer_vlm/api_server/server.py b/multimodal/dashinfer_vlm/api_server/server.py index 266a44ae7..288116dbb 100644 --- a/multimodal/dashinfer_vlm/api_server/server.py +++ b/multimodal/dashinfer_vlm/api_server/server.py @@ -117,12 +117,16 @@ def init(): context.set("tokenizer", tokenizer) cuda_devices = [str(i) for i in range(tensor_parallel_size)] compute_unit = "CUDA:" + ",".join(cuda_devices) + if hasattr(config, "max_position_embeddings"): + engine_max_len = min(config.max_position_embeddings, context.get("max_length")) + else: + engine_max_len = context.get("max_length") # allspark model config as_model_config = allspark.AsModelConfig( model_name=model_name, model_path=as_graph_path, weights_path=as_weight_path, - engine_max_length=context.get("max_length"), + engine_max_length=engine_max_len, engine_max_batch=context.get("max_batch"), compute_unit=compute_unit, enable_prefix_cache=True if context.get("enable_prefix_cache") else False, diff --git a/multimodal/dashinfer_vlm/vl_inference/runtime/qwen_vl.py b/multimodal/dashinfer_vlm/vl_inference/runtime/qwen_vl.py index e7c6b8fc9..1071b1b34 100644 --- a/multimodal/dashinfer_vlm/vl_inference/runtime/qwen_vl.py +++ b/multimodal/dashinfer_vlm/vl_inference/runtime/qwen_vl.py @@ -267,8 +267,8 @@ def __init__( "QWEN2-AL": "audio_attention_mask", "GUMMY-AL": "audio_attention_mask", } - self.max_input_len = int(getenv("DS_LLM_MAX_IN_TOKENS", "20000")) - self.max_total_len = int(getenv("DS_LLM_MAX_TOKENS", "32000")) + self.max_total_len = as_config.engine_max_length + self.max_input_len = self.max_total_len - 1 """ 支持input_tokens格式: