Skip to content

Commit

Permalink
Revise mlir-miopen-driver command line arguments.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 6fc00bb commit 54de6de
Showing 1 changed file with 46 additions and 11 deletions.
57 changes: 46 additions & 11 deletions mlir/lib/Dialect/MIOpenOps/Driver/mlir-miopen-driver.cpp
Expand Up @@ -93,6 +93,36 @@ static cl::opt<int64_t> outputWidth("out_w", cl::desc("Output width"),
cl::value_desc("dimension value"),
cl::init(-1));

// dilation height
static cl::opt<int> dilationHeight("dilation_h", cl::desc("Dilation height"),
cl::value_desc("attribute value"),
cl::init(1));

// dilation width
static cl::opt<int> dilationWidth("dilation_w", cl::desc("Dilation width"),
cl::value_desc("attribute value"),
cl::init(1));

// stride height
static cl::opt<int> strideHeight("conv_stride_h", cl::desc("Stride height"),
cl::value_desc("attribute value"),
cl::init(1));

// stride width
static cl::opt<int> strideWidth("conv_stride_w", cl::desc("Stride width"),
cl::value_desc("attribute value"),
cl::init(1));

// padding height
static cl::opt<int> paddingHeight("padding_h", cl::desc("Padding height"),
cl::value_desc("attribute value"),
cl::init(0));

// padding width
static cl::opt<int> paddingWidth("padding_w", cl::desc("Padding width"),
cl::value_desc("attribute value"),
cl::init(0));

// populate default values
static cl::opt<bool> populateDefaultValues("p", cl::desc("To populate default values"),
cl::value_desc("To populate default values"),
Expand All @@ -109,13 +139,19 @@ int main(int argc, char **argv) {
if (populateDefaultValues.getValue() == true) {
batchSize.setValue(128);
inputChannel.setValue(8);
inputWidth.setValue(32);
inputHeight.setValue(32);
outputChannel.setValue(128);
outputWidth.setValue(30);
inputHeight.setValue(32);
inputWidth.setValue(32);
outputHeight.setValue(30);
filterWidth.setValue(3);
outputWidth.setValue(30);
filterHeight.setValue(3);
filterWidth.setValue(3);
dilationHeight.setValue(1);
dilationWidth.setValue(1);
strideHeight.setValue(1);
strideWidth.setValue(1);
paddingHeight.setValue(0);
paddingWidth.setValue(0);
}

// Construct a new ModuleOp.
Expand Down Expand Up @@ -193,18 +229,17 @@ int main(int argc, char **argv) {
builder.getNamedAttr("input_layout", builder.getArrayAttr(ArrayRef<Attribute>(inputLayoutSpec.begin(), inputLayoutSpec.end()))),
builder.getNamedAttr("output_layout", builder.getArrayAttr(ArrayRef<Attribute>(outputLayoutSpec.begin(), outputLayoutSpec.end()))),

// TBD: support dilations / strides / padding.
builder.getNamedAttr("dilations", builder.getArrayAttr({
builder.getI32IntegerAttr(1),
builder.getI32IntegerAttr(1),
builder.getI32IntegerAttr(dilationHeight.getValue()),
builder.getI32IntegerAttr(dilationWidth.getValue()),
})),
builder.getNamedAttr("strides", builder.getArrayAttr({
builder.getI32IntegerAttr(1),
builder.getI32IntegerAttr(1),
builder.getI32IntegerAttr(strideHeight.getValue()),
builder.getI32IntegerAttr(strideWidth.getValue()),
})),
builder.getNamedAttr("padding", builder.getArrayAttr({
builder.getI32IntegerAttr(0),
builder.getI32IntegerAttr(0),
builder.getI32IntegerAttr(paddingHeight.getValue()),
builder.getI32IntegerAttr(paddingWidth.getValue()),
})),
});
block->push_back(convOp);
Expand Down

0 comments on commit 54de6de

Please sign in to comment.