Skip to content

Commit

Permalink
Introduce miopen.mfma operator.
Browse files Browse the repository at this point in the history
This operator would be lowered to gpu.mfma operator.
  • Loading branch information
whchung committed Jul 6, 2020
1 parent 37597fd commit c445f8d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/MIOpen/MIOpenOps.td
Expand Up @@ -249,4 +249,20 @@ def MIOpen_ThreadwiseGemmOp:
}];
}

// mfma
def MIOpen_MFMAOp:
MIOpen_Op<"mfma">,
Arguments<(ins F32: $sourceA,
F32: $sourceB,
VectorOfRankAndType<[1], [F32]>: $destC,
I32: $cbsz,
I32: $abid,
I32: $blgp)>,
Results<(outs VectorOfRankAndType<[1], [F32]>: $destD)> {
let summary = "XDLOPS MFMA";
let description = [{
The `miopen.mfma` op is an abstraction of XDLOPS.
}];
}

#endif // MIOPEN_OPS
43 changes: 43 additions & 0 deletions mlir/lib/Dialect/MIOpen/MIOpenOps.cpp
Expand Up @@ -507,6 +507,49 @@ static LogicalResult verify(ThreadwiseCopyOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// MFMAOp
//===----------------------------------------------------------------------===//

static ParseResult parseMFMAOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType sourceA, sourceB, destC, cbsz, abid, blgp;
Type destType;
return failure(
parser.parseLParen() ||
parser.parseOperand(sourceA) ||
parser.parseComma() ||
parser.parseOperand(sourceB) ||
parser.parseComma() ||
parser.parseOperand(destC) ||
parser.parseComma() ||
parser.parseOperand(cbsz) ||
parser.parseComma() ||
parser.parseOperand(abid) ||
parser.parseComma() ||
parser.parseOperand(blgp) ||
parser.parseRParen() ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(destType) ||
parser.resolveOperand(sourceA, parser.getBuilder().getF32Type(), result.operands) ||
parser.resolveOperand(sourceB, parser.getBuilder().getF32Type(), result.operands) ||
parser.resolveOperand(destC, destType, result.operands) ||
parser.resolveOperand(cbsz, parser.getBuilder().getI32Type(), result.operands) ||
parser.resolveOperand(abid, parser.getBuilder().getI32Type(), result.operands) ||
parser.resolveOperand(blgp, parser.getBuilder().getI32Type(), result.operands) ||
parser.addTypeToList(destType, result.types));
return success();
}

static void print(OpAsmPrinter &p, miopen::MFMAOp op) {
p << op.getOperationName() << "(" << op.getOperands() << ")";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getType();
}

static LogicalResult verify(miopen::MFMAOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/MIOpen/ops_2.mlir
Expand Up @@ -217,3 +217,16 @@ func @miopen_threadwise_gemm(%lhs : memref<4x8xf32>, %rhs : memref<4x8xf32>, %ou

// CHECK-LABEL: func @miopen_threadwise_gemm
// CHECK: miopen.threadwise_gemm

func @miopen_mfma(%a : f32, %b : f32, %c : vector<32xf32>) {
%c0 = constant 0 : i32
%c1 = constant 1 : i32
miopen.mfma(%a, %b, %c, %c0, %c0, %c1) : vector<32xf32>
%d = miopen.mfma(%a, %b, %c, %c0, %c0, %c1) : vector<32xf32>

return
}

// CHECK-LABEL: func @miopen_mfma
// CHECK: miopen.mfma
// CHECK-NEXT: miopen.mfma

0 comments on commit c445f8d

Please sign in to comment.