Skip to content

Commit

Permalink
[mlir][tosa] Add tosa.depthwise lowering to existing linalg.depthwise…
Browse files Browse the repository at this point in the history
…_conv

Implements support for undialated depthwise convolution using the existing
depthwise convolution operation. Once convolutions migrate to yaml defined
versions we can rewrite for cleaner implementation.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D101579
  • Loading branch information
rsuderman committed May 5, 2021
1 parent 1d767b1 commit 7abb56c
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 142 deletions.
321 changes: 187 additions & 134 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Expand Up @@ -59,6 +59,37 @@ static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min,
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
}

static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
Attribute padAttr, OpBuilder &rewriter) {
// Input should be padded if necessary.
if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
return input;

ShapedType inputTy = input.getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
auto inputShape = inputTy.getShape();

assert((inputShape.size() * 2) == pad.size());

SmallVector<int64_t, 4> paddedShape;
SmallVector<OpFoldResult, 8> lowIndices;
SmallVector<OpFoldResult, 8> highIndices;
for (int i = 0, s = inputShape.size(); i < s; i++) {
auto lowPad = pad[i * 2];
auto highPad = pad[i * 2 + 1];
paddedShape.push_back(inputShape[i] + highPad + lowPad);
lowIndices.push_back(rewriter.getIndexAttr(lowPad));
highIndices.push_back(rewriter.getIndexAttr(highPad));
}

Value padValue = rewriter.create<ConstantOp>(loc, padAttr);

return linalg::PadTensorOp::createPadScalarOp(
RankedTensorType::get(paddedShape, inputETy), input, padValue,
lowIndices, highIndices, loc, rewriter)
.result();
}

static Value
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
ArrayRef<Type> resultTypes,
Expand Down Expand Up @@ -757,6 +788,138 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
return success();
}

static LogicalResult
convolutionMatchAndRewriterHelper(Operation *op,
ConversionPatternRewriter &rewriter) {
Location loc = op->getLoc();
Value input = op->getOperand(0);
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);

ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType weightTy = weight.getType().cast<ShapedType>();
ShapedType biasTy = bias.getType().cast<ShapedType>();
ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();

Type inputETy = inputTy.getElementType();
Type weightETy = weightTy.getElementType();
Type biasETy = biasTy.getElementType();
Type resultETy = resultTy.getElementType();

auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();

if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"tosa.conv ops require static shapes");

auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();

// TODO(suderman): Support other types.
if (!inputETy.isF32() || !weightETy.isF32() || !biasETy.isF32() ||
!resultETy.isF32())
return failure();

// Apply padding as necessary.
Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
llvm::SmallVector<int64_t> pad;
pad.resize(2, 0);
getValuesFromIntArrayAttribute(padAttr, pad);
pad.resize(pad.size() + 2, 0);

input = applyPad(loc, input, pad, zeroAttr, rewriter);

// We need to transpose the Conv2DOp kernel to line up the last input/output
// kernels.
// TODO(suderman): Eventually we will support specifying the filter channel
// ordering then we can avoid transposing the kernel.
if (isa<tosa::Conv2DOp>(op)) {
int32_t weightRank = weightTy.getRank();
SmallVector<int64_t> permutation, transposeWeightShape;
permutation.resize(weightRank, 0);
transposeWeightShape.resize(weightRank, 0);
for (int i = 0; i < weightRank; i++) {
permutation[i] = (i + 1) % weightRank;
transposeWeightShape[i] = weightShape[permutation[i]];
}

Value permutationValue = rewriter.create<ConstantOp>(
loc, DenseIntElementsAttr::get(
RankedTensorType::get({weightRank}, rewriter.getI64Type()),
permutation));
Type newWeightTy = RankedTensorType::get(transposeWeightShape, biasETy);

weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
permutationValue);
}

// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));

Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultTy.getElementType());

Value biasBroadcast =
rewriter
.create<linalg::GenericOp>(
loc, resultTy, bias, initTensor, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);

// Extract the attributes for convolution.
llvm::SmallVector<int64_t> stride, dilation;
getValuesFromIntArrayAttribute(strideTosaAttr, stride);
getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);

// Create the convolution op.
auto strideAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), stride);
auto dilationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), dilation);

if (isa<tosa::Conv2DOp>(op)) {
rewriter.replaceOpWithNewOp<linalg::ConvInputNHWCFilterHWCFOp>(
op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast},
dilationAttr, strideAttr);
return success();
}

if (isa<tosa::DepthwiseConv2DOp>(op)) {
if (llvm::any_of(dilation, [](int64_t d) { return d > 1; }))
return failure();

ShapedType linalgConvTy =
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
weightShape[2], weightShape[3]},
resultETy);

Value biasReshape =
rewriter.create<tosa::ReshapeOp>(loc, linalgConvTy, biasBroadcast);
Value conv = rewriter
.create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
loc, linalgConvTy, ValueRange{input, weight},
ValueRange{biasReshape}, strideAttr)
.getResult(0);

Value reshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
rewriter.replaceOp(op, reshape);
return success();
}

return failure();
}

namespace {

template <typename SrcOp>
Expand All @@ -770,6 +933,17 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
}
};

template <typename T>
class ConvConverter : public OpConversionPattern<T> {
public:
using OpConversionPattern<T>::OpConversionPattern;
LogicalResult
matchAndRewrite(T op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
return convolutionMatchAndRewriterHelper(op, rewriter);
}
};

class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
public:
using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
Expand All @@ -782,8 +956,8 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {

auto outputTy = op.getType().cast<ShapedType>();
auto outputElementTy = outputTy.getElementType();
auto zero_attr = rewriter.getZeroAttr(outputElementTy);
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, outputTy.getShape(), outputTy.getElementType());
Value zeroTensor =
Expand Down Expand Up @@ -862,108 +1036,6 @@ class FullyConnectedConverter
}
};

class Conv2DConverter : public OpConversionPattern<tosa::Conv2DOp> {
public:
using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::Conv2DOp op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value input = op.input();
Value weight = op.weight();
Value bias = op.bias();

ShapedType inputTy = input.getType().cast<ShapedType>();
ShapedType weightTy = weight.getType().cast<ShapedType>();
ShapedType biasTy = bias.getType().cast<ShapedType>();
ShapedType resultTy = op.getType().cast<ShapedType>();

Type inputETy = inputTy.getElementType();
Type weightETy = weightTy.getElementType();
Type biasETy = biasTy.getElementType();
Type resultETy = resultTy.getElementType();

if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
!biasTy.hasStaticShape() || !resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"tosa.conv2d requires static shapes");

auto inputShape = inputTy.getShape();
auto weightShape = weightTy.getShape();

// TODO(suderman): Support other types.
if (!inputETy.isF32() || !weightETy.isF32() || !biasETy.isF32() ||
!resultETy.isF32())
return failure();

// Broadcast the initial value to the output tensor before convolving.
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.push_back(AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(3)},
rewriter.getContext()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));

Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultTy.getElementType());
Value biasBroadcast =
rewriter
.create<linalg::GenericOp>(
loc, resultTy, bias, initTensor, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);

// Transpose weights tensor to be in dim order: spatial dims,
// input channels, and output channels.
SmallVector<int64_t> permutation{1, 2, 3, 0};
auto permutationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4}, rewriter.getI64Type()), permutation);
Value permutationValue = rewriter.create<ConstantOp>(loc, permutationAttr);

SmallVector<int64_t> newKernelShape{weightShape[1], weightShape[2],
weightShape[3], weightShape[0]};
Type newKernelTy = RankedTensorType::get(newKernelShape, biasETy);

Value transposedKernel = rewriter.create<tosa::TransposeOp>(
loc, newKernelTy, weight, permutationValue);

// Extract the attributes for convolution.
llvm::SmallVector<int64_t> stride, dilation, pad;
getValuesFromIntArrayAttribute(op.stride(), stride);
getValuesFromIntArrayAttribute(op.dilation(), dilation);
getValuesFromIntArrayAttribute(op.pad(), pad);

// Input should be padded if necessary.
if (llvm::any_of(pad, [](int64_t p) { return p; })) {
llvm::SmallVector<int64_t, 8> newPad{0, 0, pad[0], pad[1],
pad[2], pad[3], 0, 0};
auto padAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI64Type()), newPad);
Value padValue = rewriter.create<ConstantOp>(loc, padAttr);

SmallVector<int64_t, 4> paddedShape{
inputShape[0], inputShape[1] + pad[0] + pad[1],
inputShape[2] + pad[2] + pad[3], inputShape[3]};
Type paddedTy = RankedTensorType::get(paddedShape, inputETy);
input = rewriter.create<tosa::PadOp>(loc, paddedTy, input, padValue);
}

auto strideAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), stride);
auto dilationAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()), dilation);

auto convOp = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
loc, resultTy, ValueRange{input, transposedKernel},
ValueRange{biasBroadcast}, dilationAttr, strideAttr);

rewriter.replaceOp(op, convOp.getResult(0));
return success();
}
};

class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
public:
Expand Down Expand Up @@ -2102,7 +2174,6 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {

ShapedType resultTy = op.getType().template cast<ShapedType>();
Type outElementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();

if (!inputTy.hasStaticShape())
return failure();
Expand All @@ -2127,43 +2198,24 @@ class Pool2dConverter : public OpRewritePattern<SrcOp> {
return rewriter.notifyMatchFailure(
op, "Unsupported initial value for tosa.maxpool_2d op");

// Apply padding as necessary.
llvm::SmallVector<int64_t> pad;
pad.resize(2, 0);
getValuesFromIntArrayAttribute(op.pad(), pad);
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, initialAttr, rewriter);

Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);

SmallVector<int64_t> kernel, stride, pad;
SmallVector<int64_t> kernel, stride;
getValuesFromIntArrayAttribute(op.kernel(), kernel);
getValuesFromIntArrayAttribute(op.stride(), stride);
getValuesFromIntArrayAttribute(op.pad(), pad);

Attribute strideAttr = rewriter.getI64VectorAttr(stride);
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
int64_t kernelSize = kernel[0] * kernel[1];

// If non-zero padding we need to pad the input
if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) {
SmallVector<int64_t, 4> paddedShape;
for (int64_t i = 0; i < rank; i++)
paddedShape.push_back(inputTy.getDimSize(i));

paddedShape[1] += pad[0] + pad[1];
paddedShape[2] += pad[2] + pad[3];

OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
OpFoldResult heightLowPadIndex = rewriter.getIndexAttr(pad[0]);
OpFoldResult heightHighPadIndex = rewriter.getIndexAttr(pad[1]);
OpFoldResult widthLowPadIndex = rewriter.getIndexAttr(pad[2]);
OpFoldResult widthHighPadIndex = rewriter.getIndexAttr(pad[3]);

SmallVector<OpFoldResult, 4> lowIndices = {zeroIndex, heightLowPadIndex,
widthLowPadIndex, zeroIndex};
SmallVector<OpFoldResult, 4> highIndices = {zeroIndex, heightHighPadIndex,
widthHighPadIndex, zeroIndex};

input = linalg::PadTensorOp::createPadScalarOp(
RankedTensorType::get(paddedShape, inElementTy), input,
initialValue, lowIndices, highIndices, loc, rewriter)
.result();
}

// Create the linalg op that performs pooling.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultTy.getShape(), resultTy.getElementType());

Expand Down Expand Up @@ -2277,7 +2329,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
ReduceConverter<tosa::ReduceProdOp>,
ArgMaxConverter,
ConcatConverter,
Conv2DConverter,
ConvConverter<tosa::Conv2DOp>,
ConvConverter<tosa::DepthwiseConv2DOp>,
GatherConverter,
PadConverter,
ReshapeConverter,
Expand Down

0 comments on commit 7abb56c

Please sign in to comment.