Skip to content

Commit

Permalink
Add threadwise GEMM op and its lowering logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent be5e214 commit d8eb253
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 9 deletions.
60 changes: 52 additions & 8 deletions mlir/include/mlir/Dialect/MIOpenOps/LowerMIOpenOps.h
Expand Up @@ -1018,22 +1018,66 @@ struct BlockwiseGemmRewritePattern : public OpRewritePattern<miopen::BlockwiseGe
using OpRewritePattern<miopen::BlockwiseGemmOp>::OpRewritePattern;

PatternMatchResult naiveRewrite(miopen::BlockwiseGemmOp op, PatternRewriter &b) const {
// Prepare some useful constants.
auto zeroConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 0);
auto oneConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 1);
auto twoConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 2);

auto registerMemorySpace = 5;
auto registerMemorySpaceConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), registerMemorySpace);

// Alloc register for thread_a and thread_b.
// TBD compute actual size from attributes.
auto threadARegisterSize = 1024;
auto threadARegisterSizeConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), threadARegisterSize);
auto threadARegisterMemRefType =
MemRefType::get({threadARegisterSize}, b.getIntegerType(8), {}, registerMemorySpace);
auto threadAAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadARegisterMemRefType, threadARegisterSizeConstantIndexOp, registerMemorySpaceConstantIndexOp);

auto threadBRegisterSize = 1024;
auto threadBRegisterSizeConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), threadBRegisterSize);
auto threadBRegisterMemRefType =
MemRefType::get({threadARegisterSize}, b.getIntegerType(8), {}, registerMemorySpace);
auto threadBAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadBRegisterMemRefType, threadBRegisterSizeConstantIndexOp, registerMemorySpaceConstantIndexOp);

// Main loop.
// TBD. compute loop iterations from attributes.
auto loopIteration = 15;
auto loopIterationConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), loopIteration);
auto loopOp = b.create<loop::ForOp>(op.getLoc(), zeroConstantIndexOp, loopIterationConstantIndexOp, oneConstantIndexOp);

// TBD alloc register for thread_a and thread_b.
// inside the main loop.
auto lb = loopOp.getBodyBuilder();

// read matrix A loop.
// TBD. compute loop iterations from attributes.
auto loopReadMatrixAIteration = 15;
auto loopReadMatrixAIterationConstantIndexOp = lb.create<ConstantIndexOp>(op.getLoc(), loopReadMatrixAIteration);
auto loopReadMatrixAOp = lb.create<loop::ForOp>(op.getLoc(), zeroConstantIndexOp, loopReadMatrixAIterationConstantIndexOp, oneConstantIndexOp);

// TBD loop.
// inside read matrix A loop.
auto lab = loopReadMatrixAOp.getBodyBuilder();

// TBD read matrix A loop.
// Threadwise copy from LDS (naive tensor) to register (generic tensor).
lab.create<miopen::ThreadwiseCopyOp>(op.getLoc(), op.getOperand(0), threadAAllocOp);

// TBD threadwise_copy.
// read matrix B loop.
// TBD. compute loop iterations from attributes.
auto loopReadMatrixBIteration = 15;
auto loopReadMatrixBIterationConstantIndexOp = lb.create<ConstantIndexOp>(op.getLoc(), loopReadMatrixBIteration);
auto loopReadMatrixBOp = lb.create<loop::ForOp>(op.getLoc(), zeroConstantIndexOp, loopReadMatrixBIterationConstantIndexOp, oneConstantIndexOp);

// TBD read matrix B loop.
// inside read matrix A loop.
auto lbb = loopReadMatrixBOp.getBodyBuilder();

// TBD threadwise_copy.
// Threadwise copy from LDS (naive tensor) to register (generic tensor).
lbb.create<miopen::ThreadwiseCopyOp>(op.getLoc(), op.getOperand(1), threadBAllocOp);

// TBD threadwise_gemm.
// Emit threadwise GEMM.
lb.create<miopen::ThreadwiseGemmOp>(op.getLoc(), threadAAllocOp, threadBAllocOp, op.getOperand(2));
// TBD add attributes.

//op.erase();
op.erase();
return matchSuccess();
}

Expand Down
16 changes: 15 additions & 1 deletion mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
Expand Up @@ -186,10 +186,24 @@ def MIOpen_BlockwiseGemmOp:
MemRefRankOf<[F32], [2]>)> {
let summary = "Blockwise GEMM";
let description = [{
The `miopen.blockgemm` op does GEMM at workgroup (block) level.
The `miopen.block_gemm` op does GEMM at workgroup (block) level.
- Matrix A and Matrix B shall reside on LDS (naive tensor).
- Matrix C shall reside on register (naive tensor).
}];
}

// threadwise_gemm
def MIOpen_ThreadwiseGemmOp:
MIOpen_Op<"threadwise_gemm">,
Arguments<(ins AnyMemRef,
AnyMemRef,
AnyMemRef)> {
let summary = "Threadwise GEMM";
let description = [{
The `miopen.threadwise_gemm` op does GEMM at thread level.
- Matrix A and Matrix B shall reside on register (naive tensor).
- Matrix C shall reside on LDS (naive tensor).
}];
}

#endif // MIOPEN_OPS
24 changes: 24 additions & 0 deletions mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp
Expand Up @@ -292,6 +292,30 @@ static LogicalResult verify(BlockwiseGemmOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// ThreadwiseGemmOp
//===----------------------------------------------------------------------===//

static ParseResult parseThreadwiseGemmOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return failure(
parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(types) ||
parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands));
}

static void print(OpAsmPrinter &p, ThreadwiseGemmOp op) {
p << op.getOperationName() << "(" << op.getOperands() << ")";
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getOperandTypes();
}

static LogicalResult verify(ThreadwiseGemmOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// BlockwiseCopyOp
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit d8eb253

Please sign in to comment.