Skip to content

Commit

Permalink
[mlir][tosa] Add broadcasting case for tosa.resize degenerate case
Browse files Browse the repository at this point in the history
When the resize is ?x1x1x?, the tosa.resize operation broadcasts the
input and (when quantized) applies a scaling factor. Updated the resize
operation to not use a tensor.extract operation, instead broadcasting
the only positional value as necessary.

Moved the tosa.resize tests to their own mlir test due to increased
complexity. Also corrected a bug where tosa.resize for bilinear-floating
point was not applying the correct scaling.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D136299
  • Loading branch information
rsuderman committed Oct 20, 2022
1 parent 8c7a1f8 commit 4309bb2
Show file tree
Hide file tree
Showing 3 changed files with 641 additions and 407 deletions.
192 changes: 155 additions & 37 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -1321,7 +1322,104 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
};

class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
// Handle the case where the resize operation is a regular broadcast. We
// perform this part separately to avoid generating Extract operations which
// are difficult to vectorize / optimize.
class BroadcastResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
public:
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::ResizeOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
ImplicitLocOpBuilder builder(loc, rewriter);
auto input = op.getInput();
auto inputTy = input.getType().cast<RankedTensorType>();
auto resultTy = op.getType().cast<RankedTensorType>();

auto imageH = inputTy.getDimSize(1);
auto imageW = inputTy.getDimSize(2);

if (imageH != 1 || imageW != 1) {
return rewriter.notifyMatchFailure(
op, "tosa.resize is not a pure broadcast operation");
}

// TODO(suderman): These string values should be declared the TOSA dialect.
if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
return failure();

const bool isBilinear = op.getMode() == "BILINEAR";

SmallVector<int32_t> scale;
getValuesFromIntArrayAttribute(op.getScale(), scale);

// Collapse the 1 dimensions away.
SmallVector<ReassociationExprs, 4> collapseMap(2);
collapseMap[0].push_back(builder.getAffineDimExpr(0));
collapseMap[1].push_back(builder.getAffineDimExpr(1));
collapseMap[1].push_back(builder.getAffineDimExpr(2));
collapseMap[1].push_back(builder.getAffineDimExpr(3));

auto collapseTy =
RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
inputTy.getElementType());
Value collapse =
builder.create<tensor::CollapseShapeOp>(collapseTy, input, collapseMap);

// Broadcast input to the output shape.
llvm::SmallVector<Value> outputDynSize;
if (inputTy.isDynamicDim(0))
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));

if (inputTy.isDynamicDim(3))
outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));

llvm::SmallVector<AffineExpr> inputExprs{
rewriter.getAffineDimExpr(0),
rewriter.getAffineDimExpr(3),
};

auto inputMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
inputExprs, builder.getContext());
auto resultMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
SmallVector<StringRef> iterators(4, getParallelIteratorTypeName());

Value empty = builder.create<tensor::EmptyOp>(
resultTy.getShape(), resultTy.getElementType(), outputDynSize);

auto generic = builder.create<linalg::GenericOp>(
resultTy, ValueRange{collapse}, ValueRange{empty},
ArrayRef<AffineMap>{inputMap, resultMap}, iterators,
[=](OpBuilder &b, Location loc, ValueRange args) {
Value value = args[0];
// This is the quantized case.
if (inputTy.getElementType() != resultTy.getElementType()) {
value =
b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);

if (isBilinear && scale[0] != 0) {
Value scaleY = b.create<arith::ConstantOp>(
loc, b.getI32IntegerAttr(scale[0]));
value = b.create<arith::MulIOp>(loc, value, scaleY);
}

if (isBilinear && scale[2] != 0) {
Value scaleX = b.create<arith::ConstantOp>(
loc, b.getI32IntegerAttr(scale[2]));
value = b.create<arith::MulIOp>(loc, value, scaleX);
}
}

b.create<linalg::YieldOp>(loc, value);
});

rewriter.replaceOp(op, generic.getResult(0));
return success();
}
};

class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
public:
using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;

Expand Down Expand Up @@ -1351,10 +1449,11 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};

Value resize = input;
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()));
rewriter.replaceOp(op, genericOp.getResult(0));
resize = genericOp.getResult(0);

OpBuilder::InsertionGuard regionGuard(rewriter);
rewriter.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
Expand Down Expand Up @@ -1496,7 +1595,6 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);

// Clamp the to be within the bounds of the input image.

iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter);
ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter);

Expand All @@ -1510,10 +1608,9 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
loc, input, ValueRange{batch, iy, ix, channel});

rewriter.create<linalg::YieldOp>(loc, result);

return success();
} else {
// The mode here must be BILINEAR. This has been checked above.
// The mode here must be BILINEAR.
assert(op.getMode() == "BILINEAR");
Value y0 = iy;
Value x0 = ix;

Expand Down Expand Up @@ -1548,7 +1645,9 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {

if (floatingPointMode) {
Value rightPart = dx;
Value leftPart = rewriter.create<arith::SubFOp>(loc, xScaleN, dx);
auto oneVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getF32FloatAttr(1.0f));
Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);

y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
Expand All @@ -1559,46 +1658,59 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);

Value bottomPart = dy;
Value topPart = rewriter.create<arith::SubFOp>(loc, yScaleN, dy);
Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);

rewriter.create<linalg::YieldOp>(loc, result);
return success();
}
y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);

if (resultElementTy.getIntOrFloatBitWidth() > 32) {
dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
}

Value rightPart = dx;
Value leftPart = rewriter.create<arith::SubIOp>(loc, xScaleN, dx);

y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
} else {
// Perform in quantized space.
y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);

if (resultElementTy.getIntOrFloatBitWidth() > 32) {
dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
}

y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
Value topAcc, bottomAcc;
if (imageW == 1) {
topAcc = rewriter.create<arith::MulIOp>(loc, y0x0, xScaleN);
bottomAcc = rewriter.create<arith::MulIOp>(loc, y1x0, xScaleN);
} else {
Value rightPart = dx;
Value leftPart = rewriter.create<arith::SubIOp>(loc, xScaleN, dx);

y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);

y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
}

Value bottomPart = dy;
Value topPart = rewriter.create<arith::SubIOp>(loc, yScaleN, dy);
topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
bottomAcc = rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
Value result;
if (imageH == 1) {
result = rewriter.create<arith::MulIOp>(loc, topAcc, yScaleN);
} else {
Value bottomPart = dy;
Value topPart = rewriter.create<arith::SubIOp>(loc, yScaleN, dy);
topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
bottomAcc =
rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
}

rewriter.create<linalg::YieldOp>(loc, result);
return success();
}
}

return failure();
rewriter.replaceOp(op, resize);
return success();
}
};

Expand Down Expand Up @@ -2210,6 +2322,13 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {

void mlir::tosa::populateTosaToLinalgConversionPatterns(
RewritePatternSet *patterns) {

// We have multiple resize coverters to handle degenerate cases.
patterns->add<GenericResizeConverter>(patterns->getContext(),
/*benefit=*/100);
patterns->add<BroadcastResizeConverter>(patterns->getContext(),
/*benefit=*/200);

patterns->add<
// clang-format off
PointwiseConverter<tosa::AddOp>,
Expand Down Expand Up @@ -2262,7 +2381,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
ReshapeConverterExpand,
ReshapeConverterCollapseExpand,
RescaleConverter,
ResizeConverter,
ReverseConverter,
TableConverter,
TileConverter,
Expand Down
Loading

0 comments on commit 4309bb2

Please sign in to comment.