diff --git a/csrc/src/utils.h b/csrc/src/utils.h index 79d91fa..2e2df1b 100644 --- a/csrc/src/utils.h +++ b/csrc/src/utils.h @@ -497,7 +497,7 @@ __forceinline__ __device__ void copy( //////////////////////////////////////////////////////////////////////////////////////////////////// -template __forceinline__ __device__ void copy_ZOH( @@ -505,11 +505,11 @@ __forceinline__ __device__ void copy_ZOH( Tensor &D, Tensor const &identity_MN, const int max_M=0, const int max_N=0 ) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, M, N) - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, M, N) + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // N + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N #pragma unroll for (int m = 0; m < size<1>(S); ++m) {