Skip to content

Commit

Permalink
Revise the syntax of miopen.mfma so it takes memref instead of vector.
Browse files Browse the repository at this point in the history
Disable lowering test from miopen.mfma to gpu.mfma for now as it needs to be revised.
  • Loading branch information
whchung committed Jul 8, 2020
1 parent aab131f commit f482141
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 41 deletions.
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/MIOpen/MIOpenOps.td
Expand Up @@ -254,8 +254,7 @@ def MIOpen_MFMAOp:
MIOpen_Op<"mfma">,
Arguments<(ins F32: $sourceA,
F32: $sourceB,
VectorOfRankAndType<[1], [F32]>: $destC)>,
Results<(outs VectorOfRankAndType<[1], [F32]>: $destD)> {
MemRefRankOf<[F32], [1]>: $destC)> {
let summary = "XDLOPS MFMA";
let description = [{
The `miopen.mfma` op is an abstraction of XDLOPS.
Expand Down
30 changes: 15 additions & 15 deletions mlir/lib/Conversion/MIOpenToGPU/MIOpenToGPU.cpp
Expand Up @@ -236,21 +236,21 @@ void LowerMIOpenOpsToGPUPass::runOnOperation() {
op.erase();
});

gpuFunc.walk([&](miopen::MFMAOp op) {
auto loc = op.getLoc();
OpBuilder b(op.getContext());
b.setInsertionPoint(op);
Value newOp =
b.create<gpu::MFMAOp>(loc, op.getType(), op.sourceA(), op.sourceB(),
op.destC());
auto gpuMfmaOp = cast<gpu::MFMAOp>(newOp.getDefiningOp());
if (op.getAttr("m_per_wave"))
gpuMfmaOp.setAttr("m_per_wave", op.getAttr("m_per_wave"));
if (op.getAttr("n_per_wave"))
gpuMfmaOp.setAttr("n_per_wave", op.getAttr("n_per_wave"));
op.replaceAllUsesWith(newOp);
op.erase();
});
//gpuFunc.walk([&](miopen::MFMAOp op) {
// auto loc = op.getLoc();
// OpBuilder b(op.getContext());
// b.setInsertionPoint(op);
// Value newOp =
// b.create<gpu::MFMAOp>(loc, op.getType(), op.sourceA(), op.sourceB(),
// op.destC());
// auto gpuMfmaOp = cast<gpu::MFMAOp>(newOp.getDefiningOp());
// if (op.getAttr("m_per_wave"))
// gpuMfmaOp.setAttr("m_per_wave", op.getAttr("m_per_wave"));
// if (op.getAttr("n_per_wave"))
// gpuMfmaOp.setAttr("n_per_wave", op.getAttr("n_per_wave"));
// op.replaceAllUsesWith(newOp);
// op.erase();
//});
});
}
}
Expand Down
24 changes: 7 additions & 17 deletions mlir/lib/Dialect/MIOpen/MIOpenOps.cpp
Expand Up @@ -512,29 +512,19 @@ static LogicalResult verify(ThreadwiseCopyOp op) {
//===----------------------------------------------------------------------===//

static ParseResult parseMFMAOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType sourceA, sourceB, destC;
Type destType;
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return failure(
parser.parseLParen() ||
parser.parseOperand(sourceA) ||
parser.parseComma() ||
parser.parseOperand(sourceB) ||
parser.parseComma() ||
parser.parseOperand(destC) ||
parser.parseRParen() ||
parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(destType) ||
parser.resolveOperand(sourceA, parser.getBuilder().getF32Type(), result.operands) ||
parser.resolveOperand(sourceB, parser.getBuilder().getF32Type(), result.operands) ||
parser.resolveOperand(destC, destType, result.operands) ||
parser.addTypeToList(destType, result.types));
return success();
parser.parseColonTypeList(types) ||
parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands));
}

static void print(OpAsmPrinter &p, miopen::MFMAOp op) {
static void print(OpAsmPrinter &p, MFMAOp op) {
p << op.getOperationName() << "(" << op.getOperands() << ")";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getType();
p << " : " << op.getOperandTypes();
}

static LogicalResult verify(miopen::MFMAOp op) {
Expand Down
5 changes: 3 additions & 2 deletions mlir/test/Conversion/MIOpenToGPU/mfma.mlir
@@ -1,9 +1,10 @@
// XFAIL: *
// RUN: mlir-opt -convert-miopen-to-gpu="kernel-name=mfma" %s | FileCheck %s
// RUN: mlir-opt -convert-miopen-to-gpu="kernel-name=mfma" -convert-gpu-to-rocdl %s | FileCheck %s --check-prefix=ROCDL

module {
func @mfma(%a : f32, %b : f32, %c : vector<32xf32>) {
%d = miopen.mfma(%a, %b, %c) { m_per_wave = 64, n_per_wave = 64 }: vector<32xf32>
func @mfma(%a : f32, %b : f32, %c : memref<64xf32>) {
miopen.mfma(%a, %b, %c) { m_per_wave = 64, n_per_wave = 64 }: memref<64xf32>
// CHECK: %{{.*}} = gpu.mfma(%{{.*}}, %{{.*}}, %{{.*}}) {m_per_wave = 64 : i64, n_per_wave = 64 : i64} : vector<32xf32>
// ROCDL: %{{.*}} = rocdl.mfma.f32.32x32x1f32 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.float, !llvm.float, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>">

Expand Down
7 changes: 2 additions & 5 deletions mlir/test/Dialect/MIOpen/ops_2.mlir
Expand Up @@ -218,13 +218,10 @@ func @miopen_threadwise_gemm(%lhs : memref<4x8xf32>, %rhs : memref<4x8xf32>, %ou
// CHECK-LABEL: func @miopen_threadwise_gemm
// CHECK: miopen.threadwise_gemm

func @miopen_mfma(%a : f32, %b : f32, %c : vector<32xf32>) {
miopen.mfma(%a, %b, %c) : vector<32xf32>
%d = miopen.mfma(%a, %b, %c) : vector<32xf32>

func @miopen_mfma(%a : f32, %b : f32, %c : memref<64xf32>) {
miopen.mfma(%a, %b, %c) { m_per_wave = 64, n_per_wave = 64 } : f32, f32, memref<64xf32>
return
}

// CHECK-LABEL: func @miopen_mfma
// CHECK: miopen.mfma
// CHECK-NEXT: miopen.mfma

0 comments on commit f482141

Please sign in to comment.