diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 70b27fd84f2..ce4183f2f17 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -574,6 +574,9 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); + int64_t input_height = + cast(self.getType()).getShape()[-2]; + int64_t input_width = cast(self.getType()).getShape()[-1]; Type inputElementType = cast(self.getType()).getElementType(); Type resultType = typeConverter->convertType(op.getType()); @@ -595,13 +598,6 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "count_include_pad must be a constant"); - // If the padding is zero then there is no padding to include. - if (!countIncludePad && - !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { - return rewriter.notifyMatchFailure( - op, "unimplemented: count_include_pad is expected to be true"); - } - // `sumPool` contains the result of sumpool operation over the input. Value sumPool, paddedInput; SmallVector outTensorShape; @@ -611,17 +607,12 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, sumPool))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); - Value divisor; - if constexpr (std::is_same()) { - Value kHtimeskW = rewriter.create( - loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); - divisor = isa(op.getDivisorOverride().getType()) - ? kHtimeskW - : adaptor.getDivisorOverride(); - } else { - divisor = kernelSizeIntValues[0]; - } - divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); + + RankedTensorType sumPoolType = cast(sumPool.getType()); + // get rank of input (same as rank of output) + const int64_t rank = sumPoolType.getRank(); + int dimH = toPositiveDim(-2, rank); + int dimW = toPositiveDim(-1, rank); Value outputTensor = rewriter.create( loc, getAsOpFoldResult(outTensorShape), resultElementType); @@ -636,6 +627,86 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { /*indexingMaps=*/indexingMapsAvg, /*iteratorTypes=*/iteratorTypesAvg, [&](OpBuilder &b, Location loc, ValueRange args) { + Value indexOh = + b.create(loc, /*value=*/dimH); + Value oh = castIndexToInt64(b, loc, indexOh); + Value indexOw = + b.create(loc, /*value=*/dimW); + Value ow = castIndexToInt64(b, loc, indexOw); + + // int64_t ih0 = oh * dH - padH; + Value dH = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[0])); + Value padH = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[0])); + Value oh_dH = b.create(loc, oh, dH); + Value ih0 = b.create(loc, oh_dH, padH); + // int64_t iw0 = ow * dW - padW; + Value dW = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[1])); + Value padW = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[1])); + Value ow_dW = b.create(loc, ow, dW); + Value iw0 = b.create(loc, ow_dW, padW); + // int64_t ih1 = std::min(ih0 + kH, input_height + padH); + Value ih = rewriter.create( + loc, rewriter.getI64IntegerAttr(input_height)); + Value ih0_kH = + b.create(loc, ih0, kernelSizeIntValues[0]); + Value ih_padH = b.create(loc, ih, padH); + Value ih1 = b.create(loc, ih0_kH, ih_padH); + // int64_t iw1 = std::min(iw0 + kW, input_width + padW); + Value iw = rewriter.create( + loc, rewriter.getI64IntegerAttr(input_width)); + Value iw0_kW = + b.create(loc, iw0, kernelSizeIntValues[1]); + Value iw_padW = b.create(loc, iw, padW); + Value iw1 = b.create(loc, iw0_kW, iw_padW); + // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); + Value ih1_ih0 = b.create(loc, ih1, ih0); + Value iw1_iw0 = b.create(loc, iw1, iw0); + Value poolSize = + b.create(loc, ih1_ih0, iw1_iw0); + // ih0 = std::max(ih0, 0); + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value ih0Clamped = + b.create(loc, ih0, cstZero); + // iw0 = std::max(iw0, 0); + Value iw0Clamped = + b.create(loc, iw0, cstZero); + // ih1 = std::min(ih1, input_height); + Value ih1Clamped = b.create(loc, ih1, ih); + // iw1 = std::min(iw1, input_width); + Value iw1Clamped = b.create(loc, iw1, iw); + + Value divisor; + // if (divisor_override.has_value()) { + // divisor = divisor_override.value(); + // } else { + // if(count_include_pad) { + // divisor = pool_size; + // } else { + // divisor = (ih1 - ih0) * (iw1 - iw0); + // } + // } + if constexpr (std::is_same()) { + if (!isa( + op.getDivisorOverride().getType())) + divisor = op.getDivisorOverride(); + } else { + if (countIncludePad) { + divisor = convertScalarToDtype(b, loc, poolSize, + resultElementType); + } else { + Value ih1_ih0 = + b.create(loc, ih1Clamped, ih0Clamped); + Value iw1_iw0 = + b.create(loc, iw1Clamped, iw0Clamped); + divisor = b.create(loc, ih1_ih0, iw1_iw0); + } + } + Value avg; if (isa(resultElementType)) avg = b.create(loc, args[0], divisor);