Skip to content

Commit

Permalink
Introduce miopen.xdlops_gemm op.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jul 20, 2020
1 parent e3c108b commit 402d072
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/MIOpen/MIOpenOps.td
Expand Up @@ -261,4 +261,20 @@ def MIOpen_MFMAOp:
}];
}

// xdlops_gemm
def MIOpen_XdlopsGemmOp:
MIOpen_Op<"xdlops_gemm">,
Arguments<(ins AnyMemRef:$matrixA,
AnyMemRef:$matrixB,
AnyMemRef:$matrixC,
Index:$threadOffsetA,
Index:$threadOffsetB,
Index:$threadOffsetC)> {
let summary = "XDLOPS GEMM";
let description = [{
The `miopen.xdlops_gemm` op is an abstraction of doing GEMM based on XDLOPS.
It would employ a series of `miopen.mfma` operations.
}];
}

#endif // MIOPEN_OPS
24 changes: 24 additions & 0 deletions mlir/lib/Dialect/MIOpen/MIOpenOps.cpp
Expand Up @@ -533,6 +533,30 @@ static LogicalResult verify(miopen::MFMAOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// XdlopsGemmOp
//===----------------------------------------------------------------------===//

static ParseResult parseXdlopsGemmOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 6> ops;
SmallVector<Type, 6> types;
return failure(
parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(types) ||
parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands));
}

static void print(OpAsmPrinter &p, XdlopsGemmOp op) {
p << op.getOperationName() << "(" << op.getOperands() << ")";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getOperandTypes();
}

static LogicalResult verify(XdlopsGemmOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/MIOpen/ops_2.mlir
Expand Up @@ -241,3 +241,17 @@ func @miopen_mfma_bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : memref<64x

// CHECK-LABEL: func @miopen_mfma_bf16
// CHECK: miopen.mfma

func @miopen_xdlops_gemm(%A : memref<?x?xf32, 3>, %B : memref<?x?xf32, 3>, %C : memref<?x?xf32, 5>) {
%c0 = constant 0 : index
miopen.xdlops_gemm(%A, %B, %C, %c0, %c0, %c0) {
m_per_thread = 64,
n_per_thread = 64,
m_per_wave = 64,
n_per_wave = 64
} : memref<?x?xf32, 3>, memref<?x?xf32, 3>, memref<?x?xf32, 5>, index, index, index
return
}

// CHECK-LABEL: func @miopen_xdlops_gemm
// CHECK: miopen.xdlops_gemm

0 comments on commit 402d072

Please sign in to comment.