diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 55fe0d9eb4771..2a8c330b761be 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -798,6 +798,26 @@ LogicalResult MmaOp::verify() { " attribute"); } + // Validate layout combinations. According to the operation description, most + // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16 + // can use other layout combinations. + bool isM8N8K4_F16 = + (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 && + getMultiplicandAPtxType() == MMATypes::f16); + + if (!isM8N8K4_F16) { + // For all other shapes/types, layoutA must be row and layoutB must be col + if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) { + return emitOpError("requires layoutA = #nvvm.mma_layout and " + "layoutB = #nvvm.mma_layout for shape <") + << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2] + << "> with element types " + << stringifyEnum(*getMultiplicandAPtxType()) << " and " + << stringifyEnum(*getMultiplicandBPtxType()) + << ". Only m8n8k4 with f16 supports other layouts."; + } + } + return success(); } diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index b2fe2f53b5a4e..6cccfe424d293 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -568,6 +568,18 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ llvm.return } +// ----- + +// Test that ensures invalid row/col layouts for matrices A and B are not accepted +llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { + // expected-error@+1 {{Only m8n8k4 with f16 supports other layouts.}} + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior=#nvvm.mma_int_overflow, + shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} // -----