Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
When N equals 8, the SM90 WGMMA Atom has an incorrect TV Layout C.
Steps/Code to reproduce bug
import cutlass
import cutlass.cute as cute
import cutlass.utils.hopper_helpers as sm90_utils
@cute.jit
def make_mma(N: cutlass.Constexpr):
tiled_mma = sm90_utils.make_trivial_tiled_mma(
cutlass.Float16,
cutlass.Float16,
cute.nvgpu.OperandMajorMode.K,
cute.nvgpu.OperandMajorMode.K,
cutlass.Float32,
(1, 1, 1),
tiler_mn=(64, N),
)
print(tiled_mma)
if __name__ == '__main__':
make_mma(8)
make_mma(16)
make_mma(32)
make_mma(64)
Expected behavior
Wrong Result:
Tiled MMA
Thr Layout VMNK: (128,1,1,1):(1,0,0,0)
Permutation MNK: (_,_,_)
MMA Atom
ThrID: 128:1
Shape MNK: (64,8,16)
TV Layout A: (128,(64,16)):(0,(1,64))
TV Layout B: (128,(8,16)):(0,(1,8))
TV Layout C: ((4,8,4),(2,2)):((128,1,16),(64,8))
Right Result:
Tiled MMA
Thr Layout VMNK: (128,1,1,1):(1,0,0,0)
Permutation MNK: (_,_,_)
MMA Atom
ThrID: 128:1
Shape MNK: (64,8,16)
TV Layout A: (128,(64,16)):(0,(1,64))
TV Layout B: (128,(8,16)):(0,(1,8))
TV Layout C: ((4,8,4),(2,2,1)):((128,1,16),(64,8,512))
Reference Code:
|
// Accumulator layouts |
|
template<int N> |
|
using CLayout_64xN = Layout<Shape <Shape < _4,_8, _4>,Shape < _2,_2,Int<N/8>>>, |
|
Stride<Stride<_128,_1,_16>,Stride<_64,_8, _512>>>; |
Environment details (please complete the following information):
- nvidia-cutlass-dsl 4.5.1
- nvidia-cutlass-dsl-libs-base 4.5.1
- nvidia-cutlass-dsl-libs-cu13 4.5.1
Additional context
None.
Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
When N equals 8, the SM90 WGMMA Atom has an incorrect TV Layout C.
Steps/Code to reproduce bug
Expected behavior
Wrong Result:
Right Result:
Reference Code:
cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp
Lines 432 to 435 in 982cb9e
Environment details (please complete the following information):
Additional context
None.