Skip to content

Commit

Permalink
Rework cpp trans
Browse files Browse the repository at this point in the history
  • Loading branch information
zjing14 committed Jun 18, 2020
1 parent 1fd9e9b commit c8503c2
Showing 1 changed file with 49 additions and 6 deletions.
Expand Up @@ -294,7 +294,7 @@ static constexpr StringLiteral kCppEpiloguePart2Format = R"(
)";

static constexpr StringLiteral kGemmNameABlockCopySrcDataPerRead[] = {
"GemmK", // Conv2DOpType
"GemmK", // Conv2DOpType and Conv2DBwdWeightOpType
"GemmM", // Conv2DBwdDataOpType
};

Expand All @@ -306,6 +306,9 @@ void EmitCppPreamble(llvm::raw_ostream &output, miopen::ConvOpType opType) {
output << R"(#include "gridwise_convolution_implicit_gemm_v4r4_)";
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
output << R"(#include "gridwise_convolution_implicit_gemm_v1r1_)";
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
output
<< R"(#include "gridwise_convolution_backward_weight_implicit_gemm_v4r4_)";
}

// Change to fixed "mlir".
Expand All @@ -320,6 +323,9 @@ void EmitCppPreamble(llvm::raw_ostream &output, miopen::ConvOpType opType) {
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
output << R"(
__launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_backward_data_implicit_gemm_v1r1_)";
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
output << R"(
__launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_backward_weight_implicit_gemm_v4r4_)";
}
// Change to fixed "mlir".
output << "mlir";
Expand All @@ -334,6 +340,9 @@ void EmitCppPreamble(llvm::raw_ostream &output, miopen::ConvOpType opType) {
output << llvm::format(kCppPreamblePart3Format.data(),
argPOutGlobal.c_str(), argPWeiGlobal.c_str(),
argPInGlobal.c_str());
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
output << llvm::format(kCppPreamblePart3Format.data(), argPInGlobal.c_str(),
argPOutGlobal.c_str(), argPWeiGlobal.c_str());
}
}

Expand All @@ -343,8 +352,11 @@ void EmitCppInterlude(llvm::raw_ostream &output, miopen::ConvOpType opType) {
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[0].str();
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[1].str();
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[0].str();
}
output << llvm::format(kCppInterludeFormat.data(),
gemmNameABlockCopySrcDataPerRead.c_str(),
gemmNameABlockCopySrcDataPerRead.c_str());
}

Expand All @@ -360,6 +372,9 @@ void EmitCppEpilogue(llvm::raw_ostream &output,
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
output << R"(
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v1r1_)";
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
output << R"(
constexpr auto gridwise_conv = GridwiseConvolutionBackwardWeightImplicitGemm_v4r4_)";
}

// Change to fixed "mlir".
Expand All @@ -379,6 +394,8 @@ void EmitCppEpilogue(llvm::raw_ostream &output,
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[0].str();
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[1].str();
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[0].str();
}
output << llvm::format(kCppEpiloguePart2Format.data(),
gemmNameABlockCopySrcDataPerRead.c_str(),
Expand All @@ -397,7 +414,7 @@ static constexpr StringLiteral kHeaderPreamblePart1Format = R"(
namespace ck {
// GemmM = %s
// GemmN = N * Ho * Wo
// GemmN = %s
// GemmK = %s
template <index_t GridSize,
index_t BlockSize,
Expand Down Expand Up @@ -449,6 +466,14 @@ static constexpr StringLiteral kHeaderPreamblePart2BwdData = R"(
{
)";

static constexpr StringLiteral kHeaderPreamblePart2BwdWeight = R"(
{
__device__ void Run(const Float* const __restrict__ p_in_global,
Float* __restrict__ p_wei_global,
const Float* const __restrict__ p_out_global) const
{
)";

static constexpr StringLiteral kHeaderPreamblePart3 = R"(
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
Expand Down Expand Up @@ -527,28 +552,39 @@ void EmitHeaderPreamble(llvm::raw_ostream &output,
miopen::ConvOpType opType) {
std::string headerIncludeGuard;
std::string commentGemmM;
std::string commentGemmN;
std::string commentGemmK;
std::string gemmNameABlockCopySrcDataPerRead;
if (opType == miopen::ConvOpType::Conv2DOpType) {
headerIncludeGuard = "IMPLICIT_GEMM_V4R4";
commentGemmM = "K";
commentGemmN = "N * H * W";
commentGemmK = "C * Y * X";
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[0].str();
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
headerIncludeGuard = "BACKWARD_DATA_IMPLICIT_GEMM_V1R1";
commentGemmM = "C * Y * X";
commentGemmN = "N * H * W";
commentGemmK = "K";
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[1].str();
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
headerIncludeGuard = "BACKWARD_WEIGHT_IMPLICIT_GEMM_V4R4";
commentGemmM = "K";
commentGemmN = "C * Y * X";
commentGemmK = "N * H * W";
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[0].str();
}
output << llvm::format(kHeaderPreamblePart1Format.data(),
headerIncludeGuard.c_str(), headerIncludeGuard.c_str(),
commentGemmM.c_str(), commentGemmK.c_str(),
gemmNameABlockCopySrcDataPerRead.c_str());
output << llvm::format(
kHeaderPreamblePart1Format.data(), headerIncludeGuard.c_str(),
headerIncludeGuard.c_str(), commentGemmM.c_str(), commentGemmN.c_str(),
commentGemmK.c_str(), gemmNameABlockCopySrcDataPerRead.c_str());

if (opType == miopen::ConvOpType::Conv2DOpType) {
output << R"(struct GridwiseConvolutionImplicitGemm_v4r4_)";
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
output << R"(struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_)";
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
output << R"(struct GridwiseConvolutionBackwardWeightImplicitGemm_v4r4_)";
}

// Change to fixed "mlir".
Expand All @@ -558,6 +594,8 @@ void EmitHeaderPreamble(llvm::raw_ostream &output,
output << kHeaderPreamblePart2Forward;
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
output << kHeaderPreamblePart2BwdData;
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
output << kHeaderPreamblePart2BwdWeight;
}
output << kHeaderPreamblePart3;

Expand Down Expand Up @@ -599,6 +637,9 @@ void EmitHeaderEpilogue(llvm::raw_ostream &output,
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
inMemOp = "in_memory_op";
gemmHeaderEpiloguePart2Sequence = "Sequence<0, 1>";
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
inMemOp = "InMemoryDataOperation::Set";
gemmHeaderEpiloguePart2Sequence = "Sequence<1, 0>";
}
output << llvm::format(kHeaderEpiloguePart2.data(), inMemOp.c_str(),
gemmHeaderEpiloguePart2Sequence.c_str(),
Expand All @@ -617,6 +658,8 @@ void EmitHeaderEpilogue(llvm::raw_ostream &output,
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[0].str();
} else if (opType == miopen::ConvOpType::Conv2DBwdDataOpType) {
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[1].str();
} else if (opType == miopen::ConvOpType::Conv2DBwdWeightOpType) {
gemmNameABlockCopySrcDataPerRead = kGemmNameABlockCopySrcDataPerRead[0].str();
}
output << llvm::format(kHeaderEpiloguePart3Format.data(),
gemmNameABlockCopySrcDataPerRead.c_str());
Expand Down

0 comments on commit c8503c2

Please sign in to comment.