Skip to content

Commit

Permalink
Merge pull request #26 from zjing14/tuning_fix
Browse files Browse the repository at this point in the history
Fixed tuning process
  • Loading branch information
whchung committed Jul 13, 2020
2 parents dab2441 + 3798bf9 commit 976a47d
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 19 deletions.
1 change: 0 additions & 1 deletion mlir/.clang-format
@@ -1,2 +1 @@
BasedOnStyle: LLVM
AlwaysBreakTemplateDeclarations: Yes
105 changes: 96 additions & 9 deletions mlir/include/mlir/Dialect/MIOpen/gridwise_gemm_params.h
Expand Up @@ -563,17 +563,27 @@ struct InitParamsNonXDL : InitParams {
int64_t blockSize;
};

// block gemm tuning params that sepcific the layout of thread-wise gemm in a
// workgroup
struct DerivedBlockGemmParams {
int64_t gemmMLevel0Cluster;
int64_t gemmNLevel0Cluster;
int64_t gemmMLevel1Cluster;
int64_t gemmNLevel1Cluster;
};

class PopulateParams : public PopulateParamsBase {
private:
// clang-format off
llvm::SmallVector<InitParamsNonXDL, 4> initParameters = {
// M/block N/block K/block M/thread N/thread blockSize
{128, 128, 8, 4, 4, 256},
{128, 64, 8, 4, 4, 128},
{64, 128, 4, 4, 4, 128},
{64, 64, 16, 4, 4, 64},
{32, 64, 16, 2, 4, 64},
{32, 32, 4, 2, 2, 64},
// M/block N/block K/block M/thread N/thread blockSize
{128, 128, 8, 4, 4, 256},
{128, 64, 8, 4, 4, 128},
{64, 128, 4, 4, 4, 128},
{64, 64, 16, 4, 4, 64},
{32, 64, 16, 2, 4, 64},
{64, 32, 16, 4, 2, 64},
{32, 32, 4, 2, 2, 64}
};
// clang-format on

Expand All @@ -594,7 +604,8 @@ class PopulateParams : public PopulateParamsBase {
derived);
}

int64_t calculateGemmCDestDataPerWrite(ConvolutionContext &ctx) {
int64_t calculateGemmCDestDataPerWrite(const InitParamsNonXDL &param,
ConvolutionContext &ctx) {
int64_t outputVecLen = 0;
if ((ctx.opType == miopen::ConvOpType::Conv2DOpType) &&
(ctx.dimIndexVal["ko"].first == 3)) {
Expand All @@ -608,6 +619,8 @@ class PopulateParams : public PopulateParamsBase {
obtainGemmCVecLen(ctx, outputVecLen);
}

outputVecLen = gcd(outputVecLen, param.gemmNPerThread);

if ((outputVecLen > 0) && (outputVecLen % 4 == 0)) {
return 4;
} else if ((outputVecLen > 0) && (outputVecLen % 2 == 0)) {
Expand All @@ -617,11 +630,76 @@ class PopulateParams : public PopulateParamsBase {
return 1;
}

LogicalResult
CalculateBlockGemmPerformanceParameters(const InitParamsNonXDL &param,
const ConvolutionContext &ctx,
DerivedBlockGemmParams &derived) {

derived.gemmMLevel0Cluster = 0;
derived.gemmNLevel0Cluster = 0;
derived.gemmMLevel1Cluster = 0;
derived.gemmNLevel1Cluster = 0;

if (param.blockSize == 64) {
derived.gemmMLevel0Cluster = 4;
derived.gemmNLevel0Cluster = 4;
derived.gemmMLevel1Cluster = 2;
derived.gemmNLevel1Cluster = 2;
} else if (param.blockSize == 128) {
derived.gemmMLevel0Cluster = 4;
derived.gemmNLevel0Cluster = 4;
derived.gemmMLevel1Cluster = 4;
derived.gemmNLevel1Cluster = 2;
} else if (param.blockSize == 256) {
derived.gemmMLevel0Cluster = 4;
derived.gemmNLevel0Cluster = 4;
derived.gemmMLevel1Cluster = 4;
derived.gemmNLevel1Cluster = 4;
} else {
return failure();
}

if (!(param.gemmMPerThread >= 2 && param.gemmMPerThread <= 4))
return failure();

if (!(param.gemmNPerThread >= 2 && param.gemmNPerThread <= 4))
return failure();

if (!(param.gemmMPerBlock % param.gemmMPerThread == 0 &&
param.gemmNPerBlock % param.gemmNPerThread == 0))
return failure();

const auto threadGemmMPerBlock =
param.gemmMPerBlock / param.gemmMPerThread;
const auto threadGemmNPerBlock =
param.gemmNPerBlock / param.gemmNPerThread;

const auto threadGemmMPerCluster =
derived.gemmMLevel0Cluster * derived.gemmMLevel1Cluster;
const auto threadGemmNPerCluster =
derived.gemmNLevel0Cluster * derived.gemmNLevel1Cluster;

if (!(threadGemmMPerBlock % threadGemmMPerCluster == 0) &&
(threadGemmNPerBlock % threadGemmNPerCluster == 0))
return failure();

const auto clusterMPerBlock = threadGemmMPerBlock / threadGemmMPerCluster;
const auto clusterNPerBlock = threadGemmNPerBlock / threadGemmNPerCluster;

// inline asm only support clusterMPerBlock = 2 andclusterNPerBlock =
// 2
if (!(clusterMPerBlock == 2 && clusterNPerBlock == 2))
return failure();

return success();
}

public:
LogicalResult paramsFromCtx(ConvolutionContext &ctx,
InitParamsNonXDL &validParams, GemmSize &gemmSize,
DerivedParams &gemmADerivedParam,
DerivedParams &gemmBDerivedParam,
DerivedBlockGemmParams &blockGemmDerivedParam,
int64_t &gemmCDstPerWrite, int64_t &gridSize) {
LogicalResult res(LogicalResult::Failure);

Expand Down Expand Up @@ -659,6 +737,15 @@ class PopulateParams : public PopulateParamsBase {
continue;
}

res = CalculateBlockGemmPerformanceParameters(params, ctx,
blockGemmDerivedParam);

if (failed(res)) {
LLVM_DEBUG(llvm::dbgs() << "Incoherent blockGemm tuning parameter "
<< " size.\n");
continue;
}

validParams = params;
break;
}
Expand All @@ -670,7 +757,7 @@ class PopulateParams : public PopulateParamsBase {
}

gridSize = obtainGridSize(gemmSize, &validParams);
gemmCDstPerWrite = calculateGemmCDestDataPerWrite(ctx);
gemmCDstPerWrite = calculateGemmCDestDataPerWrite(validParams, ctx);
return res;
}
};
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/MIOpen/Transforms/AffixTuningParameters.cpp
Expand Up @@ -125,13 +125,14 @@ void AffixTuningParameters::runOnFunction() {
GemmSize gemmSize;
DerivedParams gemmADerivedParam;
DerivedParams gemmBDerivedParam;
DerivedBlockGemmParams blockGemmDerivedParam;
int64_t gemmCDstPerWrite;
int64_t gridSize;

PopulateParams populateParams;
LogicalResult status = populateParams.paramsFromCtx(
convContext, validParams, gemmSize, gemmADerivedParam,
gemmBDerivedParam, gemmCDstPerWrite, gridSize);
gemmBDerivedParam, blockGemmDerivedParam, gemmCDstPerWrite, gridSize);
if (failed(status)) {
signalPassFailure();
}
Expand Down
Expand Up @@ -927,13 +927,14 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCFlags(ModuleOp m)
GemmSize gemmSize;
DerivedParams gemmADerivedParam;
DerivedParams gemmBDerivedParam;
DerivedBlockGemmParams blockGemmDerivedParam;
int64_t gemmCDstPerWrite;
int64_t gridSize;

PopulateParams populateParams;
populateParams.paramsFromCtx(ctx, validParams, gemmSize,
gemmADerivedParam, gemmBDerivedParam,
gemmCDstPerWrite, gridSize);
populateParams.paramsFromCtx(
ctx, validParams, gemmSize, gemmADerivedParam, gemmBDerivedParam,
blockGemmDerivedParam, gemmCDstPerWrite, gridSize);

std::map<std::string, int> parameters;

Expand Down Expand Up @@ -1013,11 +1014,14 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCFlags(ModuleOp m)
["CK_PARAM_TUNABLE_GEMM_C_THREAD_COPY_DST_DATA_PER_WRITE_GEMM_N1"] =
gemmCDstPerWrite;

// parameters fixed.
parameters["CK_PARAM_TUNABLE_GEMM_M_LEVEL0_CLUSTER"] = 4;
parameters["CK_PARAM_TUNABLE_GEMM_N_LEVEL0_CLUSTER"] = 4;
parameters["CK_PARAM_TUNABLE_GEMM_M_LEVEL1_CLUSTER"] = 4;
parameters["CK_PARAM_TUNABLE_GEMM_N_LEVEL1_CLUSTER"] = 4;
parameters["CK_PARAM_TUNABLE_GEMM_M_LEVEL0_CLUSTER"] =
blockGemmDerivedParam.gemmMLevel0Cluster;
parameters["CK_PARAM_TUNABLE_GEMM_N_LEVEL0_CLUSTER"] =
blockGemmDerivedParam.gemmNLevel0Cluster;
parameters["CK_PARAM_TUNABLE_GEMM_M_LEVEL1_CLUSTER"] =
blockGemmDerivedParam.gemmMLevel1Cluster;
parameters["CK_PARAM_TUNABLE_GEMM_N_LEVEL1_CLUSTER"] =
blockGemmDerivedParam.gemmNLevel1Cluster;

// Emit code-gen related macros.
parameters["CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM"] = 1;
Expand Down

0 comments on commit 976a47d

Please sign in to comment.