-
Notifications
You must be signed in to change notification settings - Fork 61
Support FP8 in op flip, index_put, and index.Tensor #2190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds FP8 support to XPU tensor operations including flip, index, and index_put operations by migrating from legacy dispatch macros to the newer AT_DISPATCH_V2 system.
- Migrates flip_kernel from AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 to AT_DISPATCH_V2
- Migrates index_kernel and index_put_kernel from AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4 to AT_DISPATCH_V2
- Adds FP8 type support through AT_FLOAT8_TYPES expansion
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| src/ATen/native/xpu/sycl/TensorTransformationsKernels.cpp | Updates flip_kernel to use AT_DISPATCH_V2 with FP8 support |
| src/ATen/native/xpu/sycl/Indexing.cpp | Updates index_kernel and index_put_kernel to use AT_DISPATCH_V2 with FP8 support |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

To solve #2207
Extends support for float8 data types across various XPU tensor indexing and transformation kernels, ensuring these operations are compatible with the new types. It also adds a regression test for flipping float8 tensors and removes the skip for float8 indexing tests.
Float8 type support:
XPUScalar.cppandIndexing.cppto includeAT_FLOAT8_TYPES, enabling float8 support in scalar extraction, indexing, index_put, and deterministic index_put kernels.flip_kernelinTensorTransformationsKernels.cppto support float8 and barebones unsigned types, updating the dispatch mechanism accordingly.Dispatch_v2.hfor the updated dispatch macros.Testing improvements:
test_index_and_index_put.pyto verify correctness of the operation on XPU.test_indexing_xpu.py, re-enabling these tests now that support is implemented.