Skip to content

Commit

Permalink
Make blockwise_gemm take 1D memref so we don't need std.view op.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 5064128 commit d1e7877
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 49 deletions.
76 changes: 42 additions & 34 deletions mlir/include/mlir/Dialect/MIOpenOps/LowerMIOpenOps.h
Expand Up @@ -1054,18 +1054,19 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
auto lds2DMatrixAHeight = KPerBlock;
auto lds2DMatrixAWidth = MPerBlock;
auto lds2DMatrixAMemRefType =
MemRefType::get({lds2DMatrixAHeight, lds2DMatrixAWidth}, b.getF32Type(), {}, ldsMemorySpace);

llvm::SmallVector<int64_t, 2> lds2DMatrixADim {lds2DMatrixAHeight, lds2DMatrixAWidth};
llvm::SmallVector<NamedAttribute, 8> lds2DMatrixADimAttr {
b.getNamedAttr("dimensions", b.getI64ArrayAttr(lds2DMatrixADim)),
};
//auto lds2DMatrixAMemRefType =
// MemRefType::get({lds2DMatrixAHeight, lds2DMatrixAWidth}, b.getF32Type(), {}, ldsMemorySpace);

//llvm::SmallVector<int64_t, 2> lds2DMatrixADim {lds2DMatrixAHeight, lds2DMatrixAWidth};
//llvm::SmallVector<NamedAttribute, 8> lds2DMatrixADimAttr {
// b.getNamedAttr("dimensions", b.getI64ArrayAttr(lds2DMatrixADim)),
//};

auto lds2DMatrixAEvenSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixAMemRefType, ldsBlockAEvenSubviewOp, zeroConstantIndexOp);
lds2DMatrixAEvenSubviewOp.setAttrs(lds2DMatrixADimAttr);
auto lds2DMatrixAOddSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixAMemRefType, ldsBlockAOddSubviewOp, zeroConstantIndexOp);
lds2DMatrixAOddSubviewOp.setAttrs(lds2DMatrixADimAttr);
//auto lds2DMatrixAEvenSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixAMemRefType, ldsBlockAEvenSubviewOp, zeroConstantIndexOp);
//lds2DMatrixAEvenSubviewOp.setAttrs(lds2DMatrixADimAttr);
//auto lds2DMatrixAOddSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixAMemRefType, ldsBlockAOddSubviewOp, zeroConstantIndexOp);
//lds2DMatrixAOddSubviewOp.setAttrs(lds2DMatrixADimAttr);


// Subviews for Matrix B.
Expand Down Expand Up @@ -1096,18 +1097,19 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
auto lds2DMatrixBHeight = KPerBlock;
auto lds2DMatrixBWidth = NPerBlock;
auto lds2DMatrixBMemRefType =
MemRefType::get({lds2DMatrixBHeight, lds2DMatrixBWidth}, b.getF32Type(), {}, ldsMemorySpace);

llvm::SmallVector<int64_t, 2> lds2DMatrixBDim {lds2DMatrixBHeight, lds2DMatrixBWidth};
llvm::SmallVector<NamedAttribute, 1> lds2DMatrixBDimAttr {
b.getNamedAttr("dimensions", b.getI64ArrayAttr(lds2DMatrixBDim)),
};
//auto lds2DMatrixBMemRefType =
// MemRefType::get({lds2DMatrixBHeight, lds2DMatrixBWidth}, b.getF32Type(), {}, ldsMemorySpace);

//llvm::SmallVector<int64_t, 2> lds2DMatrixBDim {lds2DMatrixBHeight, lds2DMatrixBWidth};
//llvm::SmallVector<NamedAttribute, 1> lds2DMatrixBDimAttr {
// b.getNamedAttr("dimensions", b.getI64ArrayAttr(lds2DMatrixBDim)),
//};

auto lds2DMatrixBEvenSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixBMemRefType, ldsBlockBEvenSubviewOp, zeroConstantIndexOp);
lds2DMatrixBEvenSubviewOp.setAttrs(lds2DMatrixBDimAttr);
auto lds2DMatrixBOddSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixBMemRefType, ldsBlockBOddSubviewOp, zeroConstantIndexOp);
lds2DMatrixBOddSubviewOp.setAttrs(lds2DMatrixBDimAttr);
//auto lds2DMatrixBEvenSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixBMemRefType, ldsBlockBEvenSubviewOp, zeroConstantIndexOp);
//lds2DMatrixBEvenSubviewOp.setAttrs(lds2DMatrixBDimAttr);
//auto lds2DMatrixBOddSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), lds2DMatrixBMemRefType, ldsBlockBOddSubviewOp, zeroConstantIndexOp);
//lds2DMatrixBOddSubviewOp.setAttrs(lds2DMatrixBDimAttr);


// Alloc for Matrix C on registers.
Expand All @@ -1130,17 +1132,18 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// Compute matrix C dimension from attributes.
auto register2DMatrixCHeight = (GemmMRepeat * MPerThread);
auto register2DMatrixCWidth = (GemmNRepeat * NPerThread);
auto register2DMatrixCMemRefType =
MemRefType::get({register2DMatrixCHeight, register2DMatrixCWidth}, b.getF32Type(), {}, registerMemorySpace);

llvm::SmallVector<int64_t, 2> register2DMatrixCDim {register2DMatrixCHeight, register2DMatrixCWidth};
llvm::SmallVector<NamedAttribute, 1> register2DMatrixCDimAttr {
b.getNamedAttr("dimensions", b.getI64ArrayAttr(register2DMatrixCDim)),
};
//auto register2DMatrixCMemRefType =
// MemRefType::get({register2DMatrixCHeight, register2DMatrixCWidth}, b.getF32Type(), {}, registerMemorySpace);

//llvm::SmallVector<int64_t, 2> register2DMatrixCDim {register2DMatrixCHeight, register2DMatrixCWidth};
//llvm::SmallVector<NamedAttribute, 1> register2DMatrixCDimAttr {
// b.getNamedAttr("dimensions", b.getI64ArrayAttr(register2DMatrixCDim)),
//};

//auto register2DMatrixCSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), register2DMatrixCMemRefType, threadCAllocOp, zeroConstantIndexOp);
//register2DMatrixCSubviewOp.setAttrs(register2DMatrixCDimAttr);

auto register2DMatrixCSubviewOp = b.create<miopen::SubviewOp>(op.getLoc(), register2DMatrixCMemRefType, threadCAllocOp, zeroConstantIndexOp);
register2DMatrixCSubviewOp.setAttrs(register2DMatrixCDimAttr);


// Alloc for Matrix A / B on registers.
// TBD. compute thread A / B on registers from attributes.
Expand Down Expand Up @@ -1253,7 +1256,8 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
blockwiseCopyOpBEven.setAttr("move_source_slice_window", b.getI32IntegerAttr(KPerBlock));

// Emit blockwise GEMM.
auto blockwiseGemmEvenOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp, register2DMatrixCSubviewOp);
//auto blockwiseGemmEvenOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp, register2DMatrixCSubviewOp);
auto blockwiseGemmEvenOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), ldsBlockAEvenSubviewOp, ldsBlockBEvenSubviewOp, threadCAllocOp);
affixBlockwiseGemmAttributes(blockwiseGemmEvenOp, op);

// Blockwise copy from reigster (naitve tensor) to LDS (naive tensor).
Expand All @@ -1272,7 +1276,8 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
blockwiseCopyOpBOdd.setAttr("move_source_slice_window", b.getI32IntegerAttr(KPerBlock));

// Emit blockwise GEMM.
auto blockwiseGemmOddOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp, register2DMatrixCSubviewOp);
//auto blockwiseGemmOddOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp, register2DMatrixCSubviewOp);
auto blockwiseGemmOddOp = lb.create<miopen::BlockwiseGemmOp>(op.getLoc(), ldsBlockAOddSubviewOp, ldsBlockBOddSubviewOp, threadCAllocOp);
affixBlockwiseGemmAttributes(blockwiseGemmOddOp, op);

// Blockwise copy from reigster (naitve tensor) to LDS (naive tensor).
Expand All @@ -1287,10 +1292,12 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm

// Emit blockwise GEMM for the loop tail.
if (loopIteration % 2) {
auto blockwiseGemmTailEvenOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp, register2DMatrixCSubviewOp);
//auto blockwiseGemmTailEvenOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAEvenSubviewOp, lds2DMatrixBEvenSubviewOp, register2DMatrixCSubviewOp);
auto blockwiseGemmTailEvenOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), ldsBlockAEvenSubviewOp, ldsBlockBEvenSubviewOp, threadCAllocOp);
affixBlockwiseGemmAttributes(blockwiseGemmTailEvenOp, op);
} else {
auto blockwiseGemmTailOddOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp, register2DMatrixCSubviewOp);
//auto blockwiseGemmTailOddOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), lds2DMatrixAOddSubviewOp, lds2DMatrixBOddSubviewOp, register2DMatrixCSubviewOp);
auto blockwiseGemmTailOddOp = b.create<miopen::BlockwiseGemmOp>(op.getLoc(), ldsBlockAOddSubviewOp, ldsBlockBOddSubviewOp, threadCAllocOp);
affixBlockwiseGemmAttributes(blockwiseGemmTailOddOp, op);
}

Expand All @@ -1312,7 +1319,8 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
// n_thread_data_on_global / N1,
// n_thread_data_on_global % N1})
// .Run(p_c_thread, p_c_global);
b.create<miopen::ThreadwiseCopyOp>(op.getLoc(), register2DMatrixCSubviewOp, op.getOperand(2));
//b.create<miopen::ThreadwiseCopyOp>(op.getLoc(), register2DMatrixCSubviewOp, op.getOperand(2));
b.create<miopen::ThreadwiseCopyOp>(op.getLoc(), threadCAllocOp, op.getOperand(2));

op.erase();

Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
Expand Up @@ -181,9 +181,9 @@ def MIOpen_ThreadwiseCopyOp:
// blockwise_gemm
def MIOpen_BlockwiseGemmOp:
MIOpen_Op<"blockwise_gemm">,
Arguments<(ins MemRefRankOf<[F32], [2]>,
MemRefRankOf<[F32], [2]>,
MemRefRankOf<[F32], [2]>)> {
Arguments<(ins AnyMemRef:$matrixA, // MemRefRankOf<[F32], [2]>,
AnyMemRef:$matrixB, // MemRefRankOf<[F32], [2]>,
AnyMemRef:$matrixC)> { // MemRefRankOf<[F32], [2]>)> {
let summary = "Blockwise GEMM";
let description = [{
The `miopen.block_gemm` op does GEMM at workgroup (block) level.
Expand All @@ -195,9 +195,9 @@ def MIOpen_BlockwiseGemmOp:
// threadwise_gemm
def MIOpen_ThreadwiseGemmOp:
MIOpen_Op<"threadwise_gemm">,
Arguments<(ins AnyMemRef,
AnyMemRef,
AnyMemRef)> {
Arguments<(ins AnyMemRef:$matrixA,
AnyMemRef:$matrixB,
AnyMemRef:$matrixC)> {
let summary = "Threadwise GEMM";
let description = [{
The `miopen.threadwise_gemm` op does GEMM at thread level.
Expand Down
12 changes: 3 additions & 9 deletions mlir/lib/Dialect/MIOpenOps/LLVMOutput/ConvertMIOpenOpsToStd.cpp
Expand Up @@ -98,15 +98,9 @@ void LowerMIOpenOpsToStdPass::runOnModule() {
OpBuilder b(op.getContext());
b.setInsertionPoint(op);

if (outputShape.size() == 2) {
auto viewOp = b.create<ViewOp>(loc, op.output().getType(), op.input(), ArrayRef<Value>{});
op.replaceAllUsesWith(viewOp.getResult());
op.erase();
} else {
auto subviewOp = b.create<SubViewOp>(loc, op.output().getType(), op.input());
op.replaceAllUsesWith(subviewOp.getResult());
op.erase();
}
auto subviewOp = b.create<SubViewOp>(loc, op.output().getType(), op.input());
op.replaceAllUsesWith(subviewOp.getResult());
op.erase();
});

func.walk([&](miopen::LdsBarrierOp op) {
Expand Down

0 comments on commit d1e7877

Please sign in to comment.