Skip to content

Commit

Permalink
Benchmarks: Revise Code - Add hipblasLt tuning to dist-inference cpp …
Browse files Browse the repository at this point in the history
…implementation (#616)

**Description**
Adds hipblasLt tuning to dist-inference cpp implementation.
  • Loading branch information
yzygitzh committed Apr 2, 2024
1 parent eeaa9b1 commit cc89ee5
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 25 deletions.
8 changes: 8 additions & 0 deletions superbench/benchmarks/micro_benchmarks/dist_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ def add_parser_arguments(self):
required=False,
help='Whether to launch kernels in CUDA graph mode.',
)
self._parser.add_argument(
'--tune_gemm',
action='store_true',
required=False,
help='Whether to tune GEMM performance before testing.',
)

def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.
Expand Down Expand Up @@ -356,6 +362,8 @@ def _preprocess(self):
(self._args.num_layers, self._args.num_warmup, self._args.num_steps)
if self._args.use_cuda_graph:
args += ' --use_cuda_graph'
if self._args.tune_gemm:
args += ' --tune_gemm'
self._commands = ['%s %s' % (self.__bin_path, args)]

return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
*
*******************************************************************************/

#include <algorithm>
#include <chrono>
#include <cstdio>
#include <cstdlib>
Expand Down Expand Up @@ -60,6 +61,21 @@ using cublasLtHalf = hipblasLtHalf;
#else
#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLASLT_COMPUTE_F32
#endif
#if HIP_VERSION >= 50700000
#include <hipblaslt/hipblaslt-ext.hpp>
#if HIP_VERSION >= 60000000
#define HIPBLASLT_GETINDEXFROMALGO(algo) hipblaslt_ext::getIndexFromAlgo(algo)
#else
static int getIndexFromAlgo(hipblasLtMatmulAlgo_t &algo) {
int *algo_ptr = (int *)algo.data;
if (*algo_ptr < 0) {
return -1;
}
return *algo_ptr;
}
#define HIPBLASLT_GETINDEXFROMALGO(algo) getIndexFromAlgo(algo)
#endif
#endif
#else
#include <cublasLt.h>
#include <cuda_fp16.h>
Expand Down Expand Up @@ -94,23 +110,26 @@ using cublasLtHalf = half;
#endif

static void ShowUsage(char *argv[]) {
std::cerr << "Usage: " << argv[0] << " <options>\n"
<< "options:\n"
<< "\t-h, --help\t\t\t\tShow this help message\n"
<< "\t-m \t\t\tm\t\tGEMM_STRIDED argument m\n"
<< "\t-n \t\t\tn\t\tGEMM_STRIDED argument n\n"
<< "\t-k \t\t\tk \t\tGEMM_STRIDED argument k\n"
<< "\t--alpha \t\talpha \t\tGEMM_STRIDED argument alpha\n"
<< "\t--beta \t\t\tbeta \t\tGEMM_STRIDED argument beta\n"
<< "\t--num_layers \t\t\tnum_layers \t\tNumber of layers in the model\n"
<< "\t--num_warmups \t\t\tnum_warmups \t\tNumber of warmup runs\n"
<< "\t--num_iters \t\t\tnum_iters \t\tNumber of test runs\n"
<< "\t--use_cuda_graph \t\t\tuse_cuda_graph \t\tWhether to launch kernels in CUDA graph mode\n"
<< std::endl;
std::cerr
<< "Usage: " << argv[0] << " <options>\n"
<< "options:\n"
<< "\t-h, --help\t\t\t\tShow this help message\n"
<< "\t-m \t\t\tm\t\tGEMM_STRIDED argument m\n"
<< "\t-n \t\t\tn\t\tGEMM_STRIDED argument n\n"
<< "\t-k \t\t\tk \t\tGEMM_STRIDED argument k\n"
<< "\t--alpha \t\talpha \t\tGEMM_STRIDED argument alpha\n"
<< "\t--beta \t\t\tbeta \t\tGEMM_STRIDED argument beta\n"
<< "\t--num_layers \t\t\tnum_layers \t\tNumber of layers in the model\n"
<< "\t--num_warmups \t\t\tnum_warmups \t\tNumber of warmup runs\n"
<< "\t--num_iters \t\t\tnum_iters \t\tNumber of test runs\n"
<< "\t--use_cuda_graph \t\t\tuse_cuda_graph \t\tWhether to launch kernels in CUDA graph mode\n"
<< "\t--tune_gemm \t\t\ttune_gemm \t\tWhether to tune GEMM before testing. Currently only work for hipblasLt.\n"
<< std::endl;
}

static int ParseArguments(int argc, char *argv[], int64_t *m, int64_t *n, int64_t *k, float *alpha, float *beta,
int32_t *num_layers, int32_t *num_warmups, int32_t *num_iters, bool *use_cuda_graph) {
int32_t *num_layers, int32_t *num_warmups, int32_t *num_iters, bool *use_cuda_graph,
bool *tune_gemm) {
if (argc >= 2) {
for (int i = 1; i < argc; ++i) {
std::string arg = argv[i];
Expand Down Expand Up @@ -143,6 +162,8 @@ static int ParseArguments(int argc, char *argv[], int64_t *m, int64_t *n, int64_
std::cerr << "not supported by current environment" << std::endl << std::endl;
return -1;
#endif
} else if (arg == "--tune_gemm") {
*tune_gemm = true;
} else {
std::cerr << "error with " << arg << std::endl;
std::cerr << "do not recognize option" << std::endl << std::endl;
Expand Down Expand Up @@ -182,10 +203,91 @@ void InitializeABCDEF(std::vector<cublasLtHalf> &ha, int64_t size_a, std::vector
}
}

#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 50700000
// Tune GEMM algorithm in local rank.
// Write <0 to ret_algo_time_in_ms if nothing found.
// Write >=0 to ret_algo_time_in_ms and write ret_algo if something is found.
void TuneHipblasLtGemmLocal(const hipblasLtHandle_t &handle, const hipblasLtMatmulDesc_t &matmul, float alpha, void *da,
const hipblasLtMatrixLayout_t &matA, void *db, const hipblasLtMatrixLayout_t &matB,
float beta, void *dc, const hipblasLtMatrixLayout_t &matC, void *dd,
const hipblasLtMatrixLayout_t &matD, void *d_workspace, uint64_t workspace_size,
const cudaStream_t &stream, int rank, int num_ranks, hipblasLtMatmulAlgo_t *ret_algo,
float *ret_algo_time_in_ms) {
std::vector<hipblasLtMatmulHeuristicResult_t> gemm_heuristics;
// Get all possible algorithms
CHECK_CUBLASLT_ERROR(hipblaslt_ext::getAllAlgos(
handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, HIPBLAS_OP_N, HIPBLAS_OP_N, DIST_INF_HIP_DATATYPE_R_16F,
DIST_INF_HIP_DATATYPE_R_16F, DIST_INF_HIP_DATATYPE_R_16F, DIST_INF_HIP_DATATYPE_R_16F,
DIST_INF_HIP_COMPUTETYPE_F32, gemm_heuristics));
// Make sure the algorithm order is deterministic
std::sort(gemm_heuristics.begin(), gemm_heuristics.end(),
[](hipblasLtMatmulHeuristicResult_t &a, hipblasLtMatmulHeuristicResult_t &b) {
return HIPBLASLT_GETINDEXFROMALGO(a.algo) < HIPBLASLT_GETINDEXFROMALGO(b.algo);
});
// Timing utilities
cudaEvent_t start_event;
cudaEvent_t end_event;
const int kNumWarmups = 10;
const int kNumTestRuns = 100;
*ret_algo_time_in_ms = -1;
// Benchmark all algorithms in given shape
CHECK_CUDA_ERROR(cudaEventCreate(&start_event));
CHECK_CUDA_ERROR(cudaEventCreate(&end_event));
// Partition work evenly into different ranks
for (size_t algo_idx = rank; algo_idx < gemm_heuristics.size(); algo_idx += num_ranks) {
auto &algo = gemm_heuristics[algo_idx].algo;
size_t ret_workspace_size = 0;
auto status = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, &alpha, matA, matB, &beta, matC, matD, algo,
ret_workspace_size);
if (status != HIPBLAS_STATUS_SUCCESS || ret_workspace_size >= workspace_size) {
continue;
}
for (int i = 0; i < kNumWarmups; i++) {
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul, &alpha, da, matA, db, matB, &beta, dc, matC, dd, matD,
&algo, d_workspace, workspace_size, stream));
}
CHECK_CUDA_ERROR(cudaEventRecord(start_event, stream));
for (int i = 0; i < kNumTestRuns; i++) {
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul, &alpha, da, matA, db, matB, &beta, dc, matC, dd, matD,
&algo, d_workspace, workspace_size, stream));
}
CHECK_CUDA_ERROR(cudaEventRecord(end_event, stream));
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
float time_in_ms = 0;
CHECK_CUDA_ERROR(cudaEventElapsedTime(&time_in_ms, start_event, end_event));
time_in_ms /= kNumTestRuns;
if (*ret_algo_time_in_ms < 0 || time_in_ms < *ret_algo_time_in_ms) {
*ret_algo = algo;
*ret_algo_time_in_ms = time_in_ms;
}
}
CHECK_CUDA_ERROR(cudaEventDestroy(start_event));
CHECK_CUDA_ERROR(cudaEventDestroy(end_event));
}

// Select global best GEMM algorithms across ranks. Write global_algo if something is found.
void TuneHipblasLtGemmGlobal(int num_ranks, const hipblasLtMatmulAlgo_t &local_algo, float local_time_in_ms,
hipblasLtMatmulAlgo_t *global_algo) {
std::vector<hipblasLtMatmulAlgo_t> coll_algo(num_ranks);
std::vector<float> coll_time_in_ms(num_ranks);
MPI_Allgather(&local_algo, sizeof(local_algo), MPI_BYTE, coll_algo.data(), sizeof(local_algo), MPI_BYTE,
MPI_COMM_WORLD);
MPI_Allgather(&local_time_in_ms, sizeof(local_time_in_ms), MPI_BYTE, coll_time_in_ms.data(),
sizeof(local_time_in_ms), MPI_BYTE, MPI_COMM_WORLD);
float min_time_in_ms = -1;
for (int i = 0; i < num_ranks; i++) {
if (coll_time_in_ms[i] >= 0 && (min_time_in_ms < 0 || coll_time_in_ms[i] < min_time_in_ms)) {
min_time_in_ms = coll_time_in_ms[i];
*global_algo = coll_algo[i];
}
}
}
#endif

// B[m, k] * A[k, n] + C[m, n] = D[m, n]
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t num_layers, int32_t num_warmups,
int32_t num_iters, bool use_cuda_graph, ncclComm_t nccl_comm) {
int32_t num_iters, bool use_cuda_graph, bool tune_gemm, ncclComm_t nccl_comm, int rank, int num_ranks) {
const int kNcclBufAlignment = 512;

int size_a = k * n;
Expand Down Expand Up @@ -230,7 +332,11 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
CHECK_CUDA_ERROR(cudaMemcpy(de, he.data(), sizeof(cublasLtHalf) * size_e, cudaMemcpyHostToDevice));
CHECK_CUDA_ERROR(cudaMemcpy(df, hf.data(), sizeof(cublasLtHalf) * size_f, cudaMemcpyHostToDevice));

#if defined(__HIP_PLATFORM_AMD__)
uint64_t workspace_size = 256 * 1024 * 1024; // max workspace size allowed for hipblaslt
#else
uint64_t workspace_size = 1024 * 1024;
#endif
void *d_workspace;
CHECK_CUDA_ERROR(cudaMalloc(&d_workspace, workspace_size));
int returnedAlgoCount = 0;
Expand Down Expand Up @@ -279,8 +385,22 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
// E[k, m] * D[m, n] + F[k, n] = G[k, n]
CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul1, matB, matA, matC, matD, pref, 1,
heuristicResult1, &returnedAlgoCount));
hipblasLtMatmulAlgo_t algo1 = heuristicResult1[0].algo;
CHECK_CUBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, matmul2, matE, matD, matF, matG, pref, 1,
heuristicResult2, &returnedAlgoCount));
hipblasLtMatmulAlgo_t algo2 = heuristicResult2[0].algo;
#if HIP_VERSION >= 50700000
if (tune_gemm) {
hipblasLtMatmulAlgo_t ret_algo;
float ret_algo_time_in_ms;
TuneHipblasLtGemmLocal(handle, matmul1, alpha, db, matB, da, matA, beta, dc, matC, dd, matD, d_workspace,
workspace_size, stream, rank, num_ranks, &ret_algo, &ret_algo_time_in_ms);
TuneHipblasLtGemmGlobal(num_ranks, ret_algo, ret_algo_time_in_ms, &algo1);
TuneHipblasLtGemmLocal(handle, matmul2, alpha, de, matE, dd, matD, beta, df, matF, dg, matG, d_workspace,
workspace_size, stream, rank, num_ranks, &ret_algo, &ret_algo_time_in_ms);
TuneHipblasLtGemmGlobal(num_ranks, ret_algo, ret_algo_time_in_ms, &algo2);
}
#endif
#else
cublasLtHandle_t handle;
cublasLtMatrixLayout_t matA, matB, matC, matD, matE, matF, matG;
Expand Down Expand Up @@ -328,13 +448,13 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
// cublasLt is not well supported by ROCm hipify tools, explicitly define ROCm logic instead.
#if defined(__HIP_PLATFORM_AMD__)
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD,
&heuristicResult1[0].algo, d_workspace, workspace_size, stream));
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
&heuristicResult2[0].algo, d_workspace, workspace_size, stream));
&algo1, d_workspace, workspace_size, stream));
CHECK_CUBLASLT_ERROR(hipblasLtMatmul(handle, matmul2, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
&algo2, d_workspace, workspace_size, stream));
#else
CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, db, matB, da, matA, &beta, dc, matC, dd, matD,
&heuristicResult1[0].algo, d_workspace, workspace_size, stream));
CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul1, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
CHECK_CUBLASLT_ERROR(cublasLtMatmul(handle, matmul2, &alpha, de, matE, dd, matD, &beta, df, matF, dg, matG,
&heuristicResult2[0].algo, d_workspace, workspace_size, stream));
#endif
CHECK_NCCL_ERROR(ncclAllReduce(dg, dg, size_g, ncclFloat16, ncclSum, nccl_comm, stream));
Expand Down Expand Up @@ -456,18 +576,21 @@ int main(int argc, char *argv[]) {
int32_t num_warmups = 20;
int32_t num_iters = 100;
bool use_cuda_graph = false;
bool tune_gemm = false;

if (ParseArguments(argc, argv, &m, &n, &k, &alpha, &beta, &num_layers, &num_warmups, &num_iters, &use_cuda_graph)) {
if (ParseArguments(argc, argv, &m, &n, &k, &alpha, &beta, &num_layers, &num_warmups, &num_iters, &use_cuda_graph,
&tune_gemm)) {
ShowUsage(argv);
return -1;
}

fprintf(stdout,
"Parameters: m=%ld, n=%ld, k=%ld, alpha=%f, beta=%f, num_layers=%d, num_warmups=%d, num_iters=%d, "
"use_cuda_graph=%d\n",
m, n, k, alpha, beta, num_layers, num_warmups, num_iters, (int)use_cuda_graph);
"use_cuda_graph=%d, tune_gemm=%d\n",
m, n, k, alpha, beta, num_layers, num_warmups, num_iters, (int)use_cuda_graph, (int)tune_gemm);

TestModel(m, n, k, alpha, beta, num_layers, num_warmups, num_iters, use_cuda_graph, nccl_comm);
TestModel(m, n, k, alpha, beta, num_layers, num_warmups, num_iters, use_cuda_graph, tune_gemm, nccl_comm, comm_rank,
comm_size);

CHECK_NCCL_ERROR(ncclCommDestroy(nccl_comm));

Expand Down
7 changes: 5 additions & 2 deletions tests/benchmarks/micro_benchmarks/test_dist_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_pytorch_dist_inference_normal():
assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
assert (benchmark._args.use_cuda_graph is False)
assert (benchmark._args.tune_gemm is False)

# Check results and metrics.
assert (benchmark.run_count == 1)
Expand Down Expand Up @@ -98,6 +99,7 @@ def test_pytorch_dist_inference_fake_distributed():
assert (benchmark._args.distributed_impl == DistributedImpl.DDP)
assert (benchmark._args.distributed_backend == DistributedBackend.NCCL)
assert (benchmark._args.use_cuda_graph is False)
assert (benchmark._args.tune_gemm is False)

# Check results and metrics.
assert (benchmark.run_count == 1)
Expand Down Expand Up @@ -136,7 +138,7 @@ def _test_dist_inference_command_generation(self, platform):
num_steps = 8
wrapper_params_format_str = \
'--batch_size %d --input_size %d --hidden_size %d ' \
'--alpha %g --beta %g --num_layers %d --num_warmup %d --num_steps %d --use_cuda_graph'
'--alpha %g --beta %g --num_layers %d --num_warmup %d --num_steps %d --use_cuda_graph --tune_gemm'
parameters = wrapper_params_format_str % (
batch_size, input_size, hidden_size, alpha, beta, num_layers, num_warmup, num_steps
)
Expand All @@ -161,14 +163,15 @@ def _test_dist_inference_command_generation(self, platform):
assert (benchmark._args.num_warmup == num_warmup)
assert (benchmark._args.num_steps == num_steps)
assert (benchmark._args.use_cuda_graph is True)
assert (benchmark._args.tune_gemm is True)

# Check command
assert (1 == len(benchmark._commands))
for cmd in benchmark._commands:
m, n, k = hidden_size, batch_size, input_size
bench_params_format_str = \
'%s -m %d -n %d -k %d --alpha %g --beta %g ' + \
'--num_layers %d --num_warmups %d --num_iters %d --use_cuda_graph'
'--num_layers %d --num_warmups %d --num_iters %d --use_cuda_graph --tune_gemm'
assert (
cmd == (
bench_params_format_str %
Expand Down

0 comments on commit cc89ee5

Please sign in to comment.