Skip to content

Commit

Permalink
Add logic to consider right padding.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 97d2b37 commit 3a68376
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
Expand Up @@ -921,12 +921,11 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCFlags(ModuleOp m)
output << " -DCK_PARAM_PROBLEM_CONV_DILATION_H=" << dilationH;
output << " -DCK_PARAM_PROBLEM_CONV_DILATION_W=" << dilationW;

// TBD. compute left padding and right padding properly.
auto paddingAttr = op.getAttrOfType<ArrayAttr>("padding");
int64_t paddingHL = paddingAttr.getValue()[0].dyn_cast<IntegerAttr>().getInt();
int64_t paddingHR = paddingAttr.getValue()[0].dyn_cast<IntegerAttr>().getInt();
int64_t paddingWL = paddingAttr.getValue()[1].dyn_cast<IntegerAttr>().getInt();
int64_t paddingWR = paddingAttr.getValue()[1].dyn_cast<IntegerAttr>().getInt();
int64_t paddingHL = paddingAttr.getValue()[0].dyn_cast<ArrayAttr>().getValue()[0].dyn_cast<IntegerAttr>().getInt();
int64_t paddingWL = paddingAttr.getValue()[0].dyn_cast<ArrayAttr>().getValue()[1].dyn_cast<IntegerAttr>().getInt();
int64_t paddingHR = paddingAttr.getValue()[1].dyn_cast<ArrayAttr>().getValue()[0].dyn_cast<IntegerAttr>().getInt();
int64_t paddingWR = paddingAttr.getValue()[1].dyn_cast<ArrayAttr>().getValue()[1].dyn_cast<IntegerAttr>().getInt();

output << " -DCK_PARAM_PROBLEM_IN_LEFT_PAD_H=" << paddingHL;
output << " -DCK_PARAM_PROBLEM_IN_LEFT_PAD_W=" << paddingWL;
Expand Down
45 changes: 44 additions & 1 deletion mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
Expand Up @@ -548,6 +548,46 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
transformedOutputAttrs.push_back(b.getNamedAttr("gridwise_gemm_argument_position", b.getI32IntegerAttr(2)));
auto gemmC = b.create<miopen::TransformOp>(op.getLoc(), transformedOutputMemRefType, op.output(), transformedOutputAttrs);

// compute right padding parameters.
auto leftPadH = paddingAttr.getValue()[0].dyn_cast<IntegerAttr>().getInt();
auto leftPadW = paddingAttr.getValue()[1].dyn_cast<IntegerAttr>().getInt();
auto dilationH = dilationsAttr.getValue()[0].dyn_cast<IntegerAttr>().getInt();
auto dilationW = dilationsAttr.getValue()[1].dyn_cast<IntegerAttr>().getInt();
auto strideH = stridesAttr.getValue()[0].dyn_cast<IntegerAttr>().getInt();
auto strideW = stridesAttr.getValue()[1].dyn_cast<IntegerAttr>().getInt();

// get y, x, ho, wo, hi, wi
int64_t y, x, ho, wo, hi, wi;
y = x = ho = wo = hi = wi = 0;
for (unsigned i = 0; i < 4; ++i) {
auto filterAttr = filterLayoutAttr.getValue()[i].dyn_cast<StringAttr>();
auto inputAttr = inputLayoutAttr.getValue()[i].dyn_cast<StringAttr>();
auto outputAttr = outputLayoutAttr.getValue()[i].dyn_cast<StringAttr>();

if (filterAttr.getValue() == "y") {
y = filterShape[i];
} else if (filterAttr.getValue() == "x") {
x = filterShape[i];
}

if (inputAttr.getValue() == "hi") {
hi = inputShape[i];
} else if (inputAttr.getValue() == "wi") {
wi = inputShape[i];
}

if (outputAttr.getValue() == "ho") {
ho = outputShape[i];
} else if (outputAttr.getValue() == "wo") {
wo = outputShape[i];
}
}

auto hiPadded = 1 + (y - 1) * dilationH + (ho - 1) * strideH;
auto wiPadded = 1 + (x - 1) * dilationW + (wo - 1) * strideW;
auto rightPadH = hiPadded > (leftPadH + hi) ? hiPadded - (leftPadH + hi) : 0;
auto rightPadW = wiPadded > (leftPadW + wi) ? wiPadded - (leftPadW + wi) : 0;

// Set attributes for gridwise_gemm op.
llvm::SmallVector<NamedAttribute, 8> gridwiseGemmAttrs {
b.getNamedAttr("filter_layout", filterLayoutAttr),
Expand All @@ -558,7 +598,10 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
b.getNamedAttr("output_dimension", b.getI64ArrayAttr(outputShape)),
b.getNamedAttr("dilations", dilationsAttr),
b.getNamedAttr("strides", stridesAttr),
b.getNamedAttr("padding", paddingAttr),
b.getNamedAttr("padding", b.getArrayAttr({
paddingAttr,
b.getI32ArrayAttr({rightPadH, rightPadW})
})),
};
// Emit miopen.gridwise_gemm op.
b.create<miopen::GridwiseGemmOp>(op.getLoc(), ArrayRef<Type>{}, ValueRange{gemmA, gemmB, gemmC}, gridwiseGemmAttrs);
Expand Down

0 comments on commit 3a68376

Please sign in to comment.