Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #59515: Cublaslt fp8 matmul restriction work-around
Imported from GitHub PR tensorflow/tensorflow#59515 As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at [here](https://github.com/wenscarl/fp8_gemm_test/blob/main/fp8_gemm_backward_fail.py). A remaining restriction is the batch dimension still needs to be a leading dimension. Copybara import of the project: -- 77de65a4d0c2c0c7db8e6b305d03e039c454c2e2 by wenscarl <shuw@nvidia.com>: Workaround cublasLt fp8 matmul restrictions. -- 706ad9adb3e5109acc7909d39fa89c9286c192f3 by shuw <shuw@nvidia.com>: Work around cublasLt fp8 matmul restrictions -- 8f4856968d826b621a3a083f7843ffb30b8d921f by wenscarl <shuw@nvidia.com>: Add MatrixIsColumnMajor -- e0e67a4223dda278463e2004f975b32c0d5ac3d8 by shuw <shuw@nvidia.com>: Remove dead code -- cfea2cb8dc60edb07bdf1fd3c2113cd401354701 by shuw <shuw@nvidia.com>: Add gemm_rewrite test -- 460f616103888ad6d80dc22cc3876690ab7e16c3 by shuw <shuw@nvidia.com>: Cover 32 cases -- 6f724ff30570d7f35dba03a13d6a9bf6aa903da5 by shuw <shuw@nvidia.com>: Abbreviate branches logics -- 64f153a32cb8ddd925f6f8d4a0ae88178990f24c by shuw <shuw@nvidia.com>: Parameterized test -- 6146558b21416ec3c5b8aa32f7f59c9f3b49ae87 by shuw <shuw@nvidia.com>: auto -> type names -- 96e758886d67fe868e4d263b44cc729af972fb71 by shuw <shuw@nvidia.com>: Update comments -- 07db882429e7e93afba335714d672c460483ec4f by shuw <shuw@nvidia.com>: vector -> array -- 3089065fb4d78033b85f46b68cd25a0643dd17e3 by shuw <shuw@nvidia.com>: fix typo -- 50eaec634ccf7d77a0e8315ade5e0a5bab06c07e by shuw <shuw@nvidia.com>: Convert to NN config and add batched matmul tests Merging this change closes #59515 FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#59515 from wenscarl:cublaslt_fp8_matmul_war bff4bc682d5e0bc200e6261317ba1499960c3e45 PiperOrigin-RevId: 512191684
- Loading branch information
Showing
3 changed files
with
222 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters