Skip to content

Commit

Permalink
fixed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zjing14 committed Jul 9, 2020
1 parent 8b865d6 commit 5b4be26
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 36 deletions.
1 change: 0 additions & 1 deletion mlir/.clang-format
@@ -1,2 +1 @@
BasedOnStyle: LLVM
AlwaysBreakTemplateDeclarations: Yes
47 changes: 15 additions & 32 deletions mlir/include/mlir/Dialect/MIOpen/gridwise_gemm_params.h
Expand Up @@ -566,20 +566,23 @@ struct InitParamsNonXDL : InitParams {
// block gemm tuning params that sepcific the layout of thread-wise gemm in a
// workgroup
struct DerivedBlockGemmParams {
int64_t gemmMLevel0Cluster;
int64_t gemmNLevel0Cluster;
int64_t gemmMLevel1Cluster;
int64_t gemmNLevel1Cluster;
int64_t gemmMLevel0Cluster;
int64_t gemmNLevel0Cluster;
int64_t gemmMLevel1Cluster;
int64_t gemmNLevel1Cluster;
};

class PopulateParams : public PopulateParamsBase {
private:
// clang-format off
llvm::SmallVector<InitParamsNonXDL, 4> initParameters = {
// M/block N/block K/block M/thread N/thread blockSize
{128, 128, 8, 4, 4, 256}, {128, 64, 8, 4, 4, 128},
{64, 128, 4, 4, 4, 128}, {64, 64, 16, 4, 4, 64},
{32, 64, 16, 2, 4, 64}, {64, 32, 16, 4, 2, 64},
{128, 128, 8, 4, 4, 256},
{128, 64, 8, 4, 4, 128},
{64, 128, 4, 4, 4, 128},
{64, 64, 16, 4, 4, 64},
{32, 64, 16, 2, 4, 64},
{64, 32, 16, 4, 2, 64},
{32, 32, 4, 2, 2, 64}
};
// clang-format on
Expand All @@ -601,7 +604,8 @@ class PopulateParams : public PopulateParamsBase {
derived);
}

int64_t calculateGemmCDestDataPerWrite(InitParamsNonXDL *param, ConvolutionContext &ctx) {
int64_t calculateGemmCDestDataPerWrite(const InitParamsNonXDL *param,
const ConvolutionContext &ctx) {
int64_t outputVecLen = 0;
if ((ctx.opType == miopen::ConvOpType::Conv2DOpType) &&
(ctx.dimIndexVal["ko"].first == 3)) {
Expand All @@ -615,7 +619,7 @@ class PopulateParams : public PopulateParamsBase {
obtainGemmCVecLen(ctx, outputVecLen);
}

outputVecLen = std::__gcd(outputVecLen, param->gemmNPerThread);
outputVecLen = gcd(outputVecLen, param->gemmNPerThread);

if ((outputVecLen > 0) && (outputVecLen % 4 == 0)) {
return 4;
Expand All @@ -627,8 +631,8 @@ class PopulateParams : public PopulateParamsBase {
}

LogicalResult
CalculateBlockGemmPerformanceParameters(InitParamsNonXDL *param,
ConvolutionContext &ctx,
CalculateBlockGemmPerformanceParameters(const InitParamsNonXDL *param,
const ConvolutionContext &ctx,
DerivedBlockGemmParams &derived) {

derived.gemmMLevel0Cluster = 0;
Expand Down Expand Up @@ -690,26 +694,6 @@ class PopulateParams : public PopulateParamsBase {
return success();
}

void obtainGemmCWriteVecLen(ConvolutionContext &ctx, InitParamsNonXDL &params,
int64_t &vecLen) {
// Output tensor.
int64_t outputVecLen = 1;

if ((ctx.opType == miopen::ConvOpType::Conv2DOpType) &&
(ctx.dimIndexVal["ko"].first == 3)) {
// gemmM vectorizable. However, there is no parameters for vectorizing
// gemmM dimension for matrix C. Do nothing here.
} else if ((ctx.opType == miopen::ConvOpType::Conv2DBwdDataOpType) &&
(ctx.dimIndexVal["ci"].first == 3)) {
// gemmM vectorizable. However, there is no parameters for vectorizing
// gemmM dimension for matrix C. Do nothing here.
} else {
obtainGemmCVecLen(ctx, outputVecLen);
}

vecLen = std::__gcd(outputVecLen, params.gemmNPerThread);
}

public:
LogicalResult paramsFromCtx(ConvolutionContext &ctx,
InitParamsNonXDL &validParams, GemmSize &gemmSize,
Expand Down Expand Up @@ -747,7 +731,6 @@ class PopulateParams : public PopulateParamsBase {
res = calculateGemmBBlockCopyPerformanceParameters(&params, ctx,
gemmBDerivedParam);


if (failed(res)) {
LLVM_DEBUG(llvm::dbgs() << "Incoherent gemmB tuning parameter "
<< " size.\n");
Expand Down
Expand Up @@ -932,9 +932,9 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCFlags(ModuleOp m)
int64_t gridSize;

PopulateParams populateParams;
populateParams.paramsFromCtx(ctx, validParams, gemmSize,
gemmADerivedParam, gemmBDerivedParam, blockGemmDerivedParam,
gemmCDstPerWrite, gridSize);
populateParams.paramsFromCtx(
ctx, validParams, gemmSize, gemmADerivedParam, gemmBDerivedParam,
blockGemmDerivedParam, gemmCDstPerWrite, gridSize);

std::map<std::string, int> parameters;

Expand Down

0 comments on commit 5b4be26

Please sign in to comment.