Skip to content

Commit

Permalink
Compute loop iterations in GridwiseGemmOp lowering.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 0ca2533 commit 52a70a2
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/MIOpenOps/LowerMIOpenOps.h
Expand Up @@ -970,6 +970,10 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
auto registerMemorySpace = 5;
auto registerMemorySpaceConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), registerMemorySpace);

// Obtain critical matrix dimensions.
int64_t K = op.getOperand(0).getType().dyn_cast<MemRefType>().getShape()[0];
int64_t M = op.getOperand(0).getType().dyn_cast<MemRefType>().getShape()[1];
int64_t N = op.getOperand(1).getType().dyn_cast<MemRefType>().getShape()[1];

// Obtain critical tuning parameters.
int64_t KPerBlock = op.getAttr("k_per_block").dyn_cast<IntegerAttr>().getInt();
Expand Down Expand Up @@ -1176,8 +1180,8 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
b.create<miopen::BlockwiseCopyOp>(op.getLoc(), op.getOperand(1), ldsBlockBEvenSubviewOp);

// Emit loop.
// TBD. compute loop iterations from attributes.
auto loopIteration = 15;
// Compute loop iterations from attributes.
auto loopIteration = K / (KPerBlock * 2);
auto loopIterationConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), loopIteration);
auto loopOp = b.create<loop::ForOp>(op.getLoc(), zeroConstantIndexOp, loopIterationConstantIndexOp, oneConstantIndexOp);

Expand Down

0 comments on commit 52a70a2

Please sign in to comment.