Skip to content

Commit

Permalink
Implement logic to emit compilation flags for non-XDLOPS convolution.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 93b50ee commit a117780
Showing 1 changed file with 113 additions and 15 deletions.
Expand Up @@ -651,13 +651,6 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeaderXDLOPS(Modul
output << ");\n\n";
});

// TBD get tuning parameters.
//f.walk([&output](miopen::GridwiseGemmOp op) {
// // get op name.
// //output << "op name: " << op.getOperationName() << "\n";
// //op.dump();
//});

EmitHeaderEpilogue(output, gridwiseGemmArguments);
}

Expand Down Expand Up @@ -706,13 +699,6 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCppXDLOPS(ModuleOp

EmitCppInterlude(output);

// TBD get tuning parameters.
//f.walk([&output](miopen::GridwiseGemmOp op) {
// // get op name.
// //output << "op name: " << op.getOperationName() << "\n";
// //op.dump();
//});

EmitCppEpilogue(output, layoutStr, tensorDescs);
}

Expand All @@ -722,9 +708,121 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCppXDLOPS(ModuleOp

std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCFlagsXDLOPS(ModuleOp m) {
std::string resultStr;
resultStr.reserve(4096);
llvm::raw_string_ostream output(resultStr);

for (auto f : m.getOps<FuncOp>()) {
f.walk([&output](miopen::GridwiseGemmOp op) {
output << "-std=c++14";
output << " -D__HIP_PLATFORM_HCC__=1";

// TBD: be able to set data type.
output << " -DMIOPEN_USE_FP32=1 -DMIOPEN_USE_FP16=0";

// TBD: be able to set convolution direction.
output << " -DCK_PARAM_PROBLEM_DIRECTION=0";
output << " -DCK_PARAM_PROBLEM_CONV_DIRECTION_FORWARD=1";
output << " -DCK_PARAM_PROBLEM_CONV_DIRECTION_BACKWARD_DATA=0";
output << " -DCK_PARAM_PROBLEM_CONV_DIRECTION_BACKWARD_WEIGHT=0";

// Emit flags immediately determined from convolution configs.
auto inputLayoutAttr = op.getAttrOfType<ArrayAttr>("input_layout");
auto inputDimensionAttr = op.getAttrOfType<ArrayAttr>("input_dimension");
auto outputLayoutAttr = op.getAttrOfType<ArrayAttr>("output_layout");
auto outputDimensionAttr = op.getAttrOfType<ArrayAttr>("output_dimension");
auto filterLayoutAttr = op.getAttrOfType<ArrayAttr>("filter_layout");
auto filterDimensionAttr = op.getAttrOfType<ArrayAttr>("filter_dimension");

int64_t n = 0, k = 0, ho = 0, wo = 0;

for (size_t i = 0; i < 4; ++i) {
auto filterDim = filterLayoutAttr.getValue()[i].dyn_cast<StringAttr>().getValue();
auto inputDim = inputLayoutAttr.getValue()[i].dyn_cast<StringAttr>().getValue();
auto outputDim = outputLayoutAttr.getValue()[i].dyn_cast<StringAttr>().getValue();

if (filterDim.str() == "k") {
k = filterDimensionAttr.getValue()[i].dyn_cast<IntegerAttr>().getInt();
output << " -DCK_PARAM_PROBLEM_K=" << k;
} else if (filterDim.str() == "c") {
output << " -DCK_PARAM_PROBLEM_C=" << filterDimensionAttr.getValue()[i].dyn_cast<IntegerAttr>().getValue();
} else if (filterDim.str() == "y") {
output << " -DCK_PARAM_PROBLEM_Y=" << filterDimensionAttr.getValue()[i].dyn_cast<IntegerAttr>().getValue();
} else if (filterDim.str() == "x") {
output << " -DCK_PARAM_PROBLEM_X=" << filterDimensionAttr.getValue()[i].dyn_cast<IntegerAttr>().getValue();
}

if (inputDim.str() == "ni") {
n = inputDimensionAttr.getValue()[i].dyn_cast<IntegerAttr>().getInt();
output << " -DCK_PARAM_PROBLEM_N=" << n;
} else if (inputDim.str() == "hi") {
output << " -DCK_PARAM_PROBLEM_HI=" << inputDimensionAttr.getValue()[i].dyn_cast<IntegerAttr>().getValue();
} else if (inputDim.str() == "wi") {
output << " -DCK_PARAM_PROBLEM_WI=" << inputDimensionAttr.getValue()[i].dyn_cast<IntegerAttr>().getValue();
}

if (outputDim.str() == "ho") {
ho = outputDimensionAttr.getValue()[i].dyn_cast<IntegerAttr>().getInt();
output << " -DCK_PARAM_PROBLEM_HO=" << ho;
} else if (outputDim.str() == "wo") {
wo = outputDimensionAttr.getValue()[i].dyn_cast<IntegerAttr>().getInt();
output << " -DCK_PARAM_PROBLEM_WO=" << wo;
}
}

int64_t gemmMPerBlock = 128;
int64_t gemmNPerBlock = 128;
int64_t gemmKPerBlock = 8;
int64_t gemmM = k;
int64_t gemmN = n * ho * wo;
int64_t gridSize = (gemmM / gemmMPerBlock) * (gemmN / gemmNPerBlock);

auto strideAttr = op.getAttrOfType<ArrayAttr>("strides");
auto dilationAttr = op.getAttrOfType<ArrayAttr>("dilations");
auto paddingAttr = op.getAttrOfType<ArrayAttr>("padding");
output << " -DCK_PARAM_PROBLEM_CONV_STRIDE_H=" << strideAttr.getValue()[0].dyn_cast<IntegerAttr>().getValue();
output << " -DCK_PARAM_PROBLEM_CONV_STRIDE_W=" << strideAttr.getValue()[1].dyn_cast<IntegerAttr>().getValue();
output << " -DCK_PARAM_PROBLEM_CONV_DILATION_H=" << dilationAttr.getValue()[0].dyn_cast<IntegerAttr>().getValue();
output << " -DCK_PARAM_PROBLEM_CONV_DILATION_W=" << dilationAttr.getValue()[1].dyn_cast<IntegerAttr>().getValue();
output << " -DCK_PARAM_PROBLEM_IN_LEFT_PAD_H=" << paddingAttr.getValue()[0].dyn_cast<IntegerAttr>().getValue();
output << " -DCK_PARAM_PROBLEM_IN_LEFT_PAD_W=" << paddingAttr.getValue()[1].dyn_cast<IntegerAttr>().getValue();
output << " -DCK_PARAM_PROBLEM_IN_RIGHT_PAD_H=" << paddingAttr.getValue()[0].dyn_cast<IntegerAttr>().getValue();
output << " -DCK_PARAM_PROBLEM_IN_RIGHT_PAD_W=" << paddingAttr.getValue()[1].dyn_cast<IntegerAttr>().getValue();

// TBD: ditinguish between:
// - parameters truly need to be tuned.
// - parameters deducible via transformations.
// - parameters which have heuristic-based values.
// - parameters related to code generation.
output << " -DCK_PARAM_PROBLEM_CONV_GROUP_COUNTS=1";

output << " -DCK_PARAM_TUNABLE_GEMM_M_PER_BLOCK=" << gemmMPerBlock;
output << " -DCK_PARAM_TUNABLE_GEMM_N_PER_BLOCK=" << gemmNPerBlock;
output << " -DCK_PARAM_TUNABLE_GEMM_K_PER_BLOCK=" << gemmKPerBlock;
output << " -DCK_PARAM_DEPENDENT_GRID_SIZE=" << gridSize;
output << " -DCK_PARAM_TUNABLE_BLOCK_SIZE=256";

// [8, GEMM_M_PER_BLOCK], power of 2
output << " -DCK_PARAM_GEMM_M_PER_WAVE=" << gemmMPerBlock;
// [8, GEMM_N_PER_BLOCK], power of 2
output << " -DCK_PARAM_GEMM_N_PER_WAVE=" << gemmNPerBlock;

output << " -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K=4";
output << " -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N=4";

output << " -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K=2";
output << " -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M=4";
output << " -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N=1";
output << " -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_K=1";
output << " -DCK_PARAM_TUNABLE_GEMM_C_THREAD_COPY_DATA_PER_ACCESS_N=1";
output << " -DCK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N=1";
output << " -DCK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M=1";

output << " -DCK_PARAM_KPACK_LENGTH=1";
output << " -DCK_USE_AMD_XDLOPS=1";
output << " -DCK_USE_AMD_XDLOPS_INLINE_ASM=1";
});
}

output.flush();
return std::make_unique<llvm::StringRef>(resultStr);
}

0 comments on commit a117780

Please sign in to comment.