Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MMAMatrixType;
#define GEN_PASS_DECL_CONVERTGPUOPSTONVVMOPS
#include "mlir/Conversion/Passes.h.inc"

LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
Type convertMMAToLLVMType(gpu::MMAMatrixType type);

/// Configure target to convert from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def GPU_MMAMatrix : DialectType<
GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;

// Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, F64, VectorOfRankAndType<[1], [I8, I32, F16, F32, F64]>]>;

class MMAMatrixOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1872,7 +1872,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
```
}];

let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32, F64]>>:$src,
Arg<GPU_MMAMemRef, "",[MemWriteAt<0, FullEffect>]>:$dstMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension,
Expand Down Expand Up @@ -1919,9 +1919,9 @@ def GPU_SubgroupMmaComputeOp
```
}];

let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opA,
Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opB,
Arg<MMAMatrixOf<[I32, F16, F32, F64]>>:$opC,
OptionalAttr<UnitAttr>:$a_transpose,
OptionalAttr<UnitAttr>:$b_transpose);

Expand Down
52 changes: 42 additions & 10 deletions mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"

using namespace mlir;

Expand Down Expand Up @@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
if (type.getElementType().isF32())
return type.getOperand() == "COp" ? NVVM::MMATypes::f32
: NVVM::MMATypes::tf32;

if (type.getElementType().isF64())
return NVVM::MMATypes::f64;
if (type.getElementType().isSignedInteger(8))
return NVVM::MMATypes::s8;
if (type.getElementType().isUnsignedInteger(8))
Expand Down Expand Up @@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering
// then passed on to the intrinsic call. Emit llvm ops to extract individual
// values form lowered memrefs.
SmallVector<Value> unpackedOps;

auto unpackOp = [&](Value operand) {
// f64 a and b fragments are not structs but scalars.
if (!isa<LLVM::LLVMStructType>(operand.getType())) {
unpackedOps.push_back(operand);
return;
}
// every other type is lowered to an LLVM struct, extract the values.
auto structType = cast<LLVM::LLVMStructType>(operand.getType());
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
Expand Down Expand Up @@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering
return failure();
Location loc = subgroupMmaConstantOp.getLoc();
Value cst = adaptor.getOperands()[0];
LLVM::LLVMStructType type = convertMMAToLLVMType(
Type type = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
// If the element is not a struct, it means it's a scalar f64.
LLVM::LLVMStructType structType = dyn_cast<LLVM::LLVMStructType>(type);
if (!structType) {
rewriter.replaceOp(subgroupMmaConstantOp, cst);
return success();
}
// If the element type is a vector create a vector from the operand.
if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
Value idx = LLVM::ConstantOp::create(rewriter, loc,
Expand All @@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering
}
cst = vecCst;
}
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
matrixStruct =
LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
}
Expand Down Expand Up @@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering
return failure();
Location loc = subgroupMmaElementwiseOp.getLoc();
size_t numOperands = adaptor.getOperands().size();
LLVM::LLVMStructType destType = convertMMAToLLVMType(
Type destType = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {

// If the element is not a struct, it means it's a scalar f64.
LLVM::LLVMStructType structDestTy =
dyn_cast<LLVM::LLVMStructType>(destType);
if (!structDestTy) {
SmallVector<Value> operands;
for (auto operand : adaptor.getOperands()) {
operands.push_back(operand);
}
Value element = createScalarOp(
rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands);
rewriter.replaceOp(subgroupMmaElementwiseOp, element);
return success();
}
Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
extractedOperands.push_back(LLVM::ExtractValueOp::create(
Expand All @@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering
} // namespace

/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
NVVM::MMAFrag frag = convertOperand(type.getOperand());
NVVM::MMATypes eltType = getElementType(type);
auto nRow = type.getShape()[0];
auto nCol = type.getShape()[1];
std::pair<Type, unsigned> typeInfo =
NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
// Special handling for f64 a and b fragments
Type f64Ty = Float64Type::get(type.getContext());
if (typeInfo.first == f64Ty && typeInfo.second == 1) {
return f64Ty;
}
return LLVM::LLVMStructType::getLiteral(
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }

bool MMAMatrixType::isValidElementType(Type elementType) {
return elementType.isF16() || elementType.isF32() ||
return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
elementType.isInteger(32);
}
Expand All @@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,

if (!MMAMatrixType::isValidElementType(elementType))
return emitError()
<< "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
<< "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";

return success();
}
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,28 @@ gpu.module @test_module {

// -----

gpu.module @test_module {

// CHECK-LABEL: func @gpu_wmma_f64_load_op() ->
// CHECK-SAME: f64
// CHECK32-LABEL: func @gpu_wmma_f64_load_op() ->
func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) {
%wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3>
%i = arith.constant 16 : index
%j = arith.constant 16 : index
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp">
return %0 : !gpu.mma_matrix<8x4xf64, "AOp">
// CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
// CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64
// CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64
// CHECK: llvm.return %[[LOAD]] : f64
}
}

// -----

gpu.module @test_module {

// CHECK-LABEL: func @gpu_wmma_store_op
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/GPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ func.func @mmamatrix_operand_type(){
func.func @mmamatrix_invalid_element_type(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
// expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}}
// expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp">
return
}
Expand All @@ -708,7 +708,7 @@ func.func @mmaLoadOp_identity_layout(){
// -----

func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) {
// expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}}
// expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float values of ranks 1 values}}
%0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp">
return
}
Expand Down
72 changes: 72 additions & 0 deletions mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: mlir-opt %s \
// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_cuda_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
// RUN: --entry-point-result=void \
// RUN: | FileCheck %s

#map0 = affine_map<(d0, d1) -> (d1, d0)>

func.func @main() {
%a = memref.alloc() : memref<8x4xf64>
%b = memref.alloc() : memref<4x8xf64>
%c = memref.alloc() : memref<8x8xf64>
%d = memref.alloc() : memref<8x8xf64>

%f1 = arith.constant 1.0e+00 : f64
%fcst = arith.constant 3.14e+00 : f64
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index

// Initialize the Input matrixes with ones.
scf.for %arg0 = %c0 to %c8 step %c1 {
scf.for %arg1 = %c0 to %c4 step %c1 {
memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64>
memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64>
}
}
// Initialize the accumulator matrix with a constant.
scf.for %arg0 = %c0 to %c8 step %c1 {
scf.for %arg1 = %c0 to %c8 step %c1 {
memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64>
}
}

%2 = memref.cast %a : memref<8x4xf64> to memref<*xf64>
%20 = memref.cast %b : memref<4x8xf64> to memref<*xf64>
%33 = memref.cast %c : memref<8x8xf64> to memref<*xf64>
%34 = memref.cast %d : memref<8x8xf64> to memref<*xf64>

gpu.host_register %2 : memref<*xf64>
gpu.host_register %20 : memref<*xf64>
gpu.host_register %33 : memref<*xf64>

gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
%A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp">
%B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp">
%C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp">

%R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp">

gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64>
gpu.terminator
}
// Print the memref after computation.
call @printMemrefF64(%34) : (memref<*xf64>) -> ()
// CHECK: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
// CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14]
return
}

func.func private @printMemrefF64(memref<*xf64>)