Skip to content

Commit

Permalink
Revise mlir-miopen-driver command line arguments.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 7cb2738 commit aa8c5ed
Showing 1 changed file with 93 additions and 5 deletions.
98 changes: 93 additions & 5 deletions mlir/lib/Dialect/MIOpenOps/Driver/mlir-miopen-driver.cpp
Expand Up @@ -36,18 +36,64 @@ static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"));

static cl::opt<std::string> filterLayout("f", cl::desc("Filter layout"),
static cl::opt<std::string> filterLayout("fil_layout", cl::desc("Filter layout"),
cl::value_desc("layout string"),
cl::init("kcyx"));

static cl::opt<std::string> inputLayout("i", cl::desc("Input layout"),
static cl::opt<std::string> inputLayout("in_layout", cl::desc("Input layout"),
cl::value_desc("layout string"),
cl::init("nchw"));

static cl::opt<std::string> outputLayout("d", cl::desc("Output layout"),
static cl::opt<std::string> outputLayout("out_layout", cl::desc("Output layout"),
cl::value_desc("layout string"),
cl::init("nkhw"));

// N
static cl::opt<int64_t> batchSize("batchsize", cl::desc("Batch size"),
cl::value_desc("dimension value"),
cl::init(-1));

// C
static cl::opt<int64_t> inputChannel("in_channels", cl::desc("Input channels"),
cl::value_desc("dimension value"),
cl::init(-1));

// Hi
static cl::opt<int64_t> inputHeight("in_h", cl::desc("Input height"),
cl::value_desc("dimension value"),
cl::init(-1));

// Wi
static cl::opt<int64_t> inputWidth("in_w", cl::desc("Input width"),
cl::value_desc("dimension value"),
cl::init(-1));

// K
static cl::opt<int64_t> outputChannel("out_channels", cl::desc("Output channels"),
cl::value_desc("dimension value"),
cl::init(-1));

// Y
static cl::opt<int64_t> filterWidth("fil_w", cl::desc("Filter width"),
cl::value_desc("dimension value"),
cl::init(-1));

// X
static cl::opt<int64_t> filterHeight("fil_h", cl::desc("Filter height"),
cl::value_desc("dimension value"),
cl::init(-1));

// Ho
static cl::opt<int64_t> outputHeight("out_h", cl::desc("Output height"),
cl::value_desc("dimension value"),
cl::init(-1));

// Wo
static cl::opt<int64_t> outputWidth("out_w", cl::desc("Output width"),
cl::value_desc("dimension value"),
cl::init(-1));


int main(int argc, char **argv) {
InitLLVM y(argc, argv);

Expand All @@ -59,9 +105,51 @@ int main(int argc, char **argv) {
OpBuilder builder(&context);
auto module = ModuleOp::create(builder.getUnknownLoc());

// Determine dimensions.
llvm::SmallVector<int64_t, 4> filterDimension;
llvm::SmallVector<int64_t, 4> inputDimension;
llvm::SmallVector<int64_t, 4> outputDimension;
for (size_t i = 0; i < 4; ++i) {
auto &filterDim = filterLayout.getValue()[i];
auto &inputDim = inputLayout.getValue()[i];
auto &outputDim = outputLayout.getValue()[i];

if (filterDim == 'k') {
filterDimension.push_back(outputChannel.getValue());
} else if (filterDim == 'c') {
filterDimension.push_back(inputChannel.getValue());
} else if (filterDim == 'y') {
filterDimension.push_back(filterWidth.getValue());
} else if (filterDim == 'x') {
filterDimension.push_back(filterHeight.getValue());
}

if (inputDim == 'n') {
inputDimension.push_back(batchSize.getValue());
} else if (inputDim == 'c') {
inputDimension.push_back(inputChannel.getValue());
} else if (inputDim == 'h') {
inputDimension.push_back(inputWidth.getValue());
} else if (inputDim == 'w') {
inputDimension.push_back(inputHeight.getValue());
}

if (outputDim == 'n') {
outputDimension.push_back(batchSize.getValue());
} else if (outputDim == 'k') {
outputDimension.push_back(outputChannel.getValue());
} else if (outputDim == 'h') {
outputDimension.push_back(outputWidth.getValue());
} else if (outputDim == 'w') {
outputDimension.push_back(outputHeight.getValue());
}
}

// Construct a new FuncOp.
auto argType = MemRefType::get({-1, -1, -1, -1}, builder.getF32Type());
auto funcType = builder.getFunctionType({argType, argType, argType}, {});
auto filterArgType = MemRefType::get(ArrayRef<int64_t>(filterDimension.begin(), filterDimension.end()), builder.getF32Type());
auto inputArgType = MemRefType::get(ArrayRef<int64_t>(inputDimension.begin(), inputDimension.end()), builder.getF32Type());
auto outputArgType = MemRefType::get(ArrayRef<int64_t>(outputDimension.begin(), outputDimension.end()), builder.getF32Type());
auto funcType = builder.getFunctionType({filterArgType, inputArgType, outputArgType}, {});
auto func = FuncOp::create(builder.getUnknownLoc(), "miopen_conv2d_" + filterLayout + "_" + inputLayout + "_" + outputLayout, funcType);
module.push_back(func);

Expand Down

0 comments on commit aa8c5ed

Please sign in to comment.