forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyTorch] Add Vulkan support for at::softmax 1,2,3 dimension tensors (p…
…ytorch#105012) Summary: Pull Request resolved: pytorch#105012 This rounds out the support for the [softmax function](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html) on the Vulkan GPU backend. The test inputs of the 1,2,3 dimension cases are simply the truncated existing 4 dimension inputs. The existing shader algorithms are reused. Test Plan: 1. `buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1` on Apple M1 MacBook 2. Confirm all tests pass with no regression, and the added tests `*softmax*` pass under `-- --gtest_filter="*softmax*"` 2a. All tests P782531732 2b. `softmax` tests P782529114 ``` ~/fbsource » buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 -- --gtest_filter="*softmax*" Buck UI: https://www.internalfb.com/buck2/692eb82d-c2ee-49bb-833f-3c11d6e2fea9 Jobs completed: 4. Time elapsed: 0.1s. Running main() from xplat/third-party/gmock/googletest-1.12.1/googletest/src/gtest_main.cc Note: Google Test filter = *softmax* [==========] Running 1 test from 1 test suite. [----------] Global test environment set-up. [----------] 1 test from VulkanAPITest [ RUN ] VulkanAPITest.softmax [ OK ] VulkanAPITest.softmax (42 ms) [ DISABLED ] VulkanAPITest.DISABLED_log_softmax [----------] 1 test from VulkanAPITest (42 ms total) [----------] Global test environment tear-down [==========] 1 test from 1 test suite ran. (42 ms total) [ PASSED ] 1 test. YOU HAVE 1 DISABLED TEST ``` Reviewed By: SS-JIA Differential Revision: D46985319 fbshipit-source-id: 8e30a691232e07262e85c043dac78c7ab358527e
- Loading branch information
1 parent
1a6619a
commit 632307f
Showing
2 changed files
with
122 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters