Skip to content

Commit

Permalink
Change how VGPRs are allocated.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent bbfbd34 commit 73799fc
Showing 1 changed file with 16 additions and 27 deletions.
43 changes: 16 additions & 27 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -1014,12 +1014,9 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm

auto zeroConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 0);
auto oneConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 1);
auto twoConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 2);

auto ldsMemorySpace = 3;
auto ldsMemorySpaceConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), ldsMemorySpace);
auto registerMemorySpace = 5;
auto registerMemorySpaceConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), registerMemorySpace);

// Obtain critical matrix dimensions.
int64_t K = op.getOperand(0).getType().template dyn_cast<MemRefType>().getShape()[0];
Expand All @@ -1044,7 +1041,6 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
auto elementType = op.output().getType().cast<MemRefType>().getElementType();

// Allocate LDS.
auto ldsBlockSizeConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), ldsBlockSize);
auto ldsMemRefType =
MemRefType::get({ldsBlockSize}, elementType, {}, ldsMemorySpace);
auto ldsGpuAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), ldsMemRefType);
Expand Down Expand Up @@ -1163,8 +1159,6 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm
int64_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
int64_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);

auto threadCRegisterSize = (GemmMRepeat * MPerThread) * (GemmNRepeat * NPerThread);
auto threadCRegisterSizeConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), threadCRegisterSize);
auto threadCRegisterMemRefType =
MemRefType::get({GemmMRepeat * MPerThread, GemmNRepeat * NPerThread}, elementType, {}, registerMemorySpace);
auto threadCAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadCRegisterMemRefType);
Expand All @@ -1188,17 +1182,16 @@ struct GridwiseGemmRewritePattern : public OpRewritePattern<miopen::GridwiseGemm

// Alloc for Matrix A / B on registers.
// TBD. compute thread A / B on registers from attributes.
auto threadARegisterSize = 1024;
auto threadARegisterSizeConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), threadARegisterSize);
int64_t ThreadSliceK = 8;
int64_t ThreadSliceM = 8;
int64_t ThreadSliceN = 8;
auto threadARegisterMemRefType =
MemRefType::get({threadARegisterSize}, elementType, {}, registerMemorySpace);
MemRefType::get({ThreadSliceK, ThreadSliceM}, elementType, {}, registerMemorySpace);
auto threadAEvenAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadARegisterMemRefType);
auto threadAOddAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadARegisterMemRefType);

auto threadBRegisterSize = 1024;
auto threadBRegisterSizeConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), threadBRegisterSize);
auto threadBRegisterMemRefType =
MemRefType::get({threadBRegisterSize}, elementType, {}, registerMemorySpace);
MemRefType::get({ThreadSliceK, ThreadSliceN}, elementType, {}, registerMemorySpace);
auto threadBEvenAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadBRegisterMemRefType);
auto threadBOddAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadBRegisterMemRefType);

Expand Down Expand Up @@ -1381,25 +1374,22 @@ struct BlockwiseGemmRewritePattern : public OpRewritePattern<miopen::BlockwiseGe
// Prepare some useful constants.
auto zeroConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 0);
auto oneConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 1);
auto twoConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 2);

auto registerMemorySpace = 5;
auto registerMemorySpaceConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), registerMemorySpace);

auto elementType = op.matrixC().getType().cast<MemRefType>().getElementType();

// Alloc register for thread_a and thread_b.
// TBD compute actual size from attributes.
auto threadARegisterSize = 1024;
auto threadARegisterSizeConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), threadARegisterSize);
int64_t KPerThreadLoop = 8;
int64_t MPerThread = 8;
int64_t NPerThread = 8;
auto threadARegisterMemRefType =
MemRefType::get({threadARegisterSize}, elementType, {}, registerMemorySpace);
MemRefType::get({KPerThreadLoop, MPerThread}, elementType, {}, registerMemorySpace);
auto threadAAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadARegisterMemRefType);

auto threadBRegisterSize = 1024;
auto threadBRegisterSizeConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), threadBRegisterSize);
auto threadBRegisterMemRefType =
MemRefType::get({threadARegisterSize}, elementType, {}, registerMemorySpace);
MemRefType::get({KPerThreadLoop, NPerThread}, elementType, {}, registerMemorySpace);
auto threadBAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadBRegisterMemRefType);

// Main loop.
Expand Down Expand Up @@ -1498,11 +1488,11 @@ struct BlockwiseCopyRewritePattern : public OpRewritePattern<miopen::BlockwiseCo
if (sourceType.getMemorySpace() == 0 && destType.getMemorySpace() == 3) {
// TBD. compute register size from attributes and operands.
auto registerMemorySpace = 5;
auto registerMemorySpaceConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), registerMemorySpace);
auto threadRegisterSize = 1024;
auto threadRegisterSizeConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), threadRegisterSize);

int64_t ThreadSliceRow = 8;
int64_t ThreadSliceCol = 8;
auto threadRegisterMemRefType =
MemRefType::get({threadRegisterSize}, elementType, {}, registerMemorySpace);
MemRefType::get({ThreadSliceRow, ThreadSliceCol}, elementType, {}, registerMemorySpace);
auto threadAllocOp = b.create<miopen::GpuAllocOp>(op.getLoc(), threadRegisterMemRefType);

// Threadwise copy from global (generic tensor) to register (naive tensor).
Expand Down Expand Up @@ -1640,7 +1630,7 @@ struct FillRewritePattern : public OpRewritePattern<miopen::FillOp> {

for (unsigned i = 0; i < inputShape[0]; ++i) {
auto iter = b.create<ConstantIndexOp>(loc, i);
auto storeOp = lb.create<StoreOp>(loc, op.value(), op.input(), ValueRange{iter});
lb.create<StoreOp>(loc, op.value(), op.input(), ValueRange{iter});
}
} else if (inputShape.size() == 2) {
// Rank 2 loop.
Expand All @@ -1662,7 +1652,7 @@ struct FillRewritePattern : public OpRewritePattern<miopen::FillOp> {
for (unsigned j = 0; j < inputShape[1]; ++j) {
auto iter1 = b.create<ConstantIndexOp>(loc, j);

auto storeOp = l1b.create<StoreOp>(loc, op.value(), op.input(), ValueRange{iter0, iter1});
l1b.create<StoreOp>(loc, op.value(), op.input(), ValueRange{iter0, iter1});
}
}
}
Expand Down Expand Up @@ -1711,7 +1701,6 @@ struct TransformRewritePattern : public OpRewritePattern<miopen::TransformOp> {
LogicalResult matchAndRewrite(miopen::TransformOp op,
PatternRewriter &b) const override {
auto loc = op.getLoc();
auto inputType = op.input().getType().cast<MemRefType>();
auto outputType = op.output().getType().cast<MemRefType>();

// Pass the output affine map to users of this op.
Expand Down

0 comments on commit 73799fc

Please sign in to comment.