Skip to content

Commit 857a8ba

Browse files
committed
Fix matrix size overflow issue
Fix matrix size overflow issue when cast from int to size_t implicitly.
1 parent c88c970 commit 857a8ba

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

superbench/benchmarks/micro_benchmarks/cublaslt_gemm/cublaslt_gemm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ template <typename T> cudaDataType_t get_datatype() {
8888
}
8989

9090
template <typename Ta, typename Tb, typename Tout>
91-
float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) {
91+
float timing_matmul_tn(size_t m, size_t n, size_t k, int batch, int warmup, int iter) {
9292
// init matrix
9393
Ta *matrix_a = nullptr;
9494
Tb *matrix_b = nullptr;
@@ -101,7 +101,7 @@ float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) {
101101
init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * std::max(batch, 1));
102102

103103
// init gemm
104-
int lda = k, ldb = k, ldd = m;
104+
size_t lda = k, ldb = k, ldd = m;
105105
std::unique_ptr<cublasLtGemm> gemm = std::make_unique<cublasLtGemm>();
106106
gemm->Init();
107107
gemm->Setup(m, n, k, batch, lda, ldb, ldd, get_datatype<Ta>(), get_datatype<Tb>(), get_datatype<Tout>(),

0 commit comments

Comments
 (0)