Skip to content

Commit

Permalink
Fix matrix size overflow issue
Browse files Browse the repository at this point in the history
Fix matrix size overflow issue when cast from int to size_t implicitly.
  • Loading branch information
abuccts committed Mar 25, 2023
1 parent c88c970 commit 857a8ba
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ template <typename T> cudaDataType_t get_datatype() {
}

template <typename Ta, typename Tb, typename Tout>
float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) {
float timing_matmul_tn(size_t m, size_t n, size_t k, int batch, int warmup, int iter) {
// init matrix
Ta *matrix_a = nullptr;
Tb *matrix_b = nullptr;
Expand All @@ -101,7 +101,7 @@ float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) {
init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * std::max(batch, 1));

// init gemm
int lda = k, ldb = k, ldd = m;
size_t lda = k, ldb = k, ldd = m;
std::unique_ptr<cublasLtGemm> gemm = std::make_unique<cublasLtGemm>();
gemm->Init();
gemm->Setup(m, n, k, batch, lda, ldb, ldd, get_datatype<Ta>(), get_datatype<Tb>(), get_datatype<Tout>(),
Expand Down

0 comments on commit 857a8ba

Please sign in to comment.