From 3285944bc917ff6f659384ad0519a7cd4640ec95 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 28 Oct 2025 14:01:57 +0800 Subject: [PATCH 1/8] update patches Signed-off-by: gc-fu --- vllm/patches/vllm_for_multi_arc.patch | 23452 +++++++----------------- 1 file changed, 7042 insertions(+), 16410 deletions(-) diff --git a/vllm/patches/vllm_for_multi_arc.patch b/vllm/patches/vllm_for_multi_arc.patch index e961e2a..25e1071 100644 --- a/vllm/patches/vllm_for_multi_arc.patch +++ b/vllm/patches/vllm_for_multi_arc.patch @@ -1,5 +1,5 @@ diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -index b98d42aa7..b2a1ebef2 100644 +index 792f355c4..af2c24c4c 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do @@ -10,6 +10,107 @@ index b98d42aa7..b2a1ebef2 100644 + --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,enforce_eager=true,max_num_batched_tokens=4096" \ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ --batch_size "$BATCH_SIZE" +diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml +index 8c6ef7817..a1de41652 100644 +--- a/.buildkite/release-pipeline.yaml ++++ b/.buildkite/release-pipeline.yaml +@@ -1,22 +1,24 @@ + steps: + # aarch64 + CUDA builds. PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9 + - label: "Build arm64 wheel - CUDA 12.9" +- depends_on: ~ + id: build-wheel-arm64-cuda-12-9 + agents: + queue: arm64_cpu_queue_postmerge + commands: + # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: + # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 +- - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." ++ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/scripts/upload-wheels.sh" + env: + DOCKER_BUILDKIT: "1" + ++ - block: "Build CUDA 12.8 wheel" ++ key: block-build-cu128-wheel ++ + - label: "Build wheel - CUDA 12.8" +- depends_on: ~ ++ depends_on: block-build-cu128-wheel + id: build-wheel-cuda-12-8 + agents: + queue: cpu_queue_postmerge +@@ -28,8 +30,12 @@ steps: + env: + DOCKER_BUILDKIT: "1" + +- - label: "Build wheel - CUDA 12.6" ++ - block: "Build CUDA 12.6 wheel" ++ key: block-build-cu126-wheel + depends_on: ~ ++ ++ - label: "Build wheel - CUDA 12.6" ++ depends_on: block-build-cu126-wheel + id: build-wheel-cuda-12-6 + agents: + queue: cpu_queue_postmerge +@@ -96,6 +102,8 @@ steps: + depends_on: + - create-multi-arch-manifest + - build-wheel-cuda-12-8 ++ - build-wheel-cuda-12-6 ++ - build-wheel-cuda-12-9 + id: annotate-release-workflow + agents: + queue: cpu_queue_postmerge +diff --git a/.buildkite/scripts/annotate-release.sh b/.buildkite/scripts/annotate-release.sh +index fde48603a..94e0ac239 100755 +--- a/.buildkite/scripts/annotate-release.sh ++++ b/.buildkite/scripts/annotate-release.sh +@@ -14,33 +14,18 @@ buildkite-agent annotate --style 'info' --context 'release-workflow' << EOF + To download the wheel: + \`\`\` + aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl . +-aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux2014_aarch64.whl . +- + aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu126/vllm-${RELEASE_VERSION}+cu126-cp38-abi3-manylinux1_x86_64.whl . +-aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu129/vllm-${RELEASE_VERSION}+cu129-cp38-abi3-manylinux1_x86_64.whl . ++aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu118/vllm-${RELEASE_VERSION}+cu118-cp38-abi3-manylinux1_x86_64.whl . + \`\`\` + + To download and upload the image: + + \`\`\` +-docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-x86_64 +-docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-aarch64 +- +-docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-x86_64 vllm/vllm-openai:x86_64 +-docker tag vllm/vllm-openai:x86_64 vllm/vllm-openai:latest-x86_64 +-docker tag vllm/vllm-openai:x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 +-docker push vllm/vllm-openai:latest-x86_64 +-docker push vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 +- +-docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-aarch64 vllm/vllm-openai:aarch64 +-docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:latest-aarch64 +-docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 +-docker push vllm/vllm-openai:latest-aarch64 +-docker push vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 +- +-docker manifest create vllm/vllm-openai:latest vllm/vllm-openai:latest-x86_64 vllm/vllm-openai:latest-aarch64 --amend +-docker manifest create vllm/vllm-openai:v${RELEASE_VERSION} vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 --amend +-docker manifest push vllm/vllm-openai:latest +-docker manifest push vllm/vllm-openai:v${RELEASE_VERSION} ++docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} ++docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} vllm/vllm-openai ++docker tag vllm/vllm-openai vllm/vllm-openai:latest ++docker tag vllm/vllm-openai vllm/vllm-openai:v${RELEASE_VERSION} ++docker push vllm/vllm-openai:latest ++docker push vllm/vllm-openai:v${RELEASE_VERSION} + \`\`\` + EOF +\ No newline at end of file diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 000000000..aef250abe @@ -388,93 +489,8 @@ index 000000000..eaa2f332a + else + echo "✅ All benchmarks passed" + fi -diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml -deleted file mode 100644 -index d5c6b8d43..000000000 ---- a/.github/workflows/cleanup_pr_body.yml -+++ /dev/null -@@ -1,31 +0,0 @@ --name: Cleanup PR Body -- --on: -- pull_request_target: -- types: [opened, reopened, edited] -- --permissions: -- pull-requests: write -- --jobs: -- update-description: -- runs-on: ubuntu-latest -- -- steps: -- - name: Checkout repository -- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 -- -- - name: Set up Python -- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 -- with: -- python-version: '3.12' -- -- - name: Install Python dependencies -- run: | -- python3 -m pip install --upgrade pip -- python3 -m pip install regex -- -- - name: Update PR description -- env: -- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -- run: bash .github/scripts/cleanup_pr_body.sh "${{ github.event.number }}" -diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml -deleted file mode 100644 -index 16ae1aadb..000000000 ---- a/.github/workflows/reminder_comment.yml -+++ /dev/null -@@ -1,27 +0,0 @@ --name: PR Reminder Comment Bot --permissions: -- pull-requests: write --on: -- pull_request_target: -- types: [opened] --jobs: -- pr_reminder: -- runs-on: ubuntu-latest -- steps: -- - name: Remind to run full CI on PR -- uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 -- with: -- script: | -- github.rest.issues.createComment({ -- owner: context.repo.owner, -- repo: context.repo.repo, -- issue_number: context.issue.number, -- body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + -- '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + -- 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' + -- 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + -- 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + -- '🚀' -- }) -- env: -- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -diff --git a/CMakeLists.txt b/CMakeLists.txt -index 98ed682fe..5dd6e907c 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -95,6 +95,10 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND - NOT VLLM_TARGET_DEVICE STREQUAL "rocm") - if (VLLM_TARGET_DEVICE STREQUAL "cpu") - include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) -+ elseif(VLLM_TARGET_DEVICE STREQUAL "xpu") -+ message(STATUS "Building XPU") -+ set(VLLM_GPU_LANG "SYCL") -+ include(${CMAKE_CURRENT_LIST_DIR}/cmake/xpu_extension.cmake) - else() - return() - endif() diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py -index c7229dbb8..72531f3fc 100644 +index ba7c733be..61a9eeb91 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -18,7 +18,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizer @@ -486,13648 +502,6499 @@ index c7229dbb8..72531f3fc 100644 @dataclass -diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py -index c597fb106..5bad6645b 100644 ---- a/benchmarks/benchmark_serving.py -+++ b/benchmarks/benchmark_serving.py -@@ -256,10 +256,11 @@ async def benchmark( - raise ValueError(f"Unknown backend: {backend}") +diff --git a/docker/Dockerfile b/docker/Dockerfile +index d4761e84f..307e9658f 100644 +--- a/docker/Dockerfile ++++ b/docker/Dockerfile +@@ -196,7 +196,6 @@ ARG SCCACHE_S3_NO_CREDENTIALS=0 + + # Flag to control whether to use pre-built vLLM wheels + ARG VLLM_USE_PRECOMPILED="" +-ARG VLLM_MAIN_CUDA_VERSION="" + + # if USE_SCCACHE is set, use sccache to speed up compilation + RUN --mount=type=cache,target=/root/.cache/uv \ +@@ -214,7 +213,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ + && export SCCACHE_IDLE_TIMEOUT=0 \ + && export CMAKE_BUILD_TYPE=Release \ + && export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \ +- && export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \ + && export VLLM_DOCKER_BUILD_CONTEXT=1 \ + && sccache --show-stats \ + && python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \ +diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu +index ef4223525..ffa7c6ea7 100644 +--- a/docker/Dockerfile.xpu ++++ b/docker/Dockerfile.xpu +@@ -62,7 +62,7 @@ FROM vllm-base AS vllm-openai - print("Starting initial single prompt test run...") -+ # set test_output_len=10 to avoid long prompt test run - test_prompt, test_prompt_len, test_output_len, test_mm_content = ( - input_requests[0].prompt, - input_requests[0].prompt_len, -- input_requests[0].expected_output_len, -+ 10, - input_requests[0].multi_modal_data, - ) + # install additional dependencies for openai api server + RUN --mount=type=cache,target=/root/.cache/pip \ +- pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope ++ pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] 'modelscope!=1.15.0' -diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py -index 14461121f..e9b9f0b77 100644 ---- a/benchmarks/benchmark_throughput.py -+++ b/benchmarks/benchmark_throughput.py -@@ -44,6 +44,7 @@ def run_vllm( - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False, -+ do_profile: bool = False, - ) -> tuple[float, Optional[list[RequestOutput]]]: - from vllm import LLM, SamplingParams + RUN --mount=type=cache,target=/root/.cache/pip \ + pip uninstall oneccl oneccl-devel -y +diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md +index 834c03cbe..439e1e0d7 100644 +--- a/docs/features/quantization/fp8.md ++++ b/docs/features/quantization/fp8.md +@@ -134,4 +134,4 @@ print(result[0].outputs[0].text) + ``` -@@ -89,10 +90,14 @@ def run_vllm( - outputs = None - if not use_beam_search: - start = time.perf_counter() -+ if do_profile: -+ llm.start_profile() - outputs = llm.generate( - prompts, sampling_params, lora_request=lora_requests, use_tqdm=True - ) - end = time.perf_counter() -+ if do_profile: -+ llm.stop_profile() - else: - assert lora_requests is None, "BeamSearch API does not support LoRA" - prompts = [request.prompt for request in requests] -@@ -410,6 +415,7 @@ def main(args: argparse.Namespace): - args.n, - EngineArgs.from_cli_args(args), - args.disable_detokenize, -+ args.profile - ) - elif args.backend == "hf": - assert args.tensor_parallel_size == 1 -@@ -647,6 +653,10 @@ def create_argument_parser(): - parser.add_argument( - "--num-prompts", type=int, default=1000, help="Number of prompts to process." - ) -+ parser.add_argument("--profile", -+ action='store_true', -+ default=False, -+ help="whether run with profiler.") - parser.add_argument( - "--hf-max-batch-size", - type=int, -diff --git a/cmake/utils.cmake b/cmake/utils.cmake -index 621179a70..9e1f4e9c7 100644 ---- a/cmake/utils.cmake -+++ b/cmake/utils.cmake -@@ -445,7 +445,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) - GPU - "WITH_SOABI" - "DESTINATION;LANGUAGE;USE_SABI" -- "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") -+ "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES;LINK_FLAGS") + !!! warning +- Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. ++ Currently, by default we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. To avoid this, adding `VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT=1` can allow offloading weights to cpu before quantization and quantized weights will be kept in device. +diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md +index db3dd2c25..7d3577b14 100644 +--- a/docs/models/supported_models.md ++++ b/docs/models/supported_models.md +@@ -340,6 +340,7 @@ th { + | `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | + | `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ | + | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | ++| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | | ✅︎ | ✅︎ | + | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | + | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | + | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +@@ -667,6 +668,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen + | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | + | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | + | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ | ++| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + IE+ + VE+ | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | ++| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + IE+ + VE+ | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | ++| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | ✅︎ | + | `RForConditionalGeneration` | R-VL-4B | T + IE+ | `YannQi/R-4B` | | ✅︎ | ✅︎ | + | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | + | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | +@@ -757,8 +761,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th + Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. - # Add hipify preprocessing step when building with HIP/ROCm. - if (GPU_LANGUAGE STREQUAL "HIP") -@@ -487,6 +487,11 @@ function (define_gpu_extension_target GPU_MOD_NAME) + !!! note +- For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) +- is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. ++ For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported. - target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) + #### Transcription -+ if (GPU_LANGUAGE STREQUAL "SYCL") -+ target_compile_options(${GPU_MOD_NAME} PRIVATE ${GPU_COMPILE_FLAGS}) -+ target_link_options(${GPU_MOD_NAME} PRIVATE ${GPU_LINK_FLAGS}) -+ endif() -+ - # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of - # dependencies that are not necessary and may not be installed. - if (GPU_LANGUAGE STREQUAL "CUDA") -diff --git a/cmake/xpu_extension.cmake b/cmake/xpu_extension.cmake +diff --git a/examples/bmg/reasoning.py b/examples/bmg/reasoning.py new file mode 100644 -index 000000000..fd671a6bf +index 000000000..04f91786e --- /dev/null -+++ b/cmake/xpu_extension.cmake -@@ -0,0 +1,62 @@ -+set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -+ -+# -+# Define environment variables for special configurations -+# -+# TODO: detect Intel GPU Architecture(PVC or Arc) to add AOT flag. ++++ b/examples/bmg/reasoning.py +@@ -0,0 +1,27 @@ ++from openai import OpenAI + -+# -+# Check the compile flags -+# -+# append_cmake_prefix_path("intel_extension_for_pytorch" "intel_extension_for_pytorch.cmake_prefix_path") -+# find_package(IPEX REQUIRED) -+# IPEX will overwrite TORCH_LIBRARIES, so re-add torch_python lib. -+append_torchlib_if_found(torch_python) -+# include_directories(${IPEX_INCLUDE_DIRS}) -+set(CMPLR_ROOT $ENV{CMPLR_ROOT}) -+set(CMAKE_CXX_COMPILER icpx) -+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") -+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") -+set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl) -+ -+list(APPEND VLLM_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64") -+list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") -+list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "dnnl" ) -+ -+# -+# Define extension targets -+# ++# Modify OpenAI's API key and API base to use vLLM's API server. ++openai_api_key = "EMPTY" ++openai_api_base = "http://0.0.0.0:8000/v1" + -+# -+# _C extension -+# -+set(VLLM_EXT_SRC -+ "csrc/xpu/activation_xpu.cpp" -+ "csrc/xpu/attention_xpu.cpp" -+ "csrc/xpu/attention_xpu_fp8.cpp" -+ "csrc/xpu/cache_ops_xpu.cpp" -+ "csrc/xpu/cache_ops_xpu_fp8.cpp" -+ "csrc/xpu/gemm_kernels_xpu.cpp" -+ "csrc/xpu/layernorm_xpu.cpp" -+ "csrc/xpu/pos_encoding_xpu.cpp" -+ "csrc/xpu/utils.cpp" -+ "csrc/xpu/fused_moe.cpp" -+ "csrc/xpu/pybind.cpp") -+ -+define_gpu_extension_target( -+ _C -+ DESTINATION vllm -+ LANGUAGE ${VLLM_GPU_LANG} -+ SOURCES ${VLLM_EXT_SRC} -+ COMPILE_FLAGS ${VLLM_GPU_FLAGS} -+ LINK_FLAGS ${VLLM_GPU_LINK_FLAGS} -+ ARCHITECTURES ${VLLM_GPU_ARCHES} -+ INCLUDE_DIRECTORIES ${VLLM_EXTRA_INCLUDE_DIRECTORIES} -+ LIBRARIES ${VLLM_LINK_LIBRARIES} -+ WITH_SOABI ++client = OpenAI( ++ api_key=openai_api_key, ++ base_url=openai_api_base, +) + -+add_custom_target(default_xpu) -+message(STATUS "Enabling C extension.") -+add_dependencies(default_xpu _C) -+ -diff --git a/csrc/xpu/activation_xpu.cpp b/csrc/xpu/activation_xpu.cpp -new file mode 100644 -index 000000000..6f98ddbb3 ---- /dev/null -+++ b/csrc/xpu/activation_xpu.cpp -@@ -0,0 +1,278 @@ -+// clang-format off -+#ifdef VLLM_DEV -+#undef __SYCL_DEVICE_ONLY__ -+#endif -+#include -+// clang-format on -+#include "xpu_types.h" -+ -+#include -+#include "utils.h" -+ -+template -+__inline__ T silu_xpu(const T& x) { -+ // x * sigmoid(x) -+ return (T)(((float)x) / (1.0f + sycl::exp((float)-x))); -+} -+ -+template -+__inline__ T gelu_xpu(const T& x) { -+ // Equivalent to PyTorch GELU with 'none' approximation. -+ // Refer to: -+ // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38 -+ const float f = (float) x; -+ constexpr float ALPHA = M_SQRT1_2; -+ return (T) (f * 0.5f * (1.0f + sycl::erf(f * ALPHA))); -+} -+ -+template -+__inline__ T gelu_tanh_xpu(const T& x) { -+ const float f = (float) x; -+ constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; -+ constexpr float KAPPA = 0.044715; -+ float x_cube = f * f * f; -+ float inner = BETA * (f + KAPPA * x_cube); -+ return (T) (0.5f * f * (1.0f + ::tanhf(inner))); -+} -+ -+template -+void silu_and_mul_kernel( -+ scalar_t* __restrict__ out, // [..., d] -+ const scalar_t* __restrict__ input, // [..., 2, d] -+ const int d, -+ const sycl::nd_item<3>& item_ct1) { -+ const int64_t token_idx = item_ct1.get_group(2); -+ for (int64_t idx = item_ct1.get_local_id(2); idx < d; -+ idx += item_ct1.get_local_range(2)) { -+ const scalar_t x = input[token_idx * 2 * d + idx]; -+ const scalar_t y = input[token_idx * 2 * d + d + idx]; -+ out[token_idx * d + idx] = silu_xpu(x) * y; -+ } -+} -+ -+template -+void gelu_and_mul_kernel( -+ scalar_t* __restrict__ out, // [..., d] -+ const scalar_t* __restrict__ input, // [..., 2, d] -+ const int d, -+ const sycl::nd_item<3>& item_ct1) { -+ const int64_t token_idx = item_ct1.get_group(2); -+ for (int64_t idx = item_ct1.get_local_id(2); idx < d; -+ idx += item_ct1.get_local_range(2)) { -+ const scalar_t x = input[token_idx * 2 * d + idx]; -+ const scalar_t y = input[token_idx * 2 * d + d + idx]; -+ out[token_idx * d + idx] = gelu_xpu(x) * y; -+ } -+} -+ -+template -+void gelu_tanh_and_mul_kernel( -+ scalar_t* __restrict__ out, // [..., d] -+ const scalar_t* __restrict__ input, // [..., 2, d] -+ const int d, -+ const sycl::nd_item<3>& item_ct1) { -+ const int64_t token_idx = item_ct1.get_group(2); -+ for (int64_t idx = item_ct1.get_local_id(2); idx < d; -+ idx += item_ct1.get_local_range(2)) { -+ const scalar_t x = input[token_idx * 2 * d + idx]; -+ const scalar_t y = input[token_idx * 2 * d + d + idx]; -+ out[token_idx * d + idx] = gelu_tanh_xpu(x) * y; -+ } -+} -+ -+ -+template -+void call_silu_and_mul_kernel( -+ int num_tokens, -+ int d, -+ const scalar_t* __restrict__ input, -+ scalar_t* __restrict__ output) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(d, 1024)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ silu_and_mul_kernel( -+ (sycl_t*)output, (const sycl_t*)input, d, item_ct1); -+ }); -+ }); -+} ++models = client.models.list() ++model = models.data[0].id + -+template -+void call_gelu_and_mul_kernel( -+ int num_tokens, -+ int d, -+ const scalar_t* __restrict__ input, -+ scalar_t* __restrict__ output) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(d, 1024)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ gelu_and_mul_kernel( -+ (sycl_t*)output, (const sycl_t*)input, d, item_ct1); -+ }); -+ }); -+} -+ -+template -+void call_gelu_tanh_and_mul_kernel( -+ int num_tokens, -+ int d, -+ const scalar_t* __restrict__ input, -+ scalar_t* __restrict__ output) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(d, 1024)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ gelu_tanh_and_mul_kernel( -+ (sycl_t*)output, (const sycl_t*)input, d, item_ct1); -+ }); -+ }); -+} ++# Round 1 ++messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] ++# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` ++# For Qwen3 series, if you want to disable thinking in reasoning mode, add: ++# extra_body={"chat_template_kwargs": {"enable_thinking": False}} ++response = client.chat.completions.create(model=model, messages=messages) + -+void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { -+ int num_tokens = input.numel() / input.size(-1); -+ int d = input.size(-1) / 2; -+ -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ input.scalar_type(), "call_silu_and_mul_kernel", [&] { -+ call_silu_and_mul_kernel( -+ num_tokens, -+ d, -+ input.data_ptr(), -+ out.data_ptr()); -+ }); -+} ++reasoning_content = response.choices[0].message.reasoning_content ++content = response.choices[0].message.content + -+// Element-wise activation kernel template. -+template -+void activation_kernel( -+ scalar_t* __restrict__ out, // [..., d] -+ const scalar_t* __restrict__ input, // [..., d] -+ const int d, -+ const sycl::nd_item<3>& item_ct1) { -+ const int64_t token_idx = item_ct1.get_group(2); -+ for (int64_t idx = item_ct1.get_local_id(2); idx < d; -+ idx += item_ct1.get_local_range(2)) { -+ const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); -+ out[token_idx * d + idx] = ACT_FN(x); -+ } -+} ++print("reasoning_content:", reasoning_content) ++print("content:", content) + -+template -+__inline__ T gelu_new_kernel(const T& x) { -+ const float x3 = (float)(x * x * x); -+ const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); -+ return ((T)0.5) * x * (((T)1.0) + t); -+} +diff --git a/examples/bmg/tooling.py b/examples/bmg/tooling.py +new file mode 100644 +index 000000000..bf8375831 +--- /dev/null ++++ b/examples/bmg/tooling.py +@@ -0,0 +1,37 @@ ++import json + -+template -+__inline__ T gelu_fast_kernel(const T& x) { -+ const float f = (float)x; -+ const T t = -+ (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); -+ return ((T)0.5) * x * (((T)1.0) + t); -+} ++client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="dummy") ++ ++def get_weather(location: str, unit: str): ++ return f"Getting the weather for {location} in {unit}..." ++tool_functions = {"get_weather": get_weather} ++ ++tools = [{ ++ "type": "function", ++ "function": { ++ "name": "get_weather", ++ "description": "Get the current weather in a given location", ++ "parameters": { ++ "type": "object", ++ "properties": { ++ "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, ++ "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} ++ }, ++ "required": ["location", "unit"] ++ } ++ } ++}] ++ ++response = client.chat.completions.create( ++ model=client.models.list().data[0].id, ++ messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}], ++ tools=tools, ++ temperature=0, ++ tool_choice="auto" ++) + -+template -+void call_gelu_new_activation_kernel(torch::Tensor& out, torch::Tensor& input) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ int d = input.size(-1); -+ int64_t num_tokens = input.numel() / d; -+ auto out_ptr = out.data_ptr(); -+ auto input_ptr = input.data_ptr(); -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(d, 1024)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ activation_kernel( -+ (sycl_t* __restrict__)out_ptr, -+ (const sycl_t* __restrict__)input_ptr, -+ d, -+ item_ct1); -+ }); -+ }); -+} ++tool_call = response.choices[0].message.tool_calls[0].function ++print(f"Function called: {tool_call.name}") ++print(f"Arguments: {tool_call.arguments}") ++print(f"Result: {tool_functions[tool_call.name](**json.loads(tool_call.arguments))}") 30,22 Bot ++ +diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py +index 36d805a32..2a4233b6a 100644 +--- a/examples/offline_inference/data_parallel.py ++++ b/examples/offline_inference/data_parallel.py +@@ -96,6 +96,13 @@ def parse_args(): + "--quantization", + type=str, + ) ++ parser.add_argument( ++ "--disable-expert-parallel", ++ dest="enable_expert_parallel", ++ action="store_false", ++ help="Disable expert parallel (default: enabled).", ++ ) ++ parser.set_defaults(enable_expert_parallel=True) + return parser.parse_args() + + +@@ -108,6 +115,7 @@ def main( + dp_master_port, + GPUs_per_dp_rank, + enforce_eager, ++ enable_expert_parallel, + trust_remote_code, + max_num_seqs, + max_model_len, +@@ -162,7 +170,7 @@ def main( + model=model, + tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, +- enable_expert_parallel=True, ++ enable_expert_parallel=enable_expert_parallel, + trust_remote_code=trust_remote_code, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, +@@ -222,6 +230,7 @@ if __name__ == "__main__": + dp_master_port, + tp_size, + args.enforce_eager, ++ args.enable_expert_parallel, + args.trust_remote_code, + args.max_num_seqs, + args.max_model_len, +diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py +index b104113b8..58fb423e8 100644 +--- a/examples/offline_inference/vision_language.py ++++ b/examples/offline_inference/vision_language.py +@@ -126,6 +126,23 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: + ) + + ++# Dots-OCR ++def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: ++ assert modality == "image" + -+template -+void call_gelu_fast_activation_kernel( -+ torch::Tensor& out, -+ torch::Tensor& input) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ int d = input.size(-1); -+ int64_t num_tokens = input.numel() / d; -+ auto out_ptr = out.data_ptr(); -+ auto input_ptr = input.data_ptr(); -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(d, 1024)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ activation_kernel( -+ (sycl_t* __restrict__)out_ptr, -+ (const sycl_t* __restrict__)input_ptr, -+ d, -+ item_ct1); -+ }); -+ }); -+} ++ prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] ++ engine_args = EngineArgs( ++ model="rednote-hilab/dots.ocr", ++ limit_mm_per_prompt={modality: 1}, ++ trust_remote_code=True, ++ ) + -+void gelu_new(torch::Tensor& out, torch::Tensor& input) { -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ out.scalar_type(), "call_gelu_new_activation_kernel", [&] { -+ call_gelu_new_activation_kernel(out, input); -+ }); -+} ++ return ModelRequestData( ++ engine_args=engine_args, ++ prompts=prompts, ++ ) + -+void gelu_fast(torch::Tensor& out, torch::Tensor& input) { -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ out.scalar_type(), "call_gelu_fast_activation_kernel", [&] { -+ call_gelu_fast_activation_kernel( -+ out, input); -+ }); -+} + -+void gelu_and_mul( -+ torch::Tensor& out, // [..., d] -+ torch::Tensor& input) // [..., 2 * d] -+{ -+ int num_tokens = input.numel() / input.size(-1); -+ int d = input.size(-1) / 2; -+ -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ input.scalar_type(), "call_gelu_and_mul_kernel", [&] { -+ call_gelu_and_mul_kernel( -+ num_tokens, -+ d, -+ input.data_ptr(), -+ out.data_ptr()); -+ }); -+} + def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + +@@ -1431,7 +1448,9 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, +- limit_mm_per_prompt={modality: 1}, ++ limit_mm_per_prompt={"image": 1}, ++ enforce_eager=True, ++ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) + + if modality == "image": +@@ -1497,6 +1516,80 @@ def run_qwen2_5_omni(questions: list[str], modality: str): + ) + + ++# Qwen3-VL-Dense ++def run_qwen3_vl(questions: list[str], modality: str) -> ModelRequestData: ++ model_name = "Qwen/Qwen3-VL-4B-Instruct" + -+void gelu_tanh_and_mul( -+ torch::Tensor& out, // [..., d] -+ torch::Tensor& input) // [..., 2 * d] -+{ -+ int num_tokens = input.numel() / input.size(-1); -+ int d = input.size(-1) / 2; -+ -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ input.scalar_type(), "call_gelu_tanh_and_mul_kernel", [&] { -+ call_gelu_tanh_and_mul_kernel( -+ num_tokens, -+ d, -+ input.data_ptr(), -+ out.data_ptr()); -+ }); -+} -\ No newline at end of file -diff --git a/csrc/xpu/attention_generic.h b/csrc/xpu/attention_generic.h -new file mode 100644 -index 000000000..ab3688c82 ---- /dev/null -+++ b/csrc/xpu/attention_generic.h -@@ -0,0 +1,64 @@ -+/* -+ * Copyright (c) 2023, The vLLM team. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+#pragma once -+ -+#include -+#include -+#include -+ -+namespace vllm { -+ -+// A vector type to store Q, K, V elements. -+template -+struct Vec {}; -+ -+// A vector type to store FP32 accumulators. -+template -+struct FloatVec {}; -+ -+// Template vector operations. -+template -+inline Acc mul(A a, B b); -+ -+template -+inline float sum(T v); -+ -+template -+inline float dot(T a, T b) { -+ return sum(mul(a, b)); -+} ++ engine_args = EngineArgs( ++ model=model_name, ++ max_model_len=4096, ++ max_num_seqs=5, ++ mm_processor_kwargs={ ++ "min_pixels": 28 * 28, ++ "max_pixels": 1280 * 28 * 28, ++ "fps": 1, ++ }, ++ limit_mm_per_prompt={modality: 1}, ++ ) + -+template -+inline float dot(T a, T b) { -+ return sum(mul(a, b)); -+} ++ if modality == "image": ++ placeholder = "<|image_pad|>" ++ elif modality == "video": ++ placeholder = "<|video_pad|>" + -+template -+inline void zero(T& dst) { -+ constexpr int WORDS = (sizeof(T) / 4) == 0 ? 1 : (sizeof(T) / 4); -+ union { -+ T raw; -+ uint32_t words[WORDS]; -+ } tmp; -+ -+#pragma unroll -+ for (int ii = 0; ii < WORDS; ++ii) { -+ tmp.words[ii] = 0u; -+ } -+ dst = tmp.raw; -+} ++ prompts = [ ++ ( ++ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" ++ f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" ++ f"{question}<|im_end|>\n" ++ "<|im_start|>assistant\n" ++ ) ++ for question in questions ++ ] + -+} // namespace vllm -\ No newline at end of file -diff --git a/csrc/xpu/attention_xpu.cpp b/csrc/xpu/attention_xpu.cpp -new file mode 100644 -index 000000000..97d5c0c21 ---- /dev/null -+++ b/csrc/xpu/attention_xpu.cpp -@@ -0,0 +1,3031 @@ -+// clang-format off -+#ifdef VLLM_DEV -+#undef __SYCL_DEVICE_ONLY__ -+#endif -+#include -+#include -+#include -+ -+// clang-format on -+#include -+#include -+#include -+#include "utils.h" -+#include "xpu_types.h" -+// #include "dtype_bfloat16.dp.hpp" -+#include "dtype_float16.h" -+#include "dtype_float32.h" -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+#include -+#endif -+ -+#include -+// #include -+ -+#define WARP_SIZE 32 -+#define MAX(a, b) ((a) > (b) ? (a) : (b)) -+#define MIN(a, b) ((a) < (b) ? (a) : (b)) -+#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b)) -+using namespace sycl::ext::intel::esimd; -+ -+template -+static inline T attn_softcapping(T qk, float attn_logit_softcapping) { -+ qk = qk / attn_logit_softcapping; -+ qk = (sycl::exp(qk) - sycl::exp(-qk)) / (sycl::exp(qk) + sycl::exp(-qk)); -+ qk = qk * attn_logit_softcapping; -+ return qk; -+} ++ return ModelRequestData( ++ engine_args=engine_args, ++ prompts=prompts, ++ ) + -+template -+struct Float_Trait { -+ using Type = T; -+}; -+ -+template <> -+struct Float_Trait { -+ using Type = uint16_t; -+}; -+ -+template <> -+struct Float_Trait { -+ using Type = sycl::ext::oneapi::bfloat16; -+}; -+ -+namespace vllm { -+ -+// Q*K^T operation. -+template -+inline float qk_dot_( -+ const Vec* q, -+ const Vec* k, -+ const sycl::nd_item<3>& item_ct1) { -+ using A_vec = typename FloatVec::Type; -+ // Compute the parallel products for Q*K^T (treat vector lanes separately). -+ A_vec qk_vec = mul(q[0], k[0]); -+#pragma unroll -+ for (int ii = 1; ii < N; ++ii) { -+ qk_vec = fma(q[ii], k[ii], qk_vec); -+ } -+ -+ // Finalize the reduction across lanes. -+ float qk = sum(qk_vec); -+#pragma unroll -+ for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { -+ -+ qk += dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), qk, mask); -+ } -+ return qk; -+} + -+template -+struct Qk_dot { -+ template -+ static inline float dot( -+ const Vec* q, -+ const Vec* k, -+ const sycl::nd_item<3>& item_ct1) { -+ return qk_dot_(q, k, item_ct1); -+ } -+}; -+ -+template -+inline float block_sum( -+ float* red_smem, -+ float sum, -+ const sycl::nd_item<3>& item_ct1) { -+ // Decompose the thread index into warp / lane. -+ int warp = item_ct1.get_local_id(2) / WARP_SIZE; -+ int lane = item_ct1.get_local_id(2) % WARP_SIZE; -+ -+ // Compute the sum per warp. -+#pragma unroll -+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -+ -+ /* -+ DPCT1096:42: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU -+ device. Modify the size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ sum += dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), sum, mask); -+ } -+ -+ // Warp leaders store the data to shared memory. -+ if (lane == 0) { -+ red_smem[warp] = sum; -+ } -+ -+ // Make sure the data is in shared memory. -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); ++# Qwen3-VL-MOE ++def run_qwen3_vl_moe(questions: list[str], modality: str) -> ModelRequestData: ++ model_name = "Qwen/Qwen3-VL-30B-A3B-Instruct" + -+ // The warps compute the final sums. -+ if (lane < NUM_WARPS) { -+ sum = red_smem[lane]; -+ } ++ engine_args = EngineArgs( ++ model=model_name, ++ max_model_len=4096, ++ max_num_seqs=5, ++ mm_processor_kwargs={ ++ "min_pixels": 28 * 28, ++ "max_pixels": 1280 * 28 * 28, ++ "fps": 1, ++ }, ++ limit_mm_per_prompt={modality: 1}, ++ ) + -+ // Parallel reduction inside the warp. -+#pragma unroll -+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -+ -+ /* -+ DPCT1096:43: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU -+ device. Modify the size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ sum += dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), sum, mask); -+ } -+ -+ // Broadcast to other threads. -+ -+ /* -+ DPCT1096:44: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::select_from_sub_group" may return an unexpected result on the CPU -+ device. Modify the size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ return dpct::select_from_sub_group( -+ item_ct1.get_sub_group(), sum, 0); -+} ++ if modality == "image": ++ placeholder = "<|image_pad|>" ++ elif modality == "video": ++ placeholder = "<|video_pad|>" + -+template -+void context_attention_kernel_v1_reshaped( -+ void* query, void* key, void* value, const void* block_tables, -+ const float scale, const void* query_start_loc, const void* seq_lens, -+ const void* context_lens, const int block_size, -+ // const int x, // x in kv_cache -+ void* out, // output -+ const int block_table_stride_batch, const int block_table_stride_seq, -+ const int query_stride_bs, const int query_stride_head, -+ const int query_stride_dim, const int k_cache_stride_tokens, -+ const int k_cache_stride_head, const int k_cache_stride_block_size, -+ const int k_cache_stride_dim, -+ const int v_cache_stride_tokens, const int v_cache_stride_head, -+ const int v_cache_stride_block_size, const int v_cache_stride_dim, -+ const int out_stride_tokens, const int out_stride_head, -+ const int num_queries_per_kv, const int max_input_length, -+ const int batch_size, const int num_heads) { -+ static_assert(GS * HD * sizeof(scalar_t) * 2 < 64 * 1024); -+ -+ const size_t key_slm_offset = 0; -+ const size_t value_slm_offset = GS * HD * sizeof(scalar_t); -+ sycl::queue& queue = vllm::xpu::vllmGetQueue(); -+ -+ // Get the maximum seq_lens -+ sycl::range<3> global_size(batch_size, num_heads, -+ (max_input_length + GS - 1) / GS * GS); -+ sycl::range<3> local_size(1, 1, GS); -+ -+ auto cgf = [&](sycl::handler& handle) { -+ handle.parallel_for( -+ sycl::nd_range<3>(global_size, local_size), -+ [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { -+ slm_init(); -+ -+ const size_t bsz_idx = item.get_global_id(0); -+ const size_t head_idx = item.get_global_id(1); -+ // Assuming we have 32 query head and 8 kv_heads. Then -+ // num_queries_per_group should be 4 For head_idx 13, then -+ // kv_head_idx = 13 / 4 = 3, which is correct -+ const size_t kv_head_idx = head_idx / num_queries_per_kv; -+ const int32_t seq_idx = item.get_global_id(2); -+ const size_t gid = item.get_group(2); -+ const size_t tid = item.get_local_id(2); -+ -+ // const int64_t * seq_len = (const int64_t *) seq_lens; -+ const int32_t* seq_len = (const int32_t*)seq_lens; -+ int32_t seq_bound = seq_len[bsz_idx]; -+ -+ const int32_t* query_loc = (const int32_t*)query_start_loc; -+ // There is a possibility that the current token index pass -+ // over the seq_len, therefore: token_idx is the position in -+ // the query -+ int32_t token_idx = -+ query_loc[bsz_idx] + std::min(seq_idx, seq_bound - 1); -+ -+ const int32_t* context_len_pointer = (const int32_t*)context_lens; -+ -+ const int* block_tables_ptr = (const int*)block_tables; -+ const int* block_table = -+ block_tables_ptr + bsz_idx * block_table_stride_batch; -+ // I guess this context_len should be 0... -+ const int32_t context_len = context_len_pointer[bsz_idx]; -+ -+ // Position in the sequence -+ // context + seq_idx -+ // const int32_t token_position = -+ // context_len + std::min(seq_idx, seq_bound - 1); -+ const int32_t token_position = context_len + seq_idx; -+ -+ const scalar_t* query_head = (const scalar_t*)query + -+ token_idx * query_stride_bs + -+ head_idx * query_stride_head; -+ // Target output -+ scalar_t* out_head = -+ (scalar_t*)out + -+ (query_loc[bsz_idx] + seq_idx) * out_stride_tokens + -+ head_idx * out_stride_head; -+ -+ int32_t context_groups = context_len / GS; -+ -+ // Each token load its query_row -+ simd query_row = -+ block_load(query_head) * scale; -+ simd accv = 0; -+ simd softmaxv = 0; -+ scalar_t max_attn = -sycl::detail::max_v(); -+ -+ // ################# Handle n * GS context part ###################### -+ int32_t n = context_len / GS; -+ int32_t context_offset = context_len % GS; -+ -+ for (int32_t group = 0; group < n; ++group) { -+ size_t target_key_position = group * GS + tid; -+ int which_block = target_key_position / block_size; -+ int which_slot = target_key_position % block_size; -+ -+ int physical_block_number = block_table[which_block]; -+ // Now key shape is [num_blocks, num_heads, block_size, head_dim] -+ const scalar_t* key_head = -+ (const scalar_t*)key + -+ physical_block_number * k_cache_stride_tokens + -+ kv_head_idx * k_cache_stride_head + -+ which_slot * k_cache_stride_block_size; -+ simd key_row = block_load(key_head); -+ slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t), key_row); -+ -+ const scalar_t* value_head = -+ (const scalar_t*)value + -+ physical_block_number * v_cache_stride_tokens + -+ kv_head_idx * v_cache_stride_head + which_slot * v_cache_stride_block_size; -+ simd value_row = block_load(value_head); -+ slm_block_store(value_slm_offset + tid * HD * sizeof(scalar_t), -+ value_row); -+ barrier(); -+ -+ // Calculate QK^T for this group... -+ simd attnv; -+#pragma unroll -+ for (size_t r = 0; r < GS; ++r) { -+ simd key_row = slm_block_load( -+ key_slm_offset + r * HD * sizeof(scalar_t)); -+ scalar_t attn = -+ sycl::ext::intel::esimd::detail::sum( -+ query_row * key_row); -+ attnv[r] = attn; -+ } -+ scalar_t new_max_attn = -+ std::max(hmax(attnv), max_attn); -+ scalar_t attn_exp = exp(max_attn - new_max_attn); -+ accv = accv * attn_exp; -+ softmaxv = softmaxv * attn_exp; -+ max_attn = new_max_attn; -+ const simd attn_expv = exp(attnv - max_attn); -+#pragma unorll -+ for (size_t r = 0; r < GS; ++r) { -+ simd value_row = slm_block_load( -+ value_slm_offset + r * HD * sizeof(scalar_t)); -+ accv += value_row * attn_expv[r]; -+ } -+ softmaxv += attn_expv; -+ barrier(); -+ } -+ -+ // ########## End for handling context n * GS part ########### -+ -+ // ########## Handle n * GS ################ -+ for (size_t group = 0; group < gid; ++group) { -+ // 1. begins to load each position's key and value -+ size_t target_key_position = context_len + group * GS + tid; -+ int which_block = target_key_position / block_size; -+ int which_slot = target_key_position % block_size; -+ -+ int physical_block_number = block_table[which_block]; -+ const scalar_t* key_head = -+ (const scalar_t*)key + -+ physical_block_number * k_cache_stride_tokens + -+ kv_head_idx * k_cache_stride_head + -+ which_slot * k_cache_stride_block_size; -+ simd key_row = block_load(key_head); -+ slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t), -+ key_row); -+ const scalar_t* value_head = -+ (const scalar_t*)value + -+ physical_block_number * v_cache_stride_tokens + -+ kv_head_idx * v_cache_stride_head + which_slot * v_cache_stride_block_size; -+ simd value_row = block_load(value_head); -+ slm_block_store(value_slm_offset + tid * HD * sizeof(scalar_t), -+ value_row); -+ barrier(); -+ simd attnv; -+#pragma unroll -+ for (size_t r = 0; r < GS; ++r) { -+ simd key_row = slm_block_load( -+ key_slm_offset + r * HD * sizeof(scalar_t)); -+ scalar_t attn = -+ sycl::ext::intel::esimd::detail::sum( -+ query_row * key_row); -+ attnv[r] = attn; -+ } ++ prompts = [ ++ ( ++ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" ++ f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" ++ f"{question}<|im_end|>\n" ++ "<|im_start|>assistant\n" ++ ) ++ for question in questions ++ ] + -+ scalar_t new_max_attn = -+ std::max(hmax(attnv), max_attn); -+ scalar_t attn_exp = exp(max_attn - new_max_attn); -+ accv = accv * attn_exp; -+ -+ softmaxv = softmaxv * attn_exp; -+ max_attn = new_max_attn; -+ const simd attn_expv = exp(attnv - max_attn); -+#pragma unroll -+ for (size_t r = 0; r < GS; ++r) { -+ simd value_row = slm_block_load( -+ value_slm_offset + r * HD * sizeof(scalar_t)); -+ accv += value_row * attn_expv[r]; -+ } -+ softmaxv += attn_expv; -+ barrier(); -+ } -+ -+ // ######### End of handle n * GS part ########## -+ -+ // ################ Handle offset part #################### -+ scalar_t softmax = -+ sycl::ext::intel::esimd::detail::sum( -+ softmaxv); -+ -+ // ########### handle context offset ############ -+ if (tid < context_offset) { -+ size_t target_key_position = n * GS + tid; -+ int which_block = target_key_position / block_size; -+ int which_slot = target_key_position % block_size; -+ -+ int physical_block_number = block_table[which_block]; -+ const scalar_t* key_head = -+ (const scalar_t*)key + -+ physical_block_number * k_cache_stride_tokens + -+ kv_head_idx * k_cache_stride_head + -+ which_slot * k_cache_stride_block_size; -+ simd key_row = block_load(key_head); -+ slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t), -+ key_row); -+ -+ const scalar_t* value_head = -+ (const scalar_t*)value + -+ physical_block_number * v_cache_stride_tokens + -+ kv_head_idx * v_cache_stride_head + -+ which_slot * v_cache_stride_block_size; -+ simd value_row = block_load(value_head); -+ slm_block_store(value_slm_offset + tid * HD * sizeof(scalar_t), -+ value_row); -+ } -+ -+ barrier(); -+ -+ if (token_position < seq_bound) { -+#pragma unroll -+ for (size_t r = 0; r < context_offset; ++r) { -+ simd key_row = slm_block_load( -+ key_slm_offset + r * HD * sizeof(scalar_t)); -+ simd value_row = slm_block_load( -+ value_slm_offset + r * HD * sizeof(scalar_t)); -+ scalar_t attn = -+ sycl::ext::intel::esimd::detail::sum( -+ query_row * key_row); -+ if (attn <= max_attn) { -+ scalar_t attn_exp = -+ sycl::ext::intel::esimd::exp(attn - max_attn); -+ accv += value_row * attn_exp; -+ softmax += attn_exp; -+ } else { -+ scalar_t attn_exp = -+ sycl::ext::intel::esimd::exp(max_attn - attn); -+ accv = accv * attn_exp + value_row; -+ softmax = softmax * attn_exp + 1; -+ max_attn = attn; -+ } -+ } -+ } -+ barrier(); -+ -+ // ############## handle seq offset ################# -+ if (token_position < seq_bound) { -+ const int64_t which_block = -+ static_cast(token_position / block_size); -+ const int64_t which_slot = -+ static_cast(token_position % block_size); -+ -+ const int64_t physical_block_number = -+ static_cast(block_table[which_block]); -+ -+ const scalar_t* key_head = -+ (const scalar_t*)key + -+ physical_block_number * k_cache_stride_tokens + -+ kv_head_idx * k_cache_stride_head + -+ which_slot * k_cache_stride_block_size; -+ simd key_row = block_load(key_head); -+ slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t), -+ key_row); -+ -+ // [num_blocks, num_kv_heads, head_size, block_size] -+ const scalar_t* value_head = -+ (const scalar_t*)value + -+ physical_block_number * v_cache_stride_tokens + -+ kv_head_idx * v_cache_stride_head + -+ which_slot * v_cache_stride_block_size; -+ simd value_row = block_load(value_head); -+ slm_block_store(value_slm_offset + tid * HD * sizeof(scalar_t), -+ value_row); -+ } -+ barrier(); -+ -+ if (token_position < seq_bound) { -+ for (size_t r = 0; r <= tid; ++r) { -+ simd key_row = slm_block_load( -+ key_slm_offset + r * HD * sizeof(scalar_t)); -+ simd value_row = slm_block_load( -+ value_slm_offset + r * HD * sizeof(scalar_t)); -+ scalar_t attn = -+ sycl::ext::intel::esimd::detail::sum( -+ query_row * key_row); -+ if (attn <= max_attn) { -+ scalar_t attn_exp = -+ sycl::ext::intel::esimd::exp(attn - max_attn); -+ accv += value_row * attn_exp; -+ softmax += attn_exp; -+ } else { -+ scalar_t attn_exp = -+ sycl::ext::intel::esimd::exp(max_attn - attn); -+ accv = accv * attn_exp + value_row; -+ softmax = softmax * attn_exp + 1; -+ max_attn = attn; -+ } -+ } ++ return ModelRequestData( ++ engine_args=engine_args, ++ prompts=prompts, ++ ) + -+ if (softmax > 0) { -+ simd result = accv / softmax; -+ block_store(out_head, result); -+ } else { -+ simd result = 0; -+ block_store(out_head, result); -+ } -+ } -+ // ######## Ending of handling seq offset ########## -+ }); -+ }; -+ queue.submit(cgf); -+} + -+// How about implement a first edition that can be used with non-chunked -+// prefill requests, so that we can make sure the reference for heads is -+// correct -+template -+void context_attention_kernel_v1( -+ void* query, void* key, void* value, const void* block_tables, -+ const float scale, const void* query_start_loc, const void* seq_lens, -+ const void* context_lens, const int block_size, -+ const int x, // x in kv_cache -+ void* out, // output -+ const int block_table_stride_batch, const int block_table_stride_seq, -+ const int query_stride_bs, const int query_stride_head, -+ const int query_stride_dim, const int k_cache_stride_tokens, -+ const int k_cache_stride_head, const int k_cache_stride_dim, -+ const int k_cache_stride_block_size, const int k_cache_stride_x, -+ const int v_cache_stride_tokens, const int v_cache_stride_head, -+ const int v_cache_stride_dim, const int v_cache_stride_block_size, -+ const int out_stride_tokens, const int out_stride_head, -+ const int num_queries_per_kv, const int max_input_length, -+ const int batch_size, const int num_heads) { -+ static_assert(GS * HD * sizeof(scalar_t) * 2 < 64 * 1024); -+ -+ const size_t key_slm_offset = 0; -+ const size_t value_slm_offset = GS * HD * sizeof(scalar_t); -+ sycl::queue& queue = vllm::xpu::vllmGetQueue(); -+ -+ // Get the maximum seq_lens -+ sycl::range<3> global_size(batch_size, num_heads, -+ (max_input_length + GS - 1) / GS * GS); -+ sycl::range<3> local_size(1, 1, GS); -+ -+ auto cgf = [&](sycl::handler& handle) { -+ handle.parallel_for( -+ sycl::nd_range<3>(global_size, local_size), -+ [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { -+ slm_init(); -+ -+ const size_t bsz_idx = item.get_global_id(0); -+ const size_t head_idx = item.get_global_id(1); -+ // Assuming we have 32 query head and 8 kv_heads. Then -+ // num_queries_per_group should be 4 For head_idx 13, then -+ // kv_head_idx = 13 / 4 = 3, which is correct -+ const size_t kv_head_idx = head_idx / num_queries_per_kv; -+ const int32_t seq_idx = item.get_global_id(2); -+ const size_t gid = item.get_group(2); -+ const size_t tid = item.get_local_id(2); -+ -+ // const int64_t * seq_len = (const int64_t *) seq_lens; -+ const int32_t* seq_len = (const int32_t*)seq_lens; -+ int32_t seq_bound = seq_len[bsz_idx]; -+ -+ const int32_t* query_loc = (const int32_t*)query_start_loc; -+ // There is a possibility that the current token index pass -+ // over the seq_len, therefore: token_idx is the position in -+ // the query -+ int32_t token_idx = -+ query_loc[bsz_idx] + std::min(seq_idx, seq_bound - 1); -+ -+ const int32_t* context_len_pointer = (const int32_t*)context_lens; -+ -+ const int* block_tables_ptr = (const int*)block_tables; -+ const int* block_table = -+ block_tables_ptr + bsz_idx * block_table_stride_batch; -+ // I guess this context_len should be 0... -+ const int32_t context_len = context_len_pointer[bsz_idx]; -+ -+ // Position in the sequence -+ // context + seq_idx -+ // const int32_t token_position = -+ // context_len + std::min(seq_idx, seq_bound - 1); -+ const int32_t token_position = context_len + seq_idx; -+ -+ // static const CONSTANT char FMT[] = -+ // "Invoke target function...\n "; -+ -+ // sycl::ext::oneapi::experimental::printf(FMT); -+ // static const CONSTANT char FMT[] = -+ // "GroupID = %6d bsz_idx = %6d seq_len = %6d seq_idx = -+ // %6d" "local_id = " -+ // "%6d " -+ // "token_idx = %6d " -+ // "context_len = %6d " -+ // "v_cache_stride_head_dim = %6d " -+ // "token_position = %6d\n"; -+ // sycl::ext::oneapi::experimental::printf( -+ // FMT, gid, bsz_idx, seq_bound, seq_idx, tid, -+ // token_idx, context_len, v_cache_stride_dim, -+ // token_position); -+ -+ const scalar_t* query_head = (const scalar_t*)query + -+ token_idx * query_stride_bs + -+ head_idx * query_stride_head; -+ // Target output -+ scalar_t* out_head = -+ (scalar_t*)out + -+ (query_loc[bsz_idx] + seq_idx) * out_stride_tokens + -+ head_idx * out_stride_head; -+ -+ int32_t context_groups = context_len / GS; -+ -+ // Each token load its query_row -+ simd query_row = -+ block_load(query_head) * scale; -+ simd accv = 0; -+ simd softmaxv = 0; -+ scalar_t max_attn = -sycl::detail::max_v(); -+ -+ // ################# Handle n * GS context part ###################### -+ int32_t n = context_len / GS; -+ int32_t context_offset = context_len % GS; -+ -+ for (int32_t group = 0; group < n; ++group) { -+ size_t target_key_position = group * GS + tid; -+ int which_block = target_key_position / block_size; -+ int which_slot = target_key_position % block_size; -+ -+ int physical_block_number = block_table[which_block]; -+ const scalar_t* key_head = -+ (const scalar_t*)key + -+ physical_block_number * k_cache_stride_tokens + -+ kv_head_idx * k_cache_stride_head + -+ which_slot * k_cache_stride_block_size; -+ for (int i = 0; i < HD / x; i++) { -+ // Load 8 elements, decided by x -+ simd key_row = -+ block_load(key_head + i * k_cache_stride_dim); -+ slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + -+ 8 * i * sizeof(scalar_t), -+ key_row); -+ } + # R-4B + def run_r_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" +@@ -1662,6 +1755,7 @@ model_example_map = { + "aya_vision": run_aya_vision, + "blip-2": run_blip2, + "chameleon": run_chameleon, ++ "dots_ocr": run_dots_ocr, + "command_a_vision": run_command_a_vision, + "deepseek_vl_v2": run_deepseek_vl2, + "ernie45_vl": run_ernie45_vl, +@@ -1707,6 +1801,8 @@ model_example_map = { + "qwen2_vl": run_qwen2_vl, + "qwen2_5_vl": run_qwen2_5_vl, + "qwen2_5_omni": run_qwen2_5_omni, ++ "qwen3_vl": run_qwen3_vl, ++ "qwen3_vl_moe": run_qwen3_vl_moe, + "rvl": run_r_vl, + "skywork_chat": run_skyworkr1v, + "smolvlm": run_smolvlm, +@@ -1716,6 +1812,15 @@ model_example_map = { + } + + ++MODELS_NEED_VIDEO_METADATA = [ ++ "glm4_1v", ++ "glm4_5v", ++ "glm4_5v_fp8", ++ "qwen3_vl", ++ "qwen3_vl_moe", ++] + -+ const scalar_t* value_head = -+ (const scalar_t*)value + -+ physical_block_number * v_cache_stride_tokens + -+ kv_head_idx * v_cache_stride_head + which_slot; -+ for (int i = 0; i < HD; i++) { -+ scalar_t temp_value = value_head[i * v_cache_stride_dim]; -+ slm_scalar_store(value_slm_offset + -+ tid * HD * sizeof(scalar_t) + -+ i * sizeof(scalar_t), -+ temp_value); -+ } -+ barrier(); -+ -+ // Calculate QK^T for this group... -+ simd attnv; -+#pragma unroll -+ for (size_t r = 0; r < GS; ++r) { -+ simd key_row = slm_block_load( -+ key_slm_offset + r * HD * sizeof(scalar_t)); -+ scalar_t attn = -+ sycl::ext::intel::esimd::detail::sum( -+ query_row * key_row); -+ attnv[r] = attn; -+ } -+ scalar_t new_max_attn = -+ std::max(hmax(attnv), max_attn); -+ scalar_t attn_exp = exp(max_attn - new_max_attn); -+ accv = accv * attn_exp; -+ softmaxv = softmaxv * attn_exp; -+ max_attn = new_max_attn; -+ const simd attn_expv = exp(attnv - max_attn); -+#pragma unorll -+ for (size_t r = 0; r < GS; ++r) { -+ simd value_row = slm_block_load( -+ value_slm_offset + r * HD * sizeof(scalar_t)); -+ accv += value_row * attn_expv[r]; -+ } -+ softmaxv += attn_expv; -+ barrier(); -+ } -+ -+ // ########## End for handling context n * GS part ########### -+ -+ // ########## Handle n * GS ################ -+ for (size_t group = 0; group < gid; ++group) { -+ // 1. begins to load each position's key and value -+ size_t target_key_position = context_len + group * GS + tid; -+ int which_block = target_key_position / block_size; -+ int which_slot = target_key_position % block_size; -+ -+ int physical_block_number = block_table[which_block]; -+ const scalar_t* key_head = -+ (const scalar_t*)key + -+ physical_block_number * k_cache_stride_tokens + -+ kv_head_idx * k_cache_stride_head + -+ which_slot * k_cache_stride_block_size; -+ for (int i = 0; i < HD / x; i++) { -+ // Load 8 elements -+ simd key_row = -+ block_load(key_head + i * k_cache_stride_dim); -+ slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + -+ 8 * i * sizeof(scalar_t), -+ key_row); -+ } + -+ const scalar_t* value_head = -+ (const scalar_t*)value + -+ physical_block_number * v_cache_stride_tokens + -+ kv_head_idx * v_cache_stride_head + which_slot; -+ for (int i = 0; i < HD; i++) { -+ scalar_t temp_value = value_head[i * v_cache_stride_dim]; -+ slm_scalar_store(value_slm_offset + -+ tid * HD * sizeof(scalar_t) + -+ i * sizeof(scalar_t), -+ temp_value); -+ } -+ barrier(); -+ simd attnv; -+#pragma unroll -+ for (size_t r = 0; r < GS; ++r) { -+ simd key_row = slm_block_load( -+ key_slm_offset + r * HD * sizeof(scalar_t)); -+ scalar_t attn = -+ sycl::ext::intel::esimd::detail::sum( -+ query_row * key_row); -+ attnv[r] = attn; -+ } + def get_multi_modal_input(args): + """ + return { +diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py +index 01c2905cf..2649c992b 100644 +--- a/examples/offline_inference/vision_language_multi_image.py ++++ b/examples/offline_inference/vision_language_multi_image.py +@@ -982,12 +982,14 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: + ) + smart_resize = None + +- model_name = "Qwen/Qwen2.5-VL-3B-Instruct" ++ model_name = "Qwen/Qwen2.5-VL-7B-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=32768 if smart_resize is None else 4096, +- max_num_seqs=5, ++ max_num_seqs=2, ++ enforce_eager=True, ++ gpu_memory_utilization=0.8, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + +diff --git a/examples/online_serving/structured_outputs/structured_outputs.py b/examples/online_serving/structured_outputs/structured_outputs.py +index 2a8f46372..990b47f22 100644 +--- a/examples/online_serving/structured_outputs/structured_outputs.py ++++ b/examples/online_serving/structured_outputs/structured_outputs.py +@@ -225,7 +225,7 @@ async def cli(): + ) + args = parser.parse_args() + +- base_url = os.getenv("OPENAI_BASE_URL", "http://localhost:8000/v1") ++ base_url = os.getenv("OPENAI_BASE_URL", "http://0.0.0.0:8000/v1") + client = openai.AsyncOpenAI(base_url=base_url, api_key="EMPTY") + constraints = list(PARAMS) if "*" in args.constraint else list(set(args.constraint)) + model = (await client.models.list()).data[0].id +@@ -236,6 +236,7 @@ async def cli(): + client.chat.completions.create( + model=model, + max_tokens=1024, ++ temperature=0, + stream=True, + **PARAMS[name], + ) +@@ -250,6 +251,7 @@ async def cli(): + client.chat.completions.create( + model=model, + max_tokens=1024, ++ temperature=0, + stream=False, + **PARAMS[name], + ) +diff --git a/requirements/common.txt b/requirements/common.txt +index b8665104b..a52745f69 100644 +--- a/requirements/common.txt ++++ b/requirements/common.txt +@@ -24,7 +24,7 @@ outlines_core == 0.2.11 + # required for outlines backend disk cache + diskcache == 5.6.3 + lark == 1.2.2 +-xgrammar == 0.1.23; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" ++xgrammar == 0.1.25; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" + typing_extensions >= 4.10 + filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 + partial-json-parser # used for parsing partial JSON outputs +diff --git a/requirements/xpu.txt b/requirements/xpu.txt +index 74f5b05b2..c0203a754 100644 +--- a/requirements/xpu.txt ++++ b/requirements/xpu.txt +@@ -11,9 +11,10 @@ jinja2>=3.1.6 + datasets # for benchmark scripts + numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding + nixl==0.3.0 # for PD disaggregation + -+ scalar_t new_max_attn = -+ std::max(hmax(attnv), max_attn); -+ scalar_t attn_exp = exp(max_attn - new_max_attn); -+ accv = accv * attn_exp; -+ -+ softmaxv = softmaxv * attn_exp; -+ max_attn = new_max_attn; -+ const simd attn_expv = exp(attnv - max_attn); -+#pragma unroll -+ for (size_t r = 0; r < GS; ++r) { -+ simd value_row = slm_block_load( -+ value_slm_offset + r * HD * sizeof(scalar_t)); -+ accv += value_row * attn_expv[r]; -+ } -+ softmaxv += attn_expv; -+ barrier(); -+ } -+ -+ // ######### End of handle n * GS part ########## -+ -+ // ################ Handle offset part #################### -+ scalar_t softmax = -+ sycl::ext::intel::esimd::detail::sum( -+ softmaxv); -+ -+ // ########### handle context offset ############ -+ if (tid < context_offset) { -+ size_t target_key_position = n * GS + tid; -+ int which_block = target_key_position / block_size; -+ int which_slot = target_key_position % block_size; -+ -+ int physical_block_number = block_table[which_block]; -+ const scalar_t* key_head = -+ (const scalar_t*)key + -+ physical_block_number * k_cache_stride_tokens + -+ kv_head_idx * k_cache_stride_head + -+ which_slot * k_cache_stride_block_size; -+ for (int i = 0; i < HD / x; i++) { -+ // Load 8 elements -+ simd key_row = -+ block_load(key_head + i * k_cache_stride_dim); -+ slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + -+ 8 * i * sizeof(scalar_t), -+ key_row); -+ } + torch==2.8.0+xpu + torchaudio + torchvision + --extra-index-url=https://download.pytorch.org/whl/xpu + +-intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.8.10.post0%2Bxpu-cp312-cp312-linux_x86_64.whl ++intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.8.10.post1%2Bxpu-cp312-cp312-linux_x86_64.whl +diff --git a/setup.py b/setup.py +index 67f65d9b9..eb313b7d2 100644 +--- a/setup.py ++++ b/setup.py +@@ -56,6 +56,8 @@ elif (sys.platform.startswith("linux") and torch.version.cuda is None + # fallback to cpu + VLLM_TARGET_DEVICE = "cpu" + ++MAIN_CUDA_VERSION = "12.8" + -+ const scalar_t* value_head = -+ (const scalar_t*)value + -+ physical_block_number * v_cache_stride_tokens + -+ kv_head_idx * v_cache_stride_head + which_slot; -+ for (int i = 0; i < HD; i++) { -+ // Seems to have an error here -+ scalar_t temp_value = value_head[i * v_cache_stride_dim]; -+ slm_scalar_store(value_slm_offset + -+ tid * HD * sizeof(scalar_t) + -+ i * sizeof(scalar_t), -+ temp_value); -+ } -+ } -+ -+ barrier(); -+ -+ if (token_position < seq_bound) { -+#pragma unroll -+ for (size_t r = 0; r < context_offset; ++r) { -+ simd key_row = slm_block_load( -+ key_slm_offset + r * HD * sizeof(scalar_t)); -+ simd value_row = slm_block_load( -+ value_slm_offset + r * HD * sizeof(scalar_t)); -+ scalar_t attn = -+ sycl::ext::intel::esimd::detail::sum( -+ query_row * key_row); -+ if (attn <= max_attn) { -+ scalar_t attn_exp = -+ sycl::ext::intel::esimd::exp(attn - max_attn); -+ accv += value_row * attn_exp; -+ softmax += attn_exp; -+ } else { -+ scalar_t attn_exp = -+ sycl::ext::intel::esimd::exp(max_attn - attn); -+ accv = accv * attn_exp + value_row; -+ softmax = softmax * attn_exp + 1; -+ max_attn = attn; -+ } -+ } -+ } -+ barrier(); -+ -+ // ############## handle seq offset ################# -+ if (token_position < seq_bound) { -+ const int64_t which_block = -+ static_cast(token_position / block_size); -+ const int64_t which_slot = -+ static_cast(token_position % block_size); -+ -+ const int64_t physical_block_number = -+ static_cast(block_table[which_block]); -+ -+ const scalar_t* key_head = -+ (const scalar_t*)key + -+ physical_block_number * k_cache_stride_tokens + -+ kv_head_idx * k_cache_stride_head + -+ which_slot * k_cache_stride_block_size; -+ -+ for (int i = 0; i < HD / x; i++) { -+ // Load 8 elements -+ simd key_row = -+ block_load(key_head + i * k_cache_stride_dim); -+ slm_block_store(key_slm_offset + tid * HD * sizeof(scalar_t) + -+ 8 * i * sizeof(scalar_t), -+ key_row); -+ } + + def is_sccache_available() -> bool: + return which("sccache") is not None and \ +@@ -505,7 +507,7 @@ def get_vllm_version() -> str: + version += f"{sep}precompiled" + else: + cuda_version = str(get_nvcc_cuda_version()) +- if cuda_version != envs.VLLM_MAIN_CUDA_VERSION: ++ if cuda_version != MAIN_CUDA_VERSION: + cuda_version_str = cuda_version.replace(".", "")[:3] + # skip this for source tarball, required for pypi + if "sdist" not in sys.argv: +@@ -513,7 +515,7 @@ def get_vllm_version() -> str: + elif _is_hip(): + # Get the Rocm Version + rocm_version = get_rocm_version() or torch.version.hip +- if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION: ++ if rocm_version and rocm_version != MAIN_CUDA_VERSION: + version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}" + elif _is_tpu(): + version += f"{sep}tpu" +diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py +index 29a3b40d2..72819f31d 100644 +--- a/tests/entrypoints/openai/test_vision.py ++++ b/tests/entrypoints/openai/test_vision.py +@@ -34,11 +34,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [ + ], + [ + "The image shows a Venn diagram with three over", +- "The image shows a Venn diagram with three intersect", ++ "This image shows a Venn diagram with three over", + ], + [ + "This image displays a gradient of colors ranging from", +- "The image displays a gradient of colors ranging from", ++ "This image displays a gradient of colors forming a spectrum", + ], + ] + +diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py +index c01ea3299..d37b968ed 100644 +--- a/tests/kernels/attention/test_mha_attn.py ++++ b/tests/kernels/attention/test_mha_attn.py +@@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str): + torch.set_default_dtype(torch.float16) + + if device == "cpu": +- with patch("vllm.attention.selector.current_platform", +- CpuPlatform()), \ +- patch("vllm.platforms.current_platform", CpuPlatform()): ++ with patch("vllm.attention.layer.current_platform", CpuPlatform()), \ ++ patch("vllm.model_executor.models.vision.current_platform", ++ CpuPlatform()): + attn = MultiHeadAttention(16, 64, scale=1) +- assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1 ++ assert attn.attn_backend == _Backend.TORCH_SDPA + elif device == "hip": +- with patch("vllm.attention.selector.current_platform", +- RocmPlatform()), \ +- patch("vllm.platforms.current_platform", RocmPlatform()), \ +- patch("vllm.attention.layer.current_platform", RocmPlatform()): ++ with patch("vllm.attention.layer.current_platform", RocmPlatform()), \ ++ patch("vllm.model_executor.models.vision.current_platform", ++ RocmPlatform()): + attn = MultiHeadAttention(16, 64, scale=1) + assert attn.attn_backend == _Backend.TORCH_SDPA + else: +- with patch("vllm.attention.selector.current_platform", +- CudaPlatform()), \ +- patch("vllm.platforms.current_platform", CudaPlatform()): ++ # Test CUDA with head_size=64 (divisible by 32) ++ # - should use vLLM's FlashAttention ++ with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ ++ patch("vllm.model_executor.models.vision.current_platform", ++ CudaPlatform()): + attn = MultiHeadAttention(16, 64, scale=1) +- assert attn.attn_backend == _Backend.XFORMERS ++ assert attn.attn_backend == _Backend.FLASH_ATTN + +- with patch("vllm.attention.selector.current_platform", ++ # Test CUDA with head_size=72 (not divisible by 32) ++ # - with upstream FA not available ++ # - should use xformers ++ with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ ++ patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ +- patch("vllm.platforms.current_platform", CudaPlatform()): ++ patch("vllm.attention.layer.check_upstream_fa_availability", ++ return_value=False): + attn = MultiHeadAttention(16, 72, scale=1) + assert attn.attn_backend == _Backend.XFORMERS + ++ # Test CUDA with head_size=72 (not divisible by 32) ++ # - with upstream FA available ++ # - should use upstream FA ++ with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ ++ patch("vllm.model_executor.models.vision.current_platform", ++ CudaPlatform()), \ ++ patch("vllm.attention.layer.check_upstream_fa_availability", ++ return_value=True), \ ++ patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (), ++ { ++ 'flash_attn_varlen_func': lambda *args, **kwargs: None ++ })()}): ++ attn = MultiHeadAttention(16, 72, scale=1) ++ assert attn.attn_backend == _Backend.FLASH_ATTN ++ + + def ref_attention( + query: torch.Tensor, +diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py +index ced0ab337..404854f54 100644 +--- a/tests/models/multimodal/processing/test_common.py ++++ b/tests/models/multimodal/processing/test_common.py +@@ -31,6 +31,7 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: + """ + # Ensure video metadata is included + if "video" in mm_data: ++ # GLM4.1V doesn't support multiple videos + video = mm_data["video"] + mm_data["video"] = (video, { + "total_num_frames": len(video), +@@ -41,6 +42,34 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: + return mm_data + + ++def qwen3_vl_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: ++ """ ++ Patch the multimodal data for Qwen3-VL model. ++ """ + -+ // [num_blocks, num_kv_heads, head_size, block_size] -+ const scalar_t* value_head = -+ (const scalar_t*)value + -+ physical_block_number * v_cache_stride_tokens + -+ kv_head_idx * v_cache_stride_head + which_slot; -+ for (int i = 0; i < HD; i++) { -+ scalar_t temp_value = value_head[i * v_cache_stride_dim]; -+ slm_scalar_store(value_slm_offset + -+ tid * HD * sizeof(scalar_t) + -+ i * sizeof(scalar_t), -+ temp_value); -+ } -+ } -+ barrier(); -+ -+ if (token_position < seq_bound) { -+ for (size_t r = 0; r <= tid; ++r) { -+ simd key_row = slm_block_load( -+ key_slm_offset + r * HD * sizeof(scalar_t)); -+ simd value_row = slm_block_load( -+ value_slm_offset + r * HD * sizeof(scalar_t)); -+ scalar_t attn = -+ sycl::ext::intel::esimd::detail::sum( -+ query_row * key_row); -+ if (attn <= max_attn) { -+ scalar_t attn_exp = -+ sycl::ext::intel::esimd::exp(attn - max_attn); -+ accv += value_row * attn_exp; -+ softmax += attn_exp; -+ } else { -+ scalar_t attn_exp = -+ sycl::ext::intel::esimd::exp(max_attn - attn); -+ accv = accv * attn_exp + value_row; -+ softmax = softmax * attn_exp + 1; -+ max_attn = attn; -+ } -+ } ++ def create_metadata(frames: np.ndarray): ++ num_frames = len(frames) ++ return { ++ "total_num_frames": num_frames, ++ "fps": 2.0, ++ "duration": num_frames / 2.0, ++ "video_backend": "opencv", ++ "frames_indices": list(range(num_frames)), ++ "do_sample_frames": True, ++ } + -+ if (softmax > 0) { -+ simd result = accv / softmax; -+ block_store(out_head, result); -+ } else { -+ simd result = 0; -+ block_store(out_head, result); -+ } -+ } -+ // ######## Ending of handling seq offset ########## -+ }); -+ }; -+ queue.submit(cgf); -+} ++ # Ensure video metadata is included ++ if "video" in mm_data: ++ video = mm_data["video"] ++ if isinstance(video, list): ++ # multiple videos ++ mm_data["video"] = [(vid, create_metadata(vid)) for vid in video] ++ else: ++ # single video ++ mm_data["video"] = (video, create_metadata(video)) ++ return mm_data + -+template -+void context_attention_kernel_v2( -+ void* query, void* key, void* value, const void* block_tables, -+ const float scale, const void* query_start_loc, const void* seq_lens, -+ const void* context_lens, const int block_size, -+ const int x, // x in kv_cache -+ void* out, // output -+ const int block_table_stride_batch, const int block_table_stride_seq, -+ const int query_stride_bs, const int query_stride_head, -+ const int query_stride_dim, const int k_cache_stride_tokens, -+ const int k_cache_stride_head, const int k_cache_stride_dim, -+ const int k_cache_stride_block_size, const int k_cache_stride_x, -+ const int v_cache_stride_tokens, const int v_cache_stride_head, -+ const int v_cache_stride_dim, const int v_cache_stride_block_size, -+ const int out_stride_tokens, const int out_stride_head, -+ const int num_queries_per_kv, const int max_input_length, -+ const int batch_size, const int num_heads, const int num_tokens, -+ const int max_context_len, const int max_q_len) { -+ constexpr int BLOCK_SIZE = 8; -+ constexpr int NUM_THREADS = 128; -+ // Each wrap handles one context block, therefore, each thread_group_size is -+ // this. -+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); -+ // Each query, and key thread_group loads 16 bytes -+ // Assume TGS=4 then 16 / 4 / sizeof(half) = 2 -+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ using Q_Vec = typename Vec::Type; -+ -+ // Assuming HD = 128, TGS = 2, then 128 / 2 / 2 = 32 -+ int num_vecs_per_thread = HD / THREAD_GROUP_SIZE / VEC_SIZE; -+ sycl_t* out_p = reinterpret_cast(out); -+ sycl_t* query_ptr = reinterpret_cast(query); -+ sycl_t* key_cache_ptr = reinterpret_cast(key); -+ sycl_t* value_cache_ptr = reinterpret_cast(value); -+ const int* query_loc_ptr = reinterpret_cast(query_start_loc); -+ const int* block_tables_ptr = reinterpret_cast(block_tables); -+ const int* context_lens_ptr = reinterpret_cast(context_lens); -+ const int* seq_lens_ptr = reinterpret_cast(seq_lens); -+ -+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -+ int padded_max_context_len = -+ DIVIDE_ROUND_UP(max_context_len + 1 + max_q_len, BLOCK_SIZE) * BLOCK_SIZE; -+ int logits_size = padded_max_context_len * sizeof(float); -+ int outputs_size = (NUM_WARPS / 2) * HD * sizeof(float); -+ // Python-side check in -+ // vllm.worker.worker._check_if_can_support_max_seq_len Keep that in -+ // sync with the logic here! -+ int shared_mem_size = std::max(logits_size, outputs_size); -+ // WARN: we have changed this... -+ sycl::range<3> grid(batch_size, num_heads, max_q_len); -+ // One work-group that is executing on the device -+ sycl::range<3> block(1, 1, NUM_THREADS); -+ sycl::queue& queue = vllm::xpu::vllmGetQueue(); -+ -+ auto cgf = [&](sycl::handler& handle) { -+ sycl::local_accessor dpct_local_acc_ct1( -+ sycl::range<1>(shared_mem_size), handle); -+ sycl::local_accessor q_vecs_acc_ct1( -+ sycl::range<1>(THREAD_GROUP_SIZE * num_vecs_per_thread), handle); -+ sycl::local_accessor red_smem_acc_ct1( -+ sycl::range<1>(2 * NUM_WARPS), handle); -+ -+ handle.parallel_for( -+ sycl::nd_range<3>(grid * block, block), -+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { -+ const int bsz_idx = item_ct1.get_group(0); -+ const int seq_idx = item_ct1.get_group(2); -+ constexpr bool USE_PARTITIONING = false; -+ int context_len = context_lens_ptr[bsz_idx] + seq_idx; -+ const int seq_len = seq_lens_ptr[bsz_idx]; -+ uint8_t* dpct_local = dpct_local_acc_ct1.get_pointer(); -+ Q_Vec* q_vecs = q_vecs_acc_ct1.get_pointer(); -+ float* red_smem = red_smem_acc_ct1.get_pointer(); -+ -+ // output_stream << "Original context_len: " << -+ // context_lens_ptr[bsz_idx] << sycl::endl; output_stream << -+ // "Batch_idx: " << bsz_idx << " Seq_idx: " << seq_idx -+ // << " Context_len: " << context_len << " Original context_len: " -+ // << context_lens_ptr[bsz_idx] << " Seq_len: " << seq_len -+ // << " Max input length: " << max_input_length -+ // << sycl::endl; -+ if (context_len >= seq_len) { -+ return; -+ } -+ -+ context_len = context_len + 1; -+ -+ const int num_context_blocks = -+ DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); -+ const int num_blocks_per_partition = num_context_blocks; -+ -+ const int start_block_idx = 0; -+ const int end_block_idx = -+ MIN(start_block_idx + num_context_blocks, num_context_blocks); -+ -+ const int num_blocks = end_block_idx - start_block_idx; -+ const int start_token_idx = start_block_idx * BLOCK_SIZE; -+ const int end_token_idx = -+ MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); -+ const int num_tokens = end_token_idx - start_token_idx; -+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); -+ constexpr int NUM_THREAD_GROUPS = -+ NUM_THREADS / -+ THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE -+ constexpr int NUM_TOKENS_PER_THREAD_GROUP = -+ DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); -+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -+ const int thread_idx = item_ct1.get_local_id(2); -+ const int warp_idx = thread_idx / WARP_SIZE; -+ const int lane = thread_idx % WARP_SIZE; -+ const int head_idx = item_ct1.get_group(1); -+ const int num_heads = item_ct1.get_group_range(1); -+ const int kv_head_idx = head_idx / num_queries_per_kv; -+ // TODO: consider alibi_slope later -+ constexpr int NUM_ELEMS_PER_THREAD = HD / THREAD_GROUP_SIZE; -+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; -+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; -+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; -+ const sycl_t* q_ptr = -+ query_ptr + (query_loc_ptr[bsz_idx] + seq_idx) * query_stride_bs + -+ head_idx * HD; -+ -+#pragma unroll -+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; -+ i += NUM_THREAD_GROUPS) { -+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; -+ q_vecs[thread_group_offset * NUM_VECS_PER_THREAD + i] = -+ *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); -+ } -+ // Loaded q_vecs -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ auto shared_mem = (char*)dpct_local; -+ float* logits = reinterpret_cast(shared_mem); -+ constexpr int x = 16 / sizeof(sycl_t); -+ float qk_max = -FLT_MAX; -+ const int* block_table = -+ block_tables_ptr + bsz_idx * block_table_stride_batch; -+ -+ // Loading key -+ for (int block_idx = start_block_idx + warp_idx; -+ block_idx < end_block_idx; block_idx += NUM_WARPS) { -+ const int64_t physical_block_number = -+ static_cast(block_table[block_idx]); -+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { -+ const int physical_block_offset = -+ (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; -+ const int token_idx = -+ block_idx * BLOCK_SIZE + physical_block_offset; -+ -+ Q_Vec k_vecs[NUM_VECS_PER_THREAD]; -+ -+#pragma unroll -+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { -+ const sycl_t* k_ptr = -+ key_cache_ptr + -+ physical_block_number * k_cache_stride_tokens + -+ kv_head_idx * k_cache_stride_head + -+ physical_block_offset * x; -+ -+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; -+ const int offset1 = (vec_idx * VEC_SIZE) / x; -+ const int offset2 = (vec_idx * VEC_SIZE) % x; -+ k_vecs[j] = *reinterpret_cast( -+ k_ptr + offset1 * BLOCK_SIZE * x + offset2); -+ } -+ -+ // Compute dot product. -+ // This includes a reduction across the threads in the -+ // same thread group. Q_Vec_t -+ // q_vec_[NUM_VECS_PER_THREAD] = q_vecs + -+ // thread_group_offset * THREAD_GROUP_SIZE; -+ float qk = scale * -+ Qk_dot::template dot< -+ Q_Vec, NUM_VECS_PER_THREAD>( -+ q_vecs + thread_group_offset * NUM_VECS_PER_THREAD, -+ k_vecs, item_ct1); -+ -+ if (thread_group_offset == 0) { -+ // Store the partial reductions to shared memory. -+ // NOTE(woosuk): It is required to zero out the -+ // masked logits. -+ const bool mask = token_idx > context_len; -+ logits[token_idx - start_token_idx] = mask ? 0.f : qk; -+ qk_max = mask ? qk_max : sycl::fmax(qk_max, qk); -+ } -+ } -+ } -+#pragma unroll -+ for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { -+ /* -+ DPCT1096:38: The right-most dimension of the work-group used -+ in the SYCL kernel that calls this function may be less than -+ "32". The function "dpct::permute_sub_group_by_xor" may -+ return an unexpected result on the CPU device. Modify the -+ size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ qk_max = -+ sycl::fmax(qk_max, dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), qk_max, mask)); -+ } -+ if (lane == 0) { -+ red_smem[warp_idx] = qk_max; -+ } -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ // TODO(woosuk): Refactor this part. -+ // Get the max qk value for the sequence. -+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -+#pragma unroll -+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -+ /* -+ DPCT1096:39: The right-most dimension of the work-group used -+ in the SYCL kernel that calls this function may be less than -+ "32". The function "dpct::permute_sub_group_by_xor" may -+ return an unexpected result on the CPU device. Modify the -+ size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ qk_max = -+ sycl::fmax(qk_max, dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), qk_max, mask)); -+ } -+ qk_max = -+ dpct::select_from_sub_group(item_ct1.get_sub_group(), qk_max, 0); -+ -+ // Get the sum of the exp values. -+ float exp_sum = 0.f; -+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { -+ float val = sycl::exp(logits[i] - qk_max); -+ logits[i] = val; -+ exp_sum += val; -+ } -+ exp_sum = -+ block_sum(&red_smem[NUM_WARPS], exp_sum, item_ct1); -+ // Compute softmax. -+ const float inv_sum = 1.f / (exp_sum + 1e-6f); -+#pragma unroll -+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { -+ logits[i] *= inv_sum; -+ } -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(sycl_t), BLOCK_SIZE); -+ using V_vec = typename Vec::Type; -+ using L_vec = typename Vec::Type; -+ using Float_L_vec = typename FloatVec::Type; -+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; -+ constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; -+ constexpr int NUM_ROWS_PER_THREAD = -+ DIVIDE_ROUND_UP(HD, NUM_ROWS_PER_ITER); -+ // NOTE(woosuk): We use FP32 for the accumulator for better -+ // accuracy. -+ float accs[NUM_ROWS_PER_THREAD]; -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ accs[i] = 0.f; -+ } -+ -+ sycl_t zero_value; -+ zero(zero_value); -+ for (int block_idx = start_block_idx + warp_idx; -+ block_idx < end_block_idx; block_idx += NUM_WARPS) { -+ // NOTE(woosuk): The block number is stored in int32. -+ // However, we cast it to int64 because int32 can lead to -+ // overflow when this variable is multiplied by large -+ // numbers (e.g., kv_block_stride). -+ const int64_t physical_block_number = -+ static_cast(block_table[block_idx]); -+ const int physical_block_offset = -+ (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; -+ const int token_idx = -+ block_idx * BLOCK_SIZE + physical_block_offset; -+ L_vec logits_vec; -+ vllm::from_float( -+ logits_vec, *reinterpret_cast(logits + token_idx - -+ start_token_idx)); -+ -+ const sycl_t* v_ptr = -+ value_cache_ptr + -+ physical_block_number * v_cache_stride_tokens + -+ kv_head_idx * v_cache_stride_head; -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ const int row_idx = -+ lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -+ if (row_idx < HD) { -+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset; -+ V_vec v_vec = *reinterpret_cast(v_ptr + offset); -+ if (block_idx == num_context_blocks - 1) { -+ // NOTE(woosuk): When v_vec contains the tokens -+ // that are out of the context, we should -+ // explicitly zero out the values since they may -+ // contain NaNs. See -+ // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 -+ sycl_t* v_vec_ptr = reinterpret_cast(&v_vec); -+#pragma unroll -+ for (int j = 0; j < V_VEC_SIZE; j++) { -+ v_vec_ptr[j] = -+ token_idx + j < context_len ? v_vec_ptr[j] : zero_value; -+ } -+ } -+ accs[i] += vllm::dot(logits_vec, v_vec); -+ } -+ } -+ } -+ // Perform reduction within each warp. -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ float acc = accs[i]; -+#pragma unroll -+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { -+ /* -+ DPCT1096:41: The right-most dimension of the work-group -+ used in the SYCL kernel that calls this function may be -+ less than "32". The function -+ "dpct::permute_sub_group_by_xor" may return an -+ unexpected result on the CPU device. Modify the size of -+ the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ acc += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), -+ acc, mask); -+ } -+ accs[i] = acc; -+ } -+ -+ // NOTE(woosuk): A barrier is required because the shared memory -+ // space for logits is reused for the output. -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ -+ // Perform reduction across warps. -+ float* out_smem = reinterpret_cast(shared_mem); -+#pragma unroll -+ for (int i = NUM_WARPS; i > 1; i /= 2) { -+ int mid = i / 2; -+ // Upper warps write to shared memory. -+ if (warp_idx >= mid && warp_idx < i) { -+ float* dst = &out_smem[(warp_idx - mid) * HD]; -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ const int row_idx = -+ lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -+ if (row_idx < HD && lane % NUM_V_VECS_PER_ROW == 0) { -+ dst[row_idx] = accs[i]; -+ } -+ } -+ } + -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ -+ // Lower warps update the output. -+ if (warp_idx < mid) { -+ const float* src = &out_smem[warp_idx * HD]; -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ const int row_idx = -+ lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -+ if (row_idx < HD && lane % NUM_V_VECS_PER_ROW == 0) { -+ accs[i] += src[row_idx]; -+ } -+ } -+ } -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ } -+ -+ // Write the final output. -+ if (warp_idx == 0) { -+ sycl_t* out_ptr = -+ out_p + (query_loc_ptr[bsz_idx] + seq_idx) * out_stride_tokens + -+ head_idx * out_stride_head; -+ -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ const int row_idx = -+ lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -+ if (row_idx < HD && lane % NUM_V_VECS_PER_ROW == 0) { -+ vllm::from_float(*(out_ptr + row_idx), accs[i]); -+ } -+ } -+ } -+ }); -+ // Each thread_group handles one token -+ }; -+ queue.submit(cgf); -+} + def _test_processing_correctness( + model_id_or_arch: str, + hit_rate: float, +@@ -181,8 +210,10 @@ _IGNORE_MM_KEYS = { + } + + MM_DATA_PATCHES = { +- # GLM4.1V requires video metadata to be included in the input ++ # GLM4.1V and Qwen3-VL requires video metadata to be included in the input + "glm4v": glm4_1v_patch_mm_data, ++ "qwen3_vl": qwen3_vl_patch_mm_data, ++ "qwen3_vl_moe": qwen3_vl_patch_mm_data, + } + + +@@ -328,6 +359,8 @@ def _test_processing_correctness_one( + "Qwen/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen2-Audio-7B-Instruct", + "Qwen/Qwen2.5-Omni-3B", ++ "Qwen/Qwen3-VL-4B-Instruct", ++ "Qwen/Qwen3-VL-30B-A3B-Instruct", + "YannQi/R-4B", + "Skywork/Skywork-R1V-38B", + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", +diff --git a/tests/models/registry.py b/tests/models/registry.py +index 0c77ec5ef..696aee3cc 100644 +--- a/tests/models/registry.py ++++ b/tests/models/registry.py +@@ -449,6 +449,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { + max_transformers_version="4.48", # noqa: E501 + transformers_version_reason="HF model is not compatible.", # noqa: E501 + hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 ++ "DotsOCRForCausalLM": _HfExamplesInfo("rednote-hilab/dots.ocr", ++ trust_remote_code=True), + "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501 + trust_remote_code=True), +@@ -559,6 +561,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { + max_model_len=4096), + "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), + "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 ++ "Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501 ++ max_model_len=4096, ++ min_transformers_version="4.57"), # noqa: E501 ++ "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501 ++ max_model_len=4096, ++ min_transformers_version="4.57"), + "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", + trust_remote_code=True), + "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", +diff --git a/tests/quantization/test_cpu_offload.py b/tests/quantization/test_cpu_offload.py +index 08d9573ec..82a0e0cd8 100644 +--- a/tests/quantization/test_cpu_offload.py ++++ b/tests/quantization/test_cpu_offload.py +@@ -1,4 +1,4 @@ +-# SPDX-License-Identifier: Apache-2.0 ++# SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + # Expanded quantized model tests for CPU offloading +@@ -11,6 +11,16 @@ from tests.quantization.utils import is_quant_method_supported + from ..utils import compare_two_settings + + ++@pytest.mark.skipif(not is_quant_method_supported("fp8"), ++ reason="fp8 is not supported on this GPU type.") ++def test_offload_weights_before_quant_fp8(): ++ # Test quantization of an unquantized checkpoint ++ compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", ++ ["--quantization", "fp8"], ["--quantization", "fp8"], ++ {"VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT": "1"}, ++ max_wait_seconds=480) ++ ++ + @pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.") + def test_cpu_offload_fp8(): +diff --git a/tests/quantization/test_ipex_quant.py b/tests/quantization/test_ipex_quant.py +index 34b1b6c2e..4c8082646 100644 +--- a/tests/quantization/test_ipex_quant.py ++++ b/tests/quantization/test_ipex_quant.py +@@ -25,7 +25,7 @@ DTYPE = ["bfloat16"] + @pytest.mark.parametrize("model", MODELS) + @pytest.mark.parametrize("dtype", DTYPE) + def test_ipex_quant(vllm_runner, model, dtype): +- with vllm_runner(model, dtype=dtype) as llm: ++ with vllm_runner(model, dtype=dtype, enforce_eager=True, block_size=64) as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + assert output +diff --git a/tests/utils.py b/tests/utils.py +index 16e1e6039..514da44f4 100644 +--- a/tests/utils.py ++++ b/tests/utils.py +@@ -1140,6 +1140,8 @@ def get_attn_backend_list_based_on_platform() -> list[str]: + print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed") + + return attn_backend_list ++ elif current_platform.is_xpu(): ++ return ["FLASH_ATTN"] + else: + raise ValueError("Unsupported platform") + +diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py +index 4dfe1d3bb..56f102253 100644 +--- a/tests/v1/e2e/test_correctness_sliding_window.py ++++ b/tests/v1/e2e/test_correctness_sliding_window.py +@@ -18,7 +18,7 @@ class TestConfig: + + model_config = { + "bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)), +- "google/gemma-3-1b-it": TestConfig(4096, (400, 800)), ++ #"google/gemma-3-1b-it": TestConfig(4096, (400, 800)), + } + + +@@ -26,7 +26,7 @@ model_config = { + "model", + [ + "bigcode/starcoder2-3b", # sliding window only +- "google/gemma-3-1b-it", # sliding window + full attention ++ #"google/gemma-3-1b-it", # sliding window + full attention + ]) + @pytest.mark.parametrize("batch_size", [5]) + @pytest.mark.parametrize("seed", [1]) +@@ -46,7 +46,9 @@ def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed, + + llm = LLM( + model=model, +- disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager) ++ disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager, ++ enforce_eager=True, ++ block_size=64) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + + prompts, answer, indices = prep_prompts(batch_size, +diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py +index 0b240b7d4..1ebd4fde4 100644 +--- a/tests/v1/e2e/test_spec_decode.py ++++ b/tests/v1/e2e/test_spec_decode.py +@@ -90,7 +90,7 @@ def test_ngram_correctness( + m.setenv("VLLM_USE_V1", "1") + test_prompts = get_test_prompts(mm_enabled=False) + +- ref_llm = LLM(model=model_name, max_model_len=1024) ++ ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True, block_size=64, dtype="float16") + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() +@@ -105,6 +105,10 @@ def test_ngram_correctness( + "num_speculative_tokens": 3, + }, + max_model_len=1024, ++ enforce_eager=True, ++ block_size=64, ++ dtype="float16", ++ gpu_memory_utilization=0.6, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 +@@ -125,30 +129,22 @@ def test_ngram_correctness( + cleanup_dist_env_and_memory() + + +-@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ +- (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), +- (("eagle", "meta-llama/Llama-3.1-8B-Instruct", +- "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), +- (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", +- "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), +- pytest.param( +- ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", +- "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), +- False, +- marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), +- pytest.param( +- ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", +- "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), +- True, +- marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), +- (("eagle", "eagle618/deepseek-v3-random", +- "eagle618/eagle-deepseek-v3-random", 1), False), +-], +- ids=[ +- "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", +- "llama4_eagle", "llama4_eagle_mm", +- "deepseek_eagle" +- ]) ++@pytest.mark.parametrize( ++ ["model_setup", "mm_enabled"], ++ [ ++ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 ++ # (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), ++ (("eagle", "meta-llama/Llama-3.1-8B-Instruct", ++ "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), ++ (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", ++ "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), ++ ], ++ ids=[ ++ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 ++ # "qwen3_eagle3", ++ "llama3_eagle", ++ "llama3_eagle3", ++ ]) + @pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) + def test_eagle_correctness( +@@ -188,7 +184,12 @@ def test_eagle_correctness( + + ref_llm = LLM(model=model_name, + max_model_len=2048, +- tensor_parallel_size=tp_size) ++ tensor_parallel_size=tp_size, ++ enforce_eager=True, ++ block_size=64, ++ dtype="float16", ++ gpu_memory_utilization=0.6, ++ ) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() +@@ -204,6 +205,10 @@ def test_eagle_correctness( + "num_speculative_tokens": 3, + "max_model_len": 2048, + }, ++ enforce_eager=True, ++ block_size=64, ++ dtype="float16", ++ gpu_memory_utilization=0.6, + max_model_len=2048, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) +diff --git a/tests/v1/kv_connector/nixl_integration/run_xpu_disagg_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_xpu_disagg_accuracy_test.sh +new file mode 100644 +index 000000000..ae4909b29 +--- /dev/null ++++ b/tests/v1/kv_connector/nixl_integration/run_xpu_disagg_accuracy_test.sh +@@ -0,0 +1,156 @@ ++#!/bin/bash ++set -e + -+template < -+ typename scalar_t, -+ typename Q_Vec_t, -+ int HEAD_SIZE, -+ int BLOCK_SIZE, -+ int NUM_THREADS, -+ int VEC_SIZE, -+ int PARTITION_SIZE = 0> // Zero means no partitioning. -+void paged_attention_kernel( -+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] -+ float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] -+ scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, -+ // head_size] -+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] -+ const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, -+ // head_size/x, block_size, x] -+ const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, -+ // head_size, block_size] -+ const int num_kv_heads, // [num_heads] -+ const float scale, -+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] -+ const int* __restrict__ context_lens, // [num_seqs] -+ const int max_num_blocks_per_seq, -+ const float* __restrict__ alibi_slopes, // [num_heads] -+ const int q_stride, -+ const int kv_block_stride, -+ const int kv_head_stride, -+ const float attn_logit_softcapping, -+ const sycl::nd_item<3>& item_ct1, -+ uint8_t* dpct_local, -+ Q_Vec_t* q_vecs, -+ float* red_smem) { -+ const int seq_idx = item_ct1.get_group(1); -+ const int partition_idx = item_ct1.get_group(0); -+ const int max_num_partitions = item_ct1.get_group_range(0); -+ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; -+ const int context_len = context_lens[seq_idx]; -+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { -+ // No work to do. Terminate the thread block. -+ return; -+ } -+ -+ const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); -+ const int num_blocks_per_partition = -+ USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; -+ -+ // [start_block_idx, end_block_idx) is the range of blocks to process. -+ const int start_block_idx = -+ USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; -+ const int end_block_idx = -+ MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); -+ const int num_blocks = end_block_idx - start_block_idx; -+ -+ // [start_token_idx, end_token_idx) is the range of tokens to process. -+ const int start_token_idx = start_block_idx * BLOCK_SIZE; -+ const int end_token_idx = -+ MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); -+ const int num_tokens = end_token_idx - start_token_idx; -+ -+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); -+ constexpr int NUM_THREAD_GROUPS = -+ NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE -+ // divides NUM_THREADS -+ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); -+ constexpr int NUM_TOKENS_PER_THREAD_GROUP = -+ DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); -+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -+ const int thread_idx = item_ct1.get_local_id(2); -+ const int warp_idx = thread_idx / WARP_SIZE; -+ const int lane = thread_idx % WARP_SIZE; -+ -+ const int head_idx = item_ct1.get_group(2); -+ const int num_heads = item_ct1.get_group_range(2); -+ const int num_queries_per_kv = num_heads / num_kv_heads; -+ -+ const int kv_head_idx = head_idx / num_queries_per_kv; -+ ; -+ const float alibi_slope = -+ alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; -+ -+ // A vector type to store a part of a key or a query. -+ // The vector size is configured in such a way that the threads in a thread -+ // group fetch or compute 16 bytes at a time. For example, if the size of a -+ // thread group is 4 and the data type is half, then the vector size is 16 / -+ // (4 * sizeof(half)) == 2. -+ -+ // constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), -+ // 1); -+ -+ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; -+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; -+ -+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; -+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; -+ -+ // Load the query to registers. -+ // Each thread in a thread group has a different part of the query. -+ // For example, if the the thread group size is 4, then the first thread in -+ // the group has 0, 4, 8, ... th vectors of the query, and the second thread -+ // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because -+ // q is split from a qkv tensor, it may not be contiguous. -+ const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; -+ -+#pragma unroll -+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; -+ i += NUM_THREAD_GROUPS) { -+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; -+ q_vecs[thread_group_offset * NUM_VECS_PER_THREAD + i] = -+ *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); -+ } -+ /* -+ DPCT1065:5: Consider replacing sycl::nd_item::barrier() with -+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better -+ performance if there is no access to global memory. -+ */ -+ item_ct1.barrier(sycl::access::fence_space::local_space); // TODO(naed90): possible speedup if this is replaced with -+ // a memory wall right before we use q_vecs -+ -+ // Memory planning. -+ auto shared_mem = (char*)dpct_local; -+ // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. -+ float* logits = reinterpret_cast(shared_mem); -+ // Workspace for reduction. -+ -+ // x == THREAD_GROUP_SIZE * VEC_SIZE -+ // Each thread group fetches x elements from the key at a time. -+ constexpr int x = 16 / sizeof(scalar_t); -+ float qk_max = -FLT_MAX; -+ -+ // Iterate over the key blocks. -+ // Each warp fetches a block of keys for each iteration. -+ // Each thread group in a warp fetches a key from the block, and computes -+ // dot product with the query. -+ const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; -+ -+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; -+ block_idx += NUM_WARPS) { -+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to -+ // int64 because int32 can lead to overflow when this variable is multiplied -+ // by large numbers (e.g., kv_block_stride). -+ const int64_t physical_block_number = -+ static_cast(block_table[block_idx]); -+ -+ // Load a key to registers. -+ // Each thread in a thread group has a different part of the key. -+ // For example, if the the thread group size is 4, then the first thread in -+ // the group has 0, 4, 8, ... th vectors of the key, and the second thread -+ // has 1, 5, 9, ... th vectors of the key, and so on. -+ -+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { -+ const int physical_block_offset = -+ (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; -+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; -+ -+ Q_Vec_t k_vecs[NUM_VECS_PER_THREAD]; -+ -+#pragma unroll -+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { -+ const scalar_t* k_ptr = k_cache + -+ physical_block_number * kv_block_stride + -+ kv_head_idx * kv_head_stride + physical_block_offset * x; -+ -+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; -+ const int offset1 = (vec_idx * VEC_SIZE) / x; -+ const int offset2 = (vec_idx * VEC_SIZE) % x; -+ k_vecs[j] = *reinterpret_cast( -+ k_ptr + offset1 * BLOCK_SIZE * x + offset2); -+ } -+ -+ // Compute dot product. -+ // This includes a reduction across the threads in the same thread group. -+ // Q_Vec_t q_vec_[NUM_VECS_PER_THREAD] = q_vecs + thread_group_offset * -+ // THREAD_GROUP_SIZE; -+ float qk = scale * -+ Qk_dot:: -+ template dot( -+ q_vecs + thread_group_offset * NUM_VECS_PER_THREAD, -+ k_vecs, -+ item_ct1); -+ // Add the ALiBi bias if slopes are given. -+ qk += -+ (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; -+ -+ // Add the attn_logit_softcapp if given. -+ if (attn_logit_softcapping != 0.0) { -+ qk = attn_softcapping(qk, attn_logit_softcapping); -+ } -+ if (thread_group_offset == 0) { -+ // Store the partial reductions to shared memory. -+ // NOTE(woosuk): It is required to zero out the masked logits. -+ const bool mask = token_idx >= context_len; -+ logits[token_idx - start_token_idx] = mask ? 0.f : qk; -+ // Update the max value. -+ qk_max = mask ? qk_max : sycl::fmax(qk_max, qk); -+ } -+ } -+ } ++# Hosts / ports ++PREFILL_HOST=${PREFILL_HOST:-"localhost"} ++PREFILL_PORT=${PREFILL_PORT:-8100} ++PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577} ++DECODE_HOST=${DECODE_HOST:-"localhost"} ++DECODE_PORT=${DECODE_PORT:-8200} ++PROXY_HOST=${PROXY_HOST:-"localhost"} ++PROXY_PORT=${PROXY_PORT:-8192} ++BASELINE_HOST=${BASELINE_HOST:-"localhost"} ++BASELINE_PORT=${BASELINE_PORT:-9290} + -+ // Perform reduction across the threads in the same warp to get the -+ // max qk value for each "warp" (not across the thread block yet). -+ // The 0-th thread of each thread group already has its max qk value. -+#pragma unroll -+ for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { -+ -+ /* -+ DPCT1096:38: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU -+ device. Modify the size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ qk_max = sycl::fmax( -+ qk_max, -+ dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), qk_max, mask)); -+ } -+ if (lane == 0) { -+ red_smem[warp_idx] = qk_max; -+ } -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); + -+ // TODO(woosuk): Refactor this part. -+ // Get the max qk value for the sequence. -+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -+#pragma unroll -+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -+ -+ /* -+ DPCT1096:39: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU -+ device. Modify the size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ qk_max = sycl::fmax( -+ qk_max, -+ dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), qk_max, mask)); -+ } -+ // Broadcast the max qk value to all threads. -+ -+ /* -+ DPCT1096:40: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::select_from_sub_group" may return an unexpected result on the CPU -+ device. Modify the size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ qk_max = dpct::select_from_sub_group( -+ item_ct1.get_sub_group(), qk_max, 0); -+ -+ // Get the sum of the exp values. -+ float exp_sum = 0.f; -+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { -+ float val = sycl::exp(logits[i] - qk_max); -+ logits[i] = val; -+ exp_sum += val; -+ } -+ exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum, item_ct1); -+ -+ // Compute softmax. -+ const float inv_sum = 1.f / (exp_sum + 1e-6f); -+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { -+ logits[i] *= inv_sum; -+ } -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ -+ // If partitioning is enabled, store the max logit and exp_sum. -+ if (USE_PARTITIONING && thread_idx == 0) { -+ float* max_logits_ptr = max_logits + -+ seq_idx * num_heads * max_num_partitions + -+ head_idx * max_num_partitions + partition_idx; -+ *max_logits_ptr = qk_max; -+ float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + -+ head_idx * max_num_partitions + partition_idx; -+ *exp_sums_ptr = exp_sum; -+ } -+ -+ // Each thread will fetch 16 bytes from the value cache at a time. -+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); -+ using V_vec = typename Vec::Type; -+ using L_vec = typename Vec::Type; -+ using Float_L_vec = typename FloatVec::Type; -+ -+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; -+ constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; -+ constexpr int NUM_ROWS_PER_THREAD = -+ DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); -+ -+ // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. -+ float accs[NUM_ROWS_PER_THREAD]; -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ accs[i] = 0.f; -+ } -+ -+ scalar_t zero_value; -+ zero(zero_value); -+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; -+ block_idx += NUM_WARPS) { -+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to -+ // int64 because int32 can lead to overflow when this variable is multiplied -+ // by large numbers (e.g., kv_block_stride). -+ const int64_t physical_block_number = -+ static_cast(block_table[block_idx]); -+ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; -+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; -+ L_vec logits_vec; -+ vllm::from_float( -+ logits_vec, -+ *reinterpret_cast(logits + token_idx - start_token_idx)); -+ -+ const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + -+ kv_head_idx * kv_head_stride; -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -+ if (row_idx < HEAD_SIZE) { -+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset; -+ V_vec v_vec = *reinterpret_cast(v_ptr + offset); -+ if (block_idx == num_context_blocks - 1) { -+ // NOTE(woosuk): When v_vec contains the tokens that are out of the -+ // context, we should explicitly zero out the values since they may -+ // contain NaNs. See -+ // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 -+ scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); -+#pragma unroll -+ for (int j = 0; j < V_VEC_SIZE; j++) { -+ v_vec_ptr[j] = -+ token_idx + j < context_len ? v_vec_ptr[j] : zero_value; -+ } -+ } -+ accs[i] += vllm::dot(logits_vec, v_vec); -+ } -+ } -+ } -+ -+ // Perform reduction within each warp. -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ float acc = accs[i]; -+#pragma unroll -+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { -+ -+ /* -+ DPCT1096:41: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the -+ CPU device. Modify the size of the work-group to ensure that the value of -+ the right-most dimension is a multiple of "32". -+ */ -+ acc += dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), acc, mask); -+ } -+ accs[i] = acc; -+ } -+ -+ // NOTE(woosuk): A barrier is required because the shared memory space for -+ // logits is reused for the output. -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ -+ // Perform reduction across warps. -+ float* out_smem = reinterpret_cast(shared_mem); -+#pragma unroll -+ for (int i = NUM_WARPS; i > 1; i /= 2) { -+ int mid = i / 2; -+ // Upper warps write to shared memory. -+ if (warp_idx >= mid && warp_idx < i) { -+ float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -+ dst[row_idx] = accs[i]; -+ } -+ } -+ } -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ -+ // Lower warps update the output. -+ if (warp_idx < mid) { -+ const float* src = &out_smem[warp_idx * HEAD_SIZE]; -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -+ accs[i] += src[row_idx]; -+ } -+ } -+ } -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ } -+ -+ // Write the final output. -+ if (warp_idx == 0) { -+ scalar_t* out_ptr = out + -+ seq_idx * num_heads * max_num_partitions * HEAD_SIZE + -+ head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; -+#pragma unroll -+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -+ vllm::from_float(*(out_ptr + row_idx), accs[i]); -+ } -+ } -+ } -+} ++# Model to run. ++MODEL_NAME=${MODEL_NAME:-"Qwen/Qwen3-0.6B"} ++MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024} ++BLOCK_SIZE=${BLOCK_SIZE:-16} + -+// Grid: (num_heads, num_seqs, 1). -+template < -+ typename scalar_t, -+ typename Q_Vec_t, -+ int HEAD_SIZE, -+ int BLOCK_SIZE, -+ int NUM_THREADS, -+ int VEC_SIZE> -+void paged_attention_v1_kernel( -+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] -+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] -+ const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, -+ // head_size/x, block_size, x] -+ const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, -+ // head_size, block_size] -+ const int num_kv_heads, // [num_heads] -+ const float scale, -+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] -+ const int* __restrict__ context_lens, // [num_seqs] -+ const int max_num_blocks_per_seq, -+ const float* __restrict__ alibi_slopes, // [num_heads] -+ const int q_stride, -+ const int kv_block_stride, -+ const int kv_head_stride, -+ const float attn_logit_softcapping, -+ const sycl::nd_item<3>& item_ct1, -+ uint8_t* dpct_local, -+ Q_Vec_t* q_vecs, -+ float* red_smem) { -+ paged_attention_kernel< -+ scalar_t, -+ Q_Vec_t, -+ HEAD_SIZE, -+ BLOCK_SIZE, -+ NUM_THREADS, -+ VEC_SIZE>( -+ /* exp_sums */ nullptr, -+ /* max_logits */ nullptr, -+ out, -+ q, -+ k_cache, -+ v_cache, -+ num_kv_heads, -+ scale, -+ block_tables, -+ context_lens, -+ max_num_blocks_per_seq, -+ alibi_slopes, -+ q_stride, -+ kv_block_stride, -+ kv_head_stride, -+ attn_logit_softcapping, -+ item_ct1, -+ dpct_local, -+ q_vecs, -+ red_smem); -+} + -+#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ -+ paged_attention_xpu_v1_impl::call( \ -+ out_ptr, \ -+ query_ptr, \ -+ key_cache_ptr, \ -+ value_cache_ptr, \ -+ num_kv_heads, \ -+ scale, \ -+ block_tables_ptr, \ -+ context_lens_ptr, \ -+ max_num_blocks_per_seq, \ -+ alibi_slopes_ptr, \ -+ q_stride, \ -+ kv_block_stride, \ -+ kv_head_stride, \ -+ num_seqs, \ -+ num_heads, \ -+ num_blocks); -+ -+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ -+ event = queue.submit([&](sycl::handler& cgh) { \ -+ sycl::local_accessor dpct_local_acc_ct1( \ -+ sycl::range<1>(shared_mem_size), cgh); \ -+ sycl::local_accessor q_vecs_acc_ct1( \ -+ sycl::range<1>(THREAD_GROUP_SIZE * num_vecs_per_thread), cgh); \ -+ sycl::local_accessor red_smem_acc_ct1( \ -+ sycl::range<1>(2 * NUM_WARPS), cgh); \ -+ \ -+ auto out_ptr_ct0 = out_ptr; \ -+ auto query_ptr_ct1 = query_ptr; \ -+ auto key_cache_ptr_ct2 = key_cache_ptr; \ -+ auto value_cache_ptr_ct3 = value_cache_ptr; \ -+ auto scale_ct5 = scale; \ -+ auto block_tables_ptr_ct6 = block_tables_ptr; \ -+ auto context_lens_ptr_ct7 = context_lens_ptr; \ -+ auto max_num_blocks_per_seq_ct8 = max_num_blocks_per_seq; \ -+ auto alibi_slopes_ptr_ct9 = alibi_slopes_ptr; \ -+ auto q_stride_ct10 = q_stride; \ -+ auto kv_block_stride_ct11 = kv_block_stride; \ -+ auto kv_head_stride_ct12 = kv_head_stride; \ -+ auto attn_logit_softcapping_ct13 = attn_logit_softcapping; \ -+ \ -+ cgh.parallel_for( \ -+ sycl::nd_range<3>(grid * block, block), \ -+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ -+ paged_attention_v1_kernel< \ -+ sycl_t, \ -+ Q_Vec, \ -+ HEAD_SIZE, \ -+ BLOCK_SIZE, \ -+ NUM_THREADS, \ -+ VEC_SIZE>( \ -+ out_ptr_ct0, \ -+ query_ptr_ct1, \ -+ key_cache_ptr_ct2, \ -+ value_cache_ptr_ct3, \ -+ num_kv_heads, \ -+ scale_ct5, \ -+ block_tables_ptr_ct6, \ -+ context_lens_ptr_ct7, \ -+ max_num_blocks_per_seq_ct8, \ -+ alibi_slopes_ptr_ct9, \ -+ q_stride_ct10, \ -+ kv_block_stride_ct11, \ -+ kv_head_stride_ct12, \ -+ attn_logit_softcapping_ct13, \ -+ item_ct1, \ -+ dpct_local_acc_ct1.get_pointer(), \ -+ q_vecs_acc_ct1.get_pointer(), \ -+ red_smem_acc_ct1.get_pointer()); \ -+ }); \ -+ }); -+ -+template -+void paged_attention_xpu_v1_impl_launcher( -+ torch::Tensor& out, -+ torch::Tensor& query, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ int num_kv_heads, -+ float scale, -+ torch::Tensor& block_tables, -+ torch::Tensor& context_lens, -+ int max_context_len, -+ const c10::optional& alibi_slopes, -+ const float attn_logit_softcapping) { -+ int num_seqs = query.size(0); -+ int num_heads = query.size(1); -+ int head_size = query.size(2); -+ int max_num_blocks_per_seq = block_tables.size(1); -+ int q_stride = query.stride(0); -+ int kv_block_stride = key_cache.stride(0); -+ int kv_head_stride = key_cache.stride(1); -+ -+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); -+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ using Q_Vec = typename Vec::Type; -+ -+ int num_vecs_per_thread = head_size / THREAD_GROUP_SIZE / VEC_SIZE; -+ assert(head_size % THREAD_GROUP_SIZE == 0); -+ -+ // NOTE: alibi_slopes is optional. -+ const float* alibi_slopes_ptr = alibi_slopes -+ ? reinterpret_cast(alibi_slopes.value().data_ptr()) -+ : nullptr; -+ -+ sycl_t* out_ptr = reinterpret_cast(out.data_ptr()); -+ sycl_t* query_ptr = reinterpret_cast(query.data_ptr()); -+ sycl_t* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); -+ sycl_t* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); -+ int* block_tables_ptr = block_tables.data_ptr(); -+ int* context_lens_ptr = context_lens.data_ptr(); -+ -+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -+ int padded_max_context_len = -+ DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; -+ -+ int logits_size = padded_max_context_len * sizeof(float); -+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); -+ // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len -+ // Keep that in sync with the logic here! -+ int shared_mem_size = std::max(logits_size, outputs_size); -+ -+ sycl::range<3> grid(1, num_seqs, num_heads); -+ sycl::range<3> block(1, 1, NUM_THREADS); -+ sycl::queue& queue = vllm::xpu::vllmGetQueue(); -+ sycl::event event; -+ -+ switch (head_size) { -+ // NOTE(woosuk): To reduce the compilation time, we only compile for the -+ // head sizes that we use in the model. However, we can easily extend this -+ // to support any head size which is a multiple of 16. -+ case 64: -+ LAUNCH_PAGED_ATTENTION_V1(64); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v1", event); -+#endif -+ break; -+ case 80: -+ LAUNCH_PAGED_ATTENTION_V1(80); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v1", event); -+#endif -+ break; -+ case 96: -+ LAUNCH_PAGED_ATTENTION_V1(96); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v1", event); -+#endif -+ break; -+ case 112: -+ LAUNCH_PAGED_ATTENTION_V1(112); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v1", event); -+#endif -+ break; -+ case 128: -+ LAUNCH_PAGED_ATTENTION_V1(128); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v1", event); -+#endif -+ break; -+ case 256: -+ LAUNCH_PAGED_ATTENTION_V1(256); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v1", event); -+#endif -+ break; -+ default: -+ TORCH_CHECK(false, "Unsupported head size: ", head_size); -+ break; -+ } -+ // queue.wait(); -+} ++# execution env ++GIT_ROOT=$(git rev-parse --show-toplevel) ++EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration" + -+#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ -+ vllm::paged_attention_xpu_v1_impl_launcher( \ -+ out, \ -+ query, \ -+ key_cache, \ -+ value_cache, \ -+ num_kv_heads, \ -+ scale, \ -+ block_tables, \ -+ context_lens, \ -+ max_context_len, \ -+ alibi_slopes, \ -+ attn_logit_softcapping); -+ -+#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ -+ switch (block_size) { \ -+ case 8: \ -+ CALL_KERNEL_LAUNCHER(T, 8); \ -+ break; \ -+ case 16: \ -+ CALL_KERNEL_LAUNCHER(T, 16); \ -+ break; \ -+ case 32: \ -+ CALL_KERNEL_LAUNCHER(T, 32); \ -+ break; \ -+ case 64: \ -+ CALL_KERNEL_LAUNCHER(T, 64); \ -+ break; \ -+ default: \ -+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ -+ break; \ -+ } -+ -+// Grid: (num_heads, num_seqs). -+template < -+ typename scalar_t, -+ int HEAD_SIZE, -+ int NUM_THREADS, -+ int PARTITION_SIZE> -+void paged_attention_v2_reduce_kernel( -+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] -+ const float* __restrict__ exp_sums, // [num_seqs, num_heads, -+ // max_num_partitions] -+ const float* __restrict__ max_logits, // [num_seqs, num_heads, -+ // max_num_partitions] -+ const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, -+ // max_num_partitions, head_size] -+ const int* __restrict__ context_lens, // [num_seqs] -+ const int max_num_partitions, -+ const sycl::nd_item<3>& item_ct1, -+ uint8_t* dpct_local, -+ float* red_smem) { -+ const int num_heads = item_ct1.get_group_range(2); -+ const int head_idx = item_ct1.get_group(2); -+ const int seq_idx = item_ct1.get_group(1); -+ const int context_len = context_lens[seq_idx]; -+ const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); -+ if (num_partitions == 1) { -+ // No need to reduce. Only copy tmp_out to out. -+ scalar_t* out_ptr = -+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -+ const scalar_t* tmp_out_ptr = tmp_out + -+ seq_idx * num_heads * max_num_partitions * HEAD_SIZE + -+ head_idx * max_num_partitions * HEAD_SIZE; -+ for (int i = item_ct1.get_local_id(2); i < HEAD_SIZE; -+ i += item_ct1.get_local_range(2)) { -+ out_ptr[i] = tmp_out_ptr[i]; -+ } -+ // Terminate the thread block. -+ return; -+ } -+ -+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -+ const int warp_idx = item_ct1.get_local_id(2) / WARP_SIZE; -+ const int lane = item_ct1.get_local_id(2) % WARP_SIZE; -+ -+ // Size: 2 * num_partitions. -+ auto shared_mem = (char*)dpct_local; -+ // Workspace for reduction. -+ -+ // Load max logits to shared memory. -+ float* shared_max_logits = reinterpret_cast(shared_mem); -+ const float* max_logits_ptr = max_logits + -+ seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; -+ float max_logit = -FLT_MAX; -+ for (int i = item_ct1.get_local_id(2); i < num_partitions; -+ i += item_ct1.get_local_range(2)) { -+ const float l = max_logits_ptr[i]; -+ shared_max_logits[i] = l; -+ max_logit = sycl::fmax(max_logit, (float)l); -+ } -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); ++OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.xpu_accuracy_test_outputs.txt"} + -+ // Get the global max logit. -+ // Reduce within the warp. -+#pragma unroll -+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -+ -+ /* -+ DPCT1096:45: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU -+ device. Modify the size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ max_logit = sycl::fmax( -+ max_logit, -+ dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), max_logit, mask)); -+ } -+ if (lane == 0) { -+ red_smem[warp_idx] = max_logit; -+ } -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ // Reduce across warps. -+ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -+#pragma unroll -+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -+ -+ /* -+ DPCT1096:46: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the CPU -+ device. Modify the size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ max_logit = sycl::fmax( -+ max_logit, -+ dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), max_logit, mask)); -+ } -+ // Broadcast the max value to all threads. -+ -+ /* -+ DPCT1096:47: The right-most dimension of the work-group used in the SYCL -+ kernel that calls this function may be less than "32". The function -+ "dpct::select_from_sub_group" may return an unexpected result on the CPU -+ device. Modify the size of the work-group to ensure that the value of the -+ right-most dimension is a multiple of "32". -+ */ -+ max_logit = dpct::select_from_sub_group( -+ item_ct1.get_sub_group(), max_logit, 0); -+ -+ // Load rescaled exp sums to shared memory. -+ float* shared_exp_sums = -+ reinterpret_cast(shared_mem + sizeof(float) * num_partitions); -+ const float* exp_sums_ptr = exp_sums + -+ seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; -+ float global_exp_sum = 0.0f; -+ for (int i = item_ct1.get_local_id(2); i < num_partitions; -+ i += item_ct1.get_local_range(2)) { -+ float l = shared_max_logits[i]; -+ float rescaled_exp_sum = exp_sums_ptr[i] * sycl::exp(l - max_logit); -+ global_exp_sum += rescaled_exp_sum; -+ shared_exp_sums[i] = rescaled_exp_sum; -+ } -+ -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ global_exp_sum = -+ block_sum(&red_smem[NUM_WARPS], global_exp_sum, item_ct1); -+ const float inv_global_exp_sum = 1.0f / (global_exp_sum + 1e-6f); -+ -+ // Aggregate tmp_out to out. -+ const scalar_t* tmp_out_ptr = tmp_out + -+ seq_idx * num_heads * max_num_partitions * HEAD_SIZE + -+ head_idx * max_num_partitions * HEAD_SIZE; -+ scalar_t* out_ptr = -+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -+#pragma unroll -+ for (int i = item_ct1.get_local_id(2); i < HEAD_SIZE; i += NUM_THREADS) { -+ float acc = 0.0f; -+ for (int j = 0; j < num_partitions; ++j) { -+ acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * -+ inv_global_exp_sum; -+ } -+ from_float(out_ptr[i], acc); -+ } -+} ++# Trap the SIGINT signal (triggered by Ctrl+C) ++trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + -+// Grid: (num_heads, num_seqs, max_num_partitions). -+template < -+ typename scalar_t, -+ typename Q_Vec_t, -+ int HEAD_SIZE, -+ int BLOCK_SIZE, -+ int NUM_THREADS, -+ int VEC_SIZE, -+ int PARTITION_SIZE> -+void paged_attention_v2_kernel( -+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] -+ float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] -+ scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, -+ // head_size] -+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] -+ const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, -+ // head_size/x, block_size, x] -+ const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, -+ // head_size, block_size] -+ const int num_kv_heads, // [num_heads] -+ const float scale, -+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] -+ const int* __restrict__ context_lens, // [num_seqs] -+ const int max_num_blocks_per_seq, -+ const float* __restrict__ alibi_slopes, // [num_heads] -+ const int q_stride, -+ const int kv_block_stride, -+ const int kv_head_stride, -+ const float attn_logit_softcapping, -+ const sycl::nd_item<3>& item_ct1, -+ uint8_t* dpct_local, -+ Q_Vec_t* q_vecs, -+ float* red_smem) { -+ paged_attention_kernel< -+ scalar_t, -+ Q_Vec_t, -+ HEAD_SIZE, -+ BLOCK_SIZE, -+ NUM_THREADS, -+ VEC_SIZE, -+ PARTITION_SIZE>( -+ exp_sums, -+ max_logits, -+ tmp_out, -+ q, -+ k_cache, -+ v_cache, -+ num_kv_heads, -+ scale, -+ block_tables, -+ context_lens, -+ max_num_blocks_per_seq, -+ alibi_slopes, -+ q_stride, -+ kv_block_stride, -+ kv_head_stride, -+ attn_logit_softcapping, -+ item_ct1, -+ dpct_local, -+ q_vecs, -+ red_smem); ++cleanup() { ++ echo "Cleaning up any running vLLM instances..." ++ pkill -f "vllm serve" || true ++ sleep 2 +} + -+#define LAUNCH_PAGED_ATTENTION_V2_FIRST_HALF(HEAD_SIZE) \ -+ event = queue.submit([&](sycl::handler& cgh) { \ -+ sycl::local_accessor dpct_local_acc_ct1( \ -+ sycl::range<1>(shared_mem_size), cgh); \ -+ sycl::local_accessor q_vecs_acc_ct1( \ -+ sycl::range<1>(THREAD_GROUP_SIZE * num_vecs_per_thread), cgh); \ -+ sycl::local_accessor red_smem_acc_ct1( \ -+ sycl::range<1>(2 * NUM_WARPS), cgh); \ -+ \ -+ auto exp_sums_ptr_ct0 = exp_sums_ptr; \ -+ auto max_logits_ptr_ct1 = max_logits_ptr; \ -+ auto tmp_out_ptr_ct2 = tmp_out_ptr; \ -+ auto query_ptr_ct3 = query_ptr; \ -+ auto key_cache_ptr_ct4 = key_cache_ptr; \ -+ auto value_cache_ptr_ct5 = value_cache_ptr; \ -+ auto scale_ct7 = scale; \ -+ auto block_tables_ptr_ct8 = block_tables_ptr; \ -+ auto context_lens_ptr_ct9 = context_lens_ptr; \ -+ auto max_num_blocks_per_seq_ct10 = max_num_blocks_per_seq; \ -+ auto alibi_slopes_ptr_ct11 = alibi_slopes_ptr; \ -+ auto q_stride_ct12 = q_stride; \ -+ auto kv_block_stride_ct13 = kv_block_stride; \ -+ auto kv_head_stride_ct14 = kv_head_stride; \ -+ auto attn_logit_softcapping_ct15 = attn_logit_softcapping; \ -+ \ -+ cgh.parallel_for( \ -+ sycl::nd_range<3>(grid * block, block), \ -+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ -+ vllm::paged_attention_v2_kernel< \ -+ sycl_t, \ -+ Q_Vec, \ -+ HEAD_SIZE, \ -+ BLOCK_SIZE, \ -+ NUM_THREADS, \ -+ VEC_SIZE, \ -+ PARTITION_SIZE>( \ -+ exp_sums_ptr_ct0, \ -+ max_logits_ptr_ct1, \ -+ tmp_out_ptr_ct2, \ -+ query_ptr_ct3, \ -+ key_cache_ptr_ct4, \ -+ value_cache_ptr_ct5, \ -+ num_kv_heads, \ -+ scale_ct7, \ -+ block_tables_ptr_ct8, \ -+ context_lens_ptr_ct9, \ -+ max_num_blocks_per_seq_ct10, \ -+ alibi_slopes_ptr_ct11, \ -+ q_stride_ct12, \ -+ kv_block_stride_ct13, \ -+ kv_head_stride_ct14, \ -+ attn_logit_softcapping_ct15, \ -+ item_ct1, \ -+ dpct_local_acc_ct1.get_pointer(), \ -+ q_vecs_acc_ct1.get_pointer(), \ -+ red_smem_acc_ct1.get_pointer()); \ -+ }); \ -+ }); -+ -+#define LAUNCH_PAGED_ATTENTION_V2_SECOND_HALF(HEAD_SIZE) \ -+ event2 = queue.submit([&](sycl::handler& cgh) { \ -+ sycl::local_accessor dpct_local_acc_ct1( \ -+ sycl::range<1>(reduce_shared_mem_size), cgh); \ -+ sycl::local_accessor red_smem_acc_ct1( \ -+ sycl::range<1>(2 * NUM_WARPS), cgh); \ -+ \ -+ auto out_ptr_ct0 = out_ptr; \ -+ auto exp_sums_ptr_ct1 = exp_sums_ptr; \ -+ auto max_logits_ptr_ct2 = max_logits_ptr; \ -+ auto tmp_out_ptr_ct3 = tmp_out_ptr; \ -+ auto context_lens_ptr_ct4 = context_lens_ptr; \ -+ auto max_num_partitions_ct5 = max_num_partitions; \ -+ \ -+ cgh.parallel_for( \ -+ sycl::nd_range<3>(reduce_grid * block, block), \ -+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { \ -+ vllm::paged_attention_v2_reduce_kernel< \ -+ sycl_t, \ -+ HEAD_SIZE, \ -+ NUM_THREADS, \ -+ PARTITION_SIZE>( \ -+ out_ptr_ct0, \ -+ exp_sums_ptr_ct1, \ -+ max_logits_ptr_ct2, \ -+ tmp_out_ptr_ct3, \ -+ context_lens_ptr_ct4, \ -+ max_num_partitions_ct5, \ -+ item_ct1, \ -+ dpct_local_acc_ct1.get_pointer(), \ -+ red_smem_acc_ct1.get_pointer()); \ -+ }); \ -+ }); -+ -+template < -+ typename T, -+ int BLOCK_SIZE, -+ int NUM_THREADS = 512, -+ int PARTITION_SIZE = 512> -+void paged_attention_v2_launcher( -+ torch::Tensor& out, -+ torch::Tensor& exp_sums, -+ torch::Tensor& max_logits, -+ torch::Tensor& tmp_out, -+ torch::Tensor& query, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ int num_kv_heads, -+ float scale, -+ torch::Tensor& block_tables, -+ torch::Tensor& context_lens, -+ int max_context_len, -+ const c10::optional& alibi_slopes, -+ const float attn_logit_softcapping) { -+ int num_seqs = query.size(0); -+ int num_heads = query.size(1); -+ int head_size = query.size(2); -+ int max_num_blocks_per_seq = block_tables.size(1); -+ int q_stride = query.stride(0); -+ int kv_block_stride = key_cache.stride(0); -+ int kv_head_stride = key_cache.stride(1); -+ -+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); -+ assert(head_size % THREAD_GROUP_SIZE == 0); -+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ using Q_Vec = typename Vec::Type; -+ -+ int num_vecs_per_thread = head_size / THREAD_GROUP_SIZE / VEC_SIZE; -+ assert(head_size % THREAD_GROUP_SIZE == 0); -+ -+ // NOTE: alibi_slopes is optional. -+ const float* alibi_slopes_ptr = alibi_slopes -+ ? reinterpret_cast(alibi_slopes.value().data_ptr()) -+ : nullptr; -+ -+ sycl_t* out_ptr = reinterpret_cast(out.data_ptr()); -+ float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); -+ float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); -+ sycl_t* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); -+ sycl_t* query_ptr = reinterpret_cast(query.data_ptr()); -+ sycl_t* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); -+ sycl_t* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); -+ int* block_tables_ptr = block_tables.data_ptr(); -+ int* context_lens_ptr = context_lens.data_ptr(); -+ -+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -+ int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); -+ -+ int logits_size = PARTITION_SIZE * sizeof(float); -+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); -+ -+ // For paged attention v2 kernel. -+ sycl::range<3> grid(max_num_partitions, num_seqs, num_heads); -+ int shared_mem_size = std::max(logits_size, outputs_size); -+ // For paged attention v2 reduce kernel. -+ sycl::range<3> reduce_grid(1, num_seqs, num_heads); -+ -+ int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); -+ -+ sycl::range<3> block(1, 1, NUM_THREADS); -+ sycl::queue& queue = vllm::xpu::vllmGetQueue(); -+ sycl::event event; -+ sycl::event event2; -+ switch (head_size) { -+ // NOTE(woosuk): To reduce the compilation time, we only compile for the -+ // head sizes that we use in the model. However, we can easily extend this -+ // to support any head size which is a multiple of 16. -+ case 64: -+ LAUNCH_PAGED_ATTENTION_V2_FIRST_HALF(64); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event); -+#endif -+ LAUNCH_PAGED_ATTENTION_V2_SECOND_HALF(64); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event2); -+#endif -+ break; -+ case 80: -+ LAUNCH_PAGED_ATTENTION_V2_FIRST_HALF(80); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event); -+#endif -+ LAUNCH_PAGED_ATTENTION_V2_SECOND_HALF(80); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event2); -+#endif -+ break; -+ case 96: -+ LAUNCH_PAGED_ATTENTION_V2_FIRST_HALF(96); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event); -+#endif -+ LAUNCH_PAGED_ATTENTION_V2_SECOND_HALF(96); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event2); -+#endif -+ break; -+ case 112: -+ LAUNCH_PAGED_ATTENTION_V2_FIRST_HALF(112); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event); -+#endif -+ LAUNCH_PAGED_ATTENTION_V2_SECOND_HALF(112); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event2); -+#endif -+ break; -+ case 128: -+ LAUNCH_PAGED_ATTENTION_V2_FIRST_HALF(128); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event); -+#endif -+ LAUNCH_PAGED_ATTENTION_V2_SECOND_HALF(128); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event2); -+#endif -+ break; -+ case 256: -+ LAUNCH_PAGED_ATTENTION_V2_FIRST_HALF(256); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event); -+#endif -+ LAUNCH_PAGED_ATTENTION_V2_SECOND_HALF(256); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(event_desc, event); // Uncomment when needed -+#else -+ ::xpu::profiler_record("paged attn v2", event2); -+#endif -+ break; -+ default: -+ TORCH_CHECK(false, "Unsupported head size: ", head_size); -+ break; -+ } ++wait_for_server() { ++ local host=$1 ++ local port=$2 ++ timeout 1200 bash -c " ++ until curl -s ${host}:${port}/v1/completions > /dev/null; do ++ sleep 1 ++ done" && return 0 || return 1 +} + -+#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ -+ vllm::paged_attention_v2_launcher( \ -+ out, \ -+ exp_sums, \ -+ max_logits, \ -+ tmp_out, \ -+ query, \ -+ key_cache, \ -+ value_cache, \ -+ num_kv_heads, \ -+ scale, \ -+ block_tables, \ -+ context_lens, \ -+ max_context_len, \ -+ alibi_slopes, \ -+ attn_logit_softcapping); -+ -+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ -+ switch (block_size) { \ -+ case 8: \ -+ CALL_V2_LAUNCHER(T, 8); \ -+ break; \ -+ case 16: \ -+ CALL_V2_LAUNCHER(T, 16); \ -+ break; \ -+ case 32: \ -+ CALL_V2_LAUNCHER(T, 32); \ -+ break; \ -+ case 64: \ -+ CALL_V2_LAUNCHER(T, 64); \ -+ break; \ -+ default: \ -+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ -+ break; \ -+ } -+ -+} // namespace vllm -+ -+void paged_attention_v1( -+ torch::Tensor& out, -+ torch::Tensor& query, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ int num_kv_heads, -+ float scale, -+ torch::Tensor& block_tables, -+ torch::Tensor& context_lens, -+ int block_size, -+ int max_context_len, -+ const c10::optional& alibi_slopes, -+ const std::string& kv_cache_dtype, -+ const float kv_scale, -+ const float attn_logit_softcapping) { -+ VLLM_XPU_DISPATCH_FLOATING_TYPES_FLOAT_ONLY( -+ query.scalar_type(), "paged_attention_xpu_v1_impl", [&] { -+ CALL_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); -+ }); ++launch_baseline() { ++ BASELINE_BASE_CMD=" ++ ONEAPI_DEVICE_SELECTOR=level_zero:0 \ ++ VLLM_USE_V1=1 \ ++ VLLM_WORKER_MULTIPROC_METHOD=spawn \ ++ VLLM_ENABLE_V1_MULTIPROCESSING=1 vllm serve $MODEL_NAME \ ++ --host ${BASELINE_HOST} \ ++ --port ${BASELINE_PORT} \ ++ --max-model-len ${MAX_MODEL_LEN}\ ++ --seed 42 \ ++ -tp 1 \ ++ --block-size ${BLOCK_SIZE} \ ++ --gpu-memory-utilization 0.8 \ ++ --disable-log-requests \ ++ --dtype float16 \ ++ --enforce-eager" ++ echo ${BASELINE_BASE_CMD} ++ bash -c "${BASELINE_BASE_CMD}" & ++ sleep 10 ++ wait_for_server ${BASELINE_HOST} ${BASELINE_PORT} +} + -+void paged_attention_v2( -+ torch::Tensor& out, -+ torch::Tensor& exp_sums, -+ torch::Tensor& max_logits, -+ torch::Tensor& tmp_out, -+ torch::Tensor& query, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ int num_kv_heads, -+ float scale, -+ torch::Tensor& block_tables, -+ torch::Tensor& context_lens, -+ int block_size, -+ int max_context_len, -+ const c10::optional& alibi_slopes, -+ const std::string& kv_cache_dtype, -+ const float kv_scale, -+ const float attn_logit_softcapping) { -+ VLLM_XPU_DISPATCH_FLOATING_TYPES_FLOAT_ONLY( -+ query.scalar_type(), "paged_attention_xpu_v2_impl", [&] { -+ CALL_V2_LAUNCHER_BLOCK_SIZE(scalar_t); -+ }); -+} ++launch_pd() { ++ PREFILL_BASE_CMD=" ++ ONEAPI_DEVICE_SELECTOR=level_zero:0 \ ++ VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ ++ VLLM_USE_V1=1 \ ++ VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \ ++ VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \ ++ VLLM_WORKER_MULTIPROC_METHOD=spawn \ ++ VLLM_ENABLE_V1_MULTIPROCESSING=1 vllm serve $MODEL_NAME \ ++ --host ${PREFILL_HOST} \ ++ --port ${PREFILL_PORT} \ ++ --max-model-len ${MAX_MODEL_LEN}\ ++ --seed 42 \ ++ --block-size ${BLOCK_SIZE} \ ++ --enforce-eager \ ++ --dtype float16 \ ++ -tp 1 \ ++ --gpu-memory-utilization 0.8 \ ++ --disable-log-requests \ ++ --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + -+torch::Tensor context_attention_forward_v2( -+ torch::Tensor query, // [num_tokens, num_kv_head, head_dim] -+ torch::Tensor key, // [num_tokens, num_kv_heads * head_size] -+ torch::Tensor value, // [num_tokens, num_kv_heads * head_size] -+ torch::Tensor block_tables, torch::Tensor query_start_loc, -+ torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, -+ int max_context_length, int max_q_length) { -+ // Currently, only support fp16 here -+ int64_t num_tokens = query.size(0); -+ int64_t num_heads = query.size(1); -+ int64_t head_dim = query.size(2); -+ int64_t batch_size = seq_lens.size(0); -+ int num_kv_heads = value.size(1); -+ -+ int key_dimension = key.dim(); -+ auto output = at::empty({query.size(0), query.size(1), query.size(2)}, -+ at::device(query.device()).dtype(query.dtype())); -+ -+ assert(key_dimension == 5); -+ assert(query.scalar_type() == key.scalar_type() && -+ query.scalar_type() == value.scalar_type()); -+ assert(head_dim == 128); -+ assert(query.scalar_type() == at::ScalarType::Half); -+ -+ int query_stride_token = query.stride(0); -+ int query_stride_head = query.stride(1); -+ int query_stride_dim = query.stride(2); -+ const float attn_scale = 1 / std::sqrt((float)head_dim); -+ -+ assert(num_heads % num_kv_heads == 0); -+ int num_queries_per_kv = num_heads / num_kv_heads; -+ -+ -+ // key: num_blocks, num_kv_heads, head_size // x, num_blocks, x) -+ // value: [num_blocks, num_kv_heads, head_size, block_dim] -+ int block_size = value.size(3); -+ // Currently, only block_size 16 is supported... -+ assert(block_size == 16); -+ int x = key.size(4); -+ int block_table_stride_bsz = block_tables.stride(0); -+ int block_table_stride_seq = block_tables.stride(1); -+ int k_cache_stride_token = key.stride(0); -+ int k_cache_stride_head = key.stride(1); -+ int k_cache_stride_head_dim = key.stride(2); -+ int k_cache_stride_block = key.stride(3); -+ int k_cache_stride_x = key.stride(4); -+ -+ int v_cache_stride_token = value.stride(0); -+ int v_cache_stride_head = value.stride(1); -+ int v_cache_stride_head_dim = value.stride(2); -+ int v_cache_stride_block = value.stride(3); -+ switch(head_dim) { -+ case 128: -+ vllm::context_attention_kernel_v2( -+ query.data_ptr(), key.data_ptr(), value.data_ptr(), -+ block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), -+ seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, -+ output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, -+ query_stride_token, query_stride_head, query_stride_dim, -+ k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, -+ k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, -+ v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, -+ output.stride(0), output.stride(1), num_queries_per_kv, -+ max_input_length, batch_size, num_heads, query.size(0), -+ max_context_length, max_q_length); -+ break; -+ case 64: -+ vllm::context_attention_kernel_v2( -+ query.data_ptr(), key.data_ptr(), value.data_ptr(), -+ block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), -+ seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, -+ output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, -+ query_stride_token, query_stride_head, query_stride_dim, -+ k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, -+ k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, -+ v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, -+ output.stride(0), output.stride(1), num_queries_per_kv, -+ max_input_length, batch_size, num_heads, query.size(0), -+ max_context_length, max_q_length); -+ break; -+ case 80: -+ vllm::context_attention_kernel_v2( -+ query.data_ptr(), key.data_ptr(), value.data_ptr(), -+ block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), -+ seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, -+ output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, -+ query_stride_token, query_stride_head, query_stride_dim, -+ k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, -+ k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, -+ v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, -+ output.stride(0), output.stride(1), num_queries_per_kv, -+ max_input_length, batch_size, num_heads, query.size(0), -+ max_context_length, max_q_length); -+ break; -+ case 96: -+ vllm::context_attention_kernel_v2( -+ query.data_ptr(), key.data_ptr(), value.data_ptr(), -+ block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), -+ seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, -+ output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, -+ query_stride_token, query_stride_head, query_stride_dim, -+ k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, -+ k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, -+ v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, -+ output.stride(0), output.stride(1), num_queries_per_kv, -+ max_input_length, batch_size, num_heads, query.size(0), -+ max_context_length, max_q_length); -+ break; -+ default: throw std::runtime_error("unsupported head_dim"); -+ } -+ return output; -+} + -+torch::Tensor context_attention_forward_v1( -+ torch::Tensor query, // [num_tokens, num_kv_head, head_dim] -+ torch::Tensor key, // [num_tokens, num_kv_heads * head_size] -+ torch::Tensor value, // [num_tokens, num_kv_heads * head_size] -+ torch::Tensor block_tables, torch::Tensor query_start_loc, -+ torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, -+ int max_context_length) { -+ // Currently, only support fp16 -+ int64_t num_tokens = query.size(0); -+ int64_t num_heads = query.size(1); -+ int64_t head_dim = query.size(2); -+ int64_t batch_size = seq_lens.size(0); -+ int num_kv_heads = value.size(1); -+ -+ int key_dimension = key.dim(); -+ auto output = at::empty({query.size(0), query.size(1), query.size(2)}, -+ at::device(query.device()).dtype(query.dtype())); -+ -+ // key should be in shape: -+ // 1. [num_blocks, num_heads, block_size, head_dim] -+ // 2. [num_blocks, num_heads, head_dim / x, block_size, x] -+ assert(key_dimension == 4 or key_dimension == 5); -+ assert(query.scalar_type() == key.scalar_type() && -+ query.scalar_type() == value.scalar_type()); -+ assert(query.scalar_type() == at::ScalarType::Half); -+ -+ int query_stride_token = query.stride(0); -+ int query_stride_head = query.stride(1); -+ int query_stride_dim = query.stride(2); -+ const float attn_scale = 1 / std::sqrt((float)head_dim); -+ -+ assert(num_heads % num_kv_heads == 0); -+ int num_queries_per_kv = num_heads / num_kv_heads; -+ int block_table_stride_bsz = block_tables.stride(0); -+ int block_table_stride_seq = block_tables.stride(1); -+ if (key_dimension == 4) { -+ // key/value: num_blocks, num_kv_heads, num_blocks, head_dim) -+ int block_size = value.size(2); -+ int k_cache_stride_0 = key.stride(0); -+ int k_cache_stride_1 = key.stride(1); -+ int k_cache_stride_2 = key.stride(2); -+ int k_cache_stride_3 = key.stride(3); -+ -+ int v_cache_stride_0 = value.stride(0); -+ int v_cache_stride_1 = value.stride(1); -+ int v_cache_stride_2 = value.stride(2); -+ int v_cache_stride_3 = value.stride(3); -+ switch (head_dim) { -+ case 128: -+ vllm::context_attention_kernel_v1_reshaped( -+ query.data_ptr(), key.data_ptr(), value.data_ptr(), -+ block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), -+ seq_lens.data_ptr(), context_lens.data_ptr(), block_size, -+ output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, -+ query_stride_token, query_stride_head, query_stride_dim, -+ k_cache_stride_0, k_cache_stride_1, k_cache_stride_2, -+ k_cache_stride_3, v_cache_stride_0, v_cache_stride_1, -+ v_cache_stride_2, v_cache_stride_3, output.stride(0), -+ output.stride(1), num_queries_per_kv, max_input_length, batch_size, -+ num_heads); -+ break; -+ case 64: -+ vllm::context_attention_kernel_v1_reshaped( -+ query.data_ptr(), key.data_ptr(), value.data_ptr(), -+ block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), -+ seq_lens.data_ptr(), context_lens.data_ptr(), block_size, -+ output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, -+ query_stride_token, query_stride_head, query_stride_dim, -+ k_cache_stride_0, k_cache_stride_1, k_cache_stride_2, -+ k_cache_stride_3, v_cache_stride_0, v_cache_stride_1, -+ v_cache_stride_2, v_cache_stride_3, output.stride(0), -+ output.stride(1), num_queries_per_kv, max_input_length, batch_size, -+ num_heads); -+ break; -+ default: -+ throw std::runtime_error("unsupported head_dim"); -+ } -+ } else { -+ int x = key.size(4); -+ int block_size = value.size(3); -+ int k_cache_stride_token = key.stride(0); -+ int k_cache_stride_head = key.stride(1); -+ int k_cache_stride_head_dim = key.stride(2); -+ int k_cache_stride_block = key.stride(3); -+ int k_cache_stride_x = key.stride(4); -+ -+ int v_cache_stride_token = value.stride(0); -+ int v_cache_stride_head = value.stride(1); -+ int v_cache_stride_head_dim = value.stride(2); -+ int v_cache_stride_block = value.stride(3); -+ switch (head_dim) { -+ case 128: -+ vllm::context_attention_kernel_v1( -+ query.data_ptr(), key.data_ptr(), value.data_ptr(), -+ block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), -+ seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, -+ output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, -+ query_stride_token, query_stride_head, query_stride_dim, -+ k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, -+ k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, -+ v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, -+ output.stride(0), output.stride(1), num_queries_per_kv, -+ max_input_length, batch_size, num_heads); -+ break; -+ case 64: -+ vllm::context_attention_kernel_v1( -+ query.data_ptr(), key.data_ptr(), value.data_ptr(), -+ block_tables.data_ptr(), attn_scale, query_start_loc.data_ptr(), -+ seq_lens.data_ptr(), context_lens.data_ptr(), block_size, x, -+ output.data_ptr(), block_table_stride_bsz, block_table_stride_seq, -+ query_stride_token, query_stride_head, query_stride_dim, -+ k_cache_stride_token, k_cache_stride_head, k_cache_stride_head_dim, -+ k_cache_stride_block, k_cache_stride_x, v_cache_stride_token, -+ v_cache_stride_head, v_cache_stride_head_dim, v_cache_stride_block, -+ output.stride(0), output.stride(1), num_queries_per_kv, -+ max_input_length, batch_size, num_heads); -+ break; -+ default: -+ throw std::runtime_error("unsupported head_dim"); -+ } -+ } -+ return output; -+} ++ DECODE_BASE_CMD=" ++ ONEAPI_DEVICE_SELECTOR=level_zero:1 \ ++ VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ ++ VLLM_USE_V1=1 \ ++ VLLM_WORKER_MULTIPROC_METHOD=spawn \ ++ VLLM_ENABLE_V1_MULTIPROCESSING=1 vllm serve $MODEL_NAME \ ++ --host ${DECODE_HOST} \ ++ --port ${DECODE_PORT} \ ++ --max-model-len ${MAX_MODEL_LEN}\ ++ --seed 42 \ ++ --block-size ${BLOCK_SIZE} \ ++ --enforce-eager \ ++ -tp 1 \ ++ --dtype float16 \ ++ --gpu-memory-utilization 0.8 \ ++ --disable-log-requests \ ++ --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + -+template -+void gqa_1_kernel( -+ const void * query, // [num_seqs, num_heads, head_size] -+ const void * key, // [num_blocks, num_kv_heads, head_size, block_size] -+ const void * value, // [num_blocks, num_kv_heads, head_size, block_size] -+ const void* block_tables, // [num_seqs, max_num_blocks_per_seq] -+ const void* context_lens, // [num_seqs] -+ void * o_a_s, -+ void * o_accs, -+ const int64_t query_bsz_stride, -+ const int64_t query_head_stride, -+ const int64_t kv_token_stride, -+ const int64_t kv_head_stride, -+ const int64_t kv_block_stride, -+ const int64_t block_table_stride_batch, -+ const int64_t o_a_s_bsz_stride, -+ const int64_t o_a_s_head_stride, -+ const int64_t o_accs_bsz_stride, -+ const int64_t o_accs_head_stride, -+ const float scale, -+ const int block_size, -+ const int bsz, -+ const int num_heads, -+ const int num_kv_heads, -+ const int block_num, -+ const at::Device & device -+) { -+ const int group_size = num_heads / num_kv_heads; -+ const int sub_rows = VS / group_size; -+ const int rem_rows = VS % group_size; -+ -+ const float attn_scale = scale; -+ -+ sycl::range<3> global_size(bsz, num_heads, block_num); -+ sycl::range<3> local_size(1, group_size, 1); -+ -+ auto cgf = [&](sycl::handler& handle) { -+ handle.parallel_for( -+ sycl::nd_range<3>(global_size, local_size), -+ [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { -+ slm_init(); -+ -+ const int bsz_idx = item.get_global_id(0); -+ const int head_idx = item.get_global_id(1); -+ const int kv_head_idx = item.get_group(1); -+ const int tid = item.get_local_id(1); -+ const int vid = item.get_global_id(2); -+ -+ const IT * query_head = (const IT *)query + bsz_idx * query_bsz_stride -+ + head_idx * query_head_stride; -+ -+ IT * o_accs_head = (IT *)o_accs + bsz_idx * o_accs_bsz_stride -+ + head_idx * o_accs_head_stride; -+ float * o_a_s_head = (float *)o_a_s + bsz_idx * o_a_s_bsz_stride -+ + head_idx * o_a_s_head_stride; -+ -+ const int* block_tables_ptr = (const int*)block_tables; -+ const int* block_table = -+ block_tables_ptr + bsz_idx * block_table_stride_batch; -+ -+ const int* context_lens_ptr = (const int*)context_lens; -+ const int context_length = context_lens_ptr[bsz_idx]; -+ -+ simd query_row = block_load(query_head) * attn_scale; -+ -+ // copy k_cache to slm -+ int start_row = std::min(vid * VS + tid * sub_rows + std::min(tid, rem_rows), context_length); -+ int end_row = std::min(start_row + sub_rows + (tid < rem_rows), context_length); -+ for (int r = start_row; r < end_row; ++r) { -+ int which_block = r / block_size; -+ int which_slot = r % block_size; -+ int physical_block_number = block_table[which_block]; -+ -+ const IT * key_head = (const IT *)key + physical_block_number * kv_token_stride + -+ kv_head_idx * kv_head_stride + -+ which_slot * kv_block_stride; -+ -+ simd key_row = block_load(key_head); -+ slm_block_store((r - vid * VS) * HD * sizeof(IT), key_row); -+ } -+ barrier(); -+ -+ simd attns = -sycl::detail::max_v(); -+ int row_num = (vid + 1) * VS > context_length ? context_length % VS : VS; -+ // q @ k -+ for (int r = 0; r < row_num; ++r) { -+ simd key_row = slm_block_load(r * HD * sizeof(IT)); -+ float attn = sycl::ext::intel::esimd::detail::sum(query_row * key_row); -+ attns[r] = attn; -+ } -+ -+ float max_attn = hmax(attns); -+ const simd attn_exp = exp(attns - max_attn); -+ barrier(); -+ -+ // copy v_cache to slm -+ for (int r = start_row; r < end_row; ++r) { -+ int which_block = r / block_size; -+ int which_slot = r % block_size; -+ int physical_block_number = block_table[which_block]; -+ -+ const IT * value_head = (const IT *)value + physical_block_number * kv_token_stride + -+ kv_head_idx * kv_head_stride + -+ which_slot * kv_block_stride; -+ -+ simd value_row = block_load(value_head); -+ slm_block_store((r - vid * VS) * HD * sizeof(IT), value_row); -+ } -+ barrier(); -+ -+ // attn @ v -+ simd accs = 0; -+ for (int r = 0; r < row_num; ++r) { -+ simd value_row = slm_block_load(r * HD * sizeof(IT)); -+ accs = accs + value_row * attn_exp[r]; -+ } -+ -+ float softmax = sycl::ext::intel::esimd::detail::sum(attn_exp); -+ -+ block_store(o_accs_head + vid * HD, accs); -+ block_store(o_a_s_head + vid * 2, max_attn); -+ block_store(o_a_s_head + vid * 2 + 1, softmax); -+ } -+ ); -+ }; ++ echo ${PREFILL_BASE_CMD} ++ echo ${DECODE_BASE_CMD} ++ sleep 2 + -+ utils::submit_kernel(cgf, device, "gqa kernel 1/2"); ++ # execute on hosts ++ bash -c "${PREFILL_BASE_CMD}" & ++ bash -c "${DECODE_BASE_CMD}" & ++ sleep 1 ++ wait_for_server ${PREFILL_HOST} ${PREFILL_PORT} ++ sleep 1 ++ wait_for_server ${DECODE_HOST} ${DECODE_PORT} ++ sleep 1 +} + -+template -+void gqa_2_kernel( -+ void * o_a_s, -+ void * o_accs, -+ void * output, -+ const void* context_lens, // [num_seqs] -+ const int64_t o_a_s_bsz_stride, -+ const int64_t o_a_s_head_stride, -+ const int64_t o_accs_bsz_stride, -+ const int64_t o_accs_head_stride, -+ const int64_t output_bsz_stride, -+ const int64_t output_head_stride, -+ const int bsz, -+ const int num_heads, -+ const int row_block_num, -+ const at::Device & device -+) { -+ constexpr int SUB_HD = 8; -+ static_assert(HD % SUB_HD == 0); -+ static_assert(HD / SUB_HD <= GS); -+ -+ const int sub_rows = row_block_num / GS; -+ const int rem_rows = row_block_num % GS; -+ -+ constexpr int accs_slm_offset = 0; -+ constexpr int attn_slm_offset = GS * HD * sizeof(float); -+ constexpr int softmax_slm_offset = attn_slm_offset + GS * sizeof(float); -+ -+ sycl::range<3> global_size(bsz, num_heads, GS); -+ sycl::range<3> local_size(1, 1, GS); -+ -+ auto cgf = [&](sycl::handler& handle) { -+ handle.parallel_for( -+ sycl::nd_range<3>(global_size, local_size), -+ [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { -+ slm_init(); -+ -+ const int bsz_idx = item.get_global_id(0); -+ const int head_idx = item.get_global_id(1); -+ const int tid = item.get_global_id(2); -+ -+ const int* context_lens_ptr = (const int*)context_lens; -+ const int context_length = context_lens_ptr[bsz_idx]; -+ constexpr int VS = 32; -+ const int cur_row_block_num = (context_length + VS - 1) / VS; -+ const int cur_sub_rows = cur_row_block_num / GS; -+ const int cur_rem_rows = cur_row_block_num % GS; -+ -+ const float * o_a_s_head = (const float *)o_a_s + bsz_idx * o_a_s_bsz_stride -+ + head_idx * o_a_s_head_stride; -+ const IT * o_accs_head = (const IT *)o_accs + bsz_idx * o_accs_bsz_stride -+ + head_idx * o_accs_head_stride; -+ IT * output_head = (IT *)output + bsz_idx * output_bsz_stride -+ + head_idx * output_head_stride; -+ -+ int start_row = std::min(tid * cur_sub_rows + std::min(tid, cur_rem_rows), cur_row_block_num); -+ int end_row = std::min(start_row + cur_sub_rows + (tid < cur_rem_rows), cur_row_block_num); -+ -+ float max_attn = -sycl::detail::max_v(); -+ float softmax = 0; -+ simd accs = 0; -+ for (int r = start_row; r < end_row; ++r) { -+ float sub_attn = o_a_s_head[2 * r]; -+ float sub_softmax = o_a_s_head[2 * r + 1]; -+ simd sub_accs = block_load(o_accs_head + r * HD); -+ float new_max_attn = std::max(max_attn, sub_attn); -+ float exp1 = exp(max_attn - new_max_attn); -+ float exp2 = exp(sub_attn - new_max_attn); -+ accs = accs * exp1 + sub_accs * exp2; -+ softmax = softmax * exp1 + sub_softmax * exp2; -+ max_attn = new_max_attn; -+ } -+ -+ slm_block_store(accs_slm_offset + tid * HD * sizeof(float), accs); -+ slm_block_store(attn_slm_offset + tid * sizeof(float), max_attn); -+ slm_block_store(softmax_slm_offset + tid * sizeof(float), softmax); -+ barrier(); -+ -+ if (tid < HD / SUB_HD) { -+ simd max_attns = slm_block_load(attn_slm_offset); -+ const simd scales = exp(max_attns - hmax(max_attns)); -+ simd softmaxs = slm_block_load(softmax_slm_offset); -+ float softmax_sum = sycl::ext::intel::esimd::detail::sum(softmaxs * scales); -+ -+ simd result = 0; -+ #pragma unroll -+ for (int r = 0; r < GS; ++r) { -+ simd sub_accs = slm_block_load( -+ accs_slm_offset + (r * HD + tid * SUB_HD) * sizeof(float) -+ ); -+ result = result + sub_accs * scales[r]; -+ } -+ result = result / softmax_sum; -+ block_store(output_head + tid * SUB_HD, result); -+ } -+ } -+ ); -+ }; -+ -+ utils::submit_kernel(cgf, device, "gqa kernel 2/2"); ++launch_pd_proxy(){ ++ PROXY_BASE_CMD=" ++ python3 ${EXP_ROOT}/toy_proxy_server.py \ ++ --prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \ ++ --decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \ ++ --host=${PROXY_HOST} --port ${PROXY_PORT}" ++ echo ${PROXY_BASE_CMD} ++ bash -c "${PROXY_BASE_CMD}" & ++ sleep 2 +} + -+using AT = at::ScalarType; -+using fp16 = sycl::half; -+template -+auto dispatch_gqa_kernel(AT it) { -+ switch (it) { -+ case AT::Float: return std::make_tuple(gqa_1_kernel, gqa_2_kernel); -+ case AT::Half: return std::make_tuple(gqa_1_kernel, gqa_2_kernel); -+ default: throw std::runtime_error("unsupported dtype, only fp32 and fp16 are supported"); -+ } ++run_tests(){ ++ local service_url=$1 ++ local mode=$2 ++ python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE} +} + -+void paged_attention_gqa( -+ torch::Tensor output, -+ torch::Tensor query, -+ torch::Tensor key_cache, -+ torch::Tensor value_cache, -+ int64_t bsz, -+ int64_t num_heads, -+ int64_t num_kv_heads, -+ float scale, -+ torch::Tensor& block_tables, -+ torch::Tensor& context_lens, -+ int block_size, -+ int64_t head_dim, -+ int max_seq_len -+) { -+ constexpr int VS = 32; -+ constexpr int GS = 32; -+ -+ const int row_block_num = (max_seq_len + VS - 1) / VS; -+ auto o_a_s = torch::empty({bsz, num_heads, 1, row_block_num * 2}, -+ torch::device(query.device()).dtype(torch::kFloat32)); -+ auto o_accs = torch::empty({bsz, num_heads, 1, row_block_num * head_dim}, -+ torch::device(query.device()).dtype(query.dtype())); -+ -+ auto [func1, func2] = [&](){ -+ switch (head_dim) { -+ case 128: return dispatch_gqa_kernel(query.scalar_type()); -+ case 96: return dispatch_gqa_kernel(query.scalar_type()); -+ case 80: return dispatch_gqa_kernel(query.scalar_type()); -+ case 64: return dispatch_gqa_kernel(query.scalar_type()); -+ default: throw std::runtime_error("unsupported head_dim, only 128, 96, 80 and 64 are supported"); -+ } -+ }(); -+ -+ func1( -+ query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), -+ block_tables.data_ptr(), context_lens.data_ptr(), o_a_s.data_ptr(), o_accs.data_ptr(), -+ query.stride(0), query.stride(1), key_cache.stride(0), key_cache.stride(1), key_cache.stride(2), block_tables.stride(0), -+ o_a_s.stride(0), o_a_s.stride(1), o_accs.stride(0), o_accs.stride(1), -+ scale, block_size, bsz, num_heads, num_kv_heads, row_block_num, -+ query.device() -+ ); -+ -+ func2( -+ o_a_s.data_ptr(), o_accs.data_ptr(), output.data_ptr(), context_lens.data_ptr(), -+ o_a_s.stride(0), o_a_s.stride(1), -+ o_accs.stride(0), o_accs.stride(1), -+ output.stride(0), output.stride(1), -+ bsz, num_heads, row_block_num, -+ query.device() -+ ); -+} -diff --git a/csrc/xpu/attention_xpu_fp8.cpp b/csrc/xpu/attention_xpu_fp8.cpp -new file mode 100644 -index 000000000..a2ea5819b ---- /dev/null -+++ b/csrc/xpu/attention_xpu_fp8.cpp -@@ -0,0 +1,324 @@ -+// clang-format off -+#ifdef VLLM_DEV -+#undef __SYCL_DEVICE_ONLY__ -+#endif -+#include -+#include -+#include -+#include "kv.h" -+ -+// clang-format on -+#include -+#include -+#include -+#include "utils.h" -+#include "xpu_types.h" -+// #include "dtype_bfloat16.dp.hpp" -+#include "dtype_float16.h" -+#include "dtype_float32.h" -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+#include -+#endif -+ -+#include -+// #include -+ -+using namespace sycl::ext::intel::esimd; -+using AT = at::ScalarType; -+ -+template -+void gqa_1_kernel_fp8( -+ const void* query, // [num_seqs, num_heads, head_size] -+ const void* key, // [num_blocks, num_kv_heads, head_size, block_size] -+ const void* value, // [num_blocks, num_kv_heads, head_size, block_size] -+ const void* block_tables, // [num_seqs, max_num_blocks_per_seq] -+ const void* context_lens, // [num_seqs] -+ void* o_a_s, void* o_accs, const int64_t query_bsz_stride, -+ const int64_t query_head_stride, const int64_t kv_token_stride, -+ const int64_t kv_head_stride, const int64_t kv_block_stride, -+ const int64_t block_table_stride_batch, const int64_t o_a_s_bsz_stride, -+ const int64_t o_a_s_head_stride, const int64_t o_accs_bsz_stride, -+ const int64_t o_accs_head_stride, const float scale, const int block_size, -+ const int bsz, const int num_heads, const int num_kv_heads, -+ const int block_num, const at::Device& device) { -+ const int group_size = num_heads / num_kv_heads; -+ const int sub_rows = VS / group_size; -+ const int rem_rows = VS % group_size; -+ -+ const float attn_scale = scale; -+ -+ sycl::range<3> global_size(bsz, num_heads, block_num); -+ sycl::range<3> local_size(1, group_size, 1); -+ -+ auto cgf = [&](sycl::handler& handle) { -+ handle.parallel_for( -+ sycl::nd_range<3>(global_size, local_size), -+ [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { -+ slm_init(); -+ -+ const int bsz_idx = item.get_global_id(0); -+ const int head_idx = item.get_global_id(1); -+ const int kv_head_idx = item.get_group(1); -+ const int tid = item.get_local_id(1); -+ const int vid = item.get_global_id(2); -+ -+ const IT* query_head = (const IT*)query + bsz_idx * query_bsz_stride + -+ head_idx * query_head_stride; -+ -+ IT* o_accs_head = (IT*)o_accs + bsz_idx * o_accs_bsz_stride + -+ head_idx * o_accs_head_stride; -+ float* o_a_s_head = (float*)o_a_s + bsz_idx * o_a_s_bsz_stride + -+ head_idx * o_a_s_head_stride; -+ -+ const int* block_tables_ptr = (const int*)block_tables; -+ const int* block_table = -+ block_tables_ptr + bsz_idx * block_table_stride_batch; -+ -+ const int* context_lens_ptr = (const int*)context_lens; -+ const int context_length = context_lens_ptr[bsz_idx]; -+ -+ simd query_row = block_load(query_head) * attn_scale; -+ -+ // copy k_cache to slm -+ int start_row = -+ std::min(vid * VS + tid * sub_rows + std::min(tid, rem_rows), -+ context_length); -+ int end_row = -+ std::min(start_row + sub_rows + (tid < rem_rows), context_length); -+ for (int r = start_row; r < end_row; ++r) { -+ int which_block = r / block_size; -+ int which_slot = r % block_size; -+ int physical_block_number = block_table[which_block]; -+ -+ // Load elements in uint8_t -+ const uint8_t* key_head = -+ (const uint8_t*)key + physical_block_number * kv_token_stride + -+ kv_head_idx * kv_head_stride + which_slot * kv_block_stride; -+ -+ simd key_row = block_load(key_head); -+ simd key_dequantized = dequantize_key_row(key_row); -+ slm_block_store((r - vid * VS) * HD * sizeof(IT), key_dequantized); -+ } -+ barrier(); -+ -+ simd attns = -sycl::detail::max_v(); -+ int row_num = -+ (vid + 1) * VS > context_length ? context_length % VS : VS; -+ // q @ k -+ for (int r = 0; r < row_num; ++r) { -+ simd key_row = slm_block_load(r * HD * sizeof(IT)); -+ float attn = sycl::ext::intel::esimd::detail::sum( -+ query_row * key_row); -+ attns[r] = attn; -+ } -+ -+ float max_attn = hmax(attns); -+ const simd attn_exp = exp(attns - max_attn); -+ barrier(); -+ -+ // copy v_cache to slm -+ for (int r = start_row; r < end_row; ++r) { -+ int which_block = r / block_size; -+ int which_slot = r % block_size; -+ int physical_block_number = block_table[which_block]; -+ -+ const uint8_t* value_head = -+ (const uint8_t*)value + physical_block_number * kv_token_stride + -+ kv_head_idx * kv_head_stride + which_slot * kv_block_stride; -+ -+ simd value_row = block_load(value_head); -+ simd value_dequantized = dequantize_value_row(value_row); -+ slm_block_store((r - vid * VS) * HD * sizeof(IT), -+ value_dequantized); -+ } -+ barrier(); -+ -+ // attn @ v -+ simd accs = 0; -+ for (int r = 0; r < row_num; ++r) { -+ simd value_row = -+ slm_block_load(r * HD * sizeof(IT)); -+ accs = accs + value_row * attn_exp[r]; -+ } -+ -+ float softmax = -+ sycl::ext::intel::esimd::detail::sum(attn_exp); -+ -+ block_store(o_accs_head + vid * HD, accs); -+ block_store(o_a_s_head + vid * 2, max_attn); -+ block_store(o_a_s_head + vid * 2 + 1, softmax); -+ }); -+ }; -+ -+ utils::submit_kernel(cgf, device, "gqa kernel 1/2"); -+} + -+template -+void gqa_2_kernel_fp8(void* o_a_s, void* o_accs, void* output, -+ const void* context_lens, // [num_seqs] -+ const int64_t o_a_s_bsz_stride, -+ const int64_t o_a_s_head_stride, -+ const int64_t o_accs_bsz_stride, -+ const int64_t o_accs_head_stride, -+ const int64_t output_bsz_stride, -+ const int64_t output_head_stride, const int bsz, -+ const int num_heads, const int row_block_num, -+ const at::Device& device) { -+ constexpr int SUB_HD = 8; -+ static_assert(HD % SUB_HD == 0); -+ static_assert(HD / SUB_HD <= GS); -+ -+ const int sub_rows = row_block_num / GS; -+ const int rem_rows = row_block_num % GS; -+ -+ constexpr int accs_slm_offset = 0; -+ constexpr int attn_slm_offset = GS * HD * sizeof(float); -+ constexpr int softmax_slm_offset = attn_slm_offset + GS * sizeof(float); -+ -+ sycl::range<3> global_size(bsz, num_heads, GS); -+ sycl::range<3> local_size(1, 1, GS); -+ -+ auto cgf = [&](sycl::handler& handle) { -+ handle.parallel_for( -+ sycl::nd_range<3>(global_size, local_size), -+ [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { -+ slm_init(); -+ -+ const int bsz_idx = item.get_global_id(0); -+ const int head_idx = item.get_global_id(1); -+ const int tid = item.get_global_id(2); -+ -+ const int* context_lens_ptr = (const int*)context_lens; -+ const int context_length = context_lens_ptr[bsz_idx]; -+ constexpr int VS = 32; -+ const int cur_row_block_num = (context_length + VS - 1) / VS; -+ const int cur_sub_rows = cur_row_block_num / GS; -+ const int cur_rem_rows = cur_row_block_num % GS; -+ -+ const float* o_a_s_head = (const float*)o_a_s + -+ bsz_idx * o_a_s_bsz_stride + -+ head_idx * o_a_s_head_stride; -+ const IT* o_accs_head = (const IT*)o_accs + -+ bsz_idx * o_accs_bsz_stride + -+ head_idx * o_accs_head_stride; -+ IT* output_head = (IT*)output + bsz_idx * output_bsz_stride + -+ head_idx * output_head_stride; -+ -+ int start_row = -+ std::min(tid * cur_sub_rows + std::min(tid, cur_rem_rows), -+ cur_row_block_num); -+ int end_row = -+ std::min(start_row + cur_sub_rows + (tid < cur_rem_rows), -+ cur_row_block_num); -+ -+ float max_attn = -sycl::detail::max_v(); -+ float softmax = 0; -+ simd accs = 0; -+ for (int r = start_row; r < end_row; ++r) { -+ float sub_attn = o_a_s_head[2 * r]; -+ float sub_softmax = o_a_s_head[2 * r + 1]; -+ simd sub_accs = block_load(o_accs_head + r * HD); -+ float new_max_attn = std::max(max_attn, sub_attn); -+ float exp1 = exp(max_attn - new_max_attn); -+ float exp2 = exp(sub_attn - new_max_attn); -+ accs = accs * exp1 + sub_accs * exp2; -+ softmax = softmax * exp1 + sub_softmax * exp2; -+ max_attn = new_max_attn; -+ } -+ -+ slm_block_store(accs_slm_offset + tid * HD * sizeof(float), -+ accs); -+ slm_block_store(attn_slm_offset + tid * sizeof(float), -+ max_attn); -+ slm_block_store(softmax_slm_offset + tid * sizeof(float), -+ softmax); -+ barrier(); -+ -+ if (tid < HD / SUB_HD) { -+ simd max_attns = -+ slm_block_load(attn_slm_offset); -+ const simd scales = -+ exp(max_attns - hmax(max_attns)); -+ simd softmaxs = -+ slm_block_load(softmax_slm_offset); -+ float softmax_sum = -+ sycl::ext::intel::esimd::detail::sum( -+ softmaxs * scales); -+ -+ simd result = 0; -+#pragma unroll -+ for (int r = 0; r < GS; ++r) { -+ simd sub_accs = slm_block_load( -+ accs_slm_offset + (r * HD + tid * SUB_HD) * sizeof(float)); -+ result = result + sub_accs * scales[r]; -+ } -+ result = result / softmax_sum; -+ block_store(output_head + tid * SUB_HD, result); -+ } -+ }); -+ }; -+ -+ utils::submit_kernel(cgf, device, "gqa kernel 2/2"); -+} -+ -+template -+auto dispatch_gqa_kernel_fp8(AT it) { -+ switch (it) { -+ case AT::Float: -+ return std::make_tuple(gqa_1_kernel_fp8, -+ gqa_2_kernel_fp8); -+ case AT::Half: -+ return std::make_tuple(gqa_1_kernel_fp8, -+ gqa_2_kernel_fp8); -+ default: -+ throw std::runtime_error( -+ "unsupported dtype, only fp32 and fp16 are supported"); -+ } -+} -+ -+void paged_attention_gqa_fp8(torch::Tensor output, torch::Tensor query, -+ torch::Tensor key_cache, torch::Tensor value_cache, -+ int64_t bsz, int64_t num_heads, int64_t num_kv_heads, -+ float scale, torch::Tensor& block_tables, -+ torch::Tensor& context_lens, int block_size, -+ int64_t head_dim, int max_seq_len) { -+ constexpr int VS = 32; -+ constexpr int GS = 32; -+ -+ const int row_block_num = (max_seq_len + VS - 1) / VS; -+ auto o_a_s = -+ torch::empty({bsz, num_heads, 1, row_block_num * 2}, -+ torch::device(query.device()).dtype(torch::kFloat32)); -+ auto o_accs = -+ torch::empty({bsz, num_heads, 1, row_block_num * head_dim}, -+ torch::device(query.device()).dtype(query.dtype())); -+ -+ auto [func1, func2] = [&]() { -+ switch (head_dim) { -+ case 128: -+ return dispatch_gqa_kernel_fp8(query.scalar_type()); -+ case 96: -+ return dispatch_gqa_kernel_fp8(query.scalar_type()); -+ case 80: -+ return dispatch_gqa_kernel_fp8(query.scalar_type()); -+ case 64: -+ return dispatch_gqa_kernel_fp8(query.scalar_type()); -+ default: -+ throw std::runtime_error( -+ "unsupported head_dim, only 128, 96, 80 and 64 are supported"); -+ } -+ }(); -+ -+ func1(query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), -+ block_tables.data_ptr(), context_lens.data_ptr(), o_a_s.data_ptr(), -+ o_accs.data_ptr(), query.stride(0), query.stride(1), -+ key_cache.stride(0), key_cache.stride(1), key_cache.stride(2), -+ block_tables.stride(0), o_a_s.stride(0), o_a_s.stride(1), -+ o_accs.stride(0), o_accs.stride(1), scale, block_size, bsz, num_heads, -+ num_kv_heads, row_block_num, query.device()); -+ -+ func2(o_a_s.data_ptr(), o_accs.data_ptr(), output.data_ptr(), -+ context_lens.data_ptr(), o_a_s.stride(0), o_a_s.stride(1), -+ o_accs.stride(0), o_accs.stride(1), output.stride(0), output.stride(1), -+ bsz, num_heads, row_block_num, query.device()); -+} -diff --git a/csrc/xpu/base.hpp b/csrc/xpu/base.hpp -new file mode 100644 -index 000000000..c364c62e6 ---- /dev/null -+++ b/csrc/xpu/base.hpp -@@ -0,0 +1,118 @@ -+#pragma once -+ -+#include -+#include -+ -+#include "common.h" -+ -+using namespace sycl::ext::intel::esimd; -+using fp16 = sycl::half; -+ -+constexpr int QK = 64; -+constexpr int SBS = 4; -+ -+constexpr int BLOCK_SIZES[GGML_TYPE_COUNT] = { -+ [GGML_TYPE_Q4_0] = QK / 2, -+ [GGML_TYPE_Q4_0_WOQ] = QK / 2, -+ [GGML_TYPE_FP8E5] = QK, -+}; -+ -+constexpr int SCALE_SIZES[GGML_TYPE_COUNT] = { -+ [GGML_TYPE_Q4_0] = sizeof(fp16), -+ [GGML_TYPE_Q4_0_WOQ] = sizeof(fp16), -+ [GGML_TYPE_FP8E5] = 0, -+}; -+ -+template -+ESIMD_INLINE auto load_qblocks(const uint8_t * weight, const uint8_t * scale); -+ -+template<> -+ESIMD_INLINE auto load_qblocks(const uint8_t * weight, const uint8_t * scale) { -+ constexpr int BLOCK_SIZE = BLOCK_SIZES[GGML_TYPE_Q4_0]; -+ simd ybytes = block_load(weight); -+ const simd scales = block_load((const fp16 *)scale); -+ -+ simd yvs; -+ #pragma unroll -+ for (int i = 0; i < SBS; ++i) { -+ simd uyv; -+ uyv.select(0) = ybytes.template select(i * QK / 2) & (uint8_t)0xF; -+ uyv.select(QK / 2) = ybytes.template select(i * QK / 2) >> (uint8_t)4; -+ yvs.template select(i * QK) = (uyv.bit_cast_view() - (int8_t)8) * scales[i]; -+ } -+ return yvs; -+} -+ -+template<> -+ESIMD_INLINE auto load_qblocks(const uint8_t * weight, const uint8_t * scale) { -+ constexpr int BLOCK_SIZE = BLOCK_SIZES[GGML_TYPE_Q4_0_WOQ]; -+ simd ybytes = block_load(weight); -+ const simd scales = block_load((const fp16 *)scale); -+ -+ simd yvs; -+ #pragma unroll -+ for (int i = 0; i < SBS; ++i) { -+ simd uyv; -+ uyv.select(0) = ybytes.template select(i * QK / 2) & (uint8_t)0xF; -+ uyv.select(1) = ybytes.template select(i * QK / 2) >> (uint8_t)4; -+ yvs.template select(i * QK) = (uyv.bit_cast_view() - (int8_t)8) * scales[i]; -+ } -+ return yvs; -+} ++# run non-disagg. baseline & save outputs ++launch_baseline ++run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline" ++cleanup ++sleep 10 + + -+template<> -+ESIMD_INLINE auto load_qblocks(const uint8_t * weight, const uint8_t * scale) { -+ constexpr int BLOCK_SIZE = BLOCK_SIZES[GGML_TYPE_FP8E5]; -+ simd ybytes = block_load(weight); ++# run disagg. & do exact-match with the outputs from baseline ++launch_pd ++launch_pd_proxy ++run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg" ++echo "-----P/D success----" + -+ simd yvs; -+ yvs.template bit_cast_view().template select(0) = 0x80; -+ yvs.template bit_cast_view().template select(1) = ybytes; -+ return yvs; -+} ++rm ${OUTPUT_FILE} ++cleanup + ++exit 0 +diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py +index c2868c040..f6627808c 100644 +--- a/vllm/_ipex_ops.py ++++ b/vllm/_ipex_ops.py +@@ -207,6 +207,12 @@ class ipex_ops: + is_causal, return_softmax, + gen_) + else: # XPU build ++ if max_seqlen_q is None: ++ assert seqlen_q is not None ++ max_seqlen_q = int((seqlen_q[1:] - seqlen_q[:-1]).max().item()) ++ if max_seqlen_k is None: ++ assert seqlen_k is not None ++ max_seqlen_k = int((seqlen_k[1:] - seqlen_k[:-1]).max().item()) + ipex.llm.functional.varlen_attention( + query.contiguous(), key.contiguous(), value.contiguous(), out, + seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q, +@@ -300,6 +306,7 @@ class ipex_ops: + causal, + block_table, + alibi_slopes, ++ sink=s_aux, + softcap=softcap, + window_size_left=real_window_size[0], + window_size_right=real_window_size[1], +diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py +index bb05b468f..f1d657315 100644 +--- a/vllm/attention/layer.py ++++ b/vllm/attention/layer.py +@@ -23,6 +23,7 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod ++from vllm.model_executor.models.vision import get_vit_attn_backend + from vllm.platforms import _Backend, current_platform + from vllm.utils import direct_register_custom_op + +@@ -30,6 +31,15 @@ logger = init_logger(__name__) + USE_XFORMERS_OPS = None + + + -+// C++ doesn't support function template partial specialization, so write a new version for SBS=1 -+template -+ESIMD_INLINE auto load_qblock(const uint8_t * weight, const uint8_t * scale); ++def check_upstream_fa_availability(dtype: torch.dtype): ++ if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda( ++ ) and current_platform.has_device_capability(80): ++ from transformers.utils import is_flash_attn_2_available ++ return is_flash_attn_2_available() ++ return False + -+template<> -+ESIMD_INLINE auto load_qblock(const uint8_t * weight, const uint8_t * scale) { -+ constexpr int BLOCK_SIZE = BLOCK_SIZES[GGML_TYPE_Q4_0]; -+ simd ybytes = block_load(weight); -+ fp16 scales = *(const fp16 *)scale; + -+ simd uyv; -+ uyv.select(0) = ybytes & (uint8_t)0xF; -+ uyv.select(QK / 2) = ybytes >> (uint8_t)4; -+ simd yv = (uyv.bit_cast_view() - (int8_t)8) * scales; + def check_xformers_availability(): + global USE_XFORMERS_OPS + if USE_XFORMERS_OPS is not None: +@@ -349,29 +359,55 @@ class MultiHeadAttention(nn.Module): + f"divisible by num_kv_heads ({self.num_kv_heads})" + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + ++ # During model initialization, the default dtype is set as the model ++ # weight and activation dtype. + dtype = torch.get_default_dtype() +- attn_backend = get_attn_backend(head_size, +- dtype, +- kv_cache_dtype=None, +- block_size=16, +- is_attention_free=False) +- backend = backend_name_to_enum(attn_backend.get_name()) ++ ++ # Determine the attention backend ++ backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) ++ ++ # Some auto-selected backends can be upgraded ++ # to upstream flash attention if available. ++ # If vllm native fa is selected, we use it directly. ++ use_upstream_fa = False ++ if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( ++ dtype): ++ backend = _Backend.FLASH_ATTN ++ use_upstream_fa = True ++ + if current_platform.is_rocm(): + # currently, only torch_sdpa is supported on rocm + self.attn_backend = _Backend.TORCH_SDPA + else: + -+ return yv; -+} + self.attn_backend = backend if backend in { + _Backend.TORCH_SDPA, + _Backend.TORCH_SDPA_VLLM_V1, + _Backend.XFORMERS, + _Backend.PALLAS_VLLM_V1, + _Backend.ROCM_AITER_FA, +- } else current_platform.get_vit_attn_backend() ++ _Backend.FLASH_ATTN, ++ _Backend.FLASH_ATTN_VLLM_V1, ++ } else _Backend.TORCH_SDPA + + if (self.attn_backend == _Backend.XFORMERS + and not check_xformers_availability()): + self.attn_backend = _Backend.TORCH_SDPA + ++ if self.attn_backend in { ++ _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1 ++ }: ++ if use_upstream_fa: ++ from flash_attn import flash_attn_varlen_func ++ self._flash_attn_varlen_func = flash_attn_varlen_func ++ else: ++ from vllm.vllm_flash_attn import flash_attn_varlen_func ++ self._flash_attn_varlen_func = flash_attn_varlen_func + -+template<> -+ESIMD_INLINE auto load_qblock(const uint8_t * weight, const uint8_t * scale) { -+ constexpr int BLOCK_SIZE = BLOCK_SIZES[GGML_TYPE_Q4_0_WOQ]; -+ simd ybytes = block_load(weight); -+ fp16 scales = *(const fp16 *)scale; ++ logger.info_once( ++ f"MultiHeadAttention attn_backend: {self.attn_backend}, " ++ f"use_upstream_fa: {use_upstream_fa}") + -+ simd uyv; -+ uyv.select(0) = ybytes & (uint8_t)0xF; -+ uyv.select(1) = ybytes >> (uint8_t)4; -+ simd yv = (uyv.bit_cast_view() - (int8_t)8) * scales; + def forward( + self, + query: torch.Tensor, +@@ -380,7 +416,7 @@ class MultiHeadAttention(nn.Module): + ) -> torch.Tensor: + """Input shape: batch_size x seq_len x hidden_size""" + # TODO(Isotr0py): Use existing backend implementations and support FA3 +- bsz, q_len, _ = query.size() ++ bsz, q_len = query.size()[:2] + kv_len = key.size(1) + + query = query.view(bsz, q_len, self.num_heads, self.head_size) +@@ -392,7 +428,31 @@ class MultiHeadAttention(nn.Module): + key = torch.repeat_interleave(key, num_repeat, dim=2) + value = torch.repeat_interleave(value, num_repeat, dim=2) + +- if self.attn_backend == _Backend.XFORMERS: ++ if self.attn_backend in { ++ _Backend.FLASH_ATTN, ++ _Backend.FLASH_ATTN_VLLM_V1, ++ }: ++ ++ cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, ++ step=q_len, ++ dtype=torch.int32, ++ device=query.device) ++ cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, ++ step=kv_len, ++ dtype=torch.int32, ++ device=key.device) + -+ return yv; -+} ++ out = self._flash_attn_varlen_func( ++ query.flatten(0, 1), ++ key.flatten(0, 1), ++ value.flatten(0, 1), ++ cu_seqlens_q=cu_seqlens_q, ++ cu_seqlens_k=cu_seqlens_k, ++ max_seqlen_q=q_len, ++ max_seqlen_k=kv_len, ++ softmax_scale=self.scale, ++ ) ++ elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query, +@@ -400,7 +460,8 @@ class MultiHeadAttention(nn.Module): + value, + scale=self.scale) + elif (self.attn_backend == _Backend.TORCH_SDPA +- or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1): ++ or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1 ++ or self.attn_backend == _Backend.IPEX): + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + out = F.scaled_dot_product_attention(query, +diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py +index a98eb2a78..14095ca4d 100644 +--- a/vllm/benchmarks/serve.py ++++ b/vllm/benchmarks/serve.py +@@ -430,7 +430,8 @@ async def benchmark( + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0].prompt, + input_requests[0].prompt_len, +- input_requests[0].expected_output_len, ++ #input_requests[0].expected_output_len, ++ 10, + input_requests[0].multi_modal_data, + ) + +diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py +index 067315deb..b236bae26 100644 +--- a/vllm/distributed/device_communicators/xpu_communicator.py ++++ b/vllm/distributed/device_communicators/xpu_communicator.py +@@ -25,6 +25,12 @@ class XpuCommunicator(DeviceCommunicatorBase): + super().__init__(cpu_group, device, device_group, unique_name) + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND ++ if all2all_backend != "naive": ++ logger.warning( ++ "`%s` all2all manager is not supported on XPU." ++ "Falling back to `naive` all2all manager for XPU.", ++ all2all_backend) ++ all2all_backend = "naive" + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) +@@ -67,3 +73,16 @@ class XpuCommunicator(DeviceCommunicatorBase): + + def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: + dist.broadcast(input_, src=src, group=self.device_group) ++ ++ def dispatch( ++ self, hidden_states: torch.Tensor, ++ router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ++ assert self.all2all_manager is not None ++ hidden_states, router_logits = self.all2all_manager.dispatch( ++ hidden_states, router_logits) ++ return hidden_states, router_logits ++ ++ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: ++ assert self.all2all_manager is not None ++ hidden_states = self.all2all_manager.combine(hidden_states) ++ return hidden_states +diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py +index b53dbfb3a..48d205856 100644 +--- a/vllm/entrypoints/chat_utils.py ++++ b/vllm/entrypoints/chat_utils.py +@@ -431,6 +431,51 @@ def resolve_mistral_chat_template( + return None + + ++_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]() ++""" ++Used in `_try_get_processor_chat_template` to avoid calling ++`cached_get_processor` again if the processor fails to be loaded. + ++This is needed because `lru_cache` does not cache when an exception happens. ++""" + -+template<> -+ESIMD_INLINE auto load_qblock(const uint8_t * weight, const uint8_t * scale) { -+ constexpr int BLOCK_SIZE = BLOCK_SIZES[GGML_TYPE_FP8E5]; -+ simd ybytes = block_load(weight); + -+ simd yvs; -+ yvs.template bit_cast_view().template select(0) = 0x80; -+ yvs.template bit_cast_view().template select(1) = ybytes; -+ return yvs; -+} -diff --git a/csrc/xpu/cache_ops_xpu.cpp b/csrc/xpu/cache_ops_xpu.cpp -new file mode 100644 -index 000000000..a3451c0e7 ---- /dev/null -+++ b/csrc/xpu/cache_ops_xpu.cpp -@@ -0,0 +1,579 @@ -+// clang-format off -+#ifdef VLLM_DEV -+#undef __SYCL_DEVICE_ONLY__ -+#endif -+#include -+#include -+#include -+// clang-format on -+#include "xpu_types.h" -+ -+#include -+#include "utils.h" -+ -+using fp16 = sycl::half; -+using namespace sycl::ext::intel::esimd; -+ -+template -+void reshape_and_cache_kernel( -+ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] -+ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] -+ scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, -+ // block_size, x] -+ scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, -+ // block_size] -+ const int64_t* __restrict__ slot_mapping, // [num_tokens] -+ const int key_stride, -+ const int value_stride, -+ const int num_heads, -+ const int head_size, -+ const int block_size, -+ const int x, -+ const sycl::nd_item<3>& item_ct1) { -+ const int64_t token_idx = item_ct1.get_group(2); -+ const int64_t slot_idx = slot_mapping[token_idx]; -+ if (slot_idx < 0) { -+ // Padding token that should be ignored. -+ return; -+ } -+ -+ const int64_t block_idx = slot_idx / block_size; -+ const int64_t block_offset = slot_idx % block_size; -+ -+ const int n = num_heads * head_size; -+ for (int i = item_ct1.get_local_id(2); i < n; -+ i += item_ct1.get_local_range(2)) { -+ const int64_t src_key_idx = token_idx * key_stride + i; -+ const int64_t src_value_idx = token_idx * value_stride + i; -+ -+ const int head_idx = i / head_size; -+ const int head_offset = i % head_size; -+ const int x_idx = head_offset / x; -+ const int x_offset = head_offset % x; -+ -+ const int64_t tgt_key_idx = -+ block_idx * num_heads * (head_size / x) * block_size * x + -+ head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + -+ block_offset * x + x_offset; -+ const int64_t tgt_value_idx = -+ block_idx * num_heads * head_size * block_size + -+ head_idx * head_size * block_size + head_offset * block_size + -+ block_offset; -+ key_cache[tgt_key_idx] = key[src_key_idx]; -+ value_cache[tgt_value_idx] = value[src_value_idx]; -+ } -+} ++def _try_get_processor_chat_template( ++ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], ++ model_config: ModelConfig, ++) -> Optional[str]: ++ cache_key = (tokenizer.name_or_path, model_config.trust_remote_code) ++ if cache_key in _PROCESSOR_CHAT_TEMPLATES: ++ return _PROCESSOR_CHAT_TEMPLATES[cache_key] + -+template -+void call_reshape_and_cache_kernel( -+ const scalar_t* __restrict__ key, -+ const scalar_t* __restrict__ value, -+ scalar_t* __restrict__ key_cache, -+ scalar_t* __restrict__ value_cache, -+ const int64_t* __restrict__ slot_mapping, -+ const int num_tokens, -+ const int key_stride, -+ const int value_stride, -+ const int num_heads, -+ const int head_size, -+ const int block_size, -+ const int x) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(num_heads * head_size, 512)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ reshape_and_cache_kernel( -+ (const sycl_t* __restrict__)key, -+ (const sycl_t* __restrict__)value, -+ (sycl_t* __restrict__)key_cache, -+ (sycl_t* __restrict__)value_cache, -+ slot_mapping, -+ key_stride, -+ value_stride, -+ num_heads, -+ head_size, -+ block_size, -+ x, -+ item_ct1); -+ }); -+ }); -+} ++ try: ++ processor = cached_get_processor( ++ tokenizer.name_or_path, ++ processor_cls=( ++ PreTrainedTokenizer, ++ PreTrainedTokenizerFast, ++ ProcessorMixin, ++ ), ++ trust_remote_code=model_config.trust_remote_code, ++ ) ++ if ( ++ isinstance(processor, ProcessorMixin) ++ and hasattr(processor, "chat_template") ++ and (chat_template := processor.chat_template) is not None ++ ): ++ _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template ++ return chat_template ++ except Exception: ++ logger.debug( ++ "Failed to load AutoProcessor chat template for %s", ++ tokenizer.name_or_path, ++ exc_info=True, ++ ) + -+void reshape_and_cache( -+ torch::Tensor& key, -+ torch::Tensor& value, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ torch::Tensor& slot_mapping, -+ const std::string& kv_cache_dtype, -+ const float kv_scale) { -+ int num_tokens = key.size(0); -+ int num_heads = key.size(1); -+ int head_size = key.size(2); -+ int block_size = key_cache.size(3); -+ int x = key_cache.size(4); -+ -+ int key_stride = key.stride(0); -+ int value_stride = value.stride(0); -+ -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ key.scalar_type(), "call_reshape_and_cache_kernel", [&] { -+ call_reshape_and_cache_kernel( -+ key.data_ptr(), -+ value.data_ptr(), -+ key_cache.data_ptr(), -+ value_cache.data_ptr(), -+ slot_mapping.data_ptr(), -+ num_tokens, -+ key_stride, -+ value_stride, -+ num_heads, -+ head_size, -+ block_size, -+ x); -+ }); -+} ++ _PROCESSOR_CHAT_TEMPLATES[cache_key] = None ++ return None ++ ++ + def resolve_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + chat_template: Optional[str], +@@ -444,28 +489,10 @@ def resolve_hf_chat_template( + + # 2nd priority: AutoProcessor chat template, unless tool calling is enabled + if tools is None: +- try: +- processor = cached_get_processor( +- tokenizer.name_or_path, +- processor_cls=( +- PreTrainedTokenizer, +- PreTrainedTokenizerFast, +- ProcessorMixin, +- ), +- trust_remote_code=model_config.trust_remote_code, +- ) +- if ( +- isinstance(processor, ProcessorMixin) +- and hasattr(processor, "chat_template") +- and processor.chat_template is not None +- ): +- return processor.chat_template +- except Exception: +- logger.debug( +- "Failed to load AutoProcessor chat template for %s", +- tokenizer.name_or_path, +- exc_info=True, +- ) # noqa: E501 ++ chat_template = _try_get_processor_chat_template(tokenizer, ++ model_config) ++ if chat_template is not None: ++ return chat_template + + # 3rd priority: AutoTokenizer chat template + try: +diff --git a/vllm/envs.py b/vllm/envs.py +index ac770ac4c..487fdcbfa 100755 +--- a/vllm/envs.py ++++ b/vllm/envs.py +@@ -70,7 +70,6 @@ if TYPE_CHECKING: + VLLM_VIDEO_LOADER_BACKEND: str = "opencv" + VLLM_MM_INPUT_CACHE_GIB: int = 4 + VLLM_TARGET_DEVICE: str = "cuda" +- VLLM_MAIN_CUDA_VERSION: str = "12.8" + MAX_JOBS: Optional[str] = None + NVCC_THREADS: Optional[str] = None + VLLM_USE_PRECOMPILED: bool = False +@@ -176,6 +175,8 @@ if TYPE_CHECKING: + VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False + VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False + VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True ++ VLLM_XPU_FP8_DTYPE: str = "e5m2" ++ VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT: bool = False + + + def get_default_cache_root(): +@@ -247,11 +248,6 @@ environment_variables: dict[str, Callable[[], Any]] = { + "VLLM_TARGET_DEVICE": + lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), + +- # Main CUDA version of vLLM, supporting [12.6, 12.8, 12.9], +- # 12.8 is the default. This follows PyTorch but can be overridden. +- "VLLM_MAIN_CUDA_VERSION": +- lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() or "12.8", +- + # Maximum number of compilation jobs to run in parallel. + # By default this is the number of CPUs + "MAX_JOBS": +@@ -1247,6 +1243,14 @@ environment_variables: dict[str, Callable[[], Any]] = { + # raw bytes. Defaults to True for backward compatibility. + "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": + lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))), + -+template -+void reshape_and_cache_ipexllm_kernel( -+ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] -+ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] -+ scalar_t* __restrict__ key_cache, // [num_blocks, num_kv_heads, block_size, head_size] -+ scalar_t* __restrict__ value_cache, // [num_blocks, num_kv_heads, block_size, head_size] -+ const int64_t* __restrict__ slot_mapping, // [num_tokens] -+ const int key_stride, -+ const int value_stride, -+ const int num_heads, -+ const int head_size, -+ const int block_size, -+ const int x, -+ const sycl::nd_item<3>& item_ct1) { -+ const int64_t token_idx = item_ct1.get_group(2); -+ const int64_t slot_idx = slot_mapping[token_idx]; -+ if (slot_idx < 0) { -+ // Padding token that should be ignored. -+ return; -+ } -+ -+ const int64_t block_idx = slot_idx / block_size; -+ const int64_t block_offset = slot_idx % block_size; -+ -+ const int n = num_heads * head_size; -+ for (int i = item_ct1.get_local_id(2); i < n; -+ i += item_ct1.get_local_range(2)) { -+ const int64_t src_key_idx = token_idx * key_stride + i; -+ const int64_t src_value_idx = token_idx * value_stride + i; -+ -+ const int head_idx = i / head_size; -+ const int head_offset = i % head_size; -+ -+ // const int64_t tgt_key_idx = -+ // block_idx * num_heads * (head_size / x) * block_size * x + -+ // head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + -+ // block_offset * x + x_offset; -+ -+ // const int64_t tgt_value_idx = -+ // block_idx * num_heads * head_size * block_size + -+ // head_idx * head_size * block_size + head_offset * block_size + -+ // block_offset; -+ -+ const int64_t tgt_value_idx = -+ block_idx * num_heads * head_size * block_size + -+ head_idx * head_size * block_size + -+ block_offset * head_size + -+ head_offset; -+ const int64_t tgt_key_idx = tgt_value_idx; -+ key_cache[tgt_key_idx] = key[src_key_idx]; -+ value_cache[tgt_value_idx] = value[src_value_idx]; -+ } -+} ++ # fp8 dtype for XPU platform ++ "VLLM_XPU_FP8_DTYPE": ++ lambda: os.environ.get("VLLM_XPU_FP8_DTYPE", "e5m2"), + -+template -+void call_reshape_and_cache_ipexllm_kernel( -+ const scalar_t* __restrict__ key, -+ const scalar_t* __restrict__ value, -+ scalar_t* __restrict__ key_cache, -+ scalar_t* __restrict__ value_cache, -+ const int64_t* __restrict__ slot_mapping, -+ const int num_tokens, -+ const int key_stride, -+ const int value_stride, -+ const int num_heads, -+ const int head_size, -+ const int block_size, -+ const int x) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(num_heads * head_size, 512)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ reshape_and_cache_ipexllm_kernel( -+ (const sycl_t* __restrict__)key, -+ (const sycl_t* __restrict__)value, -+ (sycl_t* __restrict__)key_cache, -+ (sycl_t* __restrict__)value_cache, -+ slot_mapping, -+ key_stride, -+ value_stride, -+ num_heads, -+ head_size, -+ block_size, -+ x, -+ item_ct1); -+ }); -+ }); -+} ++ # Offload model weights to cpu before online fp8 quantization ++ "VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT": ++ lambda: os.environ.get("VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT", "0") == "1", + } + + # --8<-- [end:env-vars-definition] +diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py +index a90a71159..5638da392 100644 +--- a/vllm/model_executor/layers/fused_moe/layer.py ++++ b/vllm/model_executor/layers/fused_moe/layer.py +@@ -601,7 +601,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + logical_replica_count is not None: + raise NotImplementedError("Expert load balancing is not supported " + "for XPU.") +- assert custom_routing_function is None + return layer.ipex_fusion( + x, + use_grouped_topk, +@@ -610,6 +609,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + renormalize, + topk_group, + num_expert_group, ++ custom_routing_function=custom_routing_function + ) + + def forward_tpu( +diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py +index 3d94626e5..72c77e15c 100644 +--- a/vllm/model_executor/layers/quantization/fp8.py ++++ b/vllm/model_executor/layers/quantization/fp8.py +@@ -309,10 +309,14 @@ class Fp8LinearMethod(LinearMethodBase): + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + ++ # Force offloading weights to cpu if VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT ++ # enabled, otherwise use original device config which can be gpu or cpu ++ # (may happen when cpu_offload_gb > 0) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, +- dtype=weight_dtype), ++ dtype=weight_dtype, ++ device="cpu" if envs.VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT else None), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) +@@ -631,8 +635,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, +- dtype=params_dtype), ++ dtype=params_dtype, ++ device="cpu" if envs.VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT else None), + requires_grad=False) + -+void reshape_and_cache_ipexllm( -+ torch::Tensor& key, -+ torch::Tensor& value, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ torch::Tensor& slot_mapping, -+ const std::string& kv_cache_dtype, -+ const float kv_scale) { -+ int num_tokens = key.size(0); -+ int num_heads = key.size(1); -+ int head_size = key.size(2); -+ int block_size = key_cache.size(2); -+ // int x = key_cache.size(4); -+ int x = 1; -+ -+ int key_stride = key.stride(0); -+ int value_stride = value.stride(0); -+ -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel", [&] { -+ call_reshape_and_cache_ipexllm_kernel( -+ key.data_ptr(), -+ value.data_ptr(), -+ key_cache.data_ptr(), -+ value_cache.data_ptr(), -+ slot_mapping.data_ptr(), -+ num_tokens, -+ key_stride, -+ value_stride, -+ num_heads, -+ head_size, -+ block_size, -+ x); -+ }); -+} + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + +@@ -640,7 +646,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): + num_experts, + hidden_size, + intermediate_size_per_partition, +- dtype=params_dtype), ++ dtype=params_dtype, ++ device="cpu" if envs.VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT else None), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) +diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py +index 5f9d48142..6364d5cf5 100644 +--- a/vllm/model_executor/layers/quantization/ipex_quant.py ++++ b/vllm/model_executor/layers/quantization/ipex_quant.py +@@ -9,6 +9,7 @@ from torch.nn import Module + from torch.nn.parameter import Parameter + + from vllm._ipex_ops import ipex_ops as ops ++import vllm.envs as envs + from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, + FusedMoeWeightScaleSupported) + from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, +@@ -45,6 +46,7 @@ class IPEXConfig(QuantizationConfig): + modules_to_not_convert: Optional[list[str]] = None, + desc_act: Optional[bool] = None, + lm_head_quantized: Optional[bool] = None, ++ is_qweight_sym: Optional[bool] = None, + ) -> None: + super().__init__() + self.method = method +@@ -62,6 +64,7 @@ class IPEXConfig(QuantizationConfig): + if self.method not in ["awq", "gptq"]: + raise ValueError(f"IPEX quantization supports [awq, gptq], " + f"but got {self.method}.") ++ self.is_qweight_sym = is_qweight_sym + + def __repr__(self) -> str: + return (f"IPEXConfig(method={self.method}," +@@ -96,16 +99,18 @@ class IPEXConfig(QuantizationConfig): + ["q_group_size", "group_size"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) ++ is_qweight_sym = not cls.get_from_keys_or(config, ["zero_point"], default=False) + return cls(method, weight_bits, group_size, modules_to_not_convert, +- False, False) ++ False, False, is_qweight_sym) + # otherwise for gptq + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) ++ is_qweight_sym = cls.get_from_keys_or(config, ["sym"], default=True) + return cls(method, weight_bits, group_size, [], desc_act, +- lm_head_quantized) ++ lm_head_quantized, is_qweight_sym) + + @classmethod + def override_quantization_method( +@@ -183,7 +188,8 @@ class IPEXGPTQLinearMethod(GPTQLinearMethod): + g_idx=g_idx, + bias=bias, + group_size=self.quant_config.group_size, +- quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"] ++ quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"], ++ weight_qscheme="sym" if self.quant_config.is_qweight_sym else "asym", + ) + + def apply(self, +@@ -249,7 +255,8 @@ class IPEXAWQLinearMethod(AWQLinearMethod): + qconfig=qconfig, + bias=bias, + group_size=self.quant_config.group_size, +- quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore ++ quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"], ++ weight_qscheme="sym" if self.quant_config.is_qweight_sym else "asym", + ) + + def apply(self, +@@ -302,12 +309,12 @@ class XPUFp8MoEMethod(FusedMoEMethodBase): + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None +- # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, +- dtype=params_dtype), ++ dtype=params_dtype, ++ device="cpu" if envs.VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT else None), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) +@@ -316,7 +323,8 @@ class XPUFp8MoEMethod(FusedMoEMethodBase): + num_experts, + hidden_size, + intermediate_size_per_partition, +- dtype=params_dtype), ++ dtype=params_dtype, ++ device="cpu" if envs.VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT else None), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) +diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py +index f935bdd84..9a80b80e7 100644 +--- a/vllm/model_executor/layers/quantization/mxfp4.py ++++ b/vllm/model_executor/layers/quantization/mxfp4.py +@@ -95,6 +95,9 @@ def get_mxfp4_backend(): + else: + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON ++ elif current_platform.is_xpu(): ++ logger.info_once("Using ipex marlin backend on XPU") ++ return Mxfp4Backend.MARLIN + elif current_platform.is_rocm() and has_triton_kernels(): + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON +@@ -140,7 +143,10 @@ class Mxfp4Config(QuantizationConfig): + return UnquantizedLinearMethod() + raise NotImplementedError("Mxfp4 linear layer is not implemented") + elif isinstance(layer, FusedMoE): +- return Mxfp4MoEMethod(layer.moe_config) ++ if current_platform.is_xpu(): ++ return IpexFp4MoeMethod(layer.moe_config) ++ else: ++ return Mxfp4MoEMethod(layer.moe_config) + elif isinstance(layer, Attention): + raise NotImplementedError( + "Mxfp4 attention layer is not implemented") +@@ -165,6 +171,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): ++ self.original_hidden_size = hidden_size + self.num_experts = num_experts + weight_dtype = torch.uint8 + scale_dtype = torch.uint8 +@@ -192,7 +199,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): + # k = intermediate_size_per_partition_after_pad + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) +- hidden_size = round_up(hidden_size, 256) ++ if current_platform.is_xpu(): ++ hidden_size = round_up(hidden_size, 128) ++ else: ++ hidden_size = round_up(hidden_size, 256) + + layer.params_dtype = params_dtype + layer.num_experts = num_experts +@@ -949,3 +959,63 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): + ) + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") + + -+template -+void copy_blocks_kernel( -+ int64_t* key_cache_ptrs, -+ int64_t* value_cache_ptrs, -+ const int64_t* __restrict__ block_mapping, -+ const int numel_per_block, -+ const sycl::nd_item<3>& item_ct1) { -+ const int layer_idx = item_ct1.get_group(2); -+ const int pair_idx = item_ct1.get_group(1); -+ -+ scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); -+ scalar_t* value_cache = -+ reinterpret_cast(value_cache_ptrs[layer_idx]); -+ int64_t src_block_number = block_mapping[2 * pair_idx]; -+ int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; -+ -+ const int64_t src_block_offset = src_block_number * numel_per_block; -+ const int64_t dst_block_offset = dst_block_number * numel_per_block; -+ for (int i = item_ct1.get_local_id(2); i < numel_per_block; -+ i += item_ct1.get_local_range(2)) { -+ int64_t src_offset = src_block_offset + i; -+ int64_t dst_offset = dst_block_offset + i; -+ key_cache[dst_offset] = key_cache[src_offset]; -+ } -+ for (int i = item_ct1.get_local_id(2); i < numel_per_block; -+ i += item_ct1.get_local_range(2)) { -+ int64_t src_offset = src_block_offset + i; -+ int64_t dst_offset = dst_block_offset + i; -+ value_cache[dst_offset] = value_cache[src_offset]; -+ } -+} ++class IpexFp4MoeMethod(Mxfp4MoEMethod): + -+template -+void call_copy_blocks_kernel( -+ std::vector& key_caches, -+ std::vector& value_caches, -+ const std::map>& block_mapping) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ int num_layers = key_caches.size(); -+ TORCH_CHECK(num_layers == value_caches.size()); -+ if (num_layers == 0) { -+ return; -+ } -+ torch::Device cache_device = key_caches[0].device(); -+ TORCH_CHECK(cache_device.is_xpu()); -+ // Create data structures for the kernel. -+ // Create an array of pointers to the key and value caches. -+ int64_t key_cache_ptrs[num_layers]; -+ int64_t value_cache_ptrs[num_layers]; -+ for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { -+ key_cache_ptrs[layer_idx] = -+ reinterpret_cast(key_caches[layer_idx].data_ptr()); -+ value_cache_ptrs[layer_idx] = -+ reinterpret_cast(value_caches[layer_idx].data_ptr()); -+ } -+ // Create block mapping array. -+ std::vector block_mapping_vec; -+ for (const auto& pair : block_mapping) { -+ int64_t src_block_number = pair.first; -+ for (int64_t dst_block_number : pair.second) { -+ block_mapping_vec.push_back(src_block_number); -+ block_mapping_vec.push_back(dst_block_number); -+ } -+ } -+ int64_t* block_mapping_array = block_mapping_vec.data(); -+ int num_pairs = block_mapping_vec.size() / 2; -+ // Move the data structures to the GPU. -+ // NOTE: This synchronizes the CPU and GPU. -+ torch::Tensor key_cache_ptrs_tensor = -+ torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) -+ .to(cache_device); -+ torch::Tensor value_cache_ptrs_tensor = -+ torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) -+ .to(cache_device); -+ torch::Tensor block_mapping_tensor = -+ torch::from_blob(block_mapping_array, {2 * num_pairs}, torch::kInt64) -+ .to(cache_device); -+ auto k_ptr = key_cache_ptrs_tensor.data_ptr(); -+ auto v_ptr = value_cache_ptrs_tensor.data_ptr(); -+ auto b_ptr = block_mapping_tensor.data_ptr(); -+ // Launch the kernel. -+ const int numel_per_block = key_caches[0][0].numel(); -+ -+ sycl::range<3> grid(1, num_pairs, num_layers); -+ sycl::range<3> block(1, 1, std::min(1024, numel_per_block)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ copy_blocks_kernel( -+ k_ptr, v_ptr, b_ptr, numel_per_block, item_ct1); -+ }); -+ }); -+} ++ def __init__(self, moe_config: FusedMoEConfig): ++ super().__init__(moe_config) ++ self.moe_config = moe_config ++ self.alpha = 1.702 ++ self.limit = 7.0 + -+void copy_blocks( -+ std::vector& key_caches, -+ std::vector& value_caches, -+ const std::map>& block_mapping) { -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ key_caches[0].scalar_type(), "call_copy_blocks_kernel", [&] { -+ call_copy_blocks_kernel( -+ key_caches, value_caches, block_mapping); -+ }); -+} ++ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ++ import intel_extension_for_pytorch as ipex ++ layer.w13_weight.data = layer.w13_weight.data.view(torch.int32) ++ layer.w2_weight.data = layer.w2_weight.data.view(torch.int32) ++ layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( ++ layer.w13_weight, ++ layer.w2_weight, ++ w1_scale_inv=layer.w13_weight_scale, ++ w2_scale_inv=layer.w2_weight_scale, ++ w13_bias=layer.w13_bias, ++ w2_bias=layer.w2_bias, ++ is_mxfp4=True, ++ ) + -+void swap_blocks( -+ torch::Tensor& src, -+ torch::Tensor& dst, -+ const std::map& block_mapping) { -+ char* src_ptr = (char*)src.data_ptr(); -+ char* dst_ptr = (char*)dst.data_ptr(); -+ -+ const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ -+ // NOTE(woosuk): This can be slow if the number of blocks is large. -+ for (const auto& pair : block_mapping) { -+ int64_t src_block_number = pair.first; -+ int64_t dst_block_number = pair.second; -+ int64_t src_offset = src_block_number * block_size_in_bytes; -+ int64_t dst_offset = dst_block_number * block_size_in_bytes; -+ queue.memcpy( -+ dst_ptr + dst_offset, src_ptr + src_offset, block_size_in_bytes); -+ } -+ queue.wait(); -+} ++ def apply( ++ self, ++ layer: torch.nn.Module, ++ x: torch.Tensor, ++ router_logits: torch.Tensor, ++ top_k: int, ++ renormalize: bool, ++ use_grouped_topk: bool = False, ++ topk_group: Optional[int] = None, ++ num_expert_group: Optional[int] = None, ++ global_num_experts: int = -1, ++ expert_map: Optional[torch.Tensor] = None, ++ custom_routing_function: Optional[Callable] = None, ++ scoring_func: str = "softmax", ++ routed_scaling_factor: float = 1.0, ++ e_score_correction_bias: Optional[torch.Tensor] = None, ++ apply_router_weight_on_input: bool = False, ++ activation: str = "silu", ++ enable_eplb: bool = False, ++ expert_load_view: Optional[torch.Tensor] = None, ++ logical_to_physical_map: Optional[torch.Tensor] = None, ++ logical_replica_count: Optional[torch.Tensor] = None, ++ ) -> torch.Tensor: ++ hidden_size_pad = round_up(self.original_hidden_size, 128) ++ x_pad = torch.nn.functional.pad( ++ x, (0, hidden_size_pad - x.size(-1))) ++ hidden_states = layer.ipex_fusion(x_pad, ++ use_grouped_topk, ++ top_k, ++ router_logits, ++ renormalize, ++ topk_group, ++ num_expert_group, ++ activation="swiglu_oai") ++ hidden_states = hidden_states[..., :self.original_hidden_size].contiguous() ++ return hidden_states +diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py +index 564f9a5c0..c9653aa9e 100644 +--- a/vllm/model_executor/layers/rotary_embedding/__init__.py ++++ b/vllm/model_executor/layers/rotary_embedding/__init__.py +@@ -103,6 +103,8 @@ def get_rope( + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], ++ mrope_interleaved=rope_scaling.get("mrope_interleaved", ++ False), + ) + else: + rotary_emb = RotaryEmbedding( +diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +index 7ac2e4bb6..450d0cee1 100644 +--- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py ++++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +@@ -138,3 +138,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) + -+template -+void gather_cached_kv_kernel( -+ scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] -+ scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, -+ // head_size] -+ const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, -+ // head_size/x, block_size, x] -+ const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, -+ // head_size, block_size] -+ const int* __restrict__ slot_mapping, // [num_tokens] -+ const int key_stride, -+ const int value_stride, -+ const int num_heads, -+ const int head_size, -+ const int block_size, -+ const int x, -+ const sycl::nd_item<3>& item_ct1) { -+ const int token_idx = item_ct1.get_group(2); -+ const int slot_idx = slot_mapping[token_idx]; -+ const int block_idx = slot_idx / block_size; -+ const int block_offset = slot_idx % block_size; -+ -+ const int num_tokens = num_heads * head_size; -+ for (int i = item_ct1.get_local_id(2); i < num_tokens; -+ i += item_ct1.get_local_range(2)) { -+ const int tgt_key_idx = token_idx * key_stride + i; -+ const int tgt_value_idx = token_idx * value_stride + i; -+ -+ const int head_idx = i / head_size; -+ const int head_offset = i % head_size; -+ const int x_idx = -+ head_offset / x; // the offset of the [head_size/x] dimension -+ const int x_offset = head_offset % x; -+ -+ // const int src_key_idx = -+ // block_idx * num_heads * (head_size / x) * block_size * x + -+ // head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + -+ // block_offset * x + x_offset; -+ // const int src_value_idx = block_idx * num_heads * head_size * block_size + -+ // head_idx * head_size * block_size + head_offset * block_size + -+ // block_offset; -+ -+ const int src_value_idx = -+ block_idx * num_heads * head_size * block_size + -+ head_idx * head_size * block_size + -+ block_offset * head_size + -+ head_offset; -+ const int src_key_idx = src_value_idx; -+ -+ key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); -+ value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); -+ } -+} ++ def forward_xpu( ++ self, ++ positions: torch.Tensor, ++ query: torch.Tensor, ++ key: Optional[torch.Tensor] = None, ++ offsets: Optional[torch.Tensor] = None, ++ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ++ return self.forward_native(positions, query, key, offsets) +diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py +index 0acb5ea74..c4b8c66eb 100644 +--- a/vllm/model_executor/layers/rotary_embedding/mrope.py ++++ b/vllm/model_executor/layers/rotary_embedding/mrope.py +@@ -177,6 +177,18 @@ def triton_mrope( + return q, k + + ++def apply_interleaved_rope(x: torch.Tensor, ++ mrope_section: list[int]) -> torch.Tensor: ++ """Apply interleaved MRoPE to 3D rotary embeddings. ++ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to ++ interleaved [THTHWHTHW...TT], preserving frequency continuity. ++ """ ++ x_t = x[0].clone() ++ x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3] ++ x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3] ++ return x_t + -+template -+void gather_cached_kv_kernel_optimized( -+ scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] -+ scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, -+ // head_size] -+ const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, -+ // head_size/x, block_size, x] -+ const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, -+ // head_size, block_size] -+ const int* __restrict__ slot_mapping, // [num_tokens] -+ const int key_stride, -+ const int value_stride, -+ const int num_heads, -+ const int head_size, -+ const int block_size, -+ const int x, -+ const sycl::nd_item<3>& item_ct1) { -+ const int token_idx = item_ct1.get_group(2); -+ const int slot_idx = slot_mapping[token_idx]; -+ const int block_idx = slot_idx / block_size; -+ const int block_offset = slot_idx % block_size; -+ -+ const int dim = num_heads * head_size; -+ assert(dim % 4 == 0); // this is true for known use cases -+ const int unroll_factor = 4; -+ const int unrolled_dim = dim / unroll_factor; -+ -+ for (int i = item_ct1.get_local_id(2); i < unrolled_dim; -+ i += item_ct1.get_local_range(2)) { -+ int tgt_key_indices[unroll_factor]; -+ int tgt_value_indices[unroll_factor]; -+ int src_key_indices[unroll_factor]; -+ int src_value_indices[unroll_factor]; -+ scalar_t keys_to_store[unroll_factor]; -+ scalar_t values_to_store[unroll_factor]; -+ -+#pragma unroll -+ for (int j = 0; j < unroll_factor; ++j) { -+ int index = i + j * unrolled_dim; -+ -+ const int tgt_key_idx = token_idx * key_stride + index; -+ const int tgt_value_idx = token_idx * value_stride + index; -+ -+ const int head_idx = index / head_size; -+ const int head_offset = index % head_size; -+ -+ const int src_value_idx = -+ block_idx * num_heads * head_size * block_size + -+ head_idx * head_size * block_size + -+ block_offset * head_size + -+ head_offset; -+ const int src_key_idx = src_value_idx; -+ -+ tgt_key_indices[j] = tgt_key_idx; -+ tgt_value_indices[j] = tgt_value_idx; -+ src_key_indices[j] = src_key_idx; -+ src_value_indices[j] = src_value_idx; -+ -+ keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); -+ values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); -+ } + -+#pragma unroll -+ for (int j = 0; j < unroll_factor; ++j) { -+ key[tgt_key_indices[j]] = keys_to_store[j]; -+ value[tgt_value_indices[j]] = values_to_store[j]; -+ } -+ } -+} + class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + +@@ -189,6 +201,7 @@ class MRotaryEmbedding(RotaryEmbedding): + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[list[int]] = None, ++ mrope_interleaved: Optional[bool] = False, + ) -> None: + # In Qwen2.5-VL, the maximum index value is related to the duration of + # the input video. We enlarge max_position_embeddings to 4 times to get +@@ -198,6 +211,7 @@ class MRotaryEmbedding(RotaryEmbedding): + base, is_neox_style, dtype) + + self.mrope_section = mrope_section ++ self.mrope_interleaved = mrope_interleaved + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + +@@ -225,17 +239,20 @@ class MRotaryEmbedding(RotaryEmbedding): + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section +- +- cos = torch.cat([ +- m[i] +- for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) +- ], +- dim=-1) +- sin = torch.cat([ +- m[i] +- for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) +- ], +- dim=-1) ++ if self.mrope_interleaved: ++ cos = apply_interleaved_rope(cos, self.mrope_section) ++ sin = apply_interleaved_rope(sin, self.mrope_section) ++ else: ++ cos = torch.cat([ ++ m[i] for i, m in enumerate( ++ cos.split(self.mrope_section, dim=-1)) ++ ], ++ dim=-1) ++ sin = torch.cat([ ++ m[i] for i, m in enumerate( ++ sin.split(self.mrope_section, dim=-1)) ++ ], ++ dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) +@@ -265,6 +282,10 @@ class MRotaryEmbedding(RotaryEmbedding): + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + ++ if self.mrope_interleaved: ++ # TODO: add triton implementation to support mrope-interleaved ++ return self.forward_native(positions, query, key) + -+template -+void call_gather_cached_kv_kernel_optimized( -+ torch::Tensor& key, -+ torch::Tensor& value, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ torch::Tensor& slot_mapping) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ int num_tokens = key.size(0); -+ int num_heads = key.size(1); -+ int head_size = key.size(2); -+ int block_size = key_cache.size(2); -+ // int x = key_cache.size(4); -+ int x = 1; -+ -+ int key_stride = key.stride(0); -+ int value_stride = value.stride(0); -+ auto key_ptr = key.data_ptr(); -+ auto value_ptr = value.data_ptr(); -+ auto key_cache_ptr = key_cache.data_ptr(); -+ auto value_cache_ptr = value_cache.data_ptr(); -+ auto slot_mapping_ptr = slot_mapping.data_ptr(); -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(num_heads * head_size, 512)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ gather_cached_kv_kernel_optimized( -+ (sycl_t* __restrict__)key_ptr, -+ (sycl_t* __restrict__)value_ptr, -+ (const sycl_t* __restrict__)key_cache_ptr, -+ (const sycl_t* __restrict__)value_cache_ptr, -+ slot_mapping_ptr, -+ key_stride, -+ value_stride, -+ num_heads, -+ head_size, -+ block_size, -+ x, -+ item_ct1); -+ }); -+ }); -+} + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) +@@ -300,6 +321,15 @@ class MRotaryEmbedding(RotaryEmbedding): + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + ++ def forward_xpu( ++ self, ++ positions: torch.Tensor, ++ query: torch.Tensor, ++ key: Optional[torch.Tensor] = None, ++ offsets: Optional[torch.Tensor] = None, ++ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ++ return self.forward_native(positions, query, key, offsets) + -+void gather_cached_kv( -+ torch::Tensor& key, -+ torch::Tensor& value, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ torch::Tensor& slot_mapping) { -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ key_cache[0].scalar_type(), -+ "call_gather_cached_kv_kernel_optimized", -+ [&] { -+ call_gather_cached_kv_kernel_optimized( -+ key, value, key_cache, value_cache, slot_mapping); -+ }); -+} -diff --git a/csrc/xpu/cache_ops_xpu_fp8.cpp b/csrc/xpu/cache_ops_xpu_fp8.cpp -new file mode 100644 -index 000000000..e4a0001fe ---- /dev/null -+++ b/csrc/xpu/cache_ops_xpu_fp8.cpp -@@ -0,0 +1,170 @@ -+// clang-format off -+#ifdef VLLM_DEV -+#undef __SYCL_DEVICE_ONLY__ -+#endif -+#include -+#include -+#include -+// clang-format on -+#include "xpu_types.h" -+ -+#include -+#include "utils.h" -+#include "kv.h" -+ -+using fp16 = sycl::half; -+using namespace sycl::ext::intel::esimd; -+ -+// scalar_t is key.scalar_type() -> half -+template -+void reshape_and_cache_ipexllm_kernel_fp8( -+ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] -+ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] -+ uint8_t * __restrict__ key_cache, // [num_blocks, num_kv_heads, block_size, -+ // head_size] -+ uint8_t * __restrict__ value_cache, // [num_blocks, num_kv_heads, -+ // block_size, head_size] -+ const int64_t* __restrict__ slot_mapping, // [num_tokens] -+ const int key_stride, const int value_stride, -+ const int key_head_stride, const int value_head_stride, -+ const int num_heads, -+ const int head_size, const int block_size, const int x, -+ const sycl::nd_item<3>& item_ct1) { -+ -+ // New Implementation // -+ const size_t token_idx = item_ct1.get_global_id(0); -+ const size_t head_idx = item_ct1.get_global_id(1); -+ const int64_t slot_idx = slot_mapping[token_idx]; -+ if (slot_idx < 0) { -+ return; -+ } -+ const int64_t block_idx = slot_idx / block_size; -+ const int64_t block_offset = slot_idx % block_size; -+ // The thread is responsible for the HD elements within key/value -+ const scalar_t * key_head = key + token_idx * key_stride + head_idx * key_head_stride; -+ -+ const scalar_t * value_head = value + token_idx * value_stride + head_idx * value_head_stride; -+ -+ uint8_t * key_output_head = key_cache + block_idx * num_heads * head_size * block_size + -+ head_idx * head_size * block_size + block_offset * head_size; -+ uint8_t * value_output_head = value_cache + block_idx * num_heads * head_size * block_size + -+ head_idx * head_size * block_size + block_offset * head_size; -+ -+ simd key_row = block_load(key_head); -+ simd key_result = quantize_key_row(key_row); -+ block_store(key_output_head, key_result); -+ -+ simd value_row = block_load(value_head); -+ simd value_result = quantize_value_row(value_row); -+ block_store(value_output_head, value_result); -+} -+ -+ -+template -+void call_reshape_and_cache_ipexllm_kernel_fp8( -+ const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, -+ uint8_t* __restrict__ key_cache, uint8_t* __restrict__ value_cache, -+ const int64_t* __restrict__ slot_mapping, const int num_tokens, -+ const int key_stride, const int value_stride, -+ const int key_head_stride, const int value_head_stride, -+ const int num_heads, -+ const int head_size, const int block_size, const int x) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ sycl::range<3> grid(num_tokens, num_heads, 1); -+ sycl::range<3> block(1, 1, 1); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) SYCL_ESIMD_KERNEL { -+ reshape_and_cache_ipexllm_kernel_fp8( -+ (const sycl_t* __restrict__)key, -+ (const sycl_t* __restrict__)value, -+ (uint8_t* __restrict__)key_cache, -+ (uint8_t* __restrict__)value_cache, slot_mapping, key_stride, -+ value_stride, key_head_stride, value_head_stride, -+ num_heads, head_size, block_size, x, item_ct1); -+ }); -+ }); -+} -+ -+void reshape_and_cache_ipexllm_fp8(torch::Tensor& key, torch::Tensor& value, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ torch::Tensor& slot_mapping, -+ const std::string& kv_cache_dtype, -+ const float kv_scale) { -+ int num_tokens = key.size(0); -+ int num_heads = key.size(1); -+ int head_size = key.size(2); -+ int block_size = key_cache.size(2); -+ // int x = key_cache.size(4); -+ int x = 1; -+ -+ int key_stride = key.stride(0); -+ int value_stride = value.stride(0); -+ -+ int key_head_stride = key.stride(1); -+ int value_head_stride = value.stride(1); -+ -+ // This actually dispatches on scalar_type, we will then need to dispatch on Head Dim... -+switch (head_size) { -+ case 64: -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { -+ call_reshape_and_cache_ipexllm_kernel_fp8( -+ key.data_ptr(), value.data_ptr(), -+ key_cache.data_ptr(), value_cache.data_ptr(), -+ slot_mapping.data_ptr(), num_tokens, key_stride, -+ value_stride, key_head_stride, value_head_stride, num_heads, -+ head_size, block_size, x); -+ }); -+ break; -+ case 128: -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { -+ call_reshape_and_cache_ipexllm_kernel_fp8( -+ key.data_ptr(), value.data_ptr(), -+ key_cache.data_ptr(), value_cache.data_ptr(), -+ slot_mapping.data_ptr(), num_tokens, key_stride, -+ value_stride, key_head_stride, value_head_stride, num_heads, -+ head_size, block_size, x); -+ }); -+ break; -+ case 96: -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { -+ call_reshape_and_cache_ipexllm_kernel_fp8( -+ key.data_ptr(), value.data_ptr(), -+ key_cache.data_ptr(), value_cache.data_ptr(), -+ slot_mapping.data_ptr(), num_tokens, key_stride, -+ value_stride, key_head_stride, value_head_stride, num_heads, -+ head_size, block_size, x); -+ }); -+ break; -+ case 80: -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { -+ call_reshape_and_cache_ipexllm_kernel_fp8( -+ key.data_ptr(), value.data_ptr(), -+ key_cache.data_ptr(), value_cache.data_ptr(), -+ slot_mapping.data_ptr(), num_tokens, key_stride, -+ value_stride, key_head_stride, value_head_stride, num_heads, -+ head_size, block_size, x); -+ }); -+ break; -+ default: -+ TORCH_CHECK(false, "Unsupported head_dim: ", head_size); -+} -+ // VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ // key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { -+ // call_reshape_and_cache_ipexllm_kernel_fp8( -+ // key.data_ptr(), value.data_ptr(), -+ // key_cache.data_ptr(), value_cache.data_ptr(), -+ // slot_mapping.data_ptr(), num_tokens, key_stride, -+ // value_stride, key_head_stride, value_head_stride, -+ // num_heads, head_size, block_size, x); -+ // }); -+} + @classmethod + def get_input_positions( + cls, +@@ -370,6 +400,15 @@ class MRotaryEmbedding(RotaryEmbedding): + context_len=context_len, + seq_len=seq_len, + ) ++ elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: ++ return cls._qwen3vl_get_input_positions_tensor( ++ input_tokens=input_tokens, ++ hf_config=hf_config, ++ image_grid_thw=image_grid_thw, ++ video_grid_thw=video_grid_thw, ++ context_len=context_len, ++ seq_len=seq_len, ++ ) + elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: + return cls._ernie_get_input_positions_tensor( + input_tokens=input_tokens, +@@ -508,6 +547,98 @@ class MRotaryEmbedding(RotaryEmbedding): + len(input_tokens)).item() + return llm_positions, mrope_position_delta + ++ @classmethod ++ def _qwen3vl_get_input_positions_tensor( ++ cls, ++ input_tokens: list[int], ++ hf_config: PretrainedConfig, ++ image_grid_thw: Union[list[list[int]], torch.Tensor], ++ video_grid_thw: Union[list[list[int]], torch.Tensor], ++ context_len: int = 0, ++ seq_len: Optional[int] = None, ++ ) -> tuple[torch.Tensor, int]: ++ """Get mrope input positions and delta value.""" ++ ++ video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw ++ for _ in range(t)] ++ ++ image_token_id = hf_config.image_token_id ++ video_token_id = hf_config.video_token_id ++ vision_start_token_id = hf_config.vision_start_token_id ++ spatial_merge_size = hf_config.vision_config.spatial_merge_size ++ ++ input_tokens_tensor = torch.tensor(input_tokens) ++ vision_start_indices = torch.argwhere( ++ input_tokens_tensor == vision_start_token_id).squeeze(1) ++ vision_tokens = input_tokens_tensor[vision_start_indices + 1] ++ image_nums = (vision_tokens == image_token_id).sum() ++ video_nums = (vision_tokens == video_token_id).sum() ++ llm_pos_ids_list: list = [] ++ ++ st = 0 ++ remain_images, remain_videos = image_nums, video_nums ++ ++ image_index, video_index = 0, 0 ++ for _ in range(image_nums + video_nums): ++ if image_token_id in input_tokens and remain_images > 0: ++ ed_image = input_tokens.index(image_token_id, st) ++ else: ++ ed_image = len(input_tokens) + 1 ++ if video_token_id in input_tokens and remain_videos > 0: ++ ed_video = input_tokens.index(video_token_id, st) ++ else: ++ ed_video = len(input_tokens) + 1 ++ if ed_image < ed_video: ++ t, h, w = ( ++ image_grid_thw[image_index][0], ++ image_grid_thw[image_index][1], ++ image_grid_thw[image_index][2], ++ ) ++ image_index += 1 ++ remain_images -= 1 ++ ed = ed_image ++ else: ++ t, h, w = ( ++ video_grid_thw[video_index][0], ++ video_grid_thw[video_index][1], ++ video_grid_thw[video_index][2], ++ ) ++ video_index += 1 ++ remain_videos -= 1 ++ ed = ed_video ++ ++ llm_grid_t, llm_grid_h, llm_grid_w = \ ++ t, h // spatial_merge_size, w // spatial_merge_size ++ text_len = ed - st ++ ++ st_idx = llm_pos_ids_list[-1].max() + 1 if len( ++ llm_pos_ids_list) > 0 else 0 ++ llm_pos_ids_list.append( ++ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) ++ ++ t_index = torch.arange(llm_grid_t).view(-1, 1).expand( ++ -1, llm_grid_h * llm_grid_w).flatten() ++ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( ++ llm_grid_t, -1, llm_grid_w).flatten() ++ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( ++ llm_grid_t, llm_grid_h, -1).flatten() ++ llm_pos_ids_list.append( ++ torch.stack([t_index, h_index, w_index]) + text_len + st_idx) ++ st = ed + llm_grid_t * llm_grid_h * llm_grid_w ++ ++ if st < len(input_tokens): ++ st_idx = llm_pos_ids_list[-1].max() + 1 if len( ++ llm_pos_ids_list) > 0 else 0 ++ text_len = len(input_tokens) - st ++ llm_pos_ids_list.append( ++ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) ++ ++ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) ++ mrope_position_delta = (llm_positions.max() + 1 - ++ len(input_tokens)).item() ++ llm_positions = llm_positions[:, context_len:seq_len] ++ return llm_positions, mrope_position_delta + + @classmethod + def _ernie_get_input_positions_tensor( + cls, +@@ -715,15 +846,23 @@ class MRotaryEmbedding(RotaryEmbedding): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( +- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) +- +- t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( +- -1, llm_grid_h * llm_grid_w)).long().flatten() +- +- h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( +- llm_grid_t, -1, llm_grid_w).flatten() +- w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( +- llm_grid_t, llm_grid_h, -1).flatten() ++ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ++ ) ++ t_index = ( ++ torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) ++ ).flatten() ++ h_index = ( ++ torch.arange(llm_grid_h) ++ .view(1, -1, 1) ++ .expand(llm_grid_t, -1, llm_grid_w) ++ .flatten() ++ ) ++ w_index = ( ++ torch.arange(llm_grid_w) ++ .view(1, 1, -1) ++ .expand(llm_grid_t, llm_grid_h, -1) ++ .flatten() ++ ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w +@@ -772,7 +911,6 @@ class MRotaryEmbedding(RotaryEmbedding): + + st = 0 + remain_images, remain_videos = image_nums, video_nums +- + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 +@@ -819,16 +957,25 @@ class MRotaryEmbedding(RotaryEmbedding): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( +- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) +- +- t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( +- -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * +- tokens_per_second).long().flatten() +- +- h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( +- llm_grid_t, -1, llm_grid_w).flatten() +- w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( +- llm_grid_t, llm_grid_h, -1).flatten() ++ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ++ ) ++ t_index = ( ++ torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) ++ * video_second_per_grid_t ++ * tokens_per_second ++ ).flatten() ++ h_index = ( ++ torch.arange(llm_grid_h) ++ .view(1, -1, 1) ++ .expand(llm_grid_t, -1, llm_grid_w) ++ .flatten() ++ ) ++ w_index = ( ++ torch.arange(llm_grid_w) ++ .view(1, 1, -1) ++ .expand(llm_grid_t, llm_grid_h, -1) ++ .flatten() ++ ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w +@@ -847,6 +994,339 @@ class MRotaryEmbedding(RotaryEmbedding): + + return llm_positions, mrope_position_delta + ++ @classmethod ++ def _omni3_get_input_positions_tensor( ++ cls, ++ config, ++ input_ids: torch.Tensor, ++ image_grid_thw: torch.Tensor, ++ video_grid_thw: torch.Tensor, ++ use_audio_in_video: bool = False, ++ audio_seqlens: Optional[torch.Tensor] = None, ++ second_per_grids: Optional[torch.Tensor] = None, ++ ) -> tuple[torch.Tensor, torch.Tensor]: ++ def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): ++ input_lengths_leave = input_lengths % 100 ++ feat_lengths = (input_lengths_leave - 1) // 2 + 1 ++ output_lengths = ( ++ ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 ++ ) ++ return output_lengths + ++ if input_ids is None or input_ids.ndim != 1: ++ raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") + -diff --git a/csrc/xpu/common.h b/csrc/xpu/common.h -new file mode 100644 -index 000000000..17d6ef643 ---- /dev/null -+++ b/csrc/xpu/common.h -@@ -0,0 +1,312 @@ -+#pragma once -+ -+#include -+#include -+ -+typedef union half_t { -+ uint16_t u; -+ sycl::half f; -+} __half_t; -+ -+typedef union ufloat32 { -+ unsigned u; -+ float f; -+} __float_t; -+ -+#define QK4_0 64 -+#define QR4_0 2 -+#define QK4_1 64 -+#define QR4_1 2 -+#define QK5_0 64 -+#define QR5_0 2 -+#define QK5_1 64 -+#define QR5_1 2 -+#define QK8_0 64 -+#define QR8_0 1 -+#define QK8_1 32 -+#define QR8_1 1 -+#define QI8_1 (QK8_1 / (4 * QR8_1)) // 8 -+#define QKFP8 64 -+#define QRFP8 1 -+#define QKFP6 64 -+// for iq2 quantization -+#define WARP_SIZE 32 -+#define QK_K 256 -+#define QK4_K 32 -+#define QR4_K 2 -+#define QK6_K 16 -+#define QKFP6_K 16 -+#define QR2_XXS 8 -+#define QI2_XXS (QK_K / (4*QR2_XXS)) // 8 -+#define QR2_XS 8 -+#define QI2_XS (QK_K / (4*QR2_XS)) // 8 -+#define QR2_K 4 -+#define QI2_K (QK_K / (4*QR2_K)) // 16 -+#define QR1_S 8 -+#define QI1_S (QK_K / (4*QR1_S)) // 8 -+ -+typedef struct { -+ sycl::half d; // delta -+ uint8_t qs[QK4_0 / 2]; // nibbles / quants -+} block_q4_0; -+ -+typedef struct { -+ uint8_t qs[QK4_0 / 2]; // nibbles / quants -+} block_q4_0_qs; -+ -+typedef struct { -+ uint8_t qs[QK4_1 / 2]; // nibbles / quants -+} block_q4_1_qs; -+ -+typedef struct { -+ sycl::half d; // delta -+ sycl::half m; // min -+ uint8_t qs[QK4_1 / 2]; // nibbles / quants -+} block_q4_1; -+ -+typedef struct { -+ sycl::half d; -+ uint8_t qh[8]; -+ uint8_t qs[QK5_0 / 2]; -+} block_q5_0; -+ -+typedef struct { -+ sycl::half d; // delta -+ sycl::half m; // min -+ uint8_t qh[8]; // 5-th bit of quants -+ uint8_t qs[QK5_1 / 2]; // nibbles / quants -+} block_q5_1; -+ -+typedef struct { -+ sycl::half d; // delta -+ uint8_t qh[8]; // 3-th bit of quants -+ uint8_t qs[QK4_0 / 4]; // nibbles / quants -+} block_nf3; -+ -+typedef struct { -+ uint8_t qh[8]; // 3-th bit of quants -+ uint8_t qs[QK4_0 / 4]; // nibbles / quants -+} block_nf3_qs; -+ -+typedef struct { -+ float d; // delta -+ int8_t qs[QK8_0]; // quants -+} block_q8_0; -+ -+typedef struct { -+ int8_t qs[QK8_0]; // quants -+} block_q8_0_qs; -+ -+typedef struct { -+ sycl::half d; -+ sycl::half sum; -+ int8_t qs[QK8_1]; // quants -+} block_q8_1; -+ -+typedef struct { -+ uint8_t qs[QKFP8]; -+} block_fp8_qs; -+ -+typedef struct { -+ float d; -+ uint8_t qs[QKFP8]; -+} block_fp8; -+ -+typedef struct { -+ sycl::half d; -+ uint16_t qs[QK_K/8]; // 32 -+} block_iq2_xxs; -+ -+typedef struct { -+ sycl::half d; -+ uint16_t qs[QK_K/8]; // 32 -+ uint8_t scales[QK_K/32]; // 8 -+} block_iq2_xs; -+ -+typedef struct { -+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits -+ uint8_t qs[QK_K/4]; // quants -+ sycl::half d; // super-block scale for quantized scales -+ sycl::half min; // super-block min for quantized mins -+} block_q2_K; -+ -+typedef struct { -+ sycl::half d; // super-block scale for quantized scales -+ sycl::half dmin; // super-block scale for quantized mins -+ uint8_t scales[16]; // scales and mins, quantized with 8 bits -+ uint8_t qs[QK_K/2]; // 4--bit quants -+} block_q4_K; -+ -+typedef struct { -+ uint8_t qs[QK_K/2]; // 4-bit quants -+} block_q4_K_qs; -+ -+typedef struct { -+ uint8_t qs[QK4_K/2]; // 4-bit quants -+} block_q4_K_qs_block; -+ -+typedef struct { -+ uint8_t scales[16]; // scales and mins, quantized with 8 bits -+} block_q4_K_scales; -+ -+typedef struct { -+ sycl::half d; // super-block scale for quantized scales -+ sycl::half dmin; // super-block scale for quantized mins -+ uint8_t scales[12]; // scales and mins, quantized with 6 bits -+ uint8_t qh[QK_K/8]; // quants, high bit -+ uint8_t qs[QK_K/2]; // quants, low 4 bits -+} block_q5_K; -+ -+typedef struct { -+ uint8_t ql[QK_K/2]; // quants, lower 4 bits -+ uint8_t qh[QK_K/4]; // quants, upper 2 bits -+ int8_t scales[QK_K/16]; // scales -+ sycl::half d; // delta -+} block_q6_K; -+ -+typedef struct { -+ uint32_t qh[QK_K/16]; // quants, upper 2 bits -+} block_q6_K_qh; -+ -+typedef struct { -+ uint32_t ql[QK_K/8]; // quants, lower 4 bits -+} block_q6_K_ql; -+ -+typedef struct { -+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits -+} block_q6_K_scales; -+ -+typedef struct { -+ uint8_t ql[QK_K/2]; // quants, lower 4 bits -+ uint8_t qh[QK_K/4]; // quants, upper 2 bits -+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits -+ sycl::half d; // super-block scale -+} block_fp6_K; -+static_assert(sizeof(block_fp6_K) == sizeof(sycl::half) + QK_K / 16 + 3*QK_K/4, "wrong fp6_K block size/padding"); -+ -+typedef struct { -+ uint32_t ql[QK_K/8]; // quants, lower 4 bits -+} block_fp6_k_ql; -+ -+typedef struct { -+ uint32_t qh[QK_K/16]; // quants, upper 2 bits -+} block_fp6_k_qh; -+ -+typedef struct { -+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits, 16 -+} block_fp6_k_scales; -+ -+typedef struct { -+ uint32_t ql[QKFP6_K/8]; // upper 2 bits, 2 -+} block_base_fp6_k_ql; -+ -+typedef struct { -+ uint32_t qh[QKFP6_K/16]; // upper 2 bits, 1 -+} block_base_fp6_k_qh; -+ -+#define NGRID_IQ1S 2048 -+#define IQ1S_DELTA 0.125f -+#define IQ1M_DELTA 0.125f -+ -+typedef struct { -+ sycl::half d; -+ uint8_t qs[QK_K/8]; -+ uint16_t qh[QK_K/32]; -+} block_iq1_s; -+ -+// 1.8125 bpw -+typedef struct { -+ uint8_t qs[QK_K/8]; // grid index, low 8 bits -+ uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8) -+ uint8_t scales[QK_K/32]; // 4-bit block scales -+} block_iq1_m; -+ -+typedef struct { -+ uint8_t ql[QKFP6/2]; // lower 4 bits, 32 -+ uint8_t qh[QKFP6/4]; // upper 2 bits, 16 -+ sycl::half d; // delta -+} block_fp6; -+ -+typedef struct { -+ uint32_t qh[QKFP6/16]; // upper 2 bits, 4 -+} block_fp6_32_qh; -+ -+typedef struct { -+ uint32_t ql[QKFP6/8]; // lower 4 bits, 8 -+} block_fp6_32_ql; -+ -+enum ggml_type { -+ GGML_TYPE_Q4_0 = 2, -+ GGML_TYPE_Q4_1 = 3, -+ GGML_TYPE_Q5_0 = 6, -+ GGML_TYPE_Q5_1 = 7, -+ GGML_TYPE_Q8_0 = 8, -+ GGML_TYPE_Q8_1 = 9, -+ GGML_TYPE_NF4 = 10, -+ GGML_TYPE_NF3 = 11, -+ GGML_TYPE_FP8E4 = 15, -+ GGML_TYPE_FP4 = 16, -+ GGML_TYPE_FP8E5 = 19, -+ GGML_TYPE_IQ2_XXS = 21, -+ GGML_TYPE_IQ2_XS = 22, -+ GGML_TYPE_Q2_K = 23, -+ GGML_TYPE_IQ1_S = 24, -+ GGML_TYPE_IQ1_M = 25, -+ GGML_TYPE_Q6_K = 26, -+ GGML_TYPE_Q4_K = 27, -+ GGML_TYPE_Q5_K = 28, -+ GGML_TYPE_FP6 = 29, -+ GGML_TYPE_FP6_K = 30, -+ GGML_TYPE_Q4_0_WOQ = 34, -+ GGML_TYPE_COUNT -+}; -+ -+static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { -+ [GGML_TYPE_Q4_0] = QK4_0, -+ [GGML_TYPE_Q4_1] = QK4_1, -+ [GGML_TYPE_Q5_0] = QK5_0, -+ [GGML_TYPE_Q5_1] = QK5_1, -+ [GGML_TYPE_NF4] = QK4_0, -+ [GGML_TYPE_NF3] = QK4_0, -+ [GGML_TYPE_Q8_0] = QK8_0, -+ [GGML_TYPE_Q8_1] = QK8_1, -+ [GGML_TYPE_FP8E4] = QKFP8, -+ [GGML_TYPE_FP4] = QK4_0, -+ [GGML_TYPE_FP6] = QKFP6, -+ [GGML_TYPE_FP8E5] = QKFP8, -+ [GGML_TYPE_IQ2_XXS] = QK_K, -+ [GGML_TYPE_IQ2_XS] = QK_K, -+ [GGML_TYPE_Q2_K] = QK_K, -+ [GGML_TYPE_IQ1_S] = QK_K, -+ [GGML_TYPE_IQ1_M] = QK_K, -+ [GGML_TYPE_Q6_K] = QK_K, -+ [GGML_TYPE_Q4_K] = QK_K, -+ [GGML_TYPE_Q5_K] = QK_K, -+ [GGML_TYPE_FP6_K] = QK_K, -+ [GGML_TYPE_Q4_0_WOQ] = QK4_0, -+}; -+ -+static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { -+ [GGML_TYPE_Q4_0] = sizeof(block_q4_0), -+ [GGML_TYPE_Q4_1] = sizeof(block_q4_1), -+ [GGML_TYPE_Q5_0] = sizeof(block_q5_1), -+ [GGML_TYPE_Q5_1] = sizeof(block_q5_1), -+ [GGML_TYPE_NF4] = sizeof(block_q4_0), -+ [GGML_TYPE_NF3] = sizeof(block_nf3), -+ [GGML_TYPE_Q8_0] = sizeof(block_q8_0), -+ [GGML_TYPE_Q8_1] = sizeof(block_q8_1), -+ [GGML_TYPE_FP8E4]= sizeof(block_fp8), -+ [GGML_TYPE_FP4] = sizeof(block_q4_0), -+ [GGML_TYPE_FP6] = sizeof(block_fp6), -+ [GGML_TYPE_FP8E5] = sizeof(block_fp8), -+ [GGML_TYPE_IQ2_XXS] = sizeof(block_iq2_xxs), -+ [GGML_TYPE_IQ2_XS] = sizeof(block_iq2_xs), -+ [GGML_TYPE_Q2_K] = sizeof(block_q2_K), -+ [GGML_TYPE_IQ1_S] = sizeof(block_iq1_s), -+ [GGML_TYPE_IQ1_M] = sizeof(block_iq1_m), -+ [GGML_TYPE_Q6_K] = sizeof(block_q6_K), -+ [GGML_TYPE_Q4_K] = sizeof(block_q4_K), -+ [GGML_TYPE_Q5_K] = sizeof(block_q5_K), -+ [GGML_TYPE_FP6_K] = sizeof(block_fp6_K), -+ [GGML_TYPE_Q4_0_WOQ] = sizeof(block_q4_0), -+}; -diff --git a/csrc/xpu/dequantize.h b/csrc/xpu/dequantize.h -new file mode 100644 -index 000000000..9a967312e ---- /dev/null -+++ b/csrc/xpu/dequantize.h -@@ -0,0 +1,74 @@ -+#include -+#include -+#include "utils.h" -+/* -+Adapted from https://github.com/mit-han-lab/llm-awq -+Modified from NVIDIA FasterTransformer: -+https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -+@article{lin2023awq, -+ title={AWQ: Activation-aware Weight Quantization for LLM Compression and -+Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, -+Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} -+} -+*/ -+ -+#pragma once -+ -+namespace vllm { -+namespace awq { -+ -+sycl::uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { -+ sycl::uint4 result; -+ -+ uint32_t* h = reinterpret_cast(&result); -+ uint32_t const i4s = reinterpret_cast(source); -+ -+ // First, we extract the i4s and construct an intermediate fp16 number. -+ static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; -+ static constexpr uint32_t BOTTOM_MASK = 0x000f000f; -+ static constexpr uint32_t TOP_MASK = 0x00f000f0; -+ static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; -+ -+ // Note that the entire sequence only requires 1 shift instruction. This is -+ // thanks to the register packing format and the fact that we force our -+ // integers to be unsigned, and account for this in the fp16 subtractions. In -+ // addition, I exploit the fact that sub and fma have the same throughput in -+ // order to convert elt_23 and elt_67 to fp16 without having to shift them to -+ // the bottom bits before hand. -+ -+ // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW -+ // dependency if we issue immediately before required. -+ const uint32_t top_i4s = i4s >> 8; -+ h[0] = (i4s & BOTTOM_MASK) | I4s_TO_F16s_MAGIC_NUM; -+ h[1] = (i4s & TOP_MASK) | I4s_TO_F16s_MAGIC_NUM; -+ h[2] = (top_i4s & BOTTOM_MASK) | I4s_TO_F16s_MAGIC_NUM; -+ h[3] = (top_i4s & TOP_MASK) | I4s_TO_F16s_MAGIC_NUM; -+ -+ // This is the half2 {1032, 1032} represented as an integer. -+ // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; -+ // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] -+ static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; -+ // This is the half2 {1 / 16, 1 / 16} represented as an integer. -+ static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; -+ // This is the half2 {-72, -72} represented as an integer. -+ // static constexpr uint32_t NEG_72 = 0xd480d480; -+ // Haotian: Let's use {-64, -64}. -+ static constexpr uint32_t NEG_64 = 0xd400d400; -+ *(sycl::half2*)(&h[0]) = sycl_half_sub2( -+ *(sycl::half2*)(&h[0]), *(sycl::half2*)(&FP16_TOP_MAGIC_NUM)); -+ *(sycl::half2*)(&h[1]) = sycl_half_fma2( -+ *(sycl::half2*)(&h[1]), -+ *(sycl::half2*)(&ONE_SIXTEENTH), -+ *(sycl::half2*)(&NEG_64)); -+ *(sycl::half2*)(&h[2]) = sycl_half_sub2( -+ *(sycl::half2*)(&h[2]), *(sycl::half2*)(&FP16_TOP_MAGIC_NUM)); -+ *(sycl::half2*)(&h[3]) = sycl_half_fma2( -+ *(sycl::half2*)(&h[3]), -+ *(sycl::half2*)(&ONE_SIXTEENTH), -+ *(sycl::half2*)(&NEG_64)); -+ -+ return result; -+} ++ seq_len = input_ids.shape[0] ++ device = input_ids.device ++ dtype = input_ids.dtype + -+} // namespace awq -+} // namespace vllm -\ No newline at end of file -diff --git a/csrc/xpu/dtype_float16.h b/csrc/xpu/dtype_float16.h -new file mode 100644 -index 000000000..1b9c1f248 ---- /dev/null -+++ b/csrc/xpu/dtype_float16.h -@@ -0,0 +1,458 @@ -+/* -+ * Adapted from -+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp -+ * and -+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h -+ * Copyright (c) 2023, The vLLM team. -+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+#pragma once -+ -+#include -+#include -+#include "attention_generic.h" -+#include "dtype_float32.h" -+#include "utils.h" -+ -+#include -+ -+namespace vllm { -+ -+// FP16 vector types for Q, K, V. -+template <> -+struct Vec { -+ using Type = sycl::half; -+}; -+template <> -+struct Vec { -+ using Type = sycl::half2; -+}; -+template <> -+struct Vec { -+ using Type = sycl::half4; -+}; -+template <> -+struct Vec { -+ using Type = sycl::half8; -+}; -+ -+template <> -+struct FloatVec { -+ using Type = float; -+}; -+template <> -+struct FloatVec { -+ using Type = sycl::float2; -+}; -+ -+template <> -+struct FloatVec { -+ using Type = Float4_; -+}; -+template <> -+struct FloatVec { -+ using Type = Float8_; -+}; -+ -+// Utility functions for type conversions. -+inline sycl::half2 h0_h0(sycl::half a) { -+ return sycl::half2{a, a}; -+} ++ if image_grid_thw is not None: ++ image_grid_thw = image_grid_thw.to(device=device, dtype=torch.long) ++ if video_grid_thw is not None: ++ video_grid_thw = video_grid_thw.to(device=device, dtype=torch.long) + -+inline float half_to_float(sycl::half h) { -+ return float(h); -+} ++ if second_per_grids is None: ++ if video_grid_thw is not None and video_grid_thw.numel() > 0: ++ second_per_grids = torch.ones( ++ video_grid_thw.shape[0], dtype=torch.float32, device=device ++ ) ++ else: ++ second_per_grids = torch.tensor([], dtype=torch.float32, device=device) ++ else: ++ second_per_grids = second_per_grids.to(device=device, dtype=torch.float32) ++ ++ if audio_seqlens is not None: ++ audio_seqlens = audio_seqlens.to(device=device, dtype=torch.long) ++ ++ spatial_merge_size = config.vision_config.spatial_merge_size ++ image_token_id = config.image_token_id ++ video_token_id = config.video_token_id ++ audio_token_id = config.audio_token_id ++ vision_start_token_id = config.vision_start_token_id ++ audio_start_token_id = config.audio_start_token_id ++ position_id_per_seconds = config.position_id_per_seconds ++ ++ vision_start_indices = torch.argwhere( ++ input_ids == vision_start_token_id ++ ).squeeze(1) ++ if vision_start_indices.numel() > 0: ++ vision_tokens = input_ids[vision_start_indices + 1] ++ else: ++ vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype) ++ audio_nums = torch.sum(input_ids == audio_start_token_id) ++ image_nums = (vision_tokens == image_token_id).sum() ++ video_nums = ( ++ (vision_tokens == audio_start_token_id).sum() ++ if use_audio_in_video ++ else (vision_tokens == video_token_id).sum() ++ ) + -+inline sycl::float2 half2_to_float2(sycl::half2 v) { ++ input_tokens = input_ids.tolist() ++ llm_pos_ids_list: list[torch.Tensor] = [] ++ st = 0 ++ image_idx = 0 ++ video_idx = 0 ++ audio_idx = 0 ++ remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501 ++ multimodal_nums = ( ++ image_nums + audio_nums ++ if use_audio_in_video ++ else image_nums + video_nums + audio_nums ++ ) # noqa: E501 ++ ++ for _ in range(multimodal_nums): ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ if (image_token_id in input_tokens or video_token_id in input_tokens) and ( ++ remain_videos > 0 or remain_images > 0 ++ ): ++ ed_vision_start = input_tokens.index(vision_start_token_id, st) ++ else: ++ ed_vision_start = len(input_tokens) + 1 ++ if audio_token_id in input_tokens and remain_audios > 0: ++ ed_audio_start = input_tokens.index(audio_start_token_id, st) ++ else: ++ ed_audio_start = len(input_tokens) + 1 ++ min_ed = min(ed_vision_start, ed_audio_start) ++ ++ if min_ed == ed_audio_start: ++ text_len = min_ed - st ++ if text_len != 0: ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ llm_pos_ids_list.append( ++ torch.arange(text_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ bos_len = 1 ++ llm_pos_ids_list.append( ++ torch.arange(bos_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) ++ llm_pos_ids = ( ++ torch.arange(audio_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ llm_pos_ids_list.append(llm_pos_ids) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ eos_len = 1 ++ llm_pos_ids_list.append( ++ torch.arange(eos_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st += text_len + bos_len + audio_len + eos_len ++ audio_idx += 1 ++ remain_audios -= 1 ++ elif ( ++ min_ed == ed_vision_start ++ and input_ids[ed_vision_start + 1] == image_token_id ++ ): ++ text_len = min_ed - st ++ if text_len != 0: ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ llm_pos_ids_list.append( ++ torch.arange(text_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ bos_len = 1 ++ llm_pos_ids_list.append( ++ torch.arange(bos_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ grid_t = image_grid_thw[image_idx][0] ++ grid_hs = image_grid_thw[:, 1] ++ grid_ws = image_grid_thw[:, 2] ++ t_index = torch.arange(grid_t, device=device) * position_id_per_seconds ++ llm_pos_ids = cls._get_llm_pos_ids_for_vision( ++ st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws ++ ) ++ image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) ++ llm_pos_ids_list.append(llm_pos_ids) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ eos_len = 1 ++ llm_pos_ids_list.append( ++ torch.arange(eos_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st += text_len + bos_len + image_len + eos_len ++ image_idx += 1 ++ remain_images -= 1 ++ elif ( ++ min_ed == ed_vision_start ++ and input_ids[ed_vision_start + 1] == video_token_id ++ and not use_audio_in_video ++ ): ++ text_len = min_ed - st ++ if text_len != 0: ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ llm_pos_ids_list.append( ++ torch.arange(text_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ bos_len = 1 ++ llm_pos_ids_list.append( ++ torch.arange(bos_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ grid_t = video_grid_thw[video_idx][0] ++ grid_hs = video_grid_thw[:, 1] ++ grid_ws = video_grid_thw[:, 2] ++ t_index = ( ++ torch.arange(grid_t, device=device) ++ * float(second_per_grids[video_idx].item()) ++ * position_id_per_seconds ++ ) ++ llm_pos_ids = cls._get_llm_pos_ids_for_vision( ++ st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws ++ ) ++ video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) ++ llm_pos_ids_list.append(llm_pos_ids) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ eos_len = 1 ++ llm_pos_ids_list.append( ++ torch.arange(eos_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st += text_len + bos_len + video_len + eos_len ++ video_idx += 1 ++ remain_videos -= 1 ++ elif ( ++ min_ed == ed_vision_start ++ and ed_vision_start + 1 == ed_audio_start ++ and use_audio_in_video ++ ): ++ text_len = min_ed - st ++ if text_len != 0: ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ llm_pos_ids_list.append( ++ torch.arange(text_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ bos_len = 1 ++ bos_block = ( ++ torch.arange(bos_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ llm_pos_ids_list.append(bos_block) ++ llm_pos_ids_list.append(bos_block) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx]) ++ audio_llm_pos_ids = ( ++ torch.arange(audio_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ grid_t = video_grid_thw[video_idx][0] ++ grid_hs = video_grid_thw[:, 1] ++ grid_ws = video_grid_thw[:, 2] ++ t_index = ( ++ torch.arange(grid_t, device=device) ++ * float(second_per_grids[video_idx].item()) ++ * position_id_per_seconds ++ ) ++ video_llm_pos_ids = cls._get_llm_pos_ids_for_vision( ++ st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws ++ ) ++ video_data_index, audio_data_index = 0, 0 ++ while ( ++ video_data_index < video_llm_pos_ids.shape[-1] ++ and audio_data_index < audio_llm_pos_ids.shape[-1] ++ ): ++ if ( ++ video_llm_pos_ids[0][video_data_index] ++ <= audio_llm_pos_ids[0][audio_data_index] ++ ): ++ llm_pos_ids_list.append( ++ video_llm_pos_ids[ ++ :, video_data_index : video_data_index + 1 ++ ] ++ ) ++ video_data_index += 1 ++ else: ++ llm_pos_ids_list.append( ++ audio_llm_pos_ids[ ++ :, audio_data_index : audio_data_index + 1 ++ ] ++ ) ++ audio_data_index += 1 ++ if video_data_index < video_llm_pos_ids.shape[-1]: ++ llm_pos_ids_list.append( ++ video_llm_pos_ids[ ++ :, video_data_index : video_llm_pos_ids.shape[-1] ++ ] ++ ) ++ if audio_data_index < audio_llm_pos_ids.shape[-1]: ++ llm_pos_ids_list.append( ++ audio_llm_pos_ids[ ++ :, audio_data_index : audio_llm_pos_ids.shape[-1] ++ ] ++ ) ++ video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ eos_len = 1 ++ eos_block = ( ++ torch.arange(eos_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) ++ llm_pos_ids_list.append(eos_block) ++ llm_pos_ids_list.append(eos_block) ++ st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501 ++ audio_idx += 1 ++ video_idx += 1 ++ remain_videos -= 1 ++ remain_audios -= 1 ++ ++ if st < len(input_tokens): ++ st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 ++ text_len = len(input_tokens) - st ++ llm_pos_ids_list.append( ++ torch.arange(text_len, device=device, dtype=torch.long) ++ .view(1, -1) ++ .expand(3, -1) ++ + st_idx ++ ) + -+ return sycl::float2(half_to_float(v.x()), half_to_float(v.y())); -+} ++ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) ++ if llm_positions.shape[1] != seq_len: ++ raise RuntimeError("Position ids length mismatch with input ids length") + -+inline sycl::half float_to_half(float f) { -+ return sycl::half(f); -+} ++ position_ids = llm_positions.to(device=device, dtype=dtype) ++ mrope_position_delta = llm_positions.max() + 1 - seq_len ++ return position_ids, mrope_position_delta + -+inline sycl::half2 float2_to_half2(sycl::float2 f) { -+ return sycl::half2{float_to_half(f.x()), float_to_half(f.y())}; -+} + @classmethod + def _omni_get_input_positions_tensor( + cls, +@@ -879,7 +1359,38 @@ class MRotaryEmbedding(RotaryEmbedding): + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + ++ model_type = hf_config.model_type + thinker_config = hf_config.thinker_config ++ ++ if isinstance(image_grid_thw, list): ++ image_grid_thw = torch.tensor(image_grid_thw) ++ if isinstance(video_grid_thw, list): ++ video_grid_thw = torch.tensor(video_grid_thw) ++ ++ if "qwen3_omni" in model_type: ++ input_tensor = torch.tensor(input_tokens) ++ audio_lengths_tensor = audio_feature_lengths ++ if audio_lengths_tensor is not None and not isinstance( ++ audio_lengths_tensor, torch.Tensor ++ ): ++ audio_lengths_tensor = torch.as_tensor( ++ audio_lengths_tensor, dtype=torch.long ++ ) ++ second_per_grids_tensor = ( ++ torch.tensor(second_per_grid_ts) if second_per_grid_ts else None ++ ) + -+// Vector addition. -+inline sycl::half add(sycl::half a, sycl::half b) { -+ return sycl_half_add(a,b); -+} ++ llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor( # noqa: E501 ++ thinker_config, ++ input_tensor, ++ image_grid_thw, ++ video_grid_thw, ++ use_audio_in_video, ++ audio_lengths_tensor, ++ second_per_grids_tensor, ++ ) ++ return llm_positions, mrope_position_delta ++ + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index +@@ -892,11 +1403,6 @@ class MRotaryEmbedding(RotaryEmbedding): + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + +- if isinstance(image_grid_thw, list): +- image_grid_thw = torch.tensor(image_grid_thw) +- if isinstance(video_grid_thw, list): +- video_grid_thw = torch.tensor(video_grid_thw) +- + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: +@@ -940,7 +1446,7 @@ class MRotaryEmbedding(RotaryEmbedding): + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] +- t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() ++ t_index = torch.arange(grid_t) * 1 * tokens_per_second + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) +@@ -953,9 +1459,11 @@ class MRotaryEmbedding(RotaryEmbedding): + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] +- t_index = (torch.arange(grid_t) * +- second_per_grid_ts[video_idx] * +- tokens_per_second).long() ++ t_index = ( ++ torch.arange(grid_t) ++ * second_per_grid_ts[video_idx] ++ * tokens_per_second ++ ) + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) +@@ -976,9 +1484,11 @@ class MRotaryEmbedding(RotaryEmbedding): + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) +- t_index = (torch.arange(grid_t) * +- second_per_grid_ts[video_idx] * +- tokens_per_second).long() ++ t_index = ( ++ torch.arange(grid_t) ++ * second_per_grid_ts[video_idx] ++ * tokens_per_second ++ ) + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 +@@ -1117,10 +1627,8 @@ class MRotaryEmbedding(RotaryEmbedding): + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) +- t_index = (torch.arange(grid_t) * video_second_per_grid_t * +- tokens_per_second).long() +- t_index_split_chunk = cls._split_list_into_ranges( +- t_index, t_ntoken_per_chunk) ++ t_index = torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second ++ t_index_split_chunk = cls._split_list_into_ranges(t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 +diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py +index 0c2441a6d..d1747f2d3 100644 +--- a/vllm/model_executor/model_loader/utils.py ++++ b/vllm/model_executor/model_loader/utils.py +@@ -15,6 +15,7 @@ from typing_extensions import assert_never + from vllm.attention import Attention + from vllm.config import (ModelConfig, ModelImpl, VllmConfig, + set_current_vllm_config) ++from vllm.envs import VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT + from vllm.logger import init_logger + from vllm.model_executor.layers.linear import QKVCrossParallelLinear + from vllm.model_executor.layers.quantization.base_config import ( +@@ -144,26 +145,30 @@ def device_loading_context(module: torch.nn.Module, + yield module + + finally: +- # Restore parameters to their original devices, ignoring new parameters +- pin_memory = is_pin_memory_available() +- for name, p in module.named_parameters(): +- if name in original_device_states: +- original_device: torch.device = original_device_states[name] +- if original_device.type == "cpu": +- # `torch.empty_like` does not support `pin_memory` argument +- cpu_data = torch.empty_strided( +- size=p.data.size(), +- stride=p.data.stride(), +- dtype=p.data.dtype, +- layout=p.data.layout, +- device="cpu", +- pin_memory=pin_memory, +- ) +- cpu_data.copy_(p.data) +- p.data = cpu_data +- else: +- p.data = p.data.to(original_device) +- # New parameters or parameters already on target device are untouched ++ # If weights were loaded onto the CPU for FP8 online quantization, there ++ # is no need to move them back to the original device. ++ if not VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT: ++ # Restore parameters to their original devices, ignoring new parameters # noqa: E501 ++ pin_memory = is_pin_memory_available() ++ for name, p in module.named_parameters(): ++ if name in original_device_states: ++ original_device: torch.device = original_device_states[ ++ name] ++ if original_device.type == "cpu": ++ # `torch.empty_like` does not support `pin_memory` argument # noqa: E501 ++ cpu_data = torch.empty_strided( ++ size=p.data.size(), ++ stride=p.data.stride(), ++ dtype=p.data.dtype, ++ layout=p.data.layout, ++ device="cpu", ++ pin_memory=pin_memory, ++ ) ++ cpu_data.copy_(p.data) ++ p.data = cpu_data ++ else: ++ p.data = p.data.to(original_device) ++ # New parameters or parameters already on target device are untouched # noqa: E501 + + + def get_model_architecture( +diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py +new file mode 100644 +index 000000000..f24cb6d52 +--- /dev/null ++++ b/vllm/model_executor/models/dots_ocr.py +@@ -0,0 +1,861 @@ ++# SPDX-License-Identifier: Apache-2.0 ++# SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++from collections.abc import Iterable, Mapping ++from typing import Literal, Optional, TypedDict, Union + -+inline sycl::half2 add(sycl::half2 a, sycl::half2 b) { -+ auto val = sycl_half_add2(a, b); -+ return (val); -+} ++import torch ++import torch.nn as nn ++import torch.nn.functional as F ++from torch.nn import LayerNorm ++from transformers.modeling_utils import PreTrainedModel ++from transformers.models.qwen2_vl import Qwen2VLProcessor + -+inline sycl::half4 add(sycl::half4 a, sycl::half4 b) { -+ sycl::half4 c; -+ c.x() = add(a.x(), b.x()); -+ c.y() = add(a.y(), b.y()); -+ c.z() = add(a.z(), b.z()); -+ c.w() = add(a.w(), b.w()); -+ return c; -+} ++from vllm.attention.layer import check_upstream_fa_availability ++from vllm.config import VllmConfig ++from vllm.model_executor.layers.activation import SiluAndMul ++from vllm.model_executor.layers.layernorm import RMSNorm ++from vllm.model_executor.layers.linear import (ColumnParallelLinear, ++ MergedColumnParallelLinear, ++ QKVParallelLinear, ++ RowParallelLinear) ++from vllm.model_executor.layers.quantization import QuantizationConfig ++from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, ++ SupportsMultiModal, ++ SupportsPP) ++from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM ++from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder, ++ Qwen2VLMultiModalProcessor, ++ Qwen2VLProcessingInfo) ++from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, ++ init_vllm_registered_model, ++ maybe_prefix, ++ merge_multimodal_embeddings) ++from vllm.model_executor.models.vision import get_vit_attn_backend ++from vllm.multimodal import MULTIMODAL_REGISTRY ++from vllm.multimodal.inputs import MultiModalDataDict ++from vllm.platforms import _Backend ++from vllm.sequence import IntermediateTensors ++from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, ++ DotsVisionConfig) + -+inline sycl::half8 add(sycl::half8 a, sycl::half8 b) { -+ sycl::half8 c; -+ c.s0() = add(a.s0(), b.s0()); -+ c.s1() = add(a.s1(), b.s1()); -+ c.s2() = add(a.s2(), b.s2()); -+ c.s3() = add(a.s3(), b.s3()); -+ c.s4() = add(a.s4(), b.s4()); -+ c.s5() = add(a.s5(), b.s5()); -+ c.s6() = add(a.s6(), b.s6()); -+ c.s7() = add(a.s7(), b.s7()); -+ return c; -+} ++IMAGE_TOKEN = "<|imgpad|>" + -+inline sycl::float2 add(sycl::half2 a, sycl::float2 fb) { -+ sycl::float2 fa = half2_to_float2(a); -+ return add(fa, fb); -+} + -+inline Float4_ add(sycl::half4 a, Float4_ fb) { -+ Float4_ fc; -+ fc.x = add(sycl::half2{a.x(), a.y()}, fb.x); -+ fc.y = add(sycl::half2{a.z(), a.w()}, fb.y); -+ return fc; -+} ++class DotsOCRImagePixelInputs(TypedDict): ++ type: Literal["pixel_values", "image_grid_thw"] + -+inline Float8_ add(sycl::half8 a, Float8_ fb) { -+ Float8_ fc; -+ fc.x = add(sycl::half2{a.s0(), a.s1()}, fb.x); -+ fc.y = add(sycl::half2{a.s2(), a.s3()}, fb.y); -+ fc.z = add(sycl::half2{a.s4(), a.s5()}, fb.z); -+ fc.w = add(sycl::half2{a.s6(), a.s7()}, fb.w); -+ return fc; -+} ++ pixel_values: torch.Tensor ++ image_grid_thw: torch.Tensor ++ ++ ++class DotsOCRImageEmbeddingInputs(TypedDict): ++ type: Literal["image_embeds", "image_grid_thw"] ++ image_embeds: torch.Tensor ++ """Supported types: ++ - List[`torch.Tensor`]: A list of tensors holding all images' features. ++ Each tensor holds an image's features. ++ - `torch.Tensor`: A tensor holding all images' features ++ (concatenation of all images' feature tensors). ++ Tensor shape: `(num_image_features, hidden_size)` ++ - `num_image_features` varies based on ++ the number and resolution of the images. ++ - `hidden_size` must match the hidden size of language model backbone. ++ """ + -+// Vector multiplication. -+template <> -+inline sycl::half mul(sycl::half a, sycl::half b) { -+ auto val = sycl_half_mul((a), (b)); -+ return (val); -+} ++ image_grid_thw: torch.Tensor + -+template <> -+inline sycl::half2 mul(sycl::half2 a, sycl::half2 b) { -+ auto val = sycl_half_mul2((a), (b)); -+ return (val); -+} + -+template <> -+inline sycl::half2 mul(sycl::half a, sycl::half2 b) { -+ return mul(h0_h0(a), b); -+} ++DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, ++ DotsOCRImageEmbeddingInputs] + + -+template <> -+inline sycl::half4 mul(sycl::half4 a, sycl::half4 b) { -+ sycl::half4 c; -+ c.x() = mul(a.x(), b.x()); -+ c.y() = mul(a.y(), b.y()); -+ c.z() = mul(a.z(), b.z()); -+ c.w() = mul(a.w(), b.w()); -+ return c; -+} ++class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder): + -+template <> -+inline sycl::half4 mul(sycl::half a, sycl::half4 b) { -+ sycl::half4 c; -+ c.x() = mul(a, b.x()); -+ c.y() = mul(a, b.y()); -+ c.z() = mul(a, b.z()); -+ c.w() = mul(a, b.w()); -+ return c; -+} ++ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: ++ num_images = mm_counts.get("image", 0) ++ return IMAGE_TOKEN * num_images + -+template <> -+inline sycl::half8 mul(sycl::half8 a, sycl::half8 b) { -+ sycl::half8 c; -+ c.s0() = mul(a.s0(), b.s0()); -+ c.s1() = mul(a.s1(), b.s1()); -+ c.s2() = mul(a.s2(), b.s2()); -+ c.s3() = mul(a.s3(), b.s3()); -+ c.s4() = mul(a.s4(), b.s4()); -+ c.s5() = mul(a.s5(), b.s5()); -+ c.s6() = mul(a.s6(), b.s6()); -+ c.s7() = mul(a.s7(), b.s7()); -+ return c; -+} ++ def get_dummy_mm_data( ++ self, ++ seq_len: int, ++ mm_counts: Mapping[str, int], ++ ) -> MultiModalDataDict: ++ num_images = mm_counts.get("image", 0) + -+template <> -+inline sycl::half8 mul(sycl::half a, sycl::half8 b) { -+ sycl::half8 c; -+ c.s0() = mul(a, b.s0()); -+ c.s1() = mul(a, b.s1()); -+ c.s2() = mul(a, b.s2()); -+ c.s3() = mul(a, b.s3()); -+ c.s4() = mul(a, b.s4()); -+ c.s5() = mul(a, b.s5()); -+ c.s6() = mul(a, b.s6()); -+ c.s7() = mul(a, b.s7()); -+ return c; -+} ++ target_width, target_height = self.info.get_image_size_with_most_features( # noqa: E501 ++ ) + -+template <> -+inline float mul(sycl::half a, sycl::half b) { -+ float fa = half_to_float(a); -+ float fb = half_to_float(b); -+ return fa * fb; -+} ++ return { ++ "image": ++ self._get_dummy_images(width=target_width, ++ height=target_height, ++ num_images=num_images), ++ } + -+template <> -+inline sycl::float2 mul(sycl::half2 a, sycl::half2 b) { -+ sycl::float2 fa = half2_to_float2(a); -+ sycl::float2 fb = half2_to_float2(b); -+ return mul(fa, fb); -+} + -+template <> -+inline sycl::float2 mul(sycl::half a, sycl::half2 b) { -+ return mul(h0_h0(a), b); -+} ++class DotsOCRProcessingInfo(Qwen2VLProcessingInfo): + -+template <> -+inline Float4_ mul(sycl::half4 a, sycl::half4 b) { -+ Float4_ fc; -+ fc.x = mul( -+ sycl::half2{a.x(), a.y()}, sycl::half2{b.x(), b.y()}); -+ fc.y = mul( -+ sycl::half2{a.z(), a.w()}, sycl::half2{b.z(), b.w()}); -+ return fc; -+} ++ def get_hf_config(self) -> DotsOCRConfig: ++ config = self.ctx.get_hf_config() ++ if not config.__class__.__name__ == 'DotsOCRConfig': ++ raise TypeError(f"Expected DotsOCRConfig, got {type(config)}") + -+template <> -+inline Float4_ mul(sycl::half a, sycl::half4 b) { -+ sycl::half2 s = h0_h0(a); -+ Float4_ fc; ++ if hasattr(config, "vision_config") and isinstance( ++ config.vision_config, dict): ++ config.vision_config = DotsVisionConfig(**config.vision_config) + -+ fc.x = -+ mul(s, sycl::half2{b.x(), b.y()}); -+ fc.y = -+ mul(s, sycl::half2{b.z(), b.w()}); -+ return fc; -+} ++ return config + -+template <> -+inline Float8_ mul(sycl::half8 a, sycl::half8 b) { -+ Float8_ fc; -+ fc.x = mul( -+ sycl::half2{a.s0(), a.s1()}, sycl::half2{b.s0(), b.s1()}); -+ fc.y = mul( -+ sycl::half2{a.s2(), a.s3()}, sycl::half2{b.s2(), b.s3()}); -+ fc.z = mul( -+ sycl::half2{a.s4(), a.s5()}, sycl::half2{b.s4(), b.s5()}); -+ fc.w = mul( -+ sycl::half2{a.s6(), a.s7()}, sycl::half2{b.s6(), b.s7()}); -+ return fc; -+} ++ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: ++ return {"image": None} + -+template <> -+inline Float8_ mul(sycl::half a, sycl::half8 b) { -+ sycl::half2 s = h0_h0(a); -+ Float8_ fc; -+ fc.x = mul( -+ s, sycl::half2{b.s0(), b.s1()}); -+ fc.y = mul( -+ s, sycl::half2{b.s2(), b.s3()}); -+ fc.z = mul( -+ s, sycl::half2{b.s4(), b.s5()}); -+ fc.w = mul( -+ s, sycl::half2{b.s6(), b.s7()}); -+ return fc; -+} ++ def get_mm_max_tokens_per_item( ++ self, ++ seq_len: int, ++ mm_counts: Mapping[str, int], ++ ) -> Mapping[str, int]: ++ max_image_tokens = self.get_max_image_tokens() ++ return {"image": max_image_tokens} + -+// Vector fused multiply-add. -+inline sycl::half2 fma(sycl::half2 a, sycl::half2 b, sycl::half2 c) { -+ auto val = sycl_half_fma2((a), (b), (c)); -+ return (val); -+} ++ def get_hf_processor( ++ self, ++ **kwargs: object, ++ ) -> Qwen2VLProcessor: ++ self.get_tokenizer( ++ ).image_token = IMAGE_TOKEN # Ensure image token is set ++ processor = self.ctx.get_hf_processor( ++ Qwen2VLProcessor, ++ **kwargs, ++ ) ++ processor.image_token = IMAGE_TOKEN ++ processor.video_token = "<|video_pad|>" ++ return processor + -+inline sycl::half2 fma(sycl::half a, sycl::half2 b, sycl::half2 c) { -+ return fma(h0_h0(a), b, c); -+} + -+inline sycl::half4 fma(sycl::half4 a, sycl::half4 b, sycl::half4 c) { -+ sycl::half4 d; -+ d.x() = fma(a.x(), b.x(), c.x()); -+ d.y() = fma(a.y(), b.y(), c.y()); -+ d.z() = fma(a.z(), b.z(), c.z()); -+ d.w() = fma(a.w(), b.w(), c.w()); -+ return d; -+} ++def rotate_half(x): ++ """Rotates half the hidden dims of the input.""" ++ x1 = x[..., :x.shape[-1] // 2] ++ x2 = x[..., x.shape[-1] // 2:] ++ return torch.cat((-x2, x1), dim=-1) + -+inline sycl::half4 fma(sycl::half a, sycl::half4 b, sycl::half4 c) { -+ sycl::half4 s = sycl::half4{a, a, a, a}; -+ return fma(s, b, c); -+} + -+inline sycl::half8 fma(sycl::half8 a, sycl::half8 b, sycl::half8 c) { -+ sycl::half8 d; -+ d.s0() = fma(a.s0(), b.s0(), c.s0()); -+ d.s1() = fma(a.s1(), b.s1(), c.s1()); -+ d.s2() = fma(a.s2(), b.s2(), c.s2()); -+ d.s3() = fma(a.s3(), b.s3(), c.s3()); -+ d.s4() = fma(a.s4(), b.s4(), c.s4()); -+ d.s5() = fma(a.s5(), b.s5(), c.s5()); -+ d.s6() = fma(a.s6(), b.s6(), c.s6()); -+ d.s7() = fma(a.s7(), b.s7(), c.s7()); -+ return d; -+} ++def apply_rotary_pos_emb_vision(tensor: torch.Tensor, ++ freqs: torch.Tensor) -> torch.Tensor: ++ orig_dtype = tensor.dtype ++ tensor = tensor.float() + -+inline sycl::half8 fma(sycl::half a, sycl::half8 b, sycl::half8 c) { -+ sycl::half8 d; -+ d.s0() = fma(a, b.s0(), c.s0()); -+ d.s1() = fma(a, b.s1(), c.s1()); -+ d.s2() = fma(a, b.s2(), c.s2()); -+ d.s3() = fma(a, b.s3(), c.s3()); -+ d.s4() = fma(a, b.s4(), c.s4()); -+ d.s5() = fma(a, b.s5(), c.s5()); -+ d.s6() = fma(a, b.s6(), c.s6()); -+ d.s7() = fma(a, b.s7(), c.s7()); -+ return d; -+} ++ cos = freqs.cos() ++ sin = freqs.sin() + -+inline float fma(sycl::half a, sycl::half b, float fc) { -+ float fa = half_to_float(a); -+ float fb = half_to_float(b); -+ return sycl::fma(fa, fb, fc); -+} ++ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() ++ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + -+inline sycl::float2 fma(sycl::half2 a, sycl::half2 b, sycl::float2 fc) { -+ sycl::float2 fa = half2_to_float2(a); -+ sycl::float2 fb = half2_to_float2(b); -+ return fma(fa, fb, fc); -+} ++ output = (tensor * cos) + (rotate_half(tensor) * sin) + -+inline sycl::float2 fma(sycl::half a, sycl::half2 b, sycl::float2 fc) { -+ return fma(h0_h0(a), b, fc); -+} ++ output = output.to(orig_dtype) + -+inline Float4_ fma(sycl::half4 a, sycl::half4 b, Float4_ fc) { -+ Float4_ fd; -+ fd.x = fma(sycl::half2{a.x(), a.y()}, sycl::half2{b.x(), b.y()}, fc.x); -+ fd.y = fma(sycl::half2{a.z(), a.w()}, sycl::half2{b.z(), b.w()}, fc.y); -+ return fd; -+} ++ return output + -+inline Float4_ fma(sycl::half a, sycl::half4 b, Float4_ fc) { -+ sycl::half4 s = sycl::half4{a, a, a, a}; + -+ return fma(s, b, fc); -+} ++class VisionRotaryEmbedding(nn.Module): + -+inline Float8_ fma(sycl::half8 a, sycl::half8 b, Float8_ fc) { -+ Float8_ fd; -+ fd.x = fma(sycl::half2{a.s0(), a.s1()}, sycl::half2{b.s0(), b.s1()}, fc.x); -+ fd.y = fma(sycl::half2{a.s2(), a.s3()}, sycl::half2{b.s2(), b.s3()}, fc.y); -+ fd.z = fma(sycl::half2{a.s4(), a.s5()}, sycl::half2{b.s4(), b.s5()}, fc.z); -+ fd.w = fma(sycl::half2{a.s6(), a.s7()}, sycl::half2{b.s6(), b.s7()}, fc.w); -+ return fd; -+} ++ def __init__(self, dim: int, theta: float = 10000.0) -> None: ++ super().__init__() ++ inv_freq = 1.0 / (theta ++ **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) ++ self.register_buffer("inv_freq", inv_freq, persistent=False) + -+inline Float8_ fma(sycl::half a, sycl::half8 b, Float8_ fc) { -+ sycl::half8 s = sycl::half8{a, a, a, a, a, a, a, a}; ++ def forward(self, seqlen: int) -> torch.Tensor: ++ seq = torch.arange(seqlen, ++ device=self.inv_freq.device, ++ dtype=self.inv_freq.dtype) ++ freqs = torch.outer(seq, self.inv_freq) ++ return freqs + -+ return fma(s, b, fc); -+} + -+// Vector sum. -+template <> -+inline float sum(sycl::half v) { -+ return half_to_float(v); -+} ++class PatchMerger(nn.Module): + -+template <> -+inline float sum(sycl::half2 v) { -+ sycl::float2 tmp = half2_to_float2(v); -+ return tmp.x() + tmp.y(); -+} ++ def __init__( ++ self, ++ dim: int, ++ context_dim: int, ++ spatial_merge_size: int = 2, ++ pre_norm="layernorm", ++ ) -> None: ++ super().__init__() ++ self.hidden_size = context_dim * (spatial_merge_size**2) ++ self.pre_norm = pre_norm ++ if self.pre_norm == "layernorm": ++ self.ln_q = LayerNorm(context_dim, eps=1e-6) ++ elif self.pre_norm == "rmsnorm": ++ self.ln_q = RMSNorm(context_dim, eps=1e-6) ++ else: ++ print("no norm in patch merger") ++ ++ self.mlp = nn.Sequential( ++ ColumnParallelLinear(self.hidden_size, ++ self.hidden_size, ++ bias=True, ++ return_bias=False, ++ disable_tp=True), ++ nn.GELU(), ++ RowParallelLinear(self.hidden_size, ++ dim, ++ bias=True, ++ return_bias=False, ++ disable_tp=True), ++ ) + -+template <> -+inline float sum(sycl::half4 v) { -+ sycl::half2 c = add(sycl::half2{v.x(), v.y()}, sycl::half2{v.z(), v.w()}); -+ return sum(c); -+} ++ def forward(self, x: torch.Tensor) -> torch.Tensor: ++ if self.pre_norm: ++ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) ++ else: ++ x = self.mlp(x.view(-1, self.hidden_size)) ++ return x + -+template <> -+inline float sum(sycl::half8 v) { -+ return add( -+ sum(sycl::half4{v.s0(), v.s1(), v.s2(), v.s3()}), -+ sum(sycl::half4{v.s4(), v.s5(), v.s6(), v.s7()})); -+} + -+inline void from_float(sycl::half& dst, float src) { -+ dst = sycl::half(src); -+} ++class DotsVisionAttention(nn.Module): + -+inline void from_float(sycl::half2& dst, sycl::float2 src) { -+ dst = float2_to_half2(src); -+} ++ def __init__(self, ++ config, ++ dim: int, ++ num_heads: int = 16, ++ bias: bool = True, ++ *, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = "") -> None: ++ super().__init__() ++ from vllm.distributed import (parallel_state, ++ tensor_model_parallel_all_gather) ++ from vllm.distributed import utils as dist_utils + -+inline void from_float(sycl::half4& dst, Float4_ src) { -+ sycl::half2 h0 = float2_to_half2(src.x); -+ sycl::half2 h1 = float2_to_half2(src.y); -+ dst.x() = h0.x(); -+ dst.y() = h0.y(); -+ dst.z() = h1.x(); -+ dst.w() = h1.y(); -+} ++ self.embed_dim = dim ++ self.num_heads = num_heads ++ self.head_dim = dim // num_heads ++ self.tp_size = parallel_state.get_tensor_model_parallel_world_size() ++ self.tp_rank = parallel_state.get_tensor_model_parallel_rank() ++ self.num_heads_per_partition = dist_utils.divide( ++ num_heads, self.tp_size) ++ ++ # qkv/proj follow Qwen2-VL style; bias controlled by arg ++ self.qkv = QKVParallelLinear(hidden_size=dim, ++ head_size=dim // num_heads, ++ total_num_heads=num_heads, ++ bias=bias, ++ quant_config=quant_config, ++ prefix=f"{prefix}.qkv") ++ self.proj = RowParallelLinear(input_size=dim, ++ output_size=dim, ++ bias=bias, ++ quant_config=quant_config, ++ prefix=f"{prefix}.proj") ++ self._all_gather = tensor_model_parallel_all_gather ++ self._split_last = dist_utils.split_tensor_along_last_dim ++ ++ # Select attention backend ++ self.attn_backend = get_vit_attn_backend(self.head_dim, ++ torch.get_default_dtype()) ++ self.use_upstream_fa = False ++ if self.attn_backend != _Backend.FLASH_ATTN and \ ++ check_upstream_fa_availability(torch.get_default_dtype()): ++ self.attn_backend = _Backend.FLASH_ATTN ++ self.use_upstream_fa = True ++ if self.attn_backend not in { ++ _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, ++ _Backend.ROCM_AITER_FA, _Backend.IPEX ++ }: ++ raise RuntimeError( ++ f"Unsupported vision attention backend: {self.attn_backend}") ++ self.is_flash_attn_backend = self.attn_backend in { ++ _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA ++ } + -+inline void from_float(sycl::half8& dst, Float8_ src) { -+ dst.s0() = float2_to_half2(src.x).x(); -+ dst.s1() = float2_to_half2(src.x).y(); -+ dst.s2() = float2_to_half2(src.y).x(); -+ dst.s3() = float2_to_half2(src.y).y(); -+ dst.s4() = float2_to_half2(src.z).x(); -+ dst.s5() = float2_to_half2(src.z).y(); -+ dst.s6() = float2_to_half2(src.w).x(); -+ dst.s7() = float2_to_half2(src.w).y(); -+} ++ def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: ++ # qkv: [S, B, 3*dim] ++ seq_len, bs, _ = qkv.shape ++ if self.tp_size > 1: ++ qkv = self._all_gather(qkv) ++ q, k, v = qkv.chunk(3, dim=2) ++ if self.tp_size > 1: ++ q = self._split_last(q, num_partitions=self.tp_size)[self.tp_rank] ++ k = self._split_last(k, num_partitions=self.tp_size)[self.tp_rank] ++ v = self._split_last(v, num_partitions=self.tp_size)[self.tp_rank] ++ new_shape = (seq_len, bs, self.num_heads_per_partition, self.head_dim) ++ return (q.view(*new_shape), k.view(*new_shape), v.view(*new_shape)) + -+// From float16 to float32. -+inline float to_float(sycl::half u) { -+ return half_to_float(u); -+} -+ -+inline sycl::float2 to_float(sycl::half2 u) { -+ return half2_to_float2(u); -+} -+ -+inline Float4_ to_float(sycl::half4 u) { -+ Float4_ tmp; -+ tmp.x = half2_to_float2(sycl::half2{u.x(), u.y()}); -+ tmp.y = half2_to_float2(sycl::half2{u.z(), u.w()}); -+ return tmp; -+} -+ -+inline Float8_ to_float(sycl::half8 u) { -+ Float8_ tmp; -+ tmp.x = half2_to_float2(sycl::half2{u.s0(), u.s1()}); -+ tmp.y = half2_to_float2(sycl::half2{u.s2(), u.s3()}); -+ tmp.z = half2_to_float2(sycl::half2{u.s4(), u.s5()}); -+ tmp.w = half2_to_float2(sycl::half2{u.s6(), u.s7()}); -+ return tmp; -+} -+ -+// Zero-out a variable. -+inline void zero(sycl::half& dst) { -+ dst = sycl::half(0); -+} -+ -+} // namespace vllm -\ No newline at end of file -diff --git a/csrc/xpu/dtype_float32.h b/csrc/xpu/dtype_float32.h -new file mode 100644 -index 000000000..7b70e4efc ---- /dev/null -+++ b/csrc/xpu/dtype_float32.h -@@ -0,0 +1,268 @@ -+/* -+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp -+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h -+ * Copyright (c) 2023, The vLLM team. -+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+#pragma once -+ -+#include -+#include -+#include "attention_generic.h" -+ -+#include -+ -+namespace vllm { -+ -+// Define custom FP32 vector data types. -+struct Float4_ { -+ sycl::float2 x; -+ sycl::float2 y; -+}; -+ -+struct Float8_ { -+ sycl::float2 x; -+ sycl::float2 y; -+ sycl::float2 z; -+ sycl::float2 w; -+}; -+ -+// FP32 vector types for Q, K, V. -+template<> -+struct Vec { -+ using Type = float; -+}; -+template<> -+struct Vec { -+ using Type = sycl::float2; -+}; -+template<> -+struct Vec { -+ using Type = sycl::float4; -+}; -+ -+// FP32 accumulator vector types corresponding to Vec. -+template<> -+struct FloatVec { -+ using Type = float; -+}; -+template <> struct FloatVec { -+ using Type = sycl::float2; -+}; -+template <> struct FloatVec { -+ using Type = sycl::float4; -+}; -+ -+// Vector addition. -+inline float add(float a, float b) { -+ return a + b; -+} -+ -+inline sycl::float2 add(sycl::float2 a, sycl::float2 b) { -+ sycl::float2 c; -+ c.x() = add(a.x(), b.x()); -+ c.y() = add(a.y(), b.y()); -+ return c; -+} -+ -+inline sycl::float4 add(sycl::float4 a, sycl::float4 b) { -+ sycl::float4 c; -+ c.x() = add(a.x(), b.x()); -+ c.y() = add(a.y(), b.y()); -+ c.z() = add(a.z(), b.z()); -+ c.w() = add(a.w(), b.w()); -+ return c; -+} -+ -+// Vector multiplication. -+template<> -+inline float mul(float a, float b) { -+ return a * b; -+} -+ -+template <> inline sycl::float2 mul(sycl::float2 a, sycl::float2 b) { -+ sycl::float2 c; -+ c.x() = a.x() * b.x(); -+ c.y() = a.y() * b.y(); -+ return c; -+} -+ -+template <> inline sycl::float2 mul(float a, sycl::float2 b) { -+ sycl::float2 c; -+ c.x() = a * b.x(); -+ c.y() = a * b.y(); -+ return c; -+} -+ -+template <> inline sycl::float4 mul(sycl::float4 a, sycl::float4 b) { -+ sycl::float4 c; -+ c.x() = a.x() * b.x(); -+ c.y() = a.y() * b.y(); -+ c.z() = a.z() * b.z(); -+ c.w() = a.w() * b.w(); -+ return c; -+} -+ -+template <> inline sycl::float4 mul(float a, sycl::float4 b) { -+ sycl::float4 c; -+ c.x() = a * b.x(); -+ c.y() = a * b.y(); -+ c.z() = a * b.z(); -+ c.w() = a * b.w(); -+ return c; -+} ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ cu_seqlens: torch.Tensor, ++ rotary_pos_emb: Optional[torch.Tensor] = None, ++ *, ++ max_seqlen: Optional[int] = None, ++ seqlens: Optional[list[int]] = None, ++ ) -> torch.Tensor: ++ # [S, C] -> [S, B=1, C] ++ x = hidden_states.unsqueeze(1) ++ x, _ = self.qkv(x) ++ q, k, v = self._split_qkv(x) ++ bs = q.shape[1] ++ # [S,B,H,D] -> [B,S,H,D] ++ q = q.permute(1, 0, 2, 3).contiguous() ++ k = k.permute(1, 0, 2, 3).contiguous() ++ v = v.permute(1, 0, 2, 3).contiguous() ++ ++ if rotary_pos_emb is not None: ++ qk_concat = torch.cat([q, k], dim=0) ++ qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) ++ q, k = torch.chunk(qk_rotated, 2, dim=0) ++ ++ if self.is_flash_attn_backend: ++ if self.attn_backend == _Backend.ROCM_AITER_FA: ++ from aiter import flash_attn_varlen_func ++ else: ++ if self.use_upstream_fa: ++ from flash_attn import flash_attn_varlen_func ++ else: ++ from vllm.vllm_flash_attn import flash_attn_varlen_func ++ q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) ++ k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) ++ v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) ++ output = flash_attn_varlen_func(q_, ++ k_, ++ v_, ++ cu_seqlens_q=cu_seqlens, ++ cu_seqlens_k=cu_seqlens, ++ max_seqlen_q=max_seqlen, ++ max_seqlen_k=max_seqlen, ++ dropout_p=0.0, ++ causal=False) ++ context_layer = output.view(bs, -1, self.num_heads_per_partition, ++ self.head_dim) ++ elif self.attn_backend == _Backend.TORCH_SDPA: ++ outputs = [] ++ for i in range(1, len(cu_seqlens)): ++ s = int(cu_seqlens[i - 1]) ++ e = int(cu_seqlens[i]) ++ q_i = q[:, s:e].permute(0, 2, 1, 3) ++ k_i = k[:, s:e].permute(0, 2, 1, 3) ++ v_i = v[:, s:e].permute(0, 2, 1, 3) ++ out_i = F.scaled_dot_product_attention(q_i, ++ k_i, ++ v_i, ++ dropout_p=0.0) ++ out_i = out_i.permute(0, 2, 1, 3) ++ outputs.append(out_i) ++ context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] ++ elif self.attn_backend == _Backend.XFORMERS: ++ from xformers import ops as xops ++ from xformers.ops.fmha.attn_bias import BlockDiagonalMask ++ attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, ++ kv_seqlen=None, ++ device=q.device) ++ context_layer = xops.memory_efficient_attention_forward( ++ q, k, v, attn_bias=attn_bias, p=0, scale=None) ++ elif self.attn_backend == _Backend.IPEX: ++ q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) ++ k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) ++ v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) ++ output = torch.empty_like(q_) + -+// Vector fused multiply-add. -+inline float fma(float a, float b, float c) { -+ return a * b + c; -+} ++ from vllm._ipex_ops import ipex_ops ++ ipex_ops.varlen_attention( ++ q_.contiguous(), # query ++ k_.contiguous(), # key ++ v_.contiguous(), # value ++ output, # out ++ cu_seqlens.int(), # seqlen_q ++ cu_seqlens.int(), # seqlen_k ++ None, # alibi_slopes ++ max_seqlen, # max_seqlen_q ++ max_seqlen, # max_seqlen_k ++ 0.0, # pdropout ++ 1.0 / (q.shape[-1] ** 0.5), # softmax_scale ++ False, # zero_tensors ++ False, # is_causal ++ False, # return_softmax ++ None, # gen_ ++ -1, # window_size_left ++ -1, # window_size_right ++ -1, # logits_soft_cap ++ ) ++ context_layer = output.view(bs, -1, self.num_heads_per_partition, ++ self.head_dim) ++ else: ++ raise RuntimeError("Unsupported attention backend") + -+inline sycl::float2 fma(sycl::float2 a, sycl::float2 b, sycl::float2 c) { -+ sycl::float2 d; -+ d.x() = fma(a.x(), b.x(), c.x()); -+ d.y() = fma(a.y(), b.y(), c.y()); -+ return d; -+} ++ # [B,S,H,D] -> [S,B,H*D] -> [S, C] ++ context_layer = context_layer.permute(1, 0, 2, 3).contiguous() ++ context_layer = context_layer.view(context_layer.shape[0], bs, -1) ++ out, _ = self.proj(context_layer) ++ return out.squeeze(1) + -+inline sycl::float2 fma(float a, sycl::float2 b, sycl::float2 c) { -+ sycl::float2 d; -+ d.x() = fma(a, b.x(), c.x()); -+ d.y() = fma(a, b.y(), c.y()); -+ return d; -+} + -+inline sycl::float4 fma(sycl::float4 a, sycl::float4 b, sycl::float4 c) { -+ sycl::float4 d; -+ d.x() = fma(a.x(), b.x(), c.x()); -+ d.y() = fma(a.y(), b.y(), c.y()); -+ d.z() = fma(a.z(), b.z(), c.z()); -+ d.w() = fma(a.w(), b.w(), c.w()); -+ return d; -+} ++class DotsSwiGLUFFN(nn.Module): + -+inline sycl::float4 fma(float a, sycl::float4 b, sycl::float4 c) { -+ sycl::float4 d; -+ d.x() = fma(a, b.x(), c.x()); -+ d.y() = fma(a, b.y(), c.y()); -+ d.z() = fma(a, b.z(), c.z()); -+ d.w() = fma(a, b.w(), c.w()); -+ return d; -+} ++ def __init__(self, ++ config, ++ *, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = ""): ++ super().__init__() ++ hidden_features = config.intermediate_size ++ in_features = config.embed_dim ++ bias = config.use_bias ++ ++ # Referenced aimv2.py AIMv2SwiGLUFFN ++ self.fc13 = MergedColumnParallelLinear(in_features, ++ [hidden_features] * 2, ++ bias=bias, ++ quant_config=quant_config, ++ prefix=f"{prefix}.fc13", ++ disable_tp=True) ++ self.fc2 = RowParallelLinear(hidden_features, ++ in_features, ++ bias=bias, ++ quant_config=quant_config, ++ prefix=f"{prefix}.fc2", ++ disable_tp=True) ++ self.act_fn = SiluAndMul() + -+inline Float4_ fma(float a, Float4_ b, Float4_ c) { -+ Float4_ d; -+ d.x = fma(a, b.x, c.x); -+ d.y = fma(a, b.y, c.y); -+ return d; -+} ++ def forward(self, x: torch.Tensor) -> torch.Tensor: ++ x, _ = self.fc13(x) ++ x = self.act_fn(x) ++ x, _ = self.fc2(x) ++ return x + -+inline Float8_ fma(float a, Float8_ b, Float8_ c) { -+ Float8_ d; -+ d.x = fma(a, b.x, c.x); -+ d.y = fma(a, b.y, c.y); -+ d.z = fma(a, b.z, c.z); -+ d.w = fma(a, b.w, c.w); -+ return d; -+} ++ def load_weights(self, weights: Iterable[tuple[str, ++ torch.Tensor]]) -> set[str]: ++ params = dict(self.named_parameters()) ++ loaded: set[str] = set() ++ for name, w in weights: ++ # Map fc1 -> fc13 (shard 0) ++ if name.startswith("fc1."): ++ tgt = name.replace("fc1.", "fc13.") ++ if tgt in params: ++ params[tgt].weight_loader(params[tgt], w, 0) ++ loaded.add(tgt) ++ continue ++ # Map fc3 -> fc13 (shard 1) ++ if name.startswith("fc3."): ++ tgt = name.replace("fc3.", "fc13.") ++ if tgt in params: ++ params[tgt].weight_loader(params[tgt], w, 1) ++ loaded.add(tgt) ++ continue ++ # Pass-through for fc2 and others ++ if name in params: ++ params[name].weight_loader(params[name], w) ++ loaded.add(name) ++ return loaded + -+// Vector sum. -+template<> -+inline float sum(float v) { -+ return v; -+} + -+template <> inline float sum(sycl::float2 v) { -+ return v.x() + v.y(); -+} ++class DotsPatchEmbed(nn.Module): + -+template <> inline float sum(sycl::float4 v) { -+ return v.x() + v.y() + v.z() + v.w(); -+} ++ def __init__(self, config): ++ super().__init__() ++ self.num_channels = config.num_channels ++ self.patch_size = config.patch_size ++ self.temporal_patch_size = config.temporal_patch_size ++ self.embed_dim = config.embed_dim ++ self.config = config ++ self.proj = nn.Conv2d( ++ config.num_channels, ++ config.embed_dim, ++ kernel_size=(config.patch_size, config.patch_size), ++ stride=(config.patch_size, config.patch_size), ++ ) ++ self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + -+template<> -+inline float sum(Float4_ v) { -+ return v.x.x() + v.x.y() + v.y.x() + v.y.y(); -+} ++ def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: ++ x = x.view(-1, self.num_channels, self.temporal_patch_size, ++ self.patch_size, self.patch_size)[:, :, 0] ++ x = self.proj(x).view(-1, self.embed_dim) ++ x = self.norm(x) ++ return x + -+template<> -+inline float sum(Float8_ v) { -+ return v.x.x() + v.x.y() + v.y.x() + v.y.y() + v.z.x() + v.z.y() + v.w.x() + -+ v.w.y(); -+} + -+// Vector dot product. -+inline float dot(float a, float b) { -+ return a * b; -+} ++class DotsViTPreprocessor(nn.Module): + -+inline float dot(sycl::float2 a, sycl::float2 b) { -+ sycl::float2 c = mul(a, b); -+ return c.x() + c.y(); -+} ++ def __init__(self, config): ++ super().__init__() ++ self.patch_h = config.patch_size ++ self.patch_w = config.patch_size ++ self.embed_dim = config.embed_dim ++ self.config = config ++ self.patchifier = DotsPatchEmbed(config) + -+inline float dot(Float4_ a, Float4_ b) { -+ sycl::float2 acc = mul(a.x, b.x); -+ acc = fma(a.y, b.y, acc); -+ return acc.x() + acc.y(); -+} ++ def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: ++ tokens = self.patchifier(x, grid_thw) ++ return tokens + -+inline float dot(Float8_ a, Float8_ b) { -+ sycl::float2 acc = mul(a.x, b.x); -+ acc = fma(a.y, b.y, acc); -+ acc = fma(a.z, b.z, acc); -+ acc = fma(a.w, b.w, acc); -+ return acc.x() + acc.y(); -+} + -+// From float to float. -+inline void from_float(float& dst, float src) { -+ dst = src; -+} ++class DotsVisionBlock(nn.Module): + -+inline void from_float(sycl::float2 &dst, sycl::float2 src) { -+ dst = src; -+} ++ def __init__(self, ++ config, ++ *, ++ quant_config: Optional[QuantizationConfig] = None, ++ prefix: str = ""): ++ super().__init__() + -+inline void from_float(sycl::float4 &dst, sycl::float4 src) { -+ dst = src; -+} ++ self.attn = DotsVisionAttention( ++ config, ++ config.embed_dim, ++ num_heads=config.num_attention_heads, ++ bias=config.use_bias, ++ quant_config=quant_config, ++ prefix=f"{prefix}.attn", ++ ) ++ self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) ++ self.mlp = DotsSwiGLUFFN(config, ++ quant_config=quant_config, ++ prefix=f"{prefix}.mlp") ++ self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) ++ ++ def forward(self, ++ hidden_states: torch.Tensor, ++ *, ++ cu_seqlens: torch.Tensor, ++ rotary_pos_emb: torch.Tensor, ++ max_seqlen: Optional[int] = None, ++ seqlens: Optional[list[int]] = None) -> torch.Tensor: ++ hidden_states = hidden_states + self.attn( ++ self.norm1(hidden_states), ++ cu_seqlens=cu_seqlens, ++ rotary_pos_emb=rotary_pos_emb, ++ max_seqlen=max_seqlen, ++ seqlens=seqlens, ++ ) ++ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) ++ return hidden_states + -+// From float to float. -+inline float to_float(float u) { -+ return u; -+} + -+inline sycl::float2 to_float(sycl::float2 u) { -+ return u; -+} ++class DotsVisionTransformer(PreTrainedModel): + -+inline sycl::float4 to_float(sycl::float4 u) { -+ return u; -+} ++ def __init__( ++ self, ++ config: DotsVisionConfig, ++ quant_config: Optional[QuantizationConfig] = None, ++ *, ++ num_hidden_layers_override: Optional[int] = None, ++ require_post_norm: Optional[bool] = None, ++ prefix: str = "", ++ ) -> None: ++ super().__init__(config) ++ self.config = config ++ self.spatial_merge_size = config.spatial_merge_size ++ ++ self.patch_embed = DotsViTPreprocessor(config) ++ ++ head_dim = config.embed_dim // config.num_attention_heads ++ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) ++ self.attn_backend = get_vit_attn_backend( ++ head_size=head_dim, dtype=torch.get_default_dtype()) ++ if self.attn_backend != _Backend.FLASH_ATTN and \ ++ check_upstream_fa_availability(torch.get_default_dtype()): ++ self.attn_backend = _Backend.FLASH_ATTN ++ ++ # Keep blocks for compatibility with other vision towers ++ num_layers = (config.num_hidden_layers if num_hidden_layers_override ++ is None else num_hidden_layers_override) ++ self.blocks = nn.ModuleList([ ++ DotsVisionBlock(config, ++ quant_config=quant_config, ++ prefix=f"{prefix}.blocks.{i}") ++ for i in range(num_layers) ++ ]) ++ if require_post_norm is None: ++ require_post_norm = (len(self.blocks) == config.num_hidden_layers) ++ if require_post_norm and self.config.post_norm: ++ self.post_trunk_norm = RMSNorm(config.embed_dim, ++ eps=config.rms_norm_eps) ++ else: ++ self.post_trunk_norm = None + -+inline Float4_ to_float(Float4_ u) { -+ return u; -+} ++ self.merger = PatchMerger( ++ dim=config.hidden_size, ++ context_dim=config.embed_dim, ++ spatial_merge_size=config.spatial_merge_size, ++ ) + -+inline Float8_ to_float(Float8_ u) { -+ return u; -+} ++ @property ++ def dtype(self) -> torch.dtype: ++ return self.patch_embed.patchifier.proj.weight.dtype ++ ++ @property ++ def device(self) -> torch.device: ++ return self.patch_embed.patchifier.proj.weight.device ++ ++ def get_pos_ids_by_grid(self, grid_thw): ++ pos_ids = [] ++ for t, h, w in grid_thw: ++ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) ++ hpos_ids = hpos_ids.reshape( ++ h // self.spatial_merge_size, ++ self.spatial_merge_size, ++ w // self.spatial_merge_size, ++ self.spatial_merge_size, ++ ) ++ hpos_ids = hpos_ids.permute(0, 2, 1, 3) ++ hpos_ids = hpos_ids.flatten() ++ ++ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) ++ wpos_ids = wpos_ids.reshape( ++ h // self.spatial_merge_size, ++ self.spatial_merge_size, ++ w // self.spatial_merge_size, ++ self.spatial_merge_size, ++ ) ++ wpos_ids = wpos_ids.permute(0, 2, 1, 3) ++ wpos_ids = wpos_ids.flatten() ++ pos_ids.append( ++ torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) ++ ++ return pos_ids ++ ++ def rot_pos_emb(self, grid_thw): ++ pos_ids = self.get_pos_ids_by_grid(grid_thw) ++ pos_ids = torch.cat(pos_ids, dim=0) ++ max_grid_size = grid_thw[:, 1:].max() ++ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) ++ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) ++ return rotary_pos_emb ++ ++ def compute_attn_mask_seqlen( ++ self, cu_seqlens: torch.Tensor ++ ) -> tuple[Optional[int], Optional[list[int]]]: ++ max_seqlen, seqlens = None, None ++ if self.attn_backend == _Backend.FLASH_ATTN: ++ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() ++ elif self.attn_backend == _Backend.XFORMERS: ++ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() ++ return max_seqlen, seqlens ++ ++ def forward(self, hidden_states: torch.Tensor, ++ grid_thw: torch.Tensor) -> torch.Tensor: ++ hidden_states = hidden_states.to(self.dtype) ++ hidden_states = self.patch_embed(hidden_states, grid_thw) ++ ++ rotary_pos_emb = self.rot_pos_emb(grid_thw) ++ ++ cu_seqlens = torch.repeat_interleave( ++ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( ++ dim=0, ++ dtype=grid_thw.dtype ++ if torch.jit.is_tracing() else torch.int32, ++ ) ++ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + -+// Zero-out a variable. -+inline void zero(float& dst) { -+ dst = 0.f; -+} ++ max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) ++ for blk in self.blocks: ++ hidden_states = blk(hidden_states, ++ cu_seqlens=cu_seqlens, ++ rotary_pos_emb=rotary_pos_emb, ++ max_seqlen=max_seqlen, ++ seqlens=seqlens) + -+} // namespace vllm -\ No newline at end of file -diff --git a/csrc/xpu/fused_moe.cpp b/csrc/xpu/fused_moe.cpp -new file mode 100644 -index 000000000..3a39d0e13 ---- /dev/null -+++ b/csrc/xpu/fused_moe.cpp -@@ -0,0 +1,269 @@ -+#include "utils.h" -+#include "base.hpp" ++ if self.post_trunk_norm is not None: ++ hidden_states = self.post_trunk_norm(hidden_states) + -+using ST = at::ScalarType; ++ hidden_states = self.merger(hidden_states) ++ return hidden_states + -+#include -+#include "xpu_types.h" -+#include + -+template -+__inline__ T silu_xpu(const T& x) { -+ // x * sigmoid(x) -+ return (T)(((float)x) / (1.0f + sycl::exp((float)-x))); -+} ++@MULTIMODAL_REGISTRY.register_processor( ++ Qwen2VLMultiModalProcessor, ++ info=DotsOCRProcessingInfo, ++ dummy_inputs=DotsOCRDummyInputsBuilder, ++) ++class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ++ hf_to_vllm_mapper = WeightsMapper( ++ orig_to_new_substr={ ++ ".attn.qkv_proj.": ".attn.qkv.", ++ ".attn.out_proj.": ".attn.proj.", ++ }, ++ orig_to_new_prefix={ ++ "lm_head.": "language_model.lm_head.", ++ "model.": "language_model.model.", ++ }, ++ ) + -+template -+void silu_and_mul_kernel( -+ scalar_t* __restrict__ out, // [..., d] -+ const scalar_t* __restrict__ input, // [..., 2, d] -+ const int d, -+ const sycl::nd_item<3>& item_ct1) { -+ const int64_t token_idx = item_ct1.get_group(2); -+ for (int64_t idx = item_ct1.get_local_id(2); idx < d; -+ idx += item_ct1.get_local_range(2)) { -+ const scalar_t x = input[token_idx * 2 * d + idx]; -+ const scalar_t y = input[token_idx * 2 * d + d + idx]; -+ out[token_idx * d + idx] = silu_xpu(x) * y; -+ } -+} ++ @classmethod ++ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: ++ if modality.startswith("image"): ++ return "<|img|><|imgpad|><|endofimg|>" + -+template -+void call_silu_and_mul_kernel( -+ int num_tokens, -+ int d, -+ const scalar_t* __restrict__ input, -+ scalar_t* __restrict__ output) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(d, 1024)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) { -+ silu_and_mul_kernel( -+ (sycl_t*)output, (const sycl_t*)input, d, item_ct1); -+ }); -+ }); -+} ++ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ++ super().__init__() + -+void _silu_and_mul(torch::Tensor& out, torch::Tensor& input) { -+ int num_tokens = input.numel() / input.size(-1); -+ int d = input.size(-1) / 2; -+ -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ input.scalar_type(), "call_silu_and_mul_kernel", [&] { -+ call_silu_and_mul_kernel( -+ num_tokens, -+ d, -+ input.data_ptr(), -+ out.data_ptr()); -+ }); -+} ++ self.config: DotsOCRConfig = vllm_config.model_config.hf_config ++ self.quant_config = vllm_config.quant_config ++ self.multimodal_config = vllm_config.model_config.multimodal_config + -+template -+static void moe_forward_kernel( -+ const void* input_ptr, -+ const int64_t* indexs, -+ const uint64_t* qweights, -+ void * output_ptr, -+ const int num_tokens, -+ const int state_size, -+ const int output_size, -+ at::Device device -+) { -+ static_assert(ES == 8 || ES == 16 || ES == 32); -+ assert(output_size % VS == 0); -+ -+ const int nb = state_size / QK; -+ const int nsb = nb / SBS; -+ -+ constexpr int BLOCK_SIZE = BLOCK_SIZES[QTYPE]; -+ constexpr int SCALE_SIZE = SCALE_SIZES[QTYPE]; -+ -+ sycl::range<2> global_size(num_tokens, output_size / VS * GS); -+ sycl::range<2> local_size(1, GS); -+ -+ auto cgf = [&](sycl::handler& handle) { -+ handle.parallel_for( -+ sycl::nd_range<2>(global_size, local_size), -+ [=](sycl::nd_item<2> item) SYCL_ESIMD_KERNEL { -+ slm_init(); -+ -+ const int eid = item.get_global_id(0); -+ const int tid = item.get_local_id(1); -+ const int vid = item.get_group(1) * VS; -+ -+ if (indexs[eid] >= 0) { -+ const uint8_t* weight = (const uint8_t *)(qweights[indexs[eid]]); -+ const uint8_t* scales = weight + (int64_t)output_size * nb * BLOCK_SIZE; -+ const IT* input = static_cast(input_ptr) + eid * state_size; -+ IT* output = static_cast(output_ptr) + eid * output_size; -+ -+ const uint8_t * weight_base = weight + nb * BLOCK_SIZE * vid; -+ const uint8_t * scale_base = scales + nb * SCALE_SIZE * vid; -+ -+ simd accvs{}; -+ -+ for (int s = tid; s < nsb; s += GS) { -+ simd xvs = block_load(input + s * SBS * QK); -+ -+ #pragma unroll -+ for (int v = 0; v < VS; ++v) { -+ simd yvs = load_qblocks( -+ weight_base + v * nb * BLOCK_SIZE + s * SBS * BLOCK_SIZE, -+ scale_base + v * nb * SCALE_SIZE + s * SBS * SCALE_SIZE -+ ); -+ -+ #pragma unroll -+ for (int i = 0; i < SBS * QK; i += ES) { -+ accvs.template select(v * ES) += -+ xvs.template select(i) * -+ yvs.template select(i); -+ } -+ } -+ } -+ -+ for (int b = nsb * SBS + tid; b < nb; b += GS) { -+ simd xv = block_load(input + b * QK); -+ -+ #pragma unroll -+ for (int v = 0; v < VS; ++v) { -+ simd yv = load_qblock( -+ weight_base + v * nb * BLOCK_SIZE + b * BLOCK_SIZE, -+ scale_base + v * nb * SCALE_SIZE + b * SCALE_SIZE -+ ); -+ -+ #pragma unroll -+ for (int i = 0; i < QK; i += ES) { -+ accvs.template select(v * ES) += -+ xv.template select(i) * -+ yv.template select(i); -+ } -+ } -+ } -+ -+ simd accs; -+ #pragma unroll -+ for(int v = 0; v < VS; ++v) { -+ accs[v] = sycl::ext::intel::esimd::detail::sum( -+ accvs.template select(v * ES) -+ ); -+ } -+ -+ slm_block_store(tid * VS * sizeof(float), accs); -+ -+ barrier(); -+ -+ if (tid == 0) { -+ #pragma unroll -+ for (int i = 1; i < GS; ++i) { -+ accs += slm_block_load(i * VS * sizeof(float)); -+ } -+ -+ block_store(output + vid, accs); -+ } -+ } -+ -+ -+ } -+ ); -+ }; ++ if isinstance(self.config.vision_config, dict): ++ vision_config = DotsVisionConfig(**self.config.vision_config) ++ self.config.vision_config = vision_config ++ else: ++ vision_config = self.config.vision_config + -+ utils::submit_kernel(cgf, device, "moe forward down kernel"); -+} ++ self.vision_tower = DotsVisionTransformer( ++ vision_config, ++ quant_config=self.quant_config, ++ prefix=maybe_prefix(prefix, "vision_tower"), ++ ) ++ self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( ++ vllm_config=vllm_config, ++ hf_config=self.config, ++ prefix=maybe_prefix(prefix, "language_model"), ++ architectures=["Qwen2ForCausalLM"], ++ ) + ++ def _validate_and_reshape_mm_tensor(self, mm_input: object, ++ name: str) -> torch.Tensor: ++ if not isinstance(mm_input, (torch.Tensor, list)): ++ raise ValueError(f"Incorrect type of {name}. " ++ f"Got type: {type(mm_input)}") ++ if isinstance(mm_input, torch.Tensor): ++ if mm_input.ndim == 2: ++ return mm_input ++ if mm_input.ndim != 3: ++ raise ValueError(f"{name} should be 2D or batched 3D tensor. " ++ f"Got ndim: {mm_input.ndim} " ++ f"(shape={mm_input.shape})") ++ return torch.concat(list(mm_input)) ++ else: ++ return torch.concat(mm_input) + -+template -+static auto dispatch_moe_forward(ST scalar_t) { -+ switch (scalar_t) { -+ case ST::Float: return std::make_tuple(moe_forward_kernel); -+ case ST::Half: return std::make_tuple(moe_forward_kernel); -+ default: throw std::runtime_error("unsupported dtype, only fp32 and fp16 are supported"); -+ } -+} ++ def _parse_and_validate_image_input( ++ self, **kwargs: object) -> Optional[DotsOCRImageInputs]: ++ pixel_values = kwargs.pop("pixel_values", None) ++ image_embeds = kwargs.pop("image_embeds", None) ++ image_grid_thw = kwargs.pop("image_grid_thw", None) + ++ if pixel_values is None and image_embeds is None: ++ return None + -+torch::Tensor moe_forward( -+ torch::Tensor input, -+ torch::Tensor indexs, -+ torch::Tensor qweights_attr, -+ int64_t state_size, -+ int64_t output_size, -+ int64_t qtype -+) { -+ auto [func] = [&] () { -+ switch (qtype) { -+ case GGML_TYPE_Q4_0: -+ return dispatch_moe_forward(input.scalar_type()); -+ case GGML_TYPE_Q4_0_WOQ: -+ return dispatch_moe_forward(input.scalar_type()); -+ case GGML_TYPE_FP8E5: -+ return dispatch_moe_forward(input.scalar_type()); -+ default: throw std::runtime_error("unsupported qtype: " + std::to_string(qtype)); -+ } -+ } (); ++ if pixel_values is not None: ++ pixel_values = self._validate_and_reshape_mm_tensor( ++ pixel_values, "image pixel values") ++ image_grid_thw = self._validate_and_reshape_mm_tensor( ++ image_grid_thw, "image grid_thw") + -+ int64_t num_tokens = indexs.numel(); ++ if not isinstance(pixel_values, (torch.Tensor, list)): ++ raise ValueError("Incorrect type of image pixel values. " ++ f"Got type: {type(pixel_values)}") + -+ torch::Tensor output = torch::zeros({num_tokens, output_size}, -+ torch::device(input.device()).dtype(input.dtype())); ++ return DotsOCRImagePixelInputs(type="pixel_values", ++ pixel_values=pixel_values, ++ image_grid_thw=image_grid_thw) + -+ func( -+ input.data_ptr(), indexs.data_ptr(), -+ qweights_attr.data_ptr(), output.data_ptr(), -+ num_tokens, state_size, output_size, input.device() -+ ); ++ if image_embeds is not None: ++ image_embeds = self._validate_and_reshape_mm_tensor( ++ image_embeds, "image embeds") ++ image_grid_thw = self._validate_and_reshape_mm_tensor( ++ image_grid_thw, "image grid_thw") + -+ return output; -+} ++ if not isinstance(image_embeds, torch.Tensor): ++ raise ValueError("Incorrect type of image embeddings. " ++ f"Got type: {type(image_embeds)}") ++ return DotsOCRImageEmbeddingInputs(type="image_embeds", ++ image_embeds=image_embeds, ++ image_grid_thw=image_grid_thw) + ++ def _process_image_input( ++ self, image_input: DotsOCRImageInputs) -> tuple[torch.Tensor, ...]: ++ grid_thw = image_input["image_grid_thw"] ++ assert grid_thw.ndim == 2 ++ grid_thw_list = grid_thw.tolist() + -+torch::Tensor fused_moe_forward( -+ torch::Tensor input, -+ torch::Tensor indexs, -+ torch::Tensor qweights1_attr, -+ torch::Tensor qweights2_attr, -+ int64_t hidden_size, -+ int64_t intermediate_size, -+ int64_t qtype -+) { -+ auto [gmm_func] = [&] () { -+ switch (qtype) { -+ case GGML_TYPE_Q4_0: -+ return dispatch_moe_forward(input.scalar_type()); -+ case GGML_TYPE_Q4_0_WOQ: -+ return dispatch_moe_forward(input.scalar_type()); -+ case GGML_TYPE_FP8E5: -+ return dispatch_moe_forward(input.scalar_type()); -+ default: throw std::runtime_error("unsupported qtype: " + std::to_string(qtype)); -+ } -+ } (); ++ if image_input["type"] == "image_embeds": ++ image_embeds = image_input["image_embeds"].type( ++ self.vision_tower.dtype) ++ else: ++ pixel_values = image_input["pixel_values"].type( ++ self.vision_tower.dtype) ++ image_embeds = self.vision_tower( ++ pixel_values, grid_thw)[:, :self.config.hidden_size] + -+ int64_t num_tokens = indexs.numel(); ++ # Split concatenated embeddings for each image item. ++ merge_size = self.vision_tower.spatial_merge_size ++ sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // ++ (merge_size * merge_size)).tolist() + -+ torch::Tensor w1_output = torch::zeros({num_tokens, intermediate_size * 2}, -+ torch::device(input.device()).dtype(input.dtype())); -+ -+ torch::Tensor tmp = torch::zeros({num_tokens, intermediate_size}, -+ torch::device(input.device()).dtype(input.dtype())); -+ -+ torch::Tensor w2_output = torch::zeros({num_tokens, hidden_size}, -+ torch::device(input.device()).dtype(input.dtype())); ++ return image_embeds.split(sizes) + -+ gmm_func( -+ input.data_ptr(), indexs.data_ptr(), -+ qweights1_attr.data_ptr(), w1_output.data_ptr(), -+ num_tokens, hidden_size, intermediate_size * 2, input.device() -+ ); ++ def get_language_model(self) -> torch.nn.Module: ++ return self.language_model + -+ _silu_and_mul(tmp, w1_output); ++ def get_multimodal_embeddings( ++ self, **kwargs: object) -> Optional[MultiModalEmbeddings]: ++ image_input = self._parse_and_validate_image_input(**kwargs) ++ if image_input is None: ++ return [] ++ vision_embeddings = self._process_image_input(image_input) ++ return vision_embeddings + -+ gmm_func( -+ tmp.data_ptr(), indexs.data_ptr(), -+ qweights2_attr.data_ptr(), w2_output.data_ptr(), -+ num_tokens, intermediate_size, hidden_size, input.device() -+ ); ++ def get_input_embeddings( ++ self, ++ input_ids: torch.Tensor, ++ multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ++ ) -> torch.Tensor: ++ inputs_embeds = self.language_model.get_input_embeddings(input_ids) ++ if multimodal_embeddings is not None: ++ inputs_embeds = merge_multimodal_embeddings( ++ input_ids, ++ inputs_embeds, ++ multimodal_embeddings, ++ self.config.image_token_id, ++ ) + -+ return w2_output; -+} -diff --git a/csrc/xpu/gemm_kernels_xpu.cpp b/csrc/xpu/gemm_kernels_xpu.cpp -new file mode 100644 -index 000000000..d96aa5880 ---- /dev/null -+++ b/csrc/xpu/gemm_kernels_xpu.cpp -@@ -0,0 +1,125 @@ -+/* -+Adapted from https://github.com/mit-han-lab/llm-awq -+@article{lin2023awq, -+ title={AWQ: Activation-aware Weight Quantization for LLM Compression and -+Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, -+Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} -+} -+ */ -+ -+#include -+#include -+#include -+//#include -+#include "dequantize.h" -+#include "utils.h" -+#include "xpu_types.h" -+ -+void awq_dequantize_impl( -+ int* __restrict__ input, -+ sycl::half* __restrict__ scaling_factors, -+ int* __restrict__ zeros, -+ sycl::half* __restrict__ output, -+ int G, -+ sycl::nd_item<3> item_ct1) { -+ int j_factors1 = 4; -+ int row_stride2 = 4; -+ int split_k_iters = 1; -+ sycl::half2 ZERO_HALF2{0, 0}; -+ sycl::half input_shared[8]; -+ -+ int N = item_ct1.get_local_range(2) * item_ct1.get_group_range(2); -+ int col = item_ct1.get_group(2) * item_ct1.get_local_range(2) + -+ item_ct1.get_local_id(2); -+ int row = item_ct1.get_group(1) * item_ct1.get_local_range(1) + -+ item_ct1.get_local_id(1); -+ int index1 = 8 * col + 8 * row * N; -+ sycl::half* output_ptr2 = output + index1; -+ -+ int index2 = col + row * N; -+ int* input_ptr2 = input + index2; -+ -+ int index3 = col + (int)(row / G) * N; -+ int* zeros_ptr2 = zeros + index3; -+ int index4 = 8 * col + (int)(row / G) * N * 8; -+ sycl::half* scale_loaded = scaling_factors + index4; -+ -+ uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); -+ sycl::uint4 zero_loaded_u4 = vllm::awq::dequantize_s4_to_fp16x2(zeros_loaded); -+ // sycl::uint4 scale_loaded_u4 = *(sycl::uint4*)(scaling_factors_ptr2); -+ // int j = 0; -+ -+ uint32_t input_loaded = *(uint32_t*)(input_ptr2); -+ sycl::uint4 input_loaded_fp16 = -+ vllm::awq::dequantize_s4_to_fp16x2(input_loaded); -+ -+ sycl::half2* input_loaded_h2 = (sycl::half2*)(&input_loaded_fp16); -+ sycl::half2* zero_loaded_h2 = (sycl::half2*)(&zero_loaded_u4); -+ sycl::half2* scale_loaded_h2 = (sycl::half2*)scale_loaded; -+ for (int i = 0; i < 4; i++) { -+ input_loaded_h2[i] = sycl_half_sub2(input_loaded_h2[i], zero_loaded_h2[i]); -+ input_loaded_h2[i] = -+ sycl_half_fma2(input_loaded_h2[i], scale_loaded_h2[i], ZERO_HALF2); -+ } -+ *(sycl::uint4*)(input_shared) = input_loaded_fp16; -+ -+ for (int i = 0; i < 8; ++i) { -+ *(output_ptr2 + i) = input_shared[i]; -+ } -+} ++ return inputs_embeds + -+torch::Tensor awq_dequantize( -+ torch::Tensor _kernel, -+ torch::Tensor _scaling_factors, -+ torch::Tensor _zeros, -+ int split_k_iters, -+ int thx, -+ int thy) { -+ int in_c = _kernel.size(0); -+ int qout_c = _kernel.size(1); -+ int out_c = qout_c * 8; -+ int G = in_c / _scaling_factors.size(0); -+ -+ int x_thread = thx; -+ int y_thread = thy; -+ -+ int x_blocks = 1; -+ int y_blocks = 1; -+ if (thx == 0) { -+ x_thread = qout_c; -+ } -+ if (thy == 0) { -+ y_thread = in_c; -+ } -+ if (thx == 0 && thy == 0) { -+ x_thread = 8; -+ y_thread = 8; -+ x_blocks = (int)(qout_c / 8); -+ y_blocks = (int)(in_c / 8); -+ } -+ -+ auto options = torch::TensorOptions() -+ .dtype(_scaling_factors.dtype()) -+ .device(_scaling_factors.device()); -+ at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); -+ auto kernel = reinterpret_cast(_kernel.data_ptr()); -+ auto de_kernel = -+ reinterpret_cast(_de_kernel.data_ptr()); -+ auto scaling_factors = -+ reinterpret_cast(_scaling_factors.data_ptr()); -+ auto zeros = reinterpret_cast(_zeros.data_ptr()); -+ -+ sycl::range<3> num_blocks(1, y_blocks, x_blocks); -+ sycl::range<3> threads_per_block(1, y_thread, x_thread); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(num_blocks * threads_per_block, threads_per_block), -+ [=](sycl::nd_item<3> item_ct1) { -+ awq_dequantize_impl( -+ kernel, scaling_factors, zeros, de_kernel, G, item_ct1); -+ }); -+ }); -+ return _de_kernel; -+} -\ No newline at end of file -diff --git a/csrc/xpu/kv.h b/csrc/xpu/kv.h -new file mode 100644 -index 000000000..9616ad7ef ---- /dev/null -+++ b/csrc/xpu/kv.h -@@ -0,0 +1,76 @@ -+#pragma once -+ -+#include -+#include -+ -+using fp16 = sycl::half; -+ -+constexpr uint8_t FP16_EXP_OFFSET = 15; -+constexpr uint8_t K_EXP_OFFSET = 9; -+constexpr uint8_t V_EXP_OFFSET = 12; -+constexpr uint8_t K_OFFSET = (FP16_EXP_OFFSET - K_EXP_OFFSET) << 3; -+constexpr uint8_t V_OFFSET = (FP16_EXP_OFFSET - V_EXP_OFFSET) << 3; -+constexpr uint16_t K_MAX = -+ (uint16_t)0x3FC0 + ((uint16_t)(FP16_EXP_OFFSET - K_EXP_OFFSET) << 10); -+constexpr uint16_t K_MIN = -+ (uint16_t)0x0040 + ((uint16_t)(FP16_EXP_OFFSET - K_EXP_OFFSET) << 10); -+constexpr uint16_t V_MAX = -+ (uint16_t)0x3FC0 + ((uint16_t)(FP16_EXP_OFFSET - V_EXP_OFFSET) << 10); -+constexpr uint16_t V_MIN = -+ (uint16_t)0x0040 + ((uint16_t)(FP16_EXP_OFFSET - V_EXP_OFFSET) << 10); -+ -+template -+ESIMD_INLINE __ESIMD_NS::simd quantize_key_row( -+ __ESIMD_NS::simd key_row) { -+ const __ESIMD_NS::simd kmax = sycl::bit_cast(K_MAX); -+ const __ESIMD_NS::simd kmin = sycl::bit_cast(K_MIN); -+ __ESIMD_NS::simd key = -+ __ESIMD_NS::max(__ESIMD_NS::min(__ESIMD_NS::abs(key_row), kmax), kmin); -+ key.template bit_cast_view() <<= 1; -+ __ESIMD_NS::simd sign = -+ key_row.template bit_cast_view().template select(1) & -+ (uint8_t)0x80; -+ return (key.template bit_cast_view().template select(1) - -+ K_OFFSET) | -+ sign; -+} ++ def forward( ++ self, ++ input_ids: Optional[torch.Tensor], ++ positions: torch.Tensor, ++ intermediate_tensors: Optional[IntermediateTensors] = None, ++ inputs_embeds: Optional[torch.Tensor] = None, ++ **kwargs, ++ ) -> Union[torch.Tensor, IntermediateTensors]: ++ if intermediate_tensors is not None: ++ inputs_embeds = None ++ elif inputs_embeds is None and kwargs.get("pixel_values") is not None: ++ image_input = self._parse_and_validate_image_input(**kwargs) ++ if image_input is None: ++ inputs_embeds = None ++ else: ++ assert input_ids is not None ++ inputs_embeds = self.get_multimodal_embeddings( ++ input_ids, ++ image_input=image_input, ++ ) ++ input_ids = None + -+template -+ESIMD_INLINE __ESIMD_NS::simd quantize_value_row( -+ __ESIMD_NS::simd value_row) { -+ const __ESIMD_NS::simd vmax = sycl::bit_cast(V_MAX); -+ const __ESIMD_NS::simd vmin = sycl::bit_cast(V_MIN); -+ __ESIMD_NS::simd value = -+ __ESIMD_NS::max(__ESIMD_NS::min(__ESIMD_NS::abs(value_row), vmax), vmin); -+ value.template bit_cast_view() <<= 1; -+ __ESIMD_NS::simd sign = -+ value_row.template bit_cast_view().template select(1) & -+ (uint8_t)0x80; -+ return (value.template bit_cast_view().template select(1) - -+ V_OFFSET) | -+ sign; -+} ++ hidden_states = self.language_model( ++ input_ids=input_ids, ++ positions=positions, ++ intermediate_tensors=intermediate_tensors, ++ inputs_embeds=inputs_embeds, ++ ) + -+template -+ESIMD_INLINE __ESIMD_NS::simd dequantize_key_row( -+ const __ESIMD_NS::simd& key_row) { -+ __ESIMD_NS::simd result = 0x80; -+ result.template bit_cast_view().template select(1) = -+ (key_row & (uint8_t)0x7F) + K_OFFSET; -+ result >>= 1; -+ __ESIMD_NS::simd sign = key_row & (uint8_t)0x80; -+ result.template bit_cast_view().template select(1) |= sign; -+ return result.template bit_cast_view(); -+} ++ return hidden_states + -+template -+ESIMD_INLINE __ESIMD_NS::simd dequantize_value_row( -+ const __ESIMD_NS::simd& value_row) { -+ __ESIMD_NS::simd result = 0x80; -+ result.template bit_cast_view().template select(1) = -+ (value_row & (uint8_t)0x7F) + V_OFFSET; -+ result >>= 1; -+ __ESIMD_NS::simd sign = value_row & (uint8_t)0x80; -+ result.template bit_cast_view().template select(1) |= sign; -+ return result.template bit_cast_view(); -+} -\ No newline at end of file -diff --git a/csrc/xpu/layernorm_xpu.cpp b/csrc/xpu/layernorm_xpu.cpp -new file mode 100644 -index 000000000..9a6a2af0a ---- /dev/null -+++ b/csrc/xpu/layernorm_xpu.cpp -@@ -0,0 +1,188 @@ -+// clang-format off -+#ifdef VLLM_DEV -+#undef __SYCL_DEVICE_ONLY__ -+#endif -+#include -+#include -+ -+#include -+#include -+#include "utils.h" -+#include "xpu_types.h" -+#include "reduction_utils.h" -+ -+namespace vllm { -+ -+template -+void rms_norm_kernel( -+ scalar_t* __restrict__ out, // [..., hidden_size] -+ const scalar_t* __restrict__ input, // [..., hidden_size] -+ const scalar_t* __restrict__ weight, // [hidden_size] -+ const float epsilon, -+ const int num_tokens, -+ const int hidden_size, -+ const sycl::nd_item<3>& item_ct1, -+ float* s_variance, -+ float* shared_vals) { -+ float variance = 0.0f; -+ -+ for (int idx = item_ct1.get_local_id(2); idx < hidden_size; -+ idx += item_ct1.get_local_range(2)) { -+ const float x = (float)input[item_ct1.get_group(2) * hidden_size + idx]; -+ variance += x * x; -+ } -+ -+ variance = blockReduceSum(variance, item_ct1, shared_vals); -+ if (item_ct1.get_local_id(2) == 0) { -+ *s_variance = sycl::rsqrt(variance / hidden_size + epsilon); -+ } -+ -+ // item_ct1.barrier(); -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ -+ for (int idx = item_ct1.get_local_id(2); idx < hidden_size; -+ idx += item_ct1.get_local_range(2)) { -+ float x = (float)input[item_ct1.get_group(2) * hidden_size + idx]; -+ out[item_ct1.get_group(2) * hidden_size + idx] = -+ ((scalar_t)(x * (*s_variance))) * weight[idx]; -+ } -+} ++ # def compute_logits( ++ # self, ++ # hidden_states: torch.Tensor, ++ # ) -> Optional[torch.Tensor]: ++ # return self.language_model.compute_logits(hidden_states) + -+template -+void call_rms_norm_kernel( -+ torch::Tensor& out, -+ torch::Tensor& input, -+ torch::Tensor& weight, -+ float epsilon) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ int hidden_size = input.size(-1); -+ int num_tokens = input.numel() / hidden_size; -+ auto out_ptr = out.data_ptr(); -+ auto input_ptr = input.data_ptr(); -+ auto weight_ptr = weight.data_ptr(); -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(hidden_size, 1024)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ sycl::local_accessor shared_vals( sycl::range<1>(32), cgh); -+ sycl::local_accessor s_variance( sycl::range<1>(1), cgh); -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), -+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { -+ rms_norm_kernel( -+ (sycl_t*)out_ptr, -+ (const sycl_t*)input_ptr, -+ (const sycl_t*)weight_ptr, -+ epsilon, -+ num_tokens, -+ hidden_size, -+ item_ct1, -+ s_variance.get_pointer(), -+ shared_vals.get_pointer()); -+ }); -+ }); -+} ++ from vllm.v1.sample.metadata import SamplingMetadata ++ def compute_logits( ++ self, ++ hidden_states: torch.Tensor, ++ sampling_metadata: Optional[SamplingMetadata] = None, ++ ) -> Optional[torch.Tensor]: ++ return self.language_model.compute_logits(hidden_states, sampling_metadata) + ++ def load_weights(self, weights: Iterable[tuple[str, ++ torch.Tensor]]) -> set[str]: ++ loader = AutoWeightsLoader(self) ++ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) +diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py +index 97aace5a2..bcff65a71 100644 +--- a/vllm/model_executor/models/ernie45_vl.py ++++ b/vllm/model_executor/models/ernie45_vl.py +@@ -34,6 +34,7 @@ import torch.nn.functional as F + from einops import rearrange, repeat + from transformers import BatchFeature + ++from vllm.attention.layer import check_upstream_fa_availability + from vllm.config import VllmConfig + from vllm.distributed import parallel_state + from vllm.distributed import utils as dist_utils +@@ -170,7 +171,16 @@ class Ernie4_5_VisionAttention(nn.Module): + prefix=f"{prefix}.proj") + + # Detect attention implementation. +- self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) ++ self.attn_backend = get_vit_attn_backend( ++ head_size=self.hidden_size_per_attention_head, ++ dtype=torch.get_default_dtype()) + -+template -+void fused_add_rms_norm_kernel( -+ scalar_t* __restrict__ input, // [..., hidden_size] -+ scalar_t* __restrict__ residual, // [..., hidden_size] -+ const scalar_t* __restrict__ weight, // [hidden_size] -+ const float epsilon, -+ const int num_tokens, -+ const int hidden_size, -+ const sycl::nd_item<3>& item_ct1, -+ float* s_variance, -+ float* shared_vals) { -+ float variance = 0.0f; -+ -+ for (int idx = item_ct1.get_local_id(2); idx < hidden_size; -+ idx += item_ct1.get_local_range(2)) { -+ float x = (float)input[item_ct1.get_group(2) * hidden_size + idx]; -+ x+=(float)residual[item_ct1.get_group(2) * hidden_size + idx]; -+ variance += x * x; -+ residual[item_ct1.get_group(2) * hidden_size + idx] = (scalar_t)x; -+ } -+ -+ variance = blockReduceSum(variance, item_ct1, shared_vals); -+ if (item_ct1.get_local_id(2) == 0) { -+ *s_variance = sycl::rsqrt(variance / hidden_size + epsilon); -+ } -+ -+ // item_ct1.barrier(); -+ item_ct1.barrier(sycl::access::fence_space::local_space); -+ -+ for (int idx = item_ct1.get_local_id(2); idx < hidden_size; -+ idx += item_ct1.get_local_range(2)) { -+ float x = (float)residual[item_ct1.get_group(2) * hidden_size + idx]; -+ input[item_ct1.get_group(2) * hidden_size + idx] = -+ ((scalar_t)(x * (*s_variance))) * weight[idx]; -+ } -+} ++ self.use_upstream_fa = False ++ if self.attn_backend != _Backend.FLASH_ATTN and \ ++ check_upstream_fa_availability(torch.get_default_dtype()): ++ self.attn_backend = _Backend.FLASH_ATTN ++ self.use_upstream_fa = True + -+template -+void call_fused_add_rms_norm_kernel( -+ torch::Tensor& input, -+ torch::Tensor& residual, -+ torch::Tensor& weight, -+ float epsilon){ -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ int hidden_size = input.size(-1); -+ int num_tokens = input.numel() / hidden_size; -+ auto input_ptr = input.data_ptr(); -+ auto residual_ptr = residual.data_ptr(); -+ auto weight_ptr = weight.data_ptr(); -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(hidden_size, 1024)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ queue.submit([&](sycl::handler& cgh) { -+ sycl::local_accessor shared_vals( sycl::range<1>(32), cgh); -+ sycl::local_accessor s_variance( sycl::range<1>(1), cgh); -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1)[[intel::reqd_sub_group_size(32)]] { -+ fused_add_rms_norm_kernel( -+ (sycl_t*)input_ptr, -+ (sycl_t*)residual_ptr, -+ (const sycl_t*)weight_ptr, -+ epsilon, -+ num_tokens, -+ hidden_size, -+ item_ct1, -+ s_variance.get_pointer(), -+ shared_vals.get_pointer()); -+ }); -+ }); -+} + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA +@@ -233,7 +243,10 @@ class Ernie4_5_VisionAttention(nn.Module): + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: +- from flash_attn import flash_attn_varlen_func ++ if self.use_upstream_fa: ++ from flash_attn import flash_attn_varlen_func ++ else: ++ from vllm.vllm_flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + +@@ -457,7 +470,11 @@ class Ernie4_5_VisionTransformer(nn.Module): + ), "vit's config.hidden must be equal to config.embed_dim" + self.ln = nn.LayerNorm(hidden_size, eps=1e-6) + +- self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) ++ self.attn_backend = get_vit_attn_backend( ++ head_size=head_dim, dtype=torch.get_default_dtype()) ++ if self.attn_backend != _Backend.FLASH_ATTN and \ ++ check_upstream_fa_availability(torch.get_default_dtype()): ++ self.attn_backend = _Backend.FLASH_ATTN + + @property + def dtype(self) -> torch.dtype: +diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py +index 539381b61..279f458df 100644 +--- a/vllm/model_executor/models/glm4_1v.py ++++ b/vllm/model_executor/models/glm4_1v.py +@@ -44,6 +44,7 @@ from transformers.models.glm4v.video_processing_glm4v import ( + Glm4vVideoProcessor) + from transformers.video_utils import VideoMetadata + ++from vllm.attention.layer import check_upstream_fa_availability + from vllm.config import VllmConfig + from vllm.distributed import (get_tensor_model_parallel_world_size, + parallel_state) +@@ -260,7 +261,15 @@ class Glm4vVisionAttention(nn.Module): + ) + + # Detect attention implementation. +- self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) ++ self.attn_backend = get_vit_attn_backend( ++ head_size=self.hidden_size_per_attention_head, ++ dtype=torch.get_default_dtype()) ++ self.use_upstream_fa = False ++ if self.attn_backend != _Backend.FLASH_ATTN and \ ++ check_upstream_fa_availability(torch.get_default_dtype()): ++ self.attn_backend = _Backend.FLASH_ATTN ++ self.use_upstream_fa = True + -+} // namespace vllm -+ -+void rms_norm( -+ torch::Tensor& out, -+ torch::Tensor& input, -+ torch::Tensor& weight, -+ float epsilon) { -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ input.scalar_type(), "call_rms_norm_kernel", [&] { -+ vllm::call_rms_norm_kernel(out, input, weight, epsilon); -+ }); -+} + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, +@@ -310,7 +319,10 @@ class Glm4vVisionAttention(nn.Module): + if self.attn_backend == _Backend.FLASH_ATTN: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) +- from flash_attn import flash_attn_varlen_func ++ if self.use_upstream_fa: ++ from flash_attn import flash_attn_varlen_func ++ else: ++ from vllm.vllm_flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + +@@ -715,7 +727,11 @@ class Glm4vVisionTransformer(nn.Module): + self.post_layernorm = RMSNorm(vision_config.hidden_size, + eps=vision_config.rms_norm_eps) + +- self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) ++ self.attn_backend = get_vit_attn_backend( ++ head_size=head_dim, dtype=torch.get_default_dtype()) ++ if self.attn_backend != _Backend.FLASH_ATTN and \ ++ check_upstream_fa_availability(torch.get_default_dtype()): ++ self.attn_backend = _Backend.FLASH_ATTN + + @property + def dtype(self) -> torch.dtype: +diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py +index e0b4df772..d85d30d91 100644 +--- a/vllm/model_executor/models/gpt_oss.py ++++ b/vllm/model_executor/models/gpt_oss.py +@@ -311,9 +311,6 @@ class GptOssModel(nn.Module): + if is_pp_missing_parameter(name, self): + continue + +- # FIXME(woosuk): Remove this after testing. +- weight = weight.cuda() +- + if ".w13_weight_scale" in name: + # Handle MLP gate and up projection weights scale + if use_ep: +diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py +index 710b805ac..04824db1b 100644 +--- a/vllm/model_executor/models/keye.py ++++ b/vllm/model_executor/models/keye.py +@@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) + from transformers.utils import torch_int + ++from vllm.attention.layer import check_upstream_fa_availability + from vllm.config import VllmConfig + from vllm.distributed import get_tensor_model_parallel_world_size + from vllm.logger import init_logger +@@ -374,7 +375,16 @@ class KeyeSiglipAttention(nn.Module): + ) + + # Detect attention implementation. +- self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) ++ self.attn_backend = get_vit_attn_backend( ++ head_size=self.head_dim, dtype=torch.get_default_dtype()) + -+void fused_add_rms_norm( -+ torch::Tensor& input, -+ torch::Tensor& residual, -+ torch::Tensor& weight, -+ float epsilon) { -+ int hidden_size = input.size(-1); -+ int num_tokens = input.numel() / hidden_size; -+ -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ input.scalar_type(), "call_fused_add_rms_norm_kernel", [&] { -+ vllm::call_fused_add_rms_norm_kernel( -+ input, -+ residual, -+ weight, -+ epsilon); -+ }); -+} ++ self.use_upstream_fa = False ++ if self.attn_backend != _Backend.FLASH_ATTN and \ ++ check_upstream_fa_availability( ++ torch.get_default_dtype()): ++ self.attn_backend = _Backend.FLASH_ATTN ++ self.use_upstream_fa = True + -diff --git a/csrc/xpu/pos_encoding_xpu.cpp b/csrc/xpu/pos_encoding_xpu.cpp -new file mode 100644 -index 000000000..3232cacbc ---- /dev/null -+++ b/csrc/xpu/pos_encoding_xpu.cpp -@@ -0,0 +1,333 @@ -+// clang-format off -+#ifdef VLLM_DEV -+#undef __SYCL_DEVICE_ONLY__ -+#endif -+#include -+// clang-format on -+#include "xpu_types.h" -+ -+#include -+#include "utils.h" -+ -+template -+inline void apply_rotary_embedding( -+ scalar_t* __restrict__ arr, -+ const scalar_t* __restrict__ cos_ptr, -+ const scalar_t* __restrict__ sin_ptr, -+ int rot_offset, -+ int embed_dim) { -+ int x_index, y_index; -+ scalar_t cos, sin; -+ if (IS_NEOX) { -+ // GPT-NeoX style rotary embedding. -+ x_index = rot_offset; -+ y_index = embed_dim + rot_offset; -+ cos = VLLM_LDG(cos_ptr + x_index); -+ sin = VLLM_LDG(sin_ptr + x_index); -+ } else { -+ // GPT-J style rotary embedding. -+ x_index = 2 * rot_offset; -+ y_index = 2 * rot_offset + 1; -+ cos = VLLM_LDG(cos_ptr + x_index / 2); -+ sin = VLLM_LDG(sin_ptr + x_index / 2); -+ } -+ -+ const scalar_t x = arr[x_index]; -+ const scalar_t y = arr[y_index]; -+ arr[x_index] = x * cos - y * sin; -+ arr[y_index] = y * cos + x * sin; -+} -+ -+template -+void rotary_embedding_kernel( -+ const int64_t* __restrict__ positions, // [batch_size, seq_len] or -+ // [num_tokens] -+ scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] -+ // or [num_tokens, num_heads, head_size] -+ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, -+ // head_size] or [num_tokens, num_kv_heads, -+ // head_size] -+ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // -+ // 2] -+ const int rot_dim, -+ const int query_stride, -+ const int key_stride, -+ const int num_heads, -+ const int num_kv_heads, -+ const int head_size, -+ const sycl::nd_item<3>& item_ct1) { -+ // Each thread block is responsible for one token. -+ const int token_idx = item_ct1.get_group(2); -+ int64_t pos = positions[token_idx]; -+ const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; -+ -+ const int embed_dim = rot_dim / 2; -+ const scalar_t* cos_ptr = cache_ptr; -+ const scalar_t* sin_ptr = cache_ptr + embed_dim; -+ -+ const int nq = num_heads * embed_dim; -+ for (int i = item_ct1.get_local_id(2); i < nq; -+ i += item_ct1.get_local_range(2)) { -+ const int head_idx = i / embed_dim; -+ const int token_head = token_idx * query_stride + head_idx * head_size; -+ const int rot_offset = i % embed_dim; -+ apply_rotary_embedding( -+ query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); -+ } -+ -+ const int nk = num_kv_heads * embed_dim; -+ for (int i = item_ct1.get_local_id(2); i < nk; -+ i += item_ct1.get_local_range(2)) { -+ const int head_idx = i / embed_dim; -+ const int token_head = token_idx * key_stride + head_idx * head_size; -+ const int rot_offset = i % embed_dim; -+ apply_rotary_embedding( -+ key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); -+ } -+} -+ -+template -+void batched_rotary_embedding_kernel( -+ const int64_t* __restrict__ positions, // [batch_size, seq_len] or -+ // [num_tokens] -+ scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] -+ // or [num_tokens, num_heads, head_size] -+ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, -+ // head_size] or [num_tokens, num_kv_heads, -+ // head_size] -+ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // -+ // 2] -+ const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] -+ const int rot_dim, -+ const int query_stride, -+ const int key_stride, -+ const int num_heads, -+ const int num_kv_heads, -+ const int head_size, -+ const sycl::nd_item<3>& item_ct1) { -+ // Each thread block is responsible for one token. -+ const int token_idx = item_ct1.get_group(2); -+ int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; -+ int64_t pos = positions[token_idx]; -+ const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; -+ -+ const int embed_dim = rot_dim / 2; -+ const scalar_t* cos_ptr = cache_ptr; -+ const scalar_t* sin_ptr = cache_ptr + embed_dim; -+ -+ const int nq = num_heads * embed_dim; -+ for (int i = item_ct1.get_local_id(2); i < nq; -+ i += item_ct1.get_local_range(2)) { -+ const int head_idx = i / embed_dim; -+ const int token_head = token_idx * query_stride + head_idx * head_size; -+ const int rot_offset = i % embed_dim; -+ apply_rotary_embedding( -+ query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); -+ } -+ -+ const int nk = num_kv_heads * embed_dim; -+ for (int i = item_ct1.get_local_id(2); i < nk; -+ i += item_ct1.get_local_range(2)) { -+ const int head_idx = i / embed_dim; -+ const int token_head = token_idx * key_stride + head_idx * head_size; -+ const int rot_offset = i % embed_dim; -+ apply_rotary_embedding( -+ key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); -+ } -+} -+ -+template -+void call_rotary_embedding_kernel( -+ const int64_t* __restrict__ positions, // [num_tokens] -+ scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] -+ scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] -+ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // -+ // 2] -+ const int rot_dim, -+ const int query_stride, -+ const int key_stride, -+ const int num_heads, -+ const int num_kv_heads, -+ const int head_size, -+ const int num_tokens, -+ const int sin_cos_dim, -+ bool is_neox) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(num_heads * rot_dim / 2, 512)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ if (is_neox) { -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), -+ [=](sycl::nd_item<3> item_ct1) { -+ rotary_embedding_kernel( -+ positions, -+ (sycl_t* __restrict__)query, -+ (sycl_t* __restrict__)key, -+ (const sycl_t* __restrict__)cos_sin_cache, -+ rot_dim, -+ query_stride, -+ key_stride, -+ num_heads, -+ num_kv_heads, -+ head_size, -+ item_ct1); -+ }); -+ }); -+ } else { -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), -+ [=](sycl::nd_item<3> item_ct1) { -+ rotary_embedding_kernel( -+ positions, -+ (sycl_t* __restrict__)query, -+ (sycl_t* __restrict__)key, -+ (const sycl_t* __restrict__)cos_sin_cache, -+ rot_dim, -+ query_stride, -+ key_stride, -+ num_heads, -+ num_kv_heads, -+ head_size, -+ item_ct1); -+ }); -+ }); -+ } -+} -+ -+template -+void call_batched_rotary_embedding_kernel( -+ const int64_t* __restrict__ positions, // [num_tokens] -+ scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] -+ scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] -+ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // -+ // 2] -+ const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] -+ const int rot_dim, -+ const int query_stride, -+ const int key_stride, -+ const int num_heads, -+ const int num_kv_heads, -+ const int head_size, -+ const int num_tokens, -+ const int sin_cos_dim, -+ bool is_neox) { -+ using sycl_t = vllm::xpu::SyclTypeTrait::Type; -+ sycl::range<3> grid(1, 1, num_tokens); -+ sycl::range<3> block(1, 1, std::min(num_heads * rot_dim / 2, 512)); -+ auto& queue = vllm::xpu::vllmGetQueue(); -+ if (is_neox) { -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), -+ [=](sycl::nd_item<3> item_ct1) { -+ batched_rotary_embedding_kernel( -+ positions, -+ (sycl_t* __restrict__)query, -+ (sycl_t* __restrict__)key, -+ (const sycl_t* __restrict__)cos_sin_cache, -+ cos_sin_cache_offsets, -+ rot_dim, -+ query_stride, -+ key_stride, -+ num_heads, -+ num_kv_heads, -+ head_size, -+ item_ct1); -+ }); -+ }); -+ } else { -+ queue.submit([&](sycl::handler& cgh) { -+ cgh.parallel_for( -+ sycl::nd_range<3>(grid * block, block), -+ [=](sycl::nd_item<3> item_ct1) { -+ batched_rotary_embedding_kernel( -+ positions, -+ (sycl_t* __restrict__)query, -+ (sycl_t* __restrict__)key, -+ (const sycl_t* __restrict__)cos_sin_cache, -+ cos_sin_cache_offsets, -+ rot_dim, -+ query_stride, -+ key_stride, -+ num_heads, -+ num_kv_heads, -+ head_size, -+ item_ct1); -+ }); -+ }); -+ } -+} -+ -+void rotary_embedding( -+ torch::Tensor& positions, -+ torch::Tensor& query, -+ torch::Tensor& key, -+ int head_size, -+ torch::Tensor& cos_sin_cache, -+ bool is_neox) { -+ -+ int num_tokens = query.numel() / query.size(-1); -+ int rot_dim = cos_sin_cache.size(1); -+ int num_heads = query.size(-1) / head_size; -+ int num_kv_heads = key.size(-1) / head_size; -+ int key_stride = key.stride(-2); -+ int query_stride = query.stride(-2); -+ int cos_sin_dim = cos_sin_cache.size(0); -+ -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ query.scalar_type(), "call_rotary_embedding_kernel", [&] { -+ call_rotary_embedding_kernel( -+ positions.data_ptr(), -+ query.data_ptr(), -+ key.data_ptr(), -+ cos_sin_cache.data_ptr(), -+ rot_dim, -+ query_stride, -+ key_stride, -+ num_heads, -+ num_kv_heads, -+ head_size, -+ num_tokens, -+ cos_sin_dim, -+ is_neox); -+ }); -+} -+ -+void batched_rotary_embedding( -+ torch::Tensor& positions, -+ torch::Tensor& query, -+ torch::Tensor& key, -+ int head_size, -+ torch::Tensor& cos_sin_cache, -+ bool is_neox, -+ int rot_dim, -+ torch::Tensor& cos_sin_cache_offsets) { -+ int64_t num_tokens = cos_sin_cache_offsets.size(0); -+ int num_heads = query.size(-1) / head_size; -+ int num_kv_heads = key.size(-1) / head_size; -+ int key_stride = key.stride(-2); -+ int query_stride = query.stride(-2); -+ int cos_sin_dim = cos_sin_cache.size(0); -+ -+ VLLM_XPU_DISPATCH_FLOATING_TYPES( -+ query.scalar_type(), "call_batched_rotary_embedding_kernel", [&] { -+ call_batched_rotary_embedding_kernel( -+ positions.data_ptr(), -+ query.data_ptr(), -+ key.data_ptr(), -+ cos_sin_cache.data_ptr(), -+ cos_sin_cache_offsets.data_ptr(), -+ rot_dim, -+ query_stride, -+ key_stride, -+ num_heads, -+ num_kv_heads, -+ head_size, -+ num_tokens, -+ cos_sin_dim, -+ is_neox); -+ }); -+} -\ No newline at end of file -diff --git a/csrc/xpu/pybind.cpp b/csrc/xpu/pybind.cpp -new file mode 100644 -index 000000000..bf9e94612 ---- /dev/null -+++ b/csrc/xpu/pybind.cpp -@@ -0,0 +1,112 @@ -+// #include "cache.h" -+#include "xpu_ops.h" -+#include -+ -+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { -+ // vLLM custom ops -+ pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); -+ -+ // Attention ops -+ ops.def( -+ "paged_attention_v1", -+ &paged_attention_v1, -+ "Compute the attention between an input query and the cached keys/values using PagedAttention."); -+ ops.def( -+ "paged_attention_v2", -+ &paged_attention_v2, -+ "PagedAttention V2."); -+ -+ ops.def("context_attention_forward_v1", &context_attention_forward_v1, -+ "Context attention forward_v1"); -+ -+ ops.def("context_attention_forward_v2", &context_attention_forward_v2, -+ "Context attention forward_v2"); -+ -+ ops.def( -+ "paged_attention_gqa", -+ &paged_attention_gqa, -+ "PagedAttention GQA."); -+ -+ ops.def("paged_attention_gqa_fp8", &paged_attention_gqa_fp8, "PagedAttention GQA fp8."); -+ -+ // Activation ops -+ ops.def( -+ "silu_and_mul", -+ &silu_and_mul, -+ "Activation function used in SwiGLU."); -+ ops.def( -+ "gelu_and_mul", -+ &gelu_and_mul, -+ "Activation function used in GeGLU with `none` approximation."); -+ ops.def( -+ "gelu_tanh_and_mul", -+ &gelu_tanh_and_mul, -+ "Activation function used in GeGLU with `tanh` approximation."); -+ ops.def( -+ "gelu_new", -+ &gelu_new, -+ "GELU implementation used in GPT-2."); -+ ops.def( -+ "gelu_fast", -+ &gelu_fast, -+ "Approximate GELU implementation."); -+ -+ // Layernorm -+ ops.def( -+ "rms_norm", -+ &rms_norm, -+ "Apply Root Mean Square (RMS) Normalization to the input tensor."); -+ -+ ops.def( -+ "fused_add_rms_norm", -+ &fused_add_rms_norm, -+ "In-place fused Add and RMS Normalization"); -+ -+ // Rotary embedding -+ ops.def( -+ "rotary_embedding", -+ &rotary_embedding, -+ "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); -+ -+ // Cache ops -+ pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); -+ cache_ops.def( -+ "swap_blocks", -+ &swap_blocks, -+ "Swap in (out) the cache blocks from src to dst"); -+ cache_ops.def( -+ "copy_blocks", -+ ©_blocks, -+ "Copy the cache blocks from src to dst"); -+ cache_ops.def( -+ "reshape_and_cache", -+ &reshape_and_cache, -+ "Reshape the key and value tensors and cache them"); -+ cache_ops.def( -+ "reshape_and_cache_ipexllm", -+ &reshape_and_cache_ipexllm, -+ "Reshape the key and value tensors and cache them for ipex_llm"); -+ -+ cache_ops.def( -+ "reshape_and_cache_ipexllm_fp8", -+ &reshape_and_cache_ipexllm_fp8, -+ "Reshape the key and value tensors and cache them for ipex_llm with fp8"); -+ -+ // Quant -+ ops.def( -+ "awq_dequantize", -+ &awq_dequantize, -+ "dequant method for awq"); -+ -+ -+ ops.def( -+ "moe_forward", -+ &moe_forward, -+ "PagedAttention GQA."); -+ -+ ops.def( -+ "fused_moe_forward", -+ &fused_moe_forward, -+ "PagedAttention GQA."); -+ -+} -diff --git a/csrc/xpu/reduction_utils.h b/csrc/xpu/reduction_utils.h -new file mode 100644 -index 000000000..93c64d759 ---- /dev/null -+++ b/csrc/xpu/reduction_utils.h -@@ -0,0 +1,56 @@ -+/* -+ * Copyright (c) 2023, The vLLM team. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+#pragma once -+ -+#include -+#include -+#include -+ -+namespace vllm { -+ -+template -+__inline__ T warpReduceSum(T val, const sycl::nd_item<3>& item_ct1) { -+#pragma unroll -+ for (int mask = 16; mask > 0; mask >>= 1) -+ val += dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), val, mask, 32); -+ return val; -+} -+ -+/* Calculate the sum of all elements in a block */ -+template -+__inline__ T blockReduceSum(T val, const sycl::nd_item<3> &item_ct1, T *shared) { -+ -+ int lane = item_ct1.get_local_id(2) & 0x1f; -+ int wid = item_ct1.get_local_id(2) >> 5; -+ -+ val = warpReduceSum(val, item_ct1); -+ -+ if (lane == 0) { -+ shared[wid] = val; -+ } -+ item_ct1.barrier(); -+ -+ // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent -+ // blockDim.x is not divided by 32 -+ val = (item_ct1.get_local_id(2) < (item_ct1.get_local_range(2) / 32.f)) -+ ? shared[lane] -+ : (T)(0.0f); -+ val = warpReduceSum(val, item_ct1); -+ return val; -+} -+ -+} // namespace vllm -\ No newline at end of file -diff --git a/csrc/xpu/utils.cpp b/csrc/xpu/utils.cpp -new file mode 100644 -index 000000000..5f613af55 ---- /dev/null -+++ b/csrc/xpu/utils.cpp -@@ -0,0 +1,34 @@ -+#include "utils.h" -+#include -+ -+sycl::half sycl_half_mul(sycl::half a, sycl::half b) { -+ return sycl::ext::intel::math::hmul(a, b); -+} -+sycl::half sycl_half_add(sycl::half a, sycl::half b) { -+ return sycl::ext::intel::math::hadd(a, b); -+} -+sycl::half sycl_half_sub(sycl::half a, sycl::half b) { -+ return sycl::ext::intel::math::hsub(a, b); -+} -+sycl::half sycl_half_fma(sycl::half a, sycl::half b, sycl::half c) { -+ return sycl::ext::intel::math::hfma(a, b, c); -+} -+ -+sycl::half2 sycl_half_mul2(sycl::half2 a, sycl::half2 b) { -+ return sycl::ext::intel::math::hmul2(a, b); -+} -+sycl::half2 sycl_half_add2(sycl::half2 a, sycl::half2 b) { -+ return sycl::ext::intel::math::hadd2(a, b); -+} -+sycl::half2 sycl_half_sub2(sycl::half2 a, sycl::half2 b) { -+ return sycl::ext::intel::math::hsub2(a, b); -+} -+ -+sycl::half2 sycl_half_fma2(sycl::half2 a, sycl::half2 b, sycl::half2 c) { -+ return sycl::ext::intel::math::hfma2(a, b, c); -+} -+ -+int get_max_shared_memory_per_block_device_attribute(int device_id) { -+ const sycl::device& device = vllm::xpu::vllmGetQueue().get_device(); -+ return device.get_info(); -+} -diff --git a/csrc/xpu/utils.h b/csrc/xpu/utils.h -new file mode 100644 -index 000000000..fa3ead51c ---- /dev/null -+++ b/csrc/xpu/utils.h -@@ -0,0 +1,82 @@ -+#pragma once -+ -+#include -+#include -+#include -+// #include -+#include -+#include -+ -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+#include -+#endif -+ -+ -+#define VLLM_LDG(arg) *(arg) -+namespace vllm { -+namespace xpu { -+ -+static inline sycl::queue& vllmGetQueue() { -+ auto device_type = c10::DeviceType::XPU; -+ c10::impl::VirtualGuardImpl impl(device_type); -+ c10::Stream c10_stream = impl.getStream(c10::Device(device_type)); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ return at::xpu::XPUStream(c10_stream).queue(); -+#else -+ return ::xpu::get_queue_from_stream(c10_stream); -+#endif -+} -+template -+struct SyclTypeTrait{ -+ using Type = T; -+}; -+ -+template <> -+struct SyclTypeTrait{ -+ using Type = sycl::half; -+}; -+ -+template <> -+struct SyclTypeTrait{ -+ using Type = sycl::ext::oneapi::bfloat16; -+}; -+ -+ -+} // namespace xpu -+ -+} // namespace vllm -+ -+SYCL_EXTERNAL sycl::half sycl_half_mul(sycl::half a, sycl::half b); -+SYCL_EXTERNAL sycl::half sycl_half_add(sycl::half a, sycl::half b); -+SYCL_EXTERNAL sycl::half sycl_half_sub(sycl::half a, sycl::half b); -+SYCL_EXTERNAL sycl::half sycl_half_fma(sycl::half a, sycl::half b, sycl::half c); -+ -+SYCL_EXTERNAL sycl::half2 sycl_half_mul2(sycl::half2 a, sycl::half2 b); -+SYCL_EXTERNAL sycl::half2 sycl_half_add2(sycl::half2 a, sycl::half2 b); -+SYCL_EXTERNAL sycl::half2 sycl_half_sub2(sycl::half2 a, sycl::half2 b); -+SYCL_EXTERNAL sycl::half2 sycl_half_fma2(sycl::half2 a, sycl::half2 b, sycl::half2 c); -+ -+int get_max_shared_memory_per_block_device_attribute(int device_id); -+ -+namespace utils { -+ static inline sycl::queue& get_queue(const at::Device& device) { -+ c10::impl::VirtualGuardImpl impl(device.type()); -+ c10::Stream c10_stream = impl.getStream(c10::Device(device)); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ return at::xpu::XPUStream(c10_stream).queue(); -+#else -+ return ::xpu::get_queue_from_stream(c10_stream); -+#endif -+ } -+ -+ static inline sycl::event submit_kernel(std::function kernel, const at::Device& device, const char * desc) { -+ sycl::queue& queue = get_queue(device); -+ sycl::event event = queue.submit(kernel); -+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 -+ // xpu::profiler_record(desc, event); -+#else -+ ::xpu::profiler_record(desc, event); -+#endif -+ return event; -+ } -+} -diff --git a/csrc/xpu/xpu_ops.h b/csrc/xpu/xpu_ops.h -new file mode 100644 -index 000000000..603d4f23d ---- /dev/null -+++ b/csrc/xpu/xpu_ops.h -@@ -0,0 +1,194 @@ -+#pragma once -+#include -+ -+void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, -+ torch::Tensor &key, int head_size, -+ torch::Tensor &cos_sin_cache, bool is_neox); -+void batched_rotary_embedding( -+ torch::Tensor& positions, -+ torch::Tensor& query, -+ torch::Tensor& key, -+ int head_size, -+ torch::Tensor& cos_sin_cache, -+ bool is_neox, -+ int rot_dim, -+ torch::Tensor& cos_sin_cache_offsets); -+ -+void silu_and_mul(torch::Tensor &out, torch::Tensor &input); -+void gelu_and_mul(torch::Tensor &out, torch::Tensor &input); -+ -+void gelu_new(torch::Tensor &out, torch::Tensor &input); -+ -+void gelu_fast(torch::Tensor &out, torch::Tensor &input); -+ -+ -+void gelu_tanh_and_mul( -+ torch::Tensor& out, -+ torch::Tensor& input); -+ -+void paged_attention_v1( -+ torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, -+ torch::Tensor &value_cache, int num_kv_heads, float scale, -+ torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, -+ int max_context_len, const c10::optional &alibi_slopes, -+ const std::string& kv_cache_dtype, const float kv_scale, const float attn_logit_softcapping); -+ -+void paged_attention_v2( -+ torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, -+ torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, -+ torch::Tensor &value_cache, int num_kv_heads, float scale, -+ torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, -+ int max_context_len, const c10::optional &alibi_slopes, -+ const std::string& kv_cache_dtype, const float kv_scale, const float attn_logit_softcapping); -+ -+torch::Tensor context_attention_forward_v1( -+ torch::Tensor query, // [num_tokens, num_kv_head, head_dim] -+ torch::Tensor key, // [num_tokens, num_kv_heads * head_size] -+ torch::Tensor value, // [num_tokens, num_kv_heads * head_size] -+ torch::Tensor block_tables, torch::Tensor query_start_loc, -+ torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, -+ int max_context_length); -+ -+torch::Tensor context_attention_forward_v2( -+ torch::Tensor query, // [num_tokens, num_kv_head, head_dim] -+ torch::Tensor key, // [num_tokens, num_kv_heads * head_size] -+ torch::Tensor value, // [num_tokens, num_kv_heads * head_size] -+ torch::Tensor block_tables, torch::Tensor query_start_loc, -+ torch::Tensor seq_lens, torch::Tensor context_lens, int max_input_length, -+ int max_context_length, int max_q_length); -+ -+void copy_blocks( -+ std::vector &key_caches, -+ std::vector &value_caches, -+ const std::map> &block_mapping); -+ -+void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, -+ torch::Tensor &key_cache, torch::Tensor &value_cache, -+ torch::Tensor &slot_mapping, -+ const std::string& kv_cache_dtype, const float kv_scale); -+void reshape_and_cache_ipexllm(torch::Tensor &key, torch::Tensor &value, -+ torch::Tensor &key_cache, torch::Tensor &value_cache, -+ torch::Tensor &slot_mapping, -+ const std::string& kv_cache_dtype, const float kv_scale); -+ -+void reshape_and_cache_ipexllm_fp8(torch::Tensor& key, torch::Tensor& value, -+ torch::Tensor& key_cache, -+ torch::Tensor& value_cache, -+ torch::Tensor& slot_mapping, -+ const std::string& kv_cache_dtype, -+ const float kv_scale); -+ -+void moe_align_block_size( -+ torch::Tensor topk_ids, -+ int num_experts, -+ int block_size, -+ torch::Tensor sorted_token_ids, -+ torch::Tensor experts_ids, -+ torch::Tensor num_tokens_post_pad) { -+ TORCH_CHECK(false, "moe_align_block_size is not supported on XPU."); -+} -+void swap_blocks(torch::Tensor &src, torch::Tensor &dst, -+ const std::map &block_mapping); -+ -+void gather_cached_kv(torch::Tensor &key, torch::Tensor &value, -+ torch::Tensor &key_cache, torch::Tensor &value_cache, -+ torch::Tensor &slot_mapping); -+ -+void convert_fp8_e5m2(torch::Tensor& src_cache, torch::Tensor& dst_cache) { -+ TORCH_CHECK(false, "Quantization is not supported on XPU."); -+} -+ -+void rms_norm(torch::Tensor &out, torch::Tensor &input, -+ torch::Tensor &weight, float epsilon); -+ -+void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, -+ torch::Tensor &weight, float epsilon); -+ -+torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, -+ torch::Tensor _scaling_factors, torch::Tensor _zeros, -+ int split_k_iters) { -+ TORCH_CHECK(false, "awq_gemm is not supported on XPU."); -+} -+ -+torch::Tensor marlin_gemm( -+ torch::Tensor& a, -+ torch::Tensor& b_q_weight, -+ torch::Tensor& b_scales, -+ torch::Tensor& workspace, -+ int64_t size_m, -+ int64_t size_n, -+ int64_t size_k) { -+ TORCH_CHECK(false, "marlin_gemm is not supported on XPU."); -+} -+ -+torch::Tensor awq_dequantize(torch::Tensor _kernel, -+ torch::Tensor _scaling_factors, -+ torch::Tensor _zeros, -+ int split_k_iters, -+ int thx, -+ int thy); -+ -+void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, -+ torch::Tensor mul, torch::Tensor lookup_table) { -+ TORCH_CHECK(false, "squeezellm_gemm is not supported on XPU."); -+} -+ -+torch::Tensor gptq_gemm( -+ torch::Tensor a, -+ torch::Tensor b_q_weight, -+ torch::Tensor b_gptq_qzeros, -+ torch::Tensor b_gptq_scales, -+ torch::Tensor b_g_idx, -+ bool use_exllama, -+ int bit) { -+ TORCH_CHECK(false, "gptq_gemm is not supported on XPU."); -+} -+ -+void gptq_shuffle( -+ torch::Tensor q_weight, -+ torch::Tensor q_perm, -+ int bit) { -+ TORCH_CHECK(false, "gptq_shuffle is not supported on XPU."); -+} -+ -+void paged_attention_gqa( -+ torch::Tensor output, -+ torch::Tensor query, -+ torch::Tensor key_cache, -+ torch::Tensor value_cache, -+ int64_t bsz, -+ int64_t num_heads, -+ int64_t num_kv_heads, -+ float scale, -+ torch::Tensor& block_tables, -+ torch::Tensor& context_lens, -+ int block_size, -+ int64_t head_dim, -+ int max_seq_len -+); -+ -+ -+torch::Tensor moe_forward( -+ torch::Tensor input, -+ torch::Tensor indexs, -+ torch::Tensor qweights_attr, -+ int64_t state_size, -+ int64_t output_size, -+ int64_t qtype -+); -+ -+torch::Tensor fused_moe_forward( -+ torch::Tensor input, -+ torch::Tensor indexs, -+ torch::Tensor qweights1_attr, -+ torch::Tensor qweights2_attr, -+ int64_t hidden_size, -+ int64_t intermediate_size, -+ int64_t qtype -+); -+void paged_attention_gqa_fp8(torch::Tensor output, torch::Tensor query, -+ torch::Tensor key_cache, torch::Tensor value_cache, -+ int64_t bsz, int64_t num_heads, int64_t num_kv_heads, -+ float scale, torch::Tensor& block_tables, -+ torch::Tensor& context_lens, int block_size, -+ int64_t head_dim, int max_seq_len); -diff --git a/csrc/xpu/xpu_types.h b/csrc/xpu/xpu_types.h -new file mode 100644 -index 000000000..23f5b805c ---- /dev/null -+++ b/csrc/xpu/xpu_types.h -@@ -0,0 +1,25 @@ -+ -+#ifndef XPU_TYPES_H -+#define XPU_TYPES_H -+ -+#include -+ -+// FIXME: FP16 is not fully supported in Torch-CPU -+#define VLLM_XPU_DISPATCH_CASE_FLOATING_TYPES(...) \ -+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ -+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ -+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -+ -+#define VLLM_XPU_DISPATCH_CASE_FLOATING_TYPES_FLOAT_ONLY(...) \ -+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ -+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) -+ -+#define VLLM_XPU_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ -+ AT_DISPATCH_SWITCH( \ -+ TYPE, NAME, VLLM_XPU_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -+ -+#define VLLM_XPU_DISPATCH_FLOATING_TYPES_FLOAT_ONLY(TYPE, NAME, ...) \ -+ AT_DISPATCH_SWITCH( \ -+ TYPE, NAME, VLLM_XPU_DISPATCH_CASE_FLOATING_TYPES_FLOAT_ONLY(__VA_ARGS__)) -+ -+#endif -\ No newline at end of file -diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu -index 7d5a589eb..25a9fd7cd 100644 ---- a/docker/Dockerfile.xpu -+++ b/docker/Dockerfile.xpu -@@ -1,9 +1,10 @@ --# oneapi 2025.0.2 docker base image use rolling 2448 package. https://dgpu-docs.intel.com/releases/packages.html?release=Rolling+2448.13&os=Ubuntu+22.04, and we don't need install driver manually. --FROM intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu22.04 AS vllm-base -+FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base - --RUN rm /etc/apt/sources.list.d/intel-graphics.list -+RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ -+ echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ -+ add-apt-repository -y ppa:kobuk-team/intel-graphics + if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: + raise RuntimeError( + f"Keye-VL does not support {self.attn_backend} backend now.") +@@ -428,7 +438,10 @@ class KeyeSiglipAttention(nn.Module): + ) --RUN apt-get update -y && \ -+RUN apt clean && apt-get update -y && \ - apt-get install -y --no-install-recommends --fix-missing \ - curl \ - ffmpeg \ -@@ -14,15 +15,29 @@ RUN apt-get update -y && \ - libgl1 \ - lsb-release \ - numactl \ -- python3 \ -- python3-dev \ -- python3-pip \ -- wget -+ wget \ -+ vim \ -+ python3.12 \ -+ python3.12-dev \ -+ python3-pip -+ -+RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 -+RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 -+ -+RUN apt install -y libze1=1.23.1-1~24.04~ppa1 libze-dev=1.23.1-1~24.04~ppa1 libze-intel-gpu1=25.27.34303.9-1~24.04~ppa1 intel-opencl-icd=25.27.34303.9-1~24.04~ppa1 libze-intel-gpu-raytracing=1.1.0-114~u24.04 -+ -+RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.4/intel-oneccl-2021.15.4.11_offline.sh -+RUN bash intel-oneccl-2021.15.4.11_offline.sh -a --silent --eula accept && echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc -+SHELL ["bash", "-c"] -+CMD ["bash", "-c", "source /root/.bashrc && exec bash"] + if self.attn_backend == _Backend.FLASH_ATTN: +- from flash_attn import flash_attn_varlen_func ++ if self.use_upstream_fa: ++ from flash_attn import flash_attn_varlen_func ++ else: ++ from vllm.vllm_flash_attn import flash_attn_varlen_func - WORKDIR /workspace/vllm - COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt - COPY requirements/common.txt /workspace/vllm/requirements/common.txt + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) -+# suppress the python externally managed environment error -+RUN python3 -m pip config set global.break-system-packages true -+ - RUN --mount=type=cache,target=/root/.cache/pip \ - pip install --no-cache-dir \ - -r requirements/xpu.txt -@@ -47,10 +62,11 @@ FROM vllm-base AS vllm-openai +diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py +index a1c452053..a74e8cdb7 100644 +--- a/vllm/model_executor/models/phi4mm_audio.py ++++ b/vllm/model_executor/models/phi4mm_audio.py +@@ -550,10 +550,11 @@ class TransformerEncoderBase(abc.ABC, nn.Module): + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, + self.chunk_size, + self.left_chunk) +- +- if xs_pad.is_cuda: +- enc_streaming_mask = enc_streaming_mask.cuda() +- xs_pad = xs_pad.cuda() ++ ++ device = xs_pad.device ++ if device.type != "cpu": ++ enc_streaming_mask = enc_streaming_mask.to(device) ++ xs_pad = xs_pad.to(device) - # install additional dependencies for openai api server - RUN --mount=type=cache,target=/root/.cache/pip \ -- pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope -+ pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] 'modelscope!=1.15.0' -+ -+RUN --mount=type=cache,target=/root/.cache/pip \ -+ pip uninstall oneccl oneccl-devel -y + input_tensor = xs_pad + input_tensor, masks = self._forward_embeddings_core( +@@ -570,8 +571,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module): + if chunk_size_nc is not None: + enc_streaming_mask_nc = self._streaming_mask( + seq_len, batch_size, chunk_size_nc, left_chunk_nc) +- if xs_pad.is_cuda: +- enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() ++ if device.type != "cpu": ++ enc_streaming_mask_nc = enc_streaming_mask_nc.to(device) + if masks is not None: + hs_mask_nc = masks & enc_streaming_mask_nc + else: +diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py +index 54dc0bebd..e13e87b93 100644 +--- a/vllm/model_executor/models/qwen2.py ++++ b/vllm/model_executor/models/qwen2.py +@@ -285,7 +285,7 @@ class Qwen2Model(nn.Module): + decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): + super().__init__() --ENV VLLM_USAGE_SOURCE production-docker-image \ -- TRITON_XPU_PROFILE 1 - # install development dependencies (for testing) - RUN python3 -m pip install -e tests/vllm_test_utils - ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] -diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md -index 0661933ac..469d88a05 100644 ---- a/docs/features/quantization/fp8.md -+++ b/docs/features/quantization/fp8.md -@@ -134,4 +134,4 @@ print(result[0].outputs[0].text) - ``` +- config = vllm_config.model_config.hf_config ++ config = vllm_config.model_config.hf_config.get_text_config() + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config - !!! warning -- Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. -+ Currently, by default we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. To avoid this, adding `VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT=1` can allow offloading weights to cpu before quantization and quantized weights will be kept in device. -diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md -index c8b6c6c86..404045306 100644 ---- a/docs/models/supported_models.md -+++ b/docs/models/supported_models.md -@@ -592,7 +592,8 @@ Specified using `--task generate`. - | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | - | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | - | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | --| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | -+| `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | -+| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | - | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | - | `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | - | `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | -@@ -602,7 +603,7 @@ Specified using `--task generate`. - | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | - | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | - | `MiniCPMO` | MiniCPM-O | T + IE+ + VE+ + AE+ | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | --| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ | -+| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | - | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | - | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | - | `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | -@@ -646,6 +647,15 @@ Specified using `--task generate`. +diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py +index 8aa777557..429516cce 100644 +--- a/vllm/model_executor/models/qwen2_5_vl.py ++++ b/vllm/model_executor/models/qwen2_5_vl.py +@@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor + from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) - This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. ++from vllm.attention.layer import check_upstream_fa_availability + from vllm.config import VllmConfig + from vllm.distributed import parallel_state + from vllm.distributed import utils as dist_utils +@@ -298,10 +299,19 @@ class Qwen2_5_VisionAttention(nn.Module): + disable_tp=use_data_parallel) -+!!! note -+ `Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its -+ MobileNet-v5 vision backbone. -+ -+ Performance is not yet fully optimized mainly due to: -+ -+ - Both audio and vision MM encoders use `transformers.AutoModel` implementation. -+ - There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups. + # Detect attention implementation. +- self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) ++ self.attn_backend = get_vit_attn_backend( ++ head_size=self.hidden_size_per_attention_head, ++ dtype=torch.get_default_dtype()) ++ self.use_upstream_fa = False ++ if self.attn_backend != _Backend.FLASH_ATTN and \ ++ check_upstream_fa_availability( ++ torch.get_default_dtype()): ++ self.attn_backend = _Backend.FLASH_ATTN ++ self.use_upstream_fa = True + - !!! note - Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, +- _Backend.ROCM_AITER_FA ++ _Backend.ROCM_AITER_FA, _Backend.IPEX + }: + raise RuntimeError( + f"Qwen2.5-VL does not support {self.attn_backend} backend now." +@@ -359,7 +369,10 @@ class Qwen2_5_VisionAttention(nn.Module): + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: +- from flash_attn import flash_attn_varlen_func ++ if self.use_upstream_fa: ++ from flash_attn import flash_attn_varlen_func ++ else: ++ from vllm.vllm_flash_attn import flash_attn_varlen_func -diff --git a/examples/offline_inference/basic/reward.py b/examples/offline_inference/basic/reward.py -new file mode 100644 -index 000000000..aec3481d2 ---- /dev/null -+++ b/examples/offline_inference/basic/reward.py -@@ -0,0 +1,55 @@ -+# SPDX-License-Identifier: Apache-2.0 -+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -+ -+from argparse import Namespace -+ -+from vllm import LLM, EngineArgs -+from vllm.utils import FlexibleArgumentParser -+ -+ -+def parse_args(): -+ parser = FlexibleArgumentParser() -+ parser = EngineArgs.add_cli_args(parser) -+ # Set example specific arguments -+ parser.set_defaults( -+ model="internlm/internlm2-1_8b-reward", -+ #runner="pooling", -+ task="reward", -+ enforce_eager=True, -+ max_model_len=1024, -+ trust_remote_code=True, -+ ) -+ return parser.parse_args() -+ -+ -+def main(args: Namespace): -+ # Sample prompts. -+ prompts = [ -+ "Hello, my name is", -+ "The president of the United States is", -+ "The capital of France is", -+ "The future of AI is", -+ ] -+ -+ # Create an LLM. -+ # You should pass runner="pooling" for reward models -+ llm = LLM(**vars(args)) + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + +@@ -376,6 +389,38 @@ class Qwen2_5_VisionAttention(nn.Module): + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) ++ elif self.attn_backend == _Backend.IPEX: ++ from vllm._ipex_ops import ipex_ops + -+ # Generate rewards. The output is a list of PoolingRequestOutput. -+ outputs = llm.reward(prompts) ++ q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + -+ # Print the outputs. -+ print("\nGenerated Outputs:\n" + "-" * 60) -+ for prompt, output in zip(prompts, outputs): -+ rewards = output.outputs.data -+ rewards_trimmed = ( -+ (str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards -+ ) -+ print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})") -+ print("-" * 60) -+ -+ -+if __name__ == "__main__": -+ args = parse_args() -+ main(args) -+ -diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py -index f0c00bcaa..c8fa36295 100644 ---- a/examples/offline_inference/multilora_inference.py -+++ b/examples/offline_inference/multilora_inference.py -@@ -30,7 +30,7 @@ def create_test_prompts( - ( - "A robot may not injure a human being", - SamplingParams( -- temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 -+ temperature=0.0, logprobs=1, max_tokens=128 - ), - None, - ), -@@ -46,7 +46,7 @@ def create_test_prompts( - SamplingParams( - temperature=0.0, - logprobs=1, -- prompt_logprobs=1, -+ #prompt_logprobs=1, - max_tokens=128, - stop_token_ids=[32003], - ), -@@ -57,7 +57,7 @@ def create_test_prompts( - SamplingParams( - temperature=0.0, - logprobs=1, -- prompt_logprobs=1, -+ #prompt_logprobs=1, - max_tokens=128, - stop_token_ids=[32003], - ), -@@ -99,14 +99,14 @@ def initialize_engine() -> LLMEngine: - # numbers will cause higher memory usage. If you know that all LoRAs will - # use the same rank, it is recommended to set this as low as possible. - # max_cpu_loras: controls the size of the CPU LoRA cache. -- engine_args = EngineArgs( -- model="meta-llama/Llama-2-7b-hf", -- enable_lora=True, -- max_loras=1, -- max_lora_rank=8, -- max_cpu_loras=2, -- max_num_seqs=256, -- ) -+ engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", -+ enable_lora=True, -+ max_loras=1, -+ max_lora_rank=8, -+ max_cpu_loras=2, -+ max_num_seqs=256, -+ enforce_eager=True, -+ block_size=64) - return LLMEngine.from_engine_args(engine_args) ++ output = torch.empty( ++ q.shape, ++ dtype=q.dtype, ++ device=q.device) ++ ipex_ops.varlen_attention( ++ q, ++ k, ++ v, ++ output, ++ cu_seqlens, ++ cu_seqlens, ++ None, ++ max_seqlen, ++ max_seqlen, ++ pdropout=0.0, ++ softmax_scale=1.0/(q.shape[-1] ** 0.5), ++ zero_tensors=False, ++ is_causal=False, ++ return_softmax=False, ++ gen_=None, ++ window_size_left=-1, ++ window_size_right=-1, ++ logits_soft_cap=-1, ++ ) ++ context_layer = rearrange(output, ++ "(b s) ... -> b s ...", ++ b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] +@@ -628,7 +673,12 @@ class Qwen2_5_VisionTransformer(nn.Module): + prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, + ) +- self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) ++ self.attn_backend = get_vit_attn_backend( ++ head_size=head_dim, dtype=torch.get_default_dtype()) ++ if self.attn_backend != _Backend.FLASH_ATTN and \ ++ check_upstream_fa_availability( ++ torch.get_default_dtype()): ++ self.attn_backend = _Backend.FLASH_ATTN + + @property + def dtype(self) -> torch.dtype: +@@ -714,6 +764,8 @@ class Qwen2_5_VisionTransformer(nn.Module): + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() ++ elif self.attn_backend == _Backend.IPEX: ++ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return max_seqlen, seqlens + + @staticmethod +@@ -1210,10 +1262,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, + if image_input is None and video_input is None: + inputs_embeds = None + else: +- if uses_mrope(self.config): +- assert positions.ndim == 2 and positions.size(0) == 3, ( +- "multimodal section rotary embedding requires " +- f"(3, seq_len) positions, but got {positions.size()}") ++ # if uses_mrope(self.config): ++ # assert positions.ndim == 2 and positions.size(0) == 3, ( ++ # "multimodal section rotary embedding requires " ++ # f"(3, seq_len) positions, but got {positions.size()}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, +diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py +index 90a1ad2a6..e6da04df4 100644 +--- a/vllm/model_executor/models/qwen2_vl.py ++++ b/vllm/model_executor/models/qwen2_vl.py +@@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize + from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( + Qwen2VLVideoProcessor) ++from vllm.attention.layer import check_upstream_fa_availability + from vllm.config import VllmConfig + from vllm.distributed import parallel_state, tensor_model_parallel_all_gather + from vllm.distributed import utils as dist_utils +@@ -82,7 +83,7 @@ from .vision import get_vit_attn_backend + logger = init_logger(__name__) -diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py -index 4fdc7a3cf..b6007b9f4 100644 ---- a/examples/offline_inference/prithvi_geospatial_mae.py -+++ b/examples/offline_inference/prithvi_geospatial_mae.py -@@ -3,12 +3,12 @@ - import argparse - import datetime - import os --import re - from typing import Union + # For profile run +-_MAX_FRAMES_PER_VIDEO = 16 ++_MAX_FRAMES_PER_VIDEO = 600 - import albumentations - import numpy as np - import rasterio -+import regex as re - import torch - from einops import rearrange - from terratorch.datamodules import Sen1Floods11NonGeoDataModule -diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py -index e4811c023..fe4393bcf 100644 ---- a/examples/offline_inference/vision_language.py -+++ b/examples/offline_inference/vision_language.py -@@ -389,6 +389,39 @@ def run_tarsier(questions: list[str], modality: str) -> ModelRequestData: - ) + # === Vision Inputs === # +@@ -314,10 +315,19 @@ class Qwen2VisionAttention(nn.Module): + prefix=f"{prefix}.proj") -+# Intern-S1 -+def run_interns1(questions: list[str], modality: str) -> ModelRequestData: -+ model_name = "internlm/Intern-S1" -+ -+ engine_args = EngineArgs( -+ model=model_name, -+ trust_remote_code=True, -+ max_model_len=8192, -+ max_num_seqs=2, -+ limit_mm_per_prompt={modality: 1}, -+ enforce_eager=True, -+ ) -+ -+ if modality == "image": -+ placeholder = "" -+ elif modality == "video": -+ placeholder = "