-
Notifications
You must be signed in to change notification settings - Fork 401
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #5911: [ROCm] Unifying hip/cuda blas-lt APIs
Imported from GitHub PR #5911 This is a follow-up PR for these two issues: #4406, #3953 We unified hip/cuda blas-lt APIs by providing a common virtual interface defined in xla/stream_executor/gpu/gpu_blas_lt.h/.cc with implementations in xla/stream_executor/cuda/cuda_blas_lt.h/.cc and xla/stream_executor/rocm/hip_blas_lt.h/.cc, respectively. The main design decision was that we made the class MatmulPlan (originally defined in xla/service/gpu/matmul_utils.h/.cc) **polymorphic** and moved it's interface declaration to gpu_blas_lt.h. There are two reasons for that, namely: 1. MatmulPlan provided a public function **ExecuteOnStream** which was implemented in terms of conditional compulation with macros '#if GOOGLE_CUDA' or '#if TF_HIPBLASLT' in order to integrate library-specific data-types. This function becomes now a part of gpu_blas_lt interface. 2. MatmulPlan contained a library-specific member variable 'plan_' of type 'se::gpu::BlasLt::MatmulPlan' which is basically a plain container of MatmulDesc and several MatrixLayouts. These underlying types are again BLASLT library-specific and are **never** used directly, hence there is no need to expose BlasLt::MatmulDesc and BlasLt::MatrixLayout in the public interface. Besides ExecuteOnStream, the class MatmulPlan also provides a number of overloaded 'DoMatmul' member functions (some of them are template functions) which were extracted as a common part from the original BlasLt implementations. These DoMatmul functions are also required for the oncoming integration of Blas-lt interface into Tensorflow: see tensorflow\core\kernels\matmul_util.h/.cc. We also extracted the library-specific argument type-checks from templated DoMatmul functions and moved them into a virtual function MatmulPlan::ValidateInputs(). The polymorphic class gpu::BlasLt (defined in gpu_blas_lt.h) is responsible for constructing the objects of type MatmulPlan, the rest blas-lt functionality is solely handled by MatmulPlan interface. The instantiations of gpu::BlasLt interface, as before, are defined in xla/stream_executor/cuda/cuda_blas.h and xla/stream_executor/rocm/rocm_blas.h, respectively. We have also tried to compile the code with TF_HIPBLASLT=0 to make sure it also works fine if no hipblas-lt is available. @akuegel: can you perhaps have a look at our implementation ? Copybara import of the project: -- daea33c by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>: Unifying hip/cuda blas-lt APIs work in progress ongoing work make sure the code runs with TF_HIPBLASLT=0 adaptions for CUDA compile moving BlasLt and related stuff to se::gpu namespace hipblas_lt interface cleanup adapted the last blas-lt inteface changes for CUDA -- b4ff019 by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>: protected code by TF_HIPBLASLT macro to make sure code builds without hipblas-lt too -- 7248f69 by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>: resolving conflicts -- d48e6ee by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>: appliyng reviewer changes -- 1d7cc54 by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>: rebased and adapted API for TF blas-lt part Merging this change closes #5911 COPYBARA_INTEGRATE_REVIEW=#5911 from ROCmSoftwarePlatform:unify_blaslt_APIs_v2 1d7cc54 PiperOrigin-RevId: 573136621
- Loading branch information
1 parent
5a9d240
commit f459e57
Showing
28 changed files
with
1,424 additions
and
1,153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.