Skip to content

Commit

Permalink
add tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
zjing14 committed Jun 18, 2020
1 parent c8503c2 commit 5553233
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 30 deletions.
98 changes: 68 additions & 30 deletions mlir/include/mlir/Target/MIOpenCPP.h
Expand Up @@ -144,37 +144,15 @@ class PopulateParamsBase {
} else {
input1GemmKVectorizable = false;
}
}
}

static void
obtainGemmAVecLen(mlir::miopen::ConvOpType opType,
llvm::StringMap<std::pair<size_t, int64_t>> &dimIndexVal,
int64_t &vecLen) {
// Vectorization length logic is the same for forward and bwd_data
if (dimIndexVal["k"].first == 3) {
vecLen = dimIndexVal["k"].second;
} else if (dimIndexVal["k"].first == 0) {
// dimKF is the lowest changing dimension, which means dimC/dimY/dimX
vecLen = dimIndexVal["c"].second * dimIndexVal["y"].second *
dimIndexVal["x"].second;
} else if (dimIndexVal["k"].first == 1) {
// K's position is at 1, vectorization legnth is last two dimension
if (dimIndexVal["c"].first == 0) {
vecLen = dimIndexVal["y"].second * dimIndexVal["x"].second;
} else if (dimIndexVal["y"].first == 0) {
vecLen = dimIndexVal["c"].second * dimIndexVal["x"].second;
} else {
vecLen = dimIndexVal["c"].second * dimIndexVal["y"].second;
}
} else {
// K's position is 2, vectorization legnth is last dimension
if (dimIndexVal["c"].first == 3) {
vecLen = dimIndexVal["c"].second;
} else if (dimIndexVal["y"].first == 3) {
vecLen = dimIndexVal["y"].second;
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
// When K is the fastest changing dimension,
// gemmM dimension is vectorizable, gemmK is not, and vice versa.
// Vectorization width depending on which among N, and HoWo be the fastest
// changing dimension.
if (dimIndexVal["k"].first == 3) {
input1GemmKVectorizable = false;
} else {
vecLen = dimIndexVal["x"].second;
input1GemmKVectorizable = true;
}
}
}
Expand Down Expand Up @@ -204,6 +182,41 @@ class PopulateParamsBase {
} else {
input2GemmKVectorizable = false;
}
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
// For input tensor
// currently, fix that GemmK (NHiWi) is always vectorizable
input2GemmKVectorizable = false;
}
}

static void
obtainFilterVecLen(llvm::StringMap<std::pair<size_t, int64_t>> &dimIndexVal,
int64_t &vecLen) {
// Vectorization length logic is the same for forward and bwd_data
if (dimIndexVal["k"].first == 3) {
vecLen = dimIndexVal["k"].second;
} else if (dimIndexVal["k"].first == 0) {
// dimKF is the lowest changing dimension, which means dimC/dimY/dimX
vecLen = dimIndexVal["c"].second * dimIndexVal["y"].second *
dimIndexVal["x"].second;
} else if (dimIndexVal["k"].first == 1) {
// K's position is at 1, vectorization legnth is last two dimension
if (dimIndexVal["c"].first == 0) {
vecLen = dimIndexVal["y"].second * dimIndexVal["x"].second;
} else if (dimIndexVal["y"].first == 0) {
vecLen = dimIndexVal["c"].second * dimIndexVal["x"].second;
} else {
vecLen = dimIndexVal["c"].second * dimIndexVal["y"].second;
}
} else {
// K's position is 2, vectorization legnth is last dimension
if (dimIndexVal["c"].first == 3) {
vecLen = dimIndexVal["c"].second;
} else if (dimIndexVal["y"].first == 3) {
vecLen = dimIndexVal["y"].second;
} else {
vecLen = dimIndexVal["x"].second;
}
}
}

Expand Down Expand Up @@ -250,6 +263,19 @@ class PopulateParamsBase {
}
}

static void
obtainGemmAVecLen(mlir::miopen::ConvOpType opType,
llvm::StringMap<std::pair<size_t, int64_t>> &dimIndexVal,
int64_t &vecLen) {
if (opType == mlir::miopen::ConvOpType::Conv2DOpType) {
obtainFilterVecLen(dimIndexVal, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) {
obtainFilterVecLen(dimIndexVal, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
obtainOutputVecLen(dimIndexVal, vecLen);
}
}

static void
obtainGemmBVecLen(mlir::miopen::ConvOpType opType,
llvm::StringMap<std::pair<size_t, int64_t>> &dimIndexVal,
Expand All @@ -258,6 +284,8 @@ class PopulateParamsBase {
obtainInputVecLen(dimIndexVal, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) {
obtainOutputVecLen(dimIndexVal, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
obtainInputVecLen(dimIndexVal, vecLen);
}
}

Expand All @@ -269,6 +297,8 @@ class PopulateParamsBase {
obtainOutputVecLen(dimIndexVal, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdDataOpType) {
obtainInputVecLen(dimIndexVal, vecLen);
} else if (opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
obtainInputVecLen(dimIndexVal, vecLen);
}
}

Expand Down Expand Up @@ -364,6 +394,14 @@ class PopulateParamsBase {
ctx.dimIndexVal["ho"].second *
ctx.dimIndexVal["wo"].second;
gemmSize.gemmK = ctx.dimIndexVal["k"].second;
} else if (ctx.opType == mlir::miopen::ConvOpType::Conv2DBwdWeightOpType) {
gemmSize.gemmM = ctx.dimIndexVal["k"].second;
gemmSize.gemmK = ctx.dimIndexVal["no"].second *
ctx.dimIndexVal["ho"].second *
ctx.dimIndexVal["wo"].second;
gemmSize.gemmN = ctx.dimIndexVal["c"].second *
ctx.dimIndexVal["y"].second *
ctx.dimIndexVal["x"].second;
}
}

Expand Down
Expand Up @@ -681,6 +681,9 @@ void EmitHeaderEpilogue(llvm::raw_ostream &output,
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
output << llvm::format(kHeaderEpiloguePart4.data(), argPWeiGlobal.c_str(),
argPOutGlobal.c_str(), argPInGlobal.c_str());
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
output << llvm::format(kHeaderEpiloguePart4.data(), argPOutGlobal.c_str(),
argPInGlobal.c_str(), argPWeiGlobal.c_str());
}
}

Expand Down

0 comments on commit 5553233

Please sign in to comment.