Skip to content

Commit

Permalink
[ROCm] add hipblaslt support (pytorch#114329)
Browse files Browse the repository at this point in the history
Disabled by default. Enable with env var DISABLE_ADDMM_HIP_LT=0. Tested on both ROCm 5.7 and 6.0.

Pull Request resolved: pytorch#114329
Approved by: https://github.com/malfet
  • Loading branch information
jeffdaily authored and dmenig committed Dec 21, 2023
1 parent 5407996 commit bad7ccb
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 32 deletions.
120 changes: 98 additions & 22 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,9 @@
#include <c10/macros/Export.h>
#include <c10/util/irange.h>

// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
#if !defined(USE_ROCM) && !defined(_MSC_VER)
#include <cublasLt.h>
#endif

#ifdef USE_ROCM
// until hipblas has an API to accept flags, we must use rocblas here
#include <hipblas/hipblas.h>
#include <rocblas/rocblas.h>
#define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
Expand Down Expand Up @@ -64,6 +59,7 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
// until we use hiblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#ifndef HIPBLAS_V2
#define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_64F HIPBLAS_R_64F
Expand All @@ -81,6 +77,7 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
#define HIP_R_16BF HIPBLAS_R_16B
#define HIP_C_16BF HIPBLAS_C_16B
#endif
#endif

#define CUDABLAS_POSINT_CHECK(FD, X) \
TORCH_CHECK( \
Expand Down Expand Up @@ -167,6 +164,7 @@ static void _cublasAdjustLdLevel3(
}
}

#ifndef USE_ROCM
uint32_t _getAlignment(uintptr_t address) {
// alignment are in bytes
uint32_t alignment = 256;
Expand All @@ -176,18 +174,25 @@ uint32_t _getAlignment(uintptr_t address) {
}
}
}
#endif

static size_t _parseChosenWorkspaceSize() {
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
#ifdef USE_ROCM
if (!val) {
// accept either env var
val = getenv("HIPBLASLT_WORKSPACE_SIZE");
}
#endif
size_t workspace_size = 1024; /* default size in KiB according to #73328 */
if (val) {
try {
workspace_size = std::stoi(val);
} catch(std::invalid_argument const& e) {
TORCH_WARN("invalid CUBLAS_LT_WORKSPACE_SIZE,",
TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,",
" using default workspace size of ", workspace_size, " bytes.");
} catch(std::out_of_range const& e) {
TORCH_WARN("CUBLAS_LT_WORKSPACE_SIZE out of range,",
TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,",
" using default workspace size of ", workspace_size, " bytes.");
}
}
Expand Down Expand Up @@ -341,12 +346,19 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
const float fbeta = beta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);

#if defined(USE_ROCM) && ROCM_VERSION >= 60000
auto compute_type = CUBLAS_COMPUTE_32F;
#else
auto compute_type = CUDA_R_32F;
#endif
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
b, CUDA_R_16BF, (int)ldb, strideb,
(void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
(int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
(int)num_batches,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}

template <>
Expand Down Expand Up @@ -516,6 +528,11 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
#endif
#if defined(USE_ROCM) && ROCM_VERSION >= 60000
auto compute_type = CUBLAS_COMPUTE_32F;
#else
auto compute_type = CUDA_R_32F;
#endif
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
Expand All @@ -536,12 +553,62 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
c,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}

#if !defined(USE_ROCM) && !defined(_MSC_VER)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)

#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
// to hipify correctly without this change.
#define hipDataType hipblasDatatype_t
#endif

// hipblaslt custom types were a temporary work-around
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_DATA_TYPE
hipblasltDatatype_t hipToLt(hipDataType type) {
switch (type) {
case HIP_R_32F: return HIPBLASLT_R_32F;
case HIP_R_64F: return HIPBLASLT_R_64F;
case HIP_R_16F: return HIPBLASLT_R_16F;
case HIP_R_8I: return HIPBLASLT_R_8I;
case HIP_C_32F: return HIPBLASLT_C_32F;
case HIP_C_64F: return HIPBLASLT_C_64F;
case HIP_C_16F: return HIPBLASLT_C_16F;
case HIP_C_8I: return HIPBLASLT_C_8I;
case HIP_R_8U: return HIPBLASLT_R_8U;
case HIP_C_8U: return HIPBLASLT_C_8U;
case HIP_R_32I: return HIPBLASLT_R_32I;
case HIP_C_32I: return HIPBLASLT_C_32I;
case HIP_R_32U: return HIPBLASLT_R_32U;
case HIP_C_32U: return HIPBLASLT_C_32U;
case HIP_R_16BF: return HIPBLASLT_R_16B;
case HIP_C_16BF: return HIPBLASLT_C_16B;
default: TORCH_CHECK(false);
}
}
#define HIPTOLT(type) hipToLt(type)
#else
#define HIPTOLT(type) type
#endif

#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && HIPBLASLT_CUSTOM_COMPUTE_TYPE
hipblasLtComputeType_t hipblasToLt(hipblasComputeType_t type) {
switch (type) {
case HIPBLAS_COMPUTE_32F: return HIPBLASLT_COMPUTE_F32;
case HIPBLAS_COMPUTE_32F_FAST_16F: return HIPBLASLT_COMPUTE_F32_FAST_F16;
case HIPBLAS_COMPUTE_32F_FAST_TF32: return HIPBLASLT_COMPUTE_F32_FAST_XF32;
case HIPBLAS_COMPUTE_64F: return HIPBLASLT_COMPUTE_F64;
case HIPBLAS_COMPUTE_32I: return HIPBLASLT_COMPUTE_I32;
default: TORCH_CHECK(false);
}
}
#define HIPCOMPTOLT(type) hipblasToLt(type)
#else
#define HIPCOMPTOLT(type) type
#endif

namespace {
// Following the pattern of CuSparseDescriptor
Expand Down Expand Up @@ -580,7 +647,7 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
cudaDataType_t scale_type) {
cublasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
cublasLtMatmulDescCreate(&raw_descriptor, HIPCOMPTOLT(compute_type), HIPTOLT(scale_type)));
descriptor_.reset(raw_descriptor);
}
template <typename T>
Expand All @@ -601,7 +668,7 @@ class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
bool t = false) {
cublasLtMatrixLayout_t raw_descriptor = nullptr;
TORCH_CUDABLAS_CHECK(
cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
cublasLtMatrixLayoutCreate(&raw_descriptor, HIPTOLT(type), t ? cols : rows, t ? rows : cols, ld));
descriptor_.reset(raw_descriptor);
}
};
Expand Down Expand Up @@ -645,13 +712,19 @@ void gemm_and_bias(
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
if constexpr (std::is_same_v<Dtype, double>) {
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
abcType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
#else
TORCH_CHECK(false, "gemm_and_bias is only supported for double type on ROCm 6.0 and above");
#endif
} else if constexpr (std::is_same_v<Dtype, float>) {
#ifndef USE_ROCM
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
#endif
abcType = CUDA_R_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
abcType = CUDA_R_16F;
Expand All @@ -668,7 +741,7 @@ void gemm_and_bias(
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
#if CUDA_VERSION >= 11040
#if CUDA_VERSION >= 11040 || defined(USE_ROCM)
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
#endif
}
Expand All @@ -685,6 +758,7 @@ void gemm_and_bias(
size_t workspaceSize = _getWorkspaceSize();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);

#ifndef USE_ROCM
uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
Expand All @@ -693,14 +767,14 @@ void gemm_and_bias(
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
#endif

auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);

cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Expand Down Expand Up @@ -876,8 +950,7 @@ void scaled_gemm(
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Expand Down Expand Up @@ -952,6 +1025,7 @@ void int8_gemm(
int64_t mat2_ld,
int32_t* result_ptr,
int64_t result_ld) {
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)

cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
cudaDataType_t scaleType = CUDA_R_32I;
Expand All @@ -970,8 +1044,7 @@ void int8_gemm(
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);

cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();

// cublas team: alpha and beta need to be the same dtype as of scaleType
at::opmath_type<int32_t> alpha_val = 1;
Expand Down Expand Up @@ -1022,11 +1095,14 @@ void int8_gemm(
computeType,
" scaleType ",
scaleType);
#else
TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
}
#endif // !defined(USE_ROCM) && !defined(_MSC_VER)
#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)

// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
#if defined(USE_ROCM) && ROCM_VERSION <= 56000
#if defined(USE_ROCM) && ROCM_VERSION <= 50600
#define ROCM_CONST_BUG
#else
#define ROCM_CONST_BUG const
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));

#if !defined(USE_ROCM) && !defined(_MSC_VER)
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
enum GEMMAndBiasActivationEpilogue {
None,
RELU,
Expand Down Expand Up @@ -149,7 +149,7 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
template <>
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));

#if defined(USE_ROCM) && ROCM_VERSION <= 55000
#if defined(USE_ROCM) && ROCM_VERSION <= 50500
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
#define CUDABLAS_TRSM_ARGTYPES(Dtype) \
hipblasHandle_t handle, hipblasSideMode_t side, hipblasFillMode_t uplo, \
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/cuda/CUDAContextLight.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
#include <cusparse.h>
#include <cublas_v2.h>

// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
#include <cublasLt.h>
#endif

#ifdef CUDART_VERSION
#include <cusolverDn.h>
#endif
Expand Down Expand Up @@ -76,6 +82,9 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
/* Handles */
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
#endif

TORCH_CUDA_CPP_API void clearCublasWorkspaces();

Expand Down
Loading

0 comments on commit bad7ccb

Please sign in to comment.