From 7a7b8799ae29fa76dfe91bc1e3c5f6ec41b9a545 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Wed, 10 Sep 2025 15:57:22 +0800 Subject: [PATCH] update patches Signed-off-by: gc-fu --- vllm/patches/vllm_for_multi_arc.patch | 1929 ++++++++++++++++++++++++- 1 file changed, 1924 insertions(+), 5 deletions(-) diff --git a/vllm/patches/vllm_for_multi_arc.patch b/vllm/patches/vllm_for_multi_arc.patch index e851c75..eebd051 100644 --- a/vllm/patches/vllm_for_multi_arc.patch +++ b/vllm/patches/vllm_for_multi_arc.patch @@ -8750,10 +8750,20 @@ index 000000000..2d8cd49ed + mm_kwargs=mm_kwargs) + vllm_model.apply_model(valid_func) diff --git a/tests/models/registry.py b/tests/models/registry.py -index 84ca0bc60..3e36c18af 100644 +index 84ca0bc60..6d32122a4 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py -@@ -373,7 +373,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { +@@ -267,6 +267,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { + "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), + "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), + "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), ++ "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 ++ trust_remote_code=True, ++ is_available_online=False), + "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 + "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), + "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), +@@ -373,7 +376,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { transformers_version_reason="HF model is not compatible."), # noqa: E501 "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", extras={"2B": "OpenGVLab/InternVL2-2B", @@ -8765,7 +8775,7 @@ index 84ca0bc60..3e36c18af 100644 trust_remote_code=True), "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 -@@ -397,7 +400,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { +@@ -397,7 +403,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", @@ -8814,6 +8824,471 @@ index 34b1b6c2e..4c8082646 100644 output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output +diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py +new file mode 100644 +index 000000000..d85bc9bbf +--- /dev/null ++++ b/tests/tool_use/test_seed_oss_tool_parser.py +@@ -0,0 +1,459 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++# ruff: noqa: E501 ++ ++import json ++from collections.abc import Generator ++from typing import Optional ++ ++import pytest ++ ++from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ++ ChatCompletionToolsParam, ++ DeltaMessage, FunctionCall, ++ ToolCall) ++from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser ++from vllm.transformers_utils.detokenizer import detokenize_incrementally ++from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer ++ ++# Use a common model that is likely to be available ++MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct" ++ ++ ++@pytest.fixture(scope="module") ++def seed_oss_tokenizer(): ++ return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) ++ ++ ++@pytest.fixture ++def seed_oss_tool_parser(seed_oss_tokenizer): ++ return SeedOssToolParser(seed_oss_tokenizer) ++ ++ ++@pytest.fixture ++def sample_tools(): ++ return [ ++ ChatCompletionToolsParam( ++ type="function", ++ function={ ++ "name": "get_weather", ++ "description": "Get current temperature for a given location.", ++ "parameters": { ++ "type": "object", ++ "properties": { ++ "location": { ++ "type": "string", ++ "description": ++ "City and country e.g. Bogotá, Colombia" ++ }, ++ "unit": { ++ "type": "string", ++ "description": "this is the unit of temperature" ++ } ++ }, ++ "required": ["location"], ++ "additionalProperties": False ++ }, ++ "returns": { ++ "type": "object", ++ "properties": { ++ "temperature": { ++ "type": "number", ++ "description": "temperature in celsius" ++ } ++ }, ++ "required": ["temperature"], ++ "additionalProperties": False ++ }, ++ "strict": True ++ }), ++ ] ++ ++ ++def assert_tool_calls(actual_tool_calls: list[ToolCall], ++ expected_tool_calls: list[ToolCall]): ++ assert len(actual_tool_calls) == len(expected_tool_calls) ++ ++ for actual_tool_call, expected_tool_call in zip(actual_tool_calls, ++ expected_tool_calls): ++ # Seed-OSS tool call will not generate id ++ assert actual_tool_call.type == "function" ++ assert actual_tool_call.function == expected_tool_call.function ++ ++ assert actual_tool_call.function.name == expected_tool_call.function.name ++ assert actual_tool_call.function.arguments == expected_tool_call.function.arguments ++ ++ ++def test_extract_tool_calls_no_tools(seed_oss_tool_parser): ++ model_output = "This is a test response without any tool calls" ++ extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( ++ model_output, request=None) # type: ignore[arg-type] ++ ++ assert not extracted_tool_calls.tools_called ++ assert extracted_tool_calls.tool_calls == [] ++ assert extracted_tool_calls.content == model_output ++ ++ ++@pytest.mark.parametrize( ++ ids=[ ++ "tool_call_0_thinking_budget", ++ "tool_call_512_thinkg_budget", ++ "tool_call_unlimited_thinking_budget", ++ ], ++ argnames=["model_output", "expected_tool_calls", "expected_content"], ++ argvalues=[ ++ ("""\n\n\n""" ++ """The current thinking budget is 0, so I will directly start answering the question.\n\n""" ++ """\n\n""" ++ """Barcelona, Spain\n\n""", ++ [ ++ ToolCall(function=FunctionCall( ++ name="get_weather", ++ arguments=json.dumps({ ++ "location": "Barcelona, Spain", ++ }, ), ++ ), ++ type='function') ++ ], ++ """\n\n\n""" ++ """The current thinking budget is 0, so I will directly start answering the question.\n\n""" ++ ), ++ ( ++ """The user\'s current thinking budget is 512.\nLet me analyze the """ ++ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ ++ """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ ++ """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ ++ """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ ++ """country). \nI have used 131 tokens, and there are 381 tokens remaining for use.""" ++ """\n Since the unit isn\'t specified, the function will default to Celsius, which """ ++ """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ ++ """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ ++ """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ ++ """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """ ++ """use.\n The unit parameter can be omitted since it\'s optional.\n""" ++ """\n\nBarcelona, Spain\n""" ++ """\n""", ++ [ ++ ToolCall(function=FunctionCall( ++ name="get_weather", ++ arguments=json.dumps({ ++ "location": "Barcelona, Spain", ++ }, ), ++ ), ++ type='function') ++ ], ++ """The user\'s current thinking budget is 512.\nLet me analyze the """ ++ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ ++ """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ ++ """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ ++ """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ ++ """country). \nI have used 131 tokens, and there are 381 tokens remaining for use.""" ++ """\n Since the unit isn\'t specified, the function will default to Celsius, which """ ++ """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ ++ """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ ++ """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ ++ """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """ ++ """use.\n The unit parameter can be omitted since it\'s optional.\n""", ++ ), ++ ( ++ """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ ++ """First, I need to remember the function I can use: get_weather. The function requires a """ ++ """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ ++ """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ ++ """let me check the function docstring again. Oh, the function says unit is optional, and """ ++ """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ ++ """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ ++ """The format is \n\nBarcelona, """ ++ """Spain\ncelsius\n\n. """ ++ """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ ++ """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ ++ """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ ++ """call should be as above. Then wait for the result to come back and tell the user the """ ++ """temperature in Celsius.\n\n""" ++ """Barcelona, Spain\ncelsius\n\n""", ++ [ ++ ToolCall(function=FunctionCall( ++ name="get_weather", ++ arguments=json.dumps( ++ { ++ "location": "Barcelona, Spain", ++ "unit": "celsius", ++ }, ), ++ ), ++ type='function') ++ ], ++ """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ ++ """First, I need to remember the function I can use: get_weather. The function requires a """ ++ """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ ++ """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ ++ """let me check the function docstring again. Oh, the function says unit is optional, and """ ++ """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ ++ """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ ++ """The format is \n\nBarcelona, """ ++ """Spain\ncelsius\n\n. """ ++ """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ ++ """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ ++ """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ ++ """call should be as above. Then wait for the result to come back and tell the user the """ ++ """temperature in Celsius.""", ++ ), ++ ], ++) ++def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output, ++ expected_tool_calls, expected_content): ++ request = ChatCompletionRequest(model=MODEL, ++ messages=[], ++ tools=sample_tools) ++ extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( ++ model_output, request=request) # type: ignore[arg-type] ++ assert extracted_tool_calls.tools_called ++ ++ assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) ++ ++ assert extracted_tool_calls.content == expected_content ++ ++ ++def test_streaming_tool_calls_no_tools(seed_oss_tool_parser): ++ model_output = "This is a test response without any tool calls" ++ ++ result = seed_oss_tool_parser.extract_tool_calls_streaming( ++ previous_text="his is a test response", ++ current_text=model_output, ++ delta_text=" without any tool calls.", ++ previous_token_ids=[], ++ current_token_ids=[], ++ delta_token_ids=[], ++ request=None, ++ ) ++ ++ # Should return the delta text as content ++ assert result is not None ++ assert hasattr(result, 'content') ++ assert result.content == " without any tool calls." ++ ++ ++def stream_delta_message_generator( ++ seed_oss_tool_parser: SeedOssToolParser, ++ seed_oss_tokenizer: AnyTokenizer, ++ model_output: str, ++ request: Optional[ChatCompletionRequest] = None ++) -> Generator[DeltaMessage, None, None]: ++ all_token_ids = seed_oss_tokenizer.encode(model_output, ++ add_special_tokens=False) ++ ++ previous_text = "" ++ previous_tokens = None ++ prefix_offset = 0 ++ read_offset = 0 ++ for i, delta_token in enumerate(all_token_ids): ++ delta_token_ids = [delta_token] ++ previous_token_ids = all_token_ids[:i] ++ current_token_ids = all_token_ids[:i + 1] ++ ++ (new_tokens, delta_text, new_prefix_offset, ++ new_read_offset) = detokenize_incrementally( ++ tokenizer=seed_oss_tokenizer, ++ all_input_ids=current_token_ids, ++ prev_tokens=previous_tokens, ++ prefix_offset=prefix_offset, ++ read_offset=read_offset, ++ skip_special_tokens=False, ++ spaces_between_special_tokens=True, ++ ) ++ ++ current_text = previous_text + delta_text ++ ++ delta_message = seed_oss_tool_parser.extract_tool_calls_streaming( ++ previous_text, ++ current_text, ++ delta_text, ++ previous_token_ids, ++ current_token_ids, ++ delta_token_ids, ++ request=request, ++ ) ++ if delta_message: ++ yield delta_message ++ ++ previous_text = current_text ++ previous_tokens = (previous_tokens + ++ new_tokens if previous_tokens else new_tokens) ++ prefix_offset = new_prefix_offset ++ read_offset = new_read_offset ++ ++ ++@pytest.mark.parametrize( ++ ids=[ ++ "tool_call_0_thinking_budget", ++ "tool_call_512_thinkg_budget", ++ "tool_call_unlimited_thinking_budget", ++ ], ++ argnames=["model_output", "expected_tool_calls", "expected_content"], ++ argvalues=[ ++ ("""\n\n\n""" ++ """The current thinking budget is 0, so I will directly start answering the question.\n\n""" ++ """\n\n""" ++ """Barcelona, Spain\n\n""", ++ [ ++ ToolCall(function=FunctionCall( ++ name="get_weather", ++ arguments=json.dumps({ ++ "location": "Barcelona, Spain", ++ }, ), ++ ), ++ type='function') ++ ], ++ """\n\n\n""" ++ """The current thinking budget is 0, so I will directly start answering the question.\n\n""" ++ ), ++ ( ++ """The user\'s current thinking budget is 512.\nLet me analyze the """ ++ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ ++ """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ ++ """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ ++ """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ ++ """country). \nI have used 131 tokens, and there are 381 tokens remaining for use.""" ++ """\n Since the unit isn\'t specified, the function will default to Celsius, which """ ++ """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ ++ """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ ++ """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ ++ """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """ ++ """use.\n The unit parameter can be omitted since it\'s optional.\n""" ++ """\n\nBarcelona, Spain\n""" ++ """\n""", ++ [ ++ ToolCall(function=FunctionCall( ++ name="get_weather", ++ arguments=json.dumps({ ++ "location": "Barcelona, Spain", ++ }, ), ++ ), ++ type='function') ++ ], ++ """The user\'s current thinking budget is 512.\nLet me analyze the """ ++ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ ++ """there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """ ++ """check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """ ++ """optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """ ++ """country). \nI have used 131 tokens, and there are 381 tokens remaining for use.""" ++ """\n Since the unit isn\'t specified, the function will default to Celsius, which """ ++ """is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """ ++ """the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """ ++ """user\'s input has a space, but the function might accept either; to be safe, using the standard format """ ++ """with a comma).\nI have used 257 tokens, and there are 255 tokens remaining for """ ++ """use.\n The unit parameter can be omitted since it\'s optional.\n""", ++ ), ++ ( ++ """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ ++ """First, I need to remember the function I can use: get_weather. The function requires a """ ++ """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ ++ """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ ++ """let me check the function docstring again. Oh, the function says unit is optional, and """ ++ """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ ++ """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ ++ """The format is \n\nBarcelona, """ ++ """Spain\ncelsius\n\n. """ ++ """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ ++ """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ ++ """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ ++ """call should be as above. Then wait for the result to come back and tell the user the """ ++ """temperature in Celsius.\n\n""" ++ """Barcelona, Spain\ncelsius\n\n""", ++ [ ++ ToolCall(function=FunctionCall( ++ name="get_weather", ++ arguments=json.dumps( ++ { ++ "location": "Barcelona, Spain", ++ "unit": "celsius", ++ }, ), ++ ), ++ type='function') ++ ], ++ """\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ ++ """First, I need to remember the function I can use: get_weather. The function requires a """ ++ """location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """ ++ """the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """ ++ """let me check the function docstring again. Oh, the function says unit is optional, and """ ++ """returns temperature in Celsius. So I should call get_weather with location "Barcelona, """ ++ """Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """ ++ """The format is \n\nBarcelona, """ ++ """Spain\ncelsius\n\n. """ ++ """Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """ ++ """of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """ ++ """it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """ ++ """call should be as above. Then wait for the result to come back and tell the user the """ ++ """temperature in Celsius.""", ++ ), ++ ], ++) ++def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, ++ sample_tools, model_output, expected_tool_calls, ++ expected_content): ++ """Test incremental streaming behavior""" ++ request = ChatCompletionRequest(model=MODEL, ++ messages=[], ++ tools=sample_tools) ++ ++ other_content = '' ++ tool_states = {} # Track state per tool index ++ ++ for delta_message in stream_delta_message_generator( ++ seed_oss_tool_parser, seed_oss_tokenizer, model_output, request): ++ # role should never be streamed from tool parser ++ assert not delta_message.role ++ ++ if delta_message.content: ++ other_content += delta_message.content ++ ++ if delta_message.tool_calls: ++ for tool_call in delta_message.tool_calls: ++ idx = tool_call.index ++ ++ # Initialize state for new tool ++ if idx not in tool_states: ++ tool_states[idx] = { ++ "id": None, ++ "name": None, ++ "arguments": "", ++ "type": None ++ } ++ ++ # First chunk should have id, name, and type ++ if tool_call.id: ++ tool_states[idx]["id"] = tool_call.id ++ ++ if tool_call.type: ++ assert tool_call.type == "function" ++ tool_states[idx]["type"] = tool_call.type ++ ++ if tool_call.function: ++ if tool_call.function.name: ++ # Should only be set once ++ assert tool_states[idx]["name"] is None ++ tool_states[idx]["name"] = tool_call.function.name ++ ++ if tool_call.function.arguments is not None: ++ # Accumulate arguments incrementally ++ tool_states[idx][ ++ "arguments"] += tool_call.function.arguments ++ ++ # Verify final content ++ assert other_content == expected_content ++ ++ # Verify we got all expected tool calls ++ assert len(tool_states) == len(expected_tool_calls) ++ ++ # Verify each tool call ++ for idx, expected_tool in enumerate(expected_tool_calls): ++ state = tool_states[idx] ++ assert state["id"] is not None ++ assert state["type"] == "function" ++ assert state["name"] == expected_tool.function.name ++ ++ # Parse accumulated arguments ++ arguments_str = state["arguments"] ++ assert arguments_str is not None ++ actual_args = json.loads(arguments_str) ++ expected_args = json.loads(expected_tool.function.arguments) ++ assert actual_args == expected_args diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 30cfbdda5..f3ce33fcf 100644 --- a/tests/v1/attention/utils.py @@ -10647,6 +11122,708 @@ index 2f766a2da..680733966 100644 def _embedding_score( self, tokenizer: AnyTokenizer, +diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py +index 88c8aa929..10b9c63e0 100644 +--- a/vllm/entrypoints/openai/tool_parsers/__init__.py ++++ b/vllm/entrypoints/openai/tool_parsers/__init__.py +@@ -18,6 +18,8 @@ from .mistral_tool_parser import MistralToolParser + from .phi4mini_tool_parser import Phi4MiniJsonToolParser + from .pythonic_tool_parser import PythonicToolParser + from .qwen3coder_tool_parser import Qwen3CoderToolParser ++from .seed_oss_tool_parser import SeedOssToolParser ++# from .step3_tool_parser import Step3ToolParser + from .xlam_tool_parser import xLAMToolParser + + __all__ = [ +@@ -40,4 +42,6 @@ __all__ = [ + "HunyuanA13BToolParser", + "Glm4MoeModelToolParser", + "Qwen3CoderToolParser", ++ "SeedOssToolParser", ++ # "Step3ToolParser", + ] +diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +new file mode 100644 +index 000000000..69cf2e68f +--- /dev/null ++++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +@@ -0,0 +1,676 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++# Adapted from qwen3coder xml parser, All rights reserved. ++# ruff: noqa: E501 ++ ++import ast ++import json ++import uuid ++from collections.abc import Sequence ++from typing import Any, Optional, Union ++ ++import regex as re ++ ++from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ++ ChatCompletionToolsParam, ++ DeltaFunctionCall, DeltaMessage, ++ DeltaToolCall, ++ ExtractedToolCallInformation, ++ FunctionCall, ToolCall) ++from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ++ ToolParser, ToolParserManager) ++from vllm.logger import init_logger ++from vllm.transformers_utils.tokenizer import AnyTokenizer ++ ++logger = init_logger(__name__) ++ ++ ++@ToolParserManager.register_module("seed_oss") ++class SeedOssToolParser(ToolParser): ++ TOOL_CALL_START = "" ++ TOOL_CALL_END = "" ++ ++ def __init__(self, tokenizer: AnyTokenizer): ++ super().__init__(tokenizer) ++ ++ # --- streaming state --- ++ self._reset_streaming_state() ++ self.prev_tool_call_arr: list[dict] = [] ++ ++ self.tool_call_start_token: str = self.TOOL_CALL_START ++ self.tool_call_end_token: str = self.TOOL_CALL_END ++ # Sentinel tokens for streaming mode ++ self.tool_call_prefix: str = " or its closing tag.") ++ ++ tool_start_re = re.escape(self.tool_call_start_token) ++ tool_end_re = re.escape(self.tool_call_end_token) ++ ++ self.tool_call_complete_regex = re.compile( ++ rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL) ++ self.tool_call_regex = re.compile( ++ rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", ++ re.DOTALL) ++ ++ self.tool_call_function_regex = re.compile( ++ r"|| str: ++ """Generate a unique tool call ID.""" ++ return f"call_{uuid.uuid4().hex[:24]}" ++ ++ def _reset_streaming_state(self): ++ """Reset all streaming state.""" ++ self.current_tool_index = 0 ++ self.is_tool_call_started = False ++ self.header_sent = False ++ self.current_tool_id = -1 ++ self.current_function_name = None ++ self.current_param_name = None ++ self.current_param_value = "" ++ self.param_count = 0 ++ self.in_param = False ++ self.in_function = False ++ self.accumulated_text = "" ++ self.json_started = False ++ self.json_closed = False ++ ++ def _parse_xml_function_call( ++ self, function_call_str: str, ++ tools: Optional[list[ChatCompletionToolsParam]] ++ ) -> Optional[ToolCall]: ++ ++ def get_arguments_config(func_name: str) -> dict: ++ if tools is None: ++ return {} ++ for config in tools: ++ if not hasattr(config, "type") or not ( ++ hasattr(config, "function") ++ and hasattr(config.function, "name")): ++ continue ++ if (config.type == "function" ++ and config.function.name == func_name): ++ if not hasattr(config.function, "parameters"): ++ return {} ++ params = config.function.parameters ++ if isinstance(params, dict) and "properties" in params: ++ return params["properties"] ++ elif isinstance(params, dict): ++ return params ++ else: ++ return {} ++ logger.warning("Tool '%s' is not defined in the tools list.", ++ func_name) ++ return {} ++ ++ def convert_param_value(param_value: str, param_name: str, ++ param_config: dict, func_name: str) -> Any: ++ # Handle null value for any type ++ if param_value.lower() == "null": ++ return None ++ ++ if param_name not in param_config: ++ if param_config != {}: ++ logger.warning( ++ "Parsed parameter '%s' is not defined in " ++ "the tool parameters for tool '%s', " ++ "directly returning the string value.", param_name, ++ func_name) ++ return param_value ++ ++ if (isinstance(param_config[param_name], dict) ++ and "type" in param_config[param_name]): ++ param_type = str( ++ param_config[param_name]["type"]).strip().lower() ++ else: ++ param_type = "string" ++ if param_type in [ ++ "string", "str", "text", "varchar", "char", "enum" ++ ]: ++ return param_value ++ elif (param_type.startswith("int") or param_type.startswith("uint") ++ or param_type.startswith("long") ++ or param_type.startswith("short") ++ or param_type.startswith("unsigned")): ++ try: ++ param_value = int(param_value) # type: ignore ++ except (ValueError, TypeError): ++ logger.warning( ++ "Parsed value '%s' of parameter '%s' is not an integer in tool " ++ "'%s', degenerating to string.", param_value, ++ param_name, func_name) ++ return param_value ++ elif param_type.startswith("num") or param_type.startswith( ++ "float"): ++ try: ++ float_param_value = float(param_value) ++ param_value = float_param_value if float_param_value - int( ++ float_param_value) != 0 else int( ++ float_param_value) # type: ignore ++ except (ValueError, TypeError): ++ logger.warning( ++ "Parsed value '%s' of parameter '%s' is not a float in tool " ++ "'%s', degenerating to string.", param_value, ++ param_name, func_name) ++ return param_value ++ elif param_type in ["boolean", "bool", "binary"]: ++ param_value = param_value.lower() ++ if param_value not in ["true", "false"]: ++ logger.warning( ++ "Parsed value '%s' of parameter '%s' is not a boolean " ++ "(`true` of `false`) in tool '%s', degenerating to false.", ++ param_value, param_name, func_name) ++ return param_value == "true" ++ else: ++ if param_type == "object" or param_type.startswith("dict"): ++ try: ++ param_value = json.loads(param_value) ++ return param_value ++ except (ValueError, TypeError, json.JSONDecodeError): ++ logger.warning( ++ "Parsed value '%s' of parameter '%s' is not a valid JSON " ++ "object in tool '%s', will try other methods to parse it.", ++ param_value, param_name, func_name) ++ try: ++ param_value = ast.literal_eval(param_value) ++ except (ValueError, SyntaxError): ++ logger.warning( ++ "Parsed value '%s' of parameter '%s' cannot be converted via " ++ "Python `ast.literal_eval()` in tool '%s', degenerating to string.", ++ param_value, param_name, func_name) ++ return param_value ++ ++ # Extract function name ++ end_index = function_call_str.index(">") ++ function_name = function_call_str[:end_index] ++ param_config = get_arguments_config(function_name) ++ parameters = function_call_str[end_index + 1:] ++ param_dict = {} ++ for match in self.tool_call_parameter_regex.findall(parameters): ++ match_text = match[0] if match[0] else match[1] ++ idx = match_text.index(">") ++ param_name = match_text[:idx] ++ param_value = str(match_text[idx + 1:]) ++ # Remove prefix and trailing \n ++ if param_value.startswith("\n"): ++ param_value = param_value[1:] ++ if param_value.endswith("\n"): ++ param_value = param_value[:-1] ++ ++ param_dict[param_name] = convert_param_value( ++ param_value, param_name, param_config, function_name) ++ return ToolCall( ++ type="function", ++ function=FunctionCall(name=function_name, ++ arguments=json.dumps(param_dict, ++ ensure_ascii=False)), ++ ) ++ ++ def _get_function_calls(self, model_output: str) -> list[str]: ++ # Find all tool calls ++ matched_ranges = self.tool_call_regex.findall(model_output) ++ raw_tool_calls = [ ++ match[0] if match[0] else match[1] for match in matched_ranges ++ ] ++ ++ # Back-off strategy if no tool_call tags found ++ if len(raw_tool_calls) == 0: ++ raw_tool_calls = [model_output] ++ ++ raw_function_calls = [] ++ for tool_call in raw_tool_calls: ++ raw_function_calls.extend( ++ self.tool_call_function_regex.findall(tool_call)) ++ ++ function_calls = [ ++ match[0] if match[0] else match[1] for match in raw_function_calls ++ ] ++ return function_calls ++ ++ def extract_tool_calls( ++ self, ++ model_output: str, ++ request: ChatCompletionRequest, ++ ) -> ExtractedToolCallInformation: ++ # Quick check to avoid unnecessary processing ++ if self.tool_call_prefix not in model_output: ++ return ExtractedToolCallInformation(tools_called=False, ++ tool_calls=[], ++ content=model_output) ++ ++ # Check if both think start and end tokens are present ++ if (self.think_start_token in model_output ++ and self.think_end_token in model_output): ++ # Find the position of think end token ++ think_end_index = model_output.find(self.think_end_token) + len( ++ self.think_end_token) ++ # Extract content after think end token ++ result_content = model_output[think_end_index:] ++ thinking_content = model_output[:think_end_index] ++ ++ try: ++ function_calls = self._get_function_calls(result_content) ++ if len(function_calls) == 0: ++ return ExtractedToolCallInformation(tools_called=False, ++ tool_calls=[], ++ content=model_output) ++ ++ tool_calls = [ ++ self._parse_xml_function_call(function_call_str, request.tools) ++ for function_call_str in function_calls ++ ] ++ ++ # Populate prev_tool_call_arr for serving layer to set finish_reason ++ self.prev_tool_call_arr.clear() # Clear previous calls ++ for tool_call in tool_calls: ++ if tool_call: ++ self.prev_tool_call_arr.append({ ++ "name": ++ tool_call.function.name, ++ "arguments": ++ tool_call.function.arguments, ++ }) ++ ++ # Extract content before tool calls ++ tool_call_start_index = result_content.find( ++ self.tool_call_start_token) ++ tool_call_start_index = ( ++ tool_call_start_index if tool_call_start_index >= 0 else ++ result_content.find(self.tool_call_prefix)) ++ content = thinking_content + result_content[:tool_call_start_index] ++ ++ return ExtractedToolCallInformation( ++ tools_called=(len(tool_calls) > 0), ++ tool_calls=tool_calls, ++ content=content if content else None, ++ ) ++ ++ except Exception: ++ logger.exception("Error in extracting tool call from response.") ++ return ExtractedToolCallInformation(tools_called=False, ++ tool_calls=[], ++ content=model_output) ++ ++ def extract_tool_calls_streaming( ++ self, ++ previous_text: str, ++ current_text: str, ++ delta_text: str, ++ previous_token_ids: Sequence[int], ++ current_token_ids: Sequence[int], ++ delta_token_ids: Sequence[int], ++ request: ChatCompletionRequest, ++ ) -> Union[DeltaMessage, None]: ++ # If no delta text, return None unless ++ # it's an EOS token after tool calls ++ if not delta_text: ++ # Check if this is an EOS token after all tool calls are complete ++ # We check for tool calls in the text even if is_tool_call_started ++ # is False because it might have been reset after processing all tools ++ if (delta_token_ids ++ and self.tool_call_end_token_id not in delta_token_ids): ++ # Count complete tool calls ++ complete_calls = len( ++ self.tool_call_complete_regex.findall(current_text)) ++ ++ # If we have completed tool calls and populated prev_tool_call_arr ++ if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: ++ # Check if all tool calls are closed ++ open_calls = current_text.count( ++ self.tool_call_start_token) - current_text.count( ++ self.tool_call_end_token) ++ if open_calls == 0: ++ # Return empty delta message to allow finish_reason processing ++ return DeltaMessage(content="") ++ elif not self.is_tool_call_started and current_text: ++ # This is a regular content response that's now complete ++ return DeltaMessage(content="") ++ return None ++ ++ # Check if this is the first call (reset state if needed) ++ if not previous_text: ++ self._reset_streaming_state() ++ ++ # Update accumulated text ++ self.accumulated_text = current_text ++ ++ # Check if we need to advance to next tool ++ if self.json_closed and not self.in_function: ++ # Check if this tool call has ended ++ tool_ends = current_text.count(self.tool_call_end_token) ++ if tool_ends > self.current_tool_index: ++ # This tool has ended, advance to next ++ self.current_tool_index += 1 ++ self.header_sent = False ++ self.param_count = 0 ++ self.json_started = False ++ self.json_closed = False ++ ++ # Check if there are more tool calls ++ if self.current_tool_index >= current_text.count( ++ self.tool_call_start_token): ++ # No more tool calls ++ self.is_tool_call_started = False ++ # Continue processing next tool ++ return None ++ ++ # Check if end thinking ++ if (not self.is_thinking_end ++ and (self.think_end_token_id in delta_token_ids ++ or self.think_end_token in delta_text)): ++ self.is_thinking_end = True ++ ++ # If thinking hasn't ended yet, don't process any tool calls ++ if not self.is_thinking_end: ++ return DeltaMessage(content=delta_text) ++ ++ # Handle normal content before tool calls ++ if not self.is_tool_call_started: ++ # Check if tool call is starting ++ if (self.tool_call_start_token_id in delta_token_ids ++ or self.tool_call_start_token in delta_text): ++ self.is_tool_call_started = True ++ # Return any content before the tool call ++ if self.tool_call_start_token in delta_text: ++ content_before = delta_text[:delta_text.index( ++ self.tool_call_start_token)] ++ if content_before: ++ return DeltaMessage(content=content_before) ++ return None ++ else: ++ # Check if we're between tool calls - skip whitespace ++ if (current_text.rstrip().endswith(self.tool_call_end_token) ++ and delta_text.strip() == ""): ++ # We just ended a tool call, skip whitespace ++ return None ++ # Normal content, no tool call ++ return DeltaMessage(content=delta_text) ++ ++ # Check if we're between tool calls (waiting for next one) ++ # Count tool calls we've seen vs processed ++ tool_starts_count = current_text.count(self.tool_call_start_token) ++ if self.current_tool_index >= tool_starts_count: ++ # We're past all tool calls, shouldn't be here ++ return None ++ ++ # We're in a tool call, find the current tool call portion ++ # Need to find the correct tool call based on current_tool_index ++ # Only process tool calls after think_end_token ++ think_end_index = current_text.find(self.think_end_token) + len( ++ self.think_end_token ++ ) if self.think_end_token in current_text else 0 ++ tool_starts: list[int] = [] ++ idx = think_end_index ++ while True: ++ idx = current_text.find(self.tool_call_start_token, idx) ++ if idx == -1: ++ break ++ tool_starts.append(idx) ++ idx += len(self.tool_call_start_token) ++ ++ if self.current_tool_index >= len(tool_starts): ++ # No more tool calls to process yet ++ return None ++ ++ tool_start_idx = tool_starts[self.current_tool_index] ++ # Find where this tool call ends (or current position if not ended yet) ++ tool_end_idx = current_text.find(self.tool_call_end_token, ++ tool_start_idx) ++ if tool_end_idx == -1: ++ tool_text = current_text[tool_start_idx:] ++ else: ++ tool_text = current_text[tool_start_idx:tool_end_idx + ++ len(self.tool_call_end_token)] ++ ++ # Looking for function header ++ if not self.header_sent: ++ if self.tool_call_prefix in tool_text: ++ func_start = tool_text.find(self.tool_call_prefix) + len( ++ self.tool_call_prefix) ++ func_end = tool_text.find(">", func_start) ++ ++ if func_end != -1: ++ # Found complete function name ++ self.current_function_name = tool_text[func_start:func_end] ++ self.current_tool_id = self._generate_tool_call_id( ++ ) # type: ignore ++ self.header_sent = True ++ self.in_function = True ++ ++ # IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call ++ # This ensures finish_reason="tool_calls" even if parsing isn't complete ++ already_added = any( ++ tool.get("name") == self.current_function_name ++ for tool in self.prev_tool_call_arr) ++ if not already_added: ++ self.prev_tool_call_arr.append({ ++ "name": self.current_function_name, ++ "arguments": ++ "{}", # Placeholder, will be updated later ++ }) ++ ++ # Send header with function info ++ return DeltaMessage(tool_calls=[ ++ DeltaToolCall( ++ index=self.current_tool_index, ++ id=self.current_tool_id, ++ function=DeltaFunctionCall( ++ name=self.current_function_name, arguments=""), ++ type="function", ++ ) ++ ]) ++ return None ++ ++ # We've sent header, now handle function body ++ if self.in_function: ++ # Send opening brace if not sent yet ++ if (not self.json_started ++ and self.parameter_prefix not in delta_text): ++ self.json_started = True ++ return DeltaMessage(tool_calls=[ ++ DeltaToolCall( ++ index=self.current_tool_index, ++ function=DeltaFunctionCall(arguments="{"), ++ ) ++ ]) ++ ++ # Make sure json_started is set if we're processing parameters ++ if not self.json_started: ++ self.json_started = True ++ ++ # Check for function end in accumulated text ++ if not self.json_closed and self.function_end_token in tool_text: ++ # Close JSON ++ self.json_closed = True ++ ++ # Extract the complete tool call to update prev_tool_call_arr with final arguments ++ # Find the function content ++ func_start = tool_text.find(self.tool_call_prefix) + len( ++ self.tool_call_prefix) ++ func_content_end = tool_text.find(self.function_end_token, ++ func_start) ++ if func_content_end != -1: ++ func_content = tool_text[func_start:func_content_end] ++ # Parse to get the complete arguments ++ try: ++ parsed_tool = self._parse_xml_function_call( ++ func_content, request.tools if request else None) ++ if parsed_tool: ++ # Update existing entry in prev_tool_call_arr with complete arguments ++ for i, tool in enumerate(self.prev_tool_call_arr): ++ if tool.get( ++ "name") == parsed_tool.function.name: ++ self.prev_tool_call_arr[i]["arguments"] = ( ++ parsed_tool.function.arguments) ++ break ++ except Exception: ++ logger.warning( ++ "Failed to parse tool arguments during streaming.", ++ exc_info=True) ++ ++ result = DeltaMessage(tool_calls=[ ++ DeltaToolCall( ++ index=self.current_tool_index, ++ function=DeltaFunctionCall(arguments="}"), ++ ) ++ ]) ++ ++ # Reset state for next tool ++ self.in_function = False ++ self.json_closed = True ++ ++ return result ++ ++ # Look for parameters ++ # Count how many complete parameters we have processed ++ complete_params = tool_text.count(self.parameter_end_token) ++ ++ # Check if we should start a new parameter ++ if not self.in_param and self.param_count < complete_params: ++ # Find the unprocessed parameter ++ # Count parameter starts ++ param_starts = [] ++ idx = 0 ++ while True: ++ idx = tool_text.find(self.parameter_prefix, idx) ++ if idx == -1: ++ break ++ param_starts.append(idx) ++ idx += len(self.parameter_prefix) ++ ++ if len(param_starts) > self.param_count: ++ # Process the next parameter ++ param_idx = param_starts[self.param_count] ++ param_start = param_idx + len(self.parameter_prefix) ++ remaining = tool_text[param_start:] ++ ++ if ">" in remaining: ++ # We have the complete parameter name ++ name_end = remaining.find(">") ++ self.current_param_name = remaining[:name_end] ++ ++ # Find the parameter value ++ value_start = param_start + name_end + 1 ++ value_text = tool_text[value_start:] ++ if value_text.startswith("\n"): ++ value_text = value_text[1:] ++ ++ # Find where this parameter ends ++ param_end_idx = value_text.find( ++ self.parameter_end_token) ++ if param_end_idx != -1: ++ # Complete parameter found ++ param_value = value_text[:param_end_idx] ++ if param_value.endswith("\n"): ++ param_value = param_value[:-1] ++ ++ # Build complete JSON fragment for this parameter ++ if self.param_count == 0: ++ json_fragment = ( ++ '"' + self.current_param_name + '": "' + ++ json.dumps(param_value)[1:-1] + '"') ++ else: ++ json_fragment = ( ++ ', "' + self.current_param_name + '": "' + ++ json.dumps(param_value)[1:-1] + '"') ++ ++ self.param_count += 1 ++ ++ return DeltaMessage(tool_calls=[ ++ DeltaToolCall( ++ index=self.current_tool_index, ++ function=DeltaFunctionCall( ++ arguments=json_fragment), ++ ) ++ ]) ++ ++ # Continue parameter value ++ if self.in_param: ++ if self.parameter_end_token in delta_text: ++ # End of parameter ++ end_idx = delta_text.find(self.parameter_end_token) ++ value_chunk = delta_text[:end_idx] ++ ++ # Skip past > if at start ++ if not self.current_param_value and ">" in value_chunk: ++ gt_idx = value_chunk.find(">") ++ value_chunk = value_chunk[gt_idx + 1:] ++ ++ if not self.current_param_value and value_chunk.startswith( ++ "\n"): ++ value_chunk = value_chunk[1:] ++ ++ # Calculate incremental JSON ++ full_value = self.current_param_value + value_chunk ++ prev_escaped = (json.dumps(self.current_param_value)[1:-1] ++ if self.current_param_value else "") ++ full_escaped = json.dumps(full_value)[1:-1] ++ delta_escaped = full_escaped[len(prev_escaped):] ++ ++ self.in_param = False ++ self.current_param_value = "" ++ ++ return DeltaMessage(tool_calls=[ ++ DeltaToolCall( ++ index=self.current_tool_index, ++ function=DeltaFunctionCall( ++ arguments=delta_escaped + '"'), ++ ) ++ ]) ++ else: ++ # Continue accumulating value ++ value_chunk = delta_text ++ ++ # Handle first chunk after param name ++ if not self.current_param_value and ">" in value_chunk: ++ gt_idx = value_chunk.find(">") ++ value_chunk = value_chunk[gt_idx + 1:] ++ ++ if not self.current_param_value and value_chunk.startswith( ++ "\n"): ++ value_chunk = value_chunk[1:] ++ ++ if value_chunk: ++ # Stream the escaped delta ++ prev_escaped = (json.dumps( ++ self.current_param_value)[1:-1] ++ if self.current_param_value else "") ++ self.current_param_value += value_chunk ++ full_escaped = json.dumps( ++ self.current_param_value)[1:-1] ++ delta_escaped = full_escaped[len(prev_escaped):] ++ ++ if delta_escaped: ++ return DeltaMessage(tool_calls=[ ++ DeltaToolCall( ++ index=self.current_tool_index, ++ function=DeltaFunctionCall( ++ arguments=delta_escaped), ++ ) ++ ]) ++ ++ return None diff --git a/vllm/envs.py b/vllm/envs.py index 5c414e82d..56a8d7253 100755 --- a/vllm/envs.py @@ -14172,10 +15349,16 @@ index 12899c280..951215ee0 100644 q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py -index 2aaac7798..4759030be 100644 +index 2aaac7798..6a832ca27 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py -@@ -127,6 +127,7 @@ _TEXT_GENERATION_MODELS = { +@@ -122,11 +122,13 @@ _TEXT_GENERATION_MODELS = { + "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), + "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), + "RWForCausalLM": ("falcon", "FalconForCausalLM"), ++ "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), + "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), + "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), @@ -14336,6 +15519,499 @@ index c6b411644..feb549d44 100644 + position_ids[offset:offset+length] = \ + create_position_ids_from_input_ids(tokens, padding_idx) + offset = offset + length +diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py +new file mode 100644 +index 000000000..34a87a6a6 +--- /dev/null ++++ b/vllm/model_executor/models/seed_oss.py +@@ -0,0 +1,487 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++ ++# Copyright 2025 The Seed team. ++# Copyright 2023 The vLLM team. ++# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. ++# ++# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX ++# and OPT implementations in this library. It has been modified from its ++# original forms to accommodate minor architectural differences compared ++# to GPT-NeoX and OPT used by the Meta AI team that trained the model. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++"""Inference-only SeedOss model compatible with HuggingFace weights.""" ++from collections.abc import Iterable ++from typing import Optional, Union ++ ++import torch ++from torch import nn ++from transformers import PretrainedConfig as SeedOssConfig ++ ++from vllm.attention import Attention, AttentionType ++from vllm.compilation.decorators import support_torch_compile ++from vllm.config import CacheConfig, VllmConfig ++from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size ++from vllm.logger import init_logger ++from vllm.model_executor.layers.activation import SiluAndMul ++from vllm.model_executor.layers.layernorm import RMSNorm ++from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ++ QKVParallelLinear, ++ RowParallelLinear) ++from vllm.model_executor.layers.logits_processor import LogitsProcessor ++from vllm.model_executor.layers.quantization import QuantizationConfig ++from vllm.model_executor.layers.rotary_embedding import get_rope ++from vllm.model_executor.layers.vocab_parallel_embedding import ( ++ ParallelLMHead, VocabParallelEmbedding) ++from vllm.model_executor.model_loader.weight_utils import ( ++ default_weight_loader, maybe_remap_kv_scale_name) ++from vllm.model_executor.sampling_metadata import SamplingMetadata ++from vllm.sequence import IntermediateTensors ++ ++from .interfaces import SupportsLoRA, SupportsPP ++from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, ++ make_empty_intermediate_tensors_factory, make_layers, ++ maybe_prefix) ++ ++logger = init_logger(__name__) ++ ++ ++class SeedOssMLP(nn.Module): ++ ++ def __init__( ++ self, ++ hidden_size: int, ++ intermediate_size: int, ++ hidden_act: str, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = "", ++ ) -> None: ++ super().__init__() ++ self.gate_up_proj = MergedColumnParallelLinear( ++ hidden_size, ++ [intermediate_size] * 2, ++ bias=False, ++ quant_config=quant_config, ++ prefix=f"{prefix}.gate_up_proj", ++ ) ++ self.down_proj = RowParallelLinear( ++ intermediate_size, ++ hidden_size, ++ bias=False, ++ quant_config=quant_config, ++ prefix=f"{prefix}.down_proj", ++ ) ++ if hidden_act != "silu": ++ raise ValueError(f"Unsupported activation: {hidden_act}. " ++ "Only silu is supported for now.") ++ self.act_fn = SiluAndMul() ++ ++ def forward(self, x): ++ gate_up, _ = self.gate_up_proj(x) ++ x = self.act_fn(gate_up) ++ x, _ = self.down_proj(x) ++ return x ++ ++ ++class SeedOssAttention(nn.Module): ++ ++ def __init__( ++ self, ++ hidden_size: int, ++ num_heads: int, ++ num_kv_heads: int, ++ head_dim: int, ++ max_position: int = 4096 * 32, ++ rope_theta: float = 10000, ++ cache_config: Optional[CacheConfig] = None, ++ quant_config: Optional[QuantizationConfig] = None, ++ rope_scaling: Optional[tuple] = None, ++ prefix: str = "", ++ attn_type: str = AttentionType.DECODER, ++ ) -> None: ++ super().__init__() ++ self.hidden_size = hidden_size ++ tp_size = get_tensor_model_parallel_world_size() ++ self.total_num_heads = num_heads ++ assert self.total_num_heads % tp_size == 0 ++ self.num_heads = self.total_num_heads // tp_size ++ self.total_num_kv_heads = num_kv_heads ++ self.head_dim = head_dim ++ if self.total_num_kv_heads >= tp_size: ++ # Number of KV heads is greater than TP size, so we partition ++ # the KV heads across multiple tensor parallel GPUs. ++ assert self.total_num_kv_heads % tp_size == 0 ++ else: ++ # Number of KV heads is less than TP size, so we replicate ++ # the KV heads across multiple tensor parallel GPUs. ++ assert tp_size % self.total_num_kv_heads == 0 ++ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) ++ self.q_size = self.num_heads * self.head_dim ++ self.kv_size = self.num_kv_heads * self.head_dim ++ self.scaling = self.head_dim**-0.5 ++ self.rope_theta = rope_theta ++ ++ self.qkv_proj = QKVParallelLinear( ++ hidden_size, ++ self.head_dim, ++ self.total_num_heads, ++ self.total_num_kv_heads, ++ bias=True, ++ quant_config=quant_config, ++ prefix=f"{prefix}.qkv_proj", ++ ) ++ self.o_proj = RowParallelLinear( ++ self.total_num_heads * self.head_dim, ++ hidden_size, ++ bias=False, ++ quant_config=quant_config, ++ prefix=f"{prefix}.o_proj", ++ ) ++ ++ self.rotary_emb = get_rope( ++ self.head_dim, ++ rotary_dim=self.head_dim, ++ max_position=max_position, ++ base=self.rope_theta, ++ rope_scaling=rope_scaling, ++ ) ++ self.attn = Attention( ++ self.num_heads, ++ self.head_dim, ++ self.scaling, ++ num_kv_heads=self.num_kv_heads, ++ cache_config=cache_config, ++ quant_config=quant_config, ++ attn_type=attn_type, ++ prefix=f"{prefix}.attn", ++ ) ++ ++ def forward( ++ self, ++ positions: torch.Tensor, ++ hidden_states: torch.Tensor, ++ ) -> torch.Tensor: ++ qkv, _ = self.qkv_proj(hidden_states) ++ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) ++ q, k = self.rotary_emb(positions, q, k) ++ attn_output = self.attn(q, k, v) ++ output, _ = self.o_proj(attn_output) ++ return output ++ ++ ++class SeedOssDecoderLayer(nn.Module): ++ ++ def __init__( ++ self, ++ config: SeedOssConfig, ++ cache_config: Optional[CacheConfig] = None, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = "", ++ ) -> None: ++ super().__init__() ++ self.hidden_size = config.hidden_size ++ # Requires transformers > 4.32.0 ++ rope_theta = getattr(config, "rope_theta", 1000000) ++ rope_scaling = getattr(config, "rope_scaling", None) ++ ++ # By default, SeedOss uses causal attention as it is a ++ # decoder-only model. ++ # You can override the HF config with `is_causal=False` to enable ++ # bidirectional attention, which is used in some embedding models ++ if getattr(config, "is_causal", True): ++ attn_type = AttentionType.DECODER ++ else: ++ attn_type = AttentionType.ENCODER_ONLY ++ ++ self.self_attn = SeedOssAttention( ++ hidden_size=self.hidden_size, ++ num_heads=config.num_attention_heads, ++ max_position=config.max_position_embeddings, ++ num_kv_heads=config.num_key_value_heads, ++ head_dim=config.head_dim, ++ rope_theta=rope_theta, ++ cache_config=cache_config, ++ quant_config=quant_config, ++ rope_scaling=rope_scaling, ++ prefix=f"{prefix}.self_attn", ++ attn_type=attn_type, ++ ) ++ self.mlp = SeedOssMLP( ++ hidden_size=self.hidden_size, ++ intermediate_size=config.intermediate_size, ++ hidden_act=config.hidden_act, ++ quant_config=quant_config, ++ prefix=f"{prefix}.mlp", ++ ) ++ self.input_layernorm = RMSNorm(config.hidden_size, ++ eps=config.rms_norm_eps) ++ self.post_attention_layernorm = RMSNorm(config.hidden_size, ++ eps=config.rms_norm_eps) ++ ++ def forward( ++ self, ++ positions: torch.Tensor, ++ hidden_states: torch.Tensor, ++ residual: Optional[torch.Tensor], ++ ) -> tuple[torch.Tensor, torch.Tensor]: ++ # Self Attention ++ if residual is None: ++ residual = hidden_states ++ hidden_states = self.input_layernorm(hidden_states) ++ else: ++ hidden_states, residual = self.input_layernorm( ++ hidden_states, residual) ++ hidden_states = self.self_attn( ++ positions=positions, ++ hidden_states=hidden_states, ++ ) ++ ++ # Fully Connected ++ hidden_states, residual = self.post_attention_layernorm( ++ hidden_states, residual) ++ hidden_states = self.mlp(hidden_states) ++ return hidden_states, residual ++ ++ ++@support_torch_compile( ++ dynamic_arg_dims={ ++ "input_ids": 0, ++ "positions": -1, ++ "intermediate_tensors": 0, ++ "inputs_embeds": 0, ++ }) ++class SeedOssModel(nn.Module): ++ ++ def __init__(self, ++ *, ++ vllm_config: VllmConfig, ++ prefix: str = "", ++ decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer): ++ super().__init__() ++ ++ config = vllm_config.model_config.hf_config ++ cache_config = vllm_config.cache_config ++ quant_config = vllm_config.quant_config ++ ++ # TODO (@robertgshaw2): see if this can be moved out ++ if (cache_config.sliding_window is not None ++ and hasattr(config, "max_window_layers")): ++ assert config.max_window_layers == config.num_hidden_layers, ( ++ "Sliding window for some but all layers is not supported. " ++ "This model uses sliding window but `max_window_layers` = {} " ++ "is less than `num_hidden_layers` = {}. Please open an issue " ++ "to discuss this feature.".format( ++ config.max_window_layers, ++ config.num_hidden_layers, ++ )) ++ ++ self.config = config ++ self.quant_config = quant_config ++ self.vocab_size = config.vocab_size ++ ++ if get_pp_group().is_first_rank or (config.tie_word_embeddings ++ and get_pp_group().is_last_rank): ++ self.embed_tokens = VocabParallelEmbedding( ++ config.vocab_size, ++ config.hidden_size, ++ quant_config=quant_config, ++ prefix=f"{prefix}.embed_tokens", ++ ) ++ else: ++ self.embed_tokens = PPMissingLayer() ++ ++ # Use the provided decoder layer type or default to SeedDecoderLayer ++ decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer ++ self.start_layer, self.end_layer, self.layers = make_layers( ++ config.num_hidden_layers, ++ lambda prefix: decoder_layer_type(config=config, ++ cache_config=cache_config, ++ quant_config=quant_config, ++ prefix=prefix), ++ prefix=f"{prefix}.layers", ++ ) ++ ++ self.make_empty_intermediate_tensors = ( ++ make_empty_intermediate_tensors_factory( ++ ["hidden_states", "residual"], config.hidden_size)) ++ if get_pp_group().is_last_rank: ++ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ else: ++ self.norm = PPMissingLayer() ++ ++ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: ++ return self.embed_tokens(input_ids) ++ ++ def forward( ++ self, ++ input_ids: torch.Tensor, ++ positions: torch.Tensor, ++ intermediate_tensors: Optional[IntermediateTensors] = None, ++ inputs_embeds: Optional[torch.Tensor] = None, ++ ) -> Union[torch.Tensor, IntermediateTensors]: ++ if get_pp_group().is_first_rank: ++ if inputs_embeds is not None: ++ hidden_states = inputs_embeds ++ else: ++ hidden_states = self.get_input_embeddings(input_ids) ++ residual = None ++ else: ++ assert intermediate_tensors is not None ++ hidden_states = intermediate_tensors["hidden_states"] ++ residual = intermediate_tensors["residual"] ++ for layer in self.layers[self.start_layer:self.end_layer]: ++ hidden_states, residual = layer( ++ positions, ++ hidden_states, ++ residual, ++ ) ++ if not get_pp_group().is_last_rank: ++ return IntermediateTensors({ ++ "hidden_states": hidden_states, ++ "residual": residual ++ }) ++ hidden_states, _ = self.norm(hidden_states, residual) ++ return hidden_states ++ ++ def load_weights(self, weights: Iterable[tuple[str, ++ torch.Tensor]]) -> set[str]: ++ stacked_params_mapping = [ ++ # (param_name, shard_name, shard_id) ++ ("qkv_proj", "q_proj", "q"), ++ ("qkv_proj", "k_proj", "k"), ++ ("qkv_proj", "v_proj", "v"), ++ ("gate_up_proj", "gate_proj", 0), ++ ("gate_up_proj", "up_proj", 1), ++ ] ++ params_dict = dict(self.named_parameters(remove_duplicate=False)) ++ loaded_params: set[str] = set() ++ for name, loaded_weight in weights: ++ if "rotary_emb.inv_freq" in name: ++ continue ++ if (self.quant_config is not None and ++ (scale_name := self.quant_config.get_cache_scale(name))): ++ # Loading kv cache quantization scales ++ param = params_dict[scale_name] ++ weight_loader = getattr(param, "weight_loader", ++ default_weight_loader) ++ loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else ++ loaded_weight[0]) ++ weight_loader(param, loaded_weight) ++ loaded_params.add(scale_name) ++ continue ++ for (param_name, weight_name, shard_id) in stacked_params_mapping: ++ if weight_name not in name: ++ continue ++ name = name.replace(weight_name, param_name) ++ # Skip loading extra bias for GPTQ models. ++ if name.endswith(".bias") and name not in params_dict: ++ continue ++ if is_pp_missing_parameter(name, self): ++ continue ++ param = params_dict[name] ++ weight_loader = param.weight_loader ++ weight_loader(param, loaded_weight, shard_id) ++ break ++ else: ++ # Skip loading extra bias for GPTQ models. ++ if name.endswith(".bias") and name not in params_dict: ++ continue ++ # Remapping the name of FP8 kv-scale. ++ name = maybe_remap_kv_scale_name(name, params_dict) ++ if name is None: ++ continue ++ if is_pp_missing_parameter(name, self): ++ continue ++ param = params_dict[name] ++ weight_loader = getattr(param, "weight_loader", ++ default_weight_loader) ++ weight_loader(param, loaded_weight) ++ loaded_params.add(name) ++ return loaded_params ++ ++ ++class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ++ packed_modules_mapping = { ++ "qkv_proj": [ ++ "q_proj", ++ "k_proj", ++ "v_proj", ++ ], ++ "gate_up_proj": [ ++ "gate_proj", ++ "up_proj", ++ ], ++ } ++ ++ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ++ super().__init__() ++ config = vllm_config.model_config.hf_config ++ quant_config = vllm_config.quant_config ++ lora_config = vllm_config.lora_config ++ ++ self.config = config ++ self.lora_config = lora_config ++ ++ self.quant_config = quant_config ++ self.model = SeedOssModel(vllm_config=vllm_config, ++ prefix=maybe_prefix(prefix, "model")) ++ ++ if get_pp_group().is_last_rank: ++ if config.tie_word_embeddings: ++ self.lm_head = self.model.embed_tokens ++ else: ++ self.lm_head = ParallelLMHead(config.vocab_size, ++ config.hidden_size, ++ quant_config=quant_config, ++ prefix=maybe_prefix( ++ prefix, "lm_head")) ++ else: ++ self.lm_head = PPMissingLayer() ++ ++ self.logits_processor = LogitsProcessor(config.vocab_size) ++ ++ self.make_empty_intermediate_tensors = ( ++ self.model.make_empty_intermediate_tensors) ++ ++ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: ++ return self.model.get_input_embeddings(input_ids) ++ ++ def forward( ++ self, ++ input_ids: torch.Tensor, ++ positions: torch.Tensor, ++ intermediate_tensors: Optional[IntermediateTensors] = None, ++ inputs_embeds: Optional[torch.Tensor] = None, ++ ) -> Union[torch.Tensor, IntermediateTensors]: ++ hidden_states = self.model(input_ids, positions, intermediate_tensors, ++ inputs_embeds) ++ return hidden_states ++ ++ def compute_logits( ++ self, ++ hidden_states: torch.Tensor, ++ sampling_metadata: SamplingMetadata, ++ ) -> Optional[torch.Tensor]: ++ logits = self.logits_processor(self.lm_head, hidden_states, ++ sampling_metadata) ++ return logits ++ ++ def load_weights(self, weights: Iterable[tuple[str, ++ torch.Tensor]]) -> set[str]: ++ loader = AutoWeightsLoader( ++ self, ++ skip_prefixes=(["lm_head."] ++ if self.config.tie_word_embeddings else None), ++ ) ++ return loader.load_weights(weights) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 3630f59f5..62566d8f6 100644 --- a/vllm/model_executor/models/siglip.py @@ -14611,6 +16287,249 @@ index 8d1f59e6e..0d96bcfef 100644 def thinker_uses_mrope(config: PretrainedConfig) -> bool: +diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py +new file mode 100644 +index 000000000..0959d4a00 +--- /dev/null ++++ b/vllm/utils/tensor_schema.py +@@ -0,0 +1,236 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++from typing import (Annotated, Any, Optional, Union, get_args, get_origin, ++ get_type_hints) ++ ++import torch ++ ++from vllm.logger import init_logger ++ ++logger = init_logger(__name__) ++ ++ ++class TensorShape: ++ ++ def __init__( ++ self, ++ *dims: Union[int, str], ++ dynamic_dims: Optional[set[str]] = None, ++ ) -> None: ++ super().__init__() ++ ++ self.dims = dims ++ self.dynamic_dims = dynamic_dims if dynamic_dims else set() ++ ++ def resolve(self, **bindings: dict[str, ++ int]) -> tuple[Union[int, str], ...]: ++ resolved = [] ++ for dim in self.dims: ++ if isinstance(dim, str) and dim in bindings: ++ resolved.append(bindings[dim]) ++ else: ++ resolved.append(dim) ++ return tuple(resolved) ++ ++ def __str__(self) -> str: ++ """Return a string representation of the tensor shape.""" ++ dim_strs = [] ++ for dim in self.dims: ++ if isinstance(dim, str): ++ if dim in self.dynamic_dims: ++ dim_strs.append( ++ f"{dim}*") # Mark dynamic dimensions with * ++ else: ++ dim_strs.append(dim) ++ else: ++ dim_strs.append(str(dim)) ++ return f"({', '.join(dim_strs)})" ++ ++ ++class TensorSchema: ++ ++ def __init__( ++ self, ++ *, ++ validate: bool = True, ++ resolve_bindings: Optional[dict[str, int]] = None, ++ **kwargs: Any, ++ ) -> None: ++ super().__init__() ++ ++ self._resolve_bindings = resolve_bindings if resolve_bindings else {} ++ ++ for key, value in kwargs.items(): ++ setattr(self, key, value) ++ ++ if validate: ++ self.validate() ++ ++ def __getitem__(self, key: str) -> Any: ++ return getattr(self, key) ++ ++ def get(self, key: str, default: Any = None) -> Any: ++ return getattr(self, key, default) ++ ++ def _match_shape_with_dynamic( ++ self, ++ actual: tuple[int, ...], ++ reference: tuple[int, ...], ++ expected_shape: tuple[Union[int, str], ...], ++ dynamic_dims: set[str], ++ ) -> bool: ++ if len(actual) != len(reference) or len(actual) > len(expected_shape): ++ return False ++ ++ for i, (a, r) in enumerate(zip(actual, reference)): ++ # When validating list inputs, we match shape suffixes only ++ # (e.g. "p", 3, "h", "w"), assuming the list length corresponds ++ # to the leading symbolic dim (e.g. "bn"). This allows comparing ++ # only the trailing dimensions of each element in the list. ++ dim = expected_shape[-len(actual) + i] ++ # Skip this dimension if it's marked dynamic ++ if dim in dynamic_dims: ++ continue ++ if a != r: ++ return False ++ return True ++ ++ def _validate_nested_tensors( ++ self, ++ value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], ++ field_name: str, ++ expected_shape: tuple[Union[int, str], ...], ++ dynamic_dims: set[str], ++ ) -> tuple[int, ...]: ++ """Validate a list/tuple of tensors and return the actual shape.""" ++ # Ensure all tensors in the list have the same ++ # shape, besides dynamic dimensions ++ first = value[0] ++ for i, v in enumerate(value): ++ if not isinstance(v, torch.Tensor): ++ raise ValueError(f"{field_name}[{i}] is not a " ++ f"torch.Tensor") ++ if not self._match_shape_with_dynamic( ++ v.shape, ++ first.shape, ++ expected_shape, ++ dynamic_dims, ++ ): ++ raise ValueError(f"{field_name} contains inconsistent " ++ f"shapes: {first.shape} vs {v.shape} " ++ f"at index {i}") ++ ++ # Treat the list as a stacked tensor: ++ # shape = (len(list), *tensor.shape) ++ return (len(value), ) + first.shape ++ ++ def _validate_tensor_shape_expected( ++ self, ++ actual_shape: tuple[int, ...], ++ expected_shape: tuple[Union[int, str], ...], ++ field_name: str, ++ shape_env: dict[str, int], ++ dynamic_dims: set[str], ++ ) -> None: ++ """Validate that the actual tensor shape matches the expected shape.""" ++ ++ if len(actual_shape) != len(expected_shape): ++ raise ValueError(f"{field_name} has rank {len(actual_shape)} " ++ f"but expected {len(expected_shape)}") ++ ++ for i, dim in enumerate(expected_shape): ++ if dim in dynamic_dims: ++ continue ++ elif isinstance(dim, int): ++ if actual_shape[i] != dim: ++ raise ValueError(f"{field_name} dim[{i}] expected " ++ f"{dim}, got {actual_shape[i]}") ++ elif isinstance(dim, str): ++ if dim in shape_env: ++ if actual_shape[i] != shape_env[dim]: ++ raise ValueError(f"{field_name} dim[{i}] expected " ++ f"'{dim}'={shape_env[dim]}, got " ++ f"{actual_shape[i]}") ++ else: ++ shape_env[dim] = actual_shape[i] ++ else: ++ raise TypeError(f"{field_name} dim[{i}] has unsupported " ++ f"type: {type(dim)}") ++ ++ def validate(self) -> None: ++ type_hints = get_type_hints(self.__class__, include_extras=True) ++ shape_env = {} ++ ++ for field_name, field_type in type_hints.items(): ++ # Check if field is missing ++ if (not hasattr(self, field_name) ++ or getattr(self, field_name) is None): ++ # Check if field is marked as optional ++ actual_type = field_type ++ if get_origin(field_type) is Annotated: ++ args = get_args(field_type) ++ actual_type = args[0] ++ ++ # Check arg was provided as Union ++ if get_origin(actual_type) is Union: ++ args = get_args(actual_type) ++ # Skip validation when Union contains None ++ if type(None) in args: ++ continue ++ # Otherwise field is required, raise error ++ raise ValueError(f"Required field '{field_name}' is missing") ++ ++ # Field exists, proceed with validation ++ value = getattr(self, field_name) ++ if get_origin(field_type) is not None: ++ args = get_args(field_type) ++ ++ for arg in args: ++ if isinstance(arg, TensorShape): ++ expected_shape = arg.resolve(**self._resolve_bindings) ++ if isinstance(value, (list, tuple)): ++ # list/tuple of Tensors → shape = (len(value), ...) ++ if value and isinstance(value[0], torch.Tensor): ++ actual_shape = self._validate_nested_tensors( ++ value, field_name, expected_shape, ++ arg.dynamic_dims) ++ elif value: ++ # list/tuple of scalars → shape = (len(value),) ++ actual_shape = (len(value), ) ++ else: ++ raise ValueError( ++ f"{field_name} is an empty list") ++ ++ # Tensor → shape = tensor.shape ++ elif isinstance(value, torch.Tensor): ++ actual_shape = value.shape ++ ++ # Otherwise, it's an unsupported type ++ else: ++ type_names = [] ++ for arg in args: ++ if hasattr(arg, "__name__"): ++ type_names.append(str(arg.__name__)) ++ else: ++ type_names.append(str(arg)) ++ ++ expected_types = ", ".join(type_names) ++ raise ValueError( ++ f"{field_name} is not one of the expected " ++ f"types: {expected_types}") ++ ++ self._validate_tensor_shape_expected( ++ actual_shape, expected_shape, field_name, ++ shape_env, arg.dynamic_dims) ++ ++ def print_shapes(self) -> None: ++ """Print TensorShape annotations for debugging.""" ++ logger.debug("Shapes in %s:", self.__class__.__name__) ++ type_hints = get_type_hints(self.__class__, include_extras=True) ++ ++ for field_name, field_type in type_hints.items(): ++ if get_origin(field_type) is not None: ++ args = get_args(field_type) ++ for arg in args: ++ if isinstance(arg, TensorShape): ++ logger.debug(" %s: %s", field_name, str(arg)) +\ No newline at end of file diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5fe274f2c..4a8657ee5 100755 --- a/vllm/v1/attention/backends/flash_attn.py