diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 4f483859ac18d..cccdc2c368d6d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2014,6 +2014,9 @@ class MMA_LDST_OPS Geom, list Frags, list Types> { // llvm supports and can be extended as needed. class NVVM_MMA_OPS { // "wmma" operations + list> fp64_wmma_ops = MMA_OPS< + [GEOM<8, 8, 4>], + ["f64"], [], ["f64"], []>.ret; list> tf32_wmma_ops = MMA_OPS< [GEOM<16, 16, 8>], ["tf32"], [], ["f32"], []>.ret; @@ -2024,6 +2027,7 @@ class NVVM_MMA_OPS { [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>], ["s8","u8"], [], ["s32"], []>.ret; list> all_wmma_ops = !listconcat( + fp64_wmma_ops, tf32_wmma_ops, fp_wmma_ops, i8_wmma_ops); @@ -2040,9 +2044,17 @@ class NVVM_MMA_OPS { list ldst_tf32_cd_ops = MMA_LDST_OPS< [GEOM<16, 16, 8>], ["c", "d"], ["f32"]>.ret; + list ldst_f64_ab_ops = MMA_LDST_OPS< + [GEOM<8, 8, 4>], + ["a", "b"], ["f64"]>.ret; + list ldst_f64_cd_ops = MMA_LDST_OPS< + [GEOM<8, 8, 4>], + ["c", "d"], ["f64"]>.ret; list all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops, ldst_tf32_ab_ops, - ldst_tf32_cd_ops); + ldst_tf32_cd_ops, + ldst_f64_ab_ops, + ldst_f64_cd_ops); // Separate A/B/C fragments (loads) from D (stores). list all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d")); list all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d")); @@ -2349,7 +2361,7 @@ def MMAFragAttr : EnumAttr { } def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">, - Results<(outs LLVM_AnyStruct:$res)>, + Results<(outs AnyTypeOf<[LLVM_AnyStruct, F64]>:$res)>, Arguments<(ins LLVM_AnyPointer: $ptr, I32: $stride, I32Attr:$m, I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layout, MMATypesAttr:$eltype, MMAFragAttr:$frag)> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index f0de4dbcc1d4b..e08a47efedef7 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -896,6 +896,12 @@ std::pair NVVM::inferMMAType(NVVM::MMATypes type, } else if (type == NVVM::MMATypes::f32) { elementType = builder.getF32Type(); numberElements = 8; + } else if (type == NVVM::MMATypes::f64) { + elementType = builder.getF64Type(); + if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) + numberElements = 1; + else + numberElements = 2; } else if (type == NVVM::MMATypes::tf32) { elementType = builder.getI32Type(); numberElements = 4; @@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() { return emitOpError() << "invalid attribute combination"; std::pair typeInfo = inferMMATypeFromMNK( getEltype(), getFrag(), getM(), getN(), getK(), getContext()); + // Special case for f64 fragments + Type f64Ty = Float64Type::get(getContext()); + if (typeInfo.first == f64Ty && typeInfo.second == 1) { + if (getType() != f64Ty) + return emitOpError("expected destination type to be f64"); + return success(); + } + // Everything else is a struct Type dstType = LLVM::LLVMStructType::getLiteral( getContext(), SmallVector(typeInfo.second, typeInfo.first)); if (getType() != dstType) diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 09b8f593154b5..42aa2210eae1a 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -621,3 +621,14 @@ func.func @invalid_range_equal_bounds() { %0 = nvvm.read.ptx.sreg.warpsize range : i32 return } + +// ----- + +// Test for correct return type check for wmma.load fragment a for f64 +llvm.func @nvvm_wmma_load_a_f64(%arg0: !llvm.ptr, %arg1 : i32) { + // expected-error @below {{'nvvm.wmma.load' op expected destination type to be f64}} + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 4 : i32, layout = #nvvm.mma_layout, m = 8 : i32, n = 8 : i32} + : (!llvm.ptr) -> !llvm.struct<(f64)> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 594ae4849e3eb..9115de65ff0e8 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -463,6 +463,43 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : llvm.return } +// CHECK-LABEL: @nvvm_wmma_load_a_f64 +llvm.func @nvvm_wmma_load_a_f64(%arg0: !llvm.ptr, %arg1 : i32) { + // CHECK: call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p0(ptr %{{.*}}, i32 %{{.*}}) + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 4 : i32, layout = #nvvm.mma_layout, m = 8 : i32, n = 8 : i32} + : (!llvm.ptr) -> f64 + llvm.return +} + +// CHECK-LABEL: @nvvm_wmma_load_c_f64 +llvm.func @nvvm_wmma_load_c_f64(%arg0: !llvm.ptr, %arg1 : i32) { + // CHECK: call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p0(ptr %{{.*}}, i32 %{{.*}}) + %0 = nvvm.wmma.load %arg0, %arg1 + {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 4 : i32, layout = #nvvm.mma_layout, m = 8 : i32, n = 8 : i32} + : (!llvm.ptr) -> !llvm.struct<(f64, f64)> + llvm.return +} + +// CHECK-LABEL: @nvvm_wmma_mma_f64 +llvm.func @nvvm_wmma_mma_f64(%0 : f64, %1 : f64, %2 : f64, %3 : f64) { + // CHECK: { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.col.f64(double %{{.*}}, double %{{.*}}, double %{{.*}}, double %{{.*}}) + %r = nvvm.wmma.mma %0, %1, %2, %3 + {eltypeA = #nvvm.mma_type, eltypeB = #nvvm.mma_type, k = 4 : i32, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, m = 8 : i32, n = 8 : i32} + : (f64, f64, f64, f64) + -> !llvm.struct<(f64, f64)> + llvm.return +} + +// CHECK-LABEL: @nvvm_wmma_store_d_f64 +llvm.func @nvvm_wmma_store_d_f64(%arg0: !llvm.ptr, %arg1 : i32, %arg2 : f64, %arg3 : f64) { + // CHECK: call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p0(ptr %{{.*}}, double %{{.*}}, double %{{.*}}, i32 %{{.*}}) + nvvm.wmma.store %arg0, %arg1, %arg2, %arg3 + {eltype = #nvvm.mma_type, k = 4 : i32, layout = #nvvm.mma_layout, m = 8 : i32, n = 8 : i32} + : !llvm.ptr, f64, f64 + llvm.return +} + // CHECK-LABEL: @cp_async llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})