Skip to content

Commit

Permalink
[Linalg] Add countIncludePad support for averagepool
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed May 16, 2024
1 parent 7405034 commit a24aa4a
Showing 1 changed file with 89 additions and 18 deletions.
107 changes: 89 additions & 18 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,9 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
const TypeConverter *typeConverter = this->getTypeConverter();
Value self = adaptor.getSelf();

int64_t input_height =
cast<RankedTensorType>(self.getType()).getShape()[-2];
int64_t input_width = cast<RankedTensorType>(self.getType()).getShape()[-1];
Type inputElementType =
cast<RankedTensorType>(self.getType()).getElementType();
Type resultType = typeConverter->convertType(op.getType());
Expand All @@ -595,13 +598,6 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
return rewriter.notifyMatchFailure(
op, "count_include_pad must be a constant");

// If the padding is zero then there is no padding to include.
if (!countIncludePad &&
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) {
return rewriter.notifyMatchFailure(
op, "unimplemented: count_include_pad is expected to be true");
}

// `sumPool` contains the result of sumpool operation over the input.
Value sumPool, paddedInput;
SmallVector<Value, Dim + 2> outTensorShape;
Expand All @@ -611,17 +607,12 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType),
outTensorShape, paddedInput, sumPool)))
return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
Value divisor;
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
Value kHtimeskW = rewriter.create<arith::MulIOp>(
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
divisor = isa<Torch::NoneType>(op.getDivisorOverride().getType())
? kHtimeskW
: adaptor.getDivisorOverride();
} else {
divisor = kernelSizeIntValues[0];
}
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);

RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType());
// get rank of input (same as rank of output)
const int64_t rank = sumPoolType.getRank();
int dimH = toPositiveDim(-2, rank);
int dimW = toPositiveDim(-1, rank);

Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outTensorShape), resultElementType);
Expand All @@ -636,6 +627,86 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
/*indexingMaps=*/indexingMapsAvg,
/*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value indexOh =
b.create<linalg::IndexOp>(loc, /*value=*/dimH);
Value oh = castIndexToInt64(b, loc, indexOh);
Value indexOw =
b.create<linalg::IndexOp>(loc, /*value=*/dimW);
Value ow = castIndexToInt64(b, loc, indexOw);

// int64_t ih0 = oh * dH - padH;
Value dH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(strideInts[0]));
Value padH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[0]));
Value oh_dH = b.create<arith::MulIOp>(loc, oh, dH);
Value ih0 = b.create<arith::SubIOp>(loc, oh_dH, padH);
// int64_t iw0 = ow * dW - padW;
Value dW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(strideInts[1]));
Value padW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[1]));
Value ow_dW = b.create<arith::MulIOp>(loc, ow, dW);
Value iw0 = b.create<arith::SubIOp>(loc, ow_dW, padW);
// int64_t ih1 = std::min(ih0 + kH, input_height + padH);
Value ih = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(input_height));
Value ih0_kH =
b.create<arith::AddIOp>(loc, ih0, kernelSizeIntValues[0]);
Value ih_padH = b.create<arith::AddIOp>(loc, ih, padH);
Value ih1 = b.create<arith::MinSIOp>(loc, ih0_kH, ih_padH);
// int64_t iw1 = std::min(iw0 + kW, input_width + padW);
Value iw = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(input_width));
Value iw0_kW =
b.create<arith::AddIOp>(loc, iw0, kernelSizeIntValues[1]);
Value iw_padW = b.create<arith::AddIOp>(loc, iw, padW);
Value iw1 = b.create<arith::MinSIOp>(loc, iw0_kW, iw_padW);
// int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
Value ih1_ih0 = b.create<arith::SubIOp>(loc, ih1, ih0);
Value iw1_iw0 = b.create<arith::SubIOp>(loc, iw1, iw0);
Value poolSize =
b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
// ih0 = std::max(ih0, 0);
Value cstZero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(0));
Value ih0Clamped =
b.create<arith::MaxSIOp>(loc, ih0, cstZero);
// iw0 = std::max(iw0, 0);
Value iw0Clamped =
b.create<arith::MaxSIOp>(loc, iw0, cstZero);
// ih1 = std::min(ih1, input_height);
Value ih1Clamped = b.create<arith::MinSIOp>(loc, ih1, ih);
// iw1 = std::min(iw1, input_width);
Value iw1Clamped = b.create<arith::MinSIOp>(loc, iw1, iw);

Value divisor;
// if (divisor_override.has_value()) {
// divisor = divisor_override.value();
// } else {
// if(count_include_pad) {
// divisor = pool_size;
// } else {
// divisor = (ih1 - ih0) * (iw1 - iw0);
// }
// }
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
if (!isa<Torch::NoneType>(
op.getDivisorOverride().getType()))
divisor = op.getDivisorOverride();
} else {
if (countIncludePad) {
divisor = convertScalarToDtype(b, loc, poolSize,
resultElementType);
} else {
Value ih1_ih0 =
b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped);
Value iw1_iw0 =
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped);
divisor = b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
}
}

Value avg;
if (isa<mlir::IntegerType>(resultElementType))
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
Expand Down

0 comments on commit a24aa4a

Please sign in to comment.