Skip to content

Commit

Permalink
[ONNX] Fix padding attributes for onnx.AveragePool
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Apr 25, 2024
1 parent 38627d4 commit 4fbcef7
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,16 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
for (int64_t i : padding) {
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all axes x.
int64_t paddingSizeHalf = padding.size()/2;
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
// Check if onnx padding attribute is symmetric.
if(padding[i] != padding[i + paddingSizeHalf])
return rewriter.notifyMatchFailure(
binder.op, "onnx padding attribute is not symmetric");
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
}
for (int64_t i : strides) {
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
Expand Down

0 comments on commit 4fbcef7

Please sign in to comment.