Skip to content

Commit

Permalink
[mlir][Linalg] Miscalleneous enhancements to cover more fusion cases.
Browse files Browse the repository at this point in the history
Adds support for
- Dropping unit dimension loops for indexed_generic ops.
- Folding consecutive folding (or expanding) reshapes when the result
  (or src) is a scalar.
- Fixes to indexed_generic -> generic fusion when zero-dim tensors are
  involved.

Differential Revision: https://reviews.llvm.org/D90118
  • Loading branch information
MaheshRavishankar committed Oct 26, 2020
1 parent 0b2f4cd commit 78f37b7
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 43 deletions.
7 changes: 5 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -461,6 +461,10 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
ArrayRef<AffineMap> mapsConsumer,
MLIRContext *context) {
// Handle the corner case of the result being a rank 0 shaped type. Return an
// emtpy ArrayAttr.
if (mapsConsumer.empty() && !mapsProducer.empty())
return ArrayAttr::get(ArrayRef<Attribute>(), context);
if (mapsProducer.empty() || mapsConsumer.empty() ||
mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
mapsProducer.size() != mapsConsumer[0].getNumDims())
Expand Down Expand Up @@ -500,8 +504,7 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
ShapedType intermediateType,
ShapedType smallerType) -> bool {
return largerType.getRank() > intermediateType.getRank() &&
intermediateType.getRank() > smallerType.getRank() &&
smallerType.getRank() > 0;
intermediateType.getRank() > smallerType.getRank();
};
// Check if producer and consumer are both expanding dims.
if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),
Expand Down
103 changes: 69 additions & 34 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Expand Up @@ -26,6 +26,8 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"

#include <set>

#define DEBUG_TYPE "linalg-drop-unit-dims"

using namespace mlir;
Expand Down Expand Up @@ -145,15 +147,42 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
context);
}

/// Modify the region of indexed generic op to drop arguments corresponding to
/// loops that are unit trip count.
template <typename OpTy>
static LogicalResult
replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
PatternRewriter &rewriterp) {
return success();
}

template <>
LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
PatternRewriter &rewriter) {
OpBuilder::InsertionGuard guard(rewriter);
Block *entryBlock = &op.getOperation()->getRegion(0).front();
rewriter.setInsertionPointToStart(entryBlock);
Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
for (unsigned unitDimLoop : unitDims) {
entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
}
std::set<unsigned> orderedUnitDims(unitDims.begin(), unitDims.end());
for (unsigned i : llvm::reverse(orderedUnitDims))
entryBlock->eraseArgument(i);
return success();
}

namespace {
/// Pattern to fold unit-trip count loops in GenericOps.
// TODO: Generalize this to indexed-generic as well by modifying the region args
// as well.
struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
template <typename GenericOpTy>
struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
if (indexingMaps.empty())
return failure();

Expand All @@ -164,10 +193,10 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
if (!invertedMap)
return failure();
SmallVector<int64_t, 4> dims;
for (ShapedType shapedType : genericOp.getInputOutputShapedTypes())
for (ShapedType shapedType : op.getInputOutputShapedTypes())
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
DenseSet<unsigned> unitDims;
ArrayAttr iteratorTypes = genericOp.iterator_types();
ArrayAttr iteratorTypes = op.iterator_types();
for (auto expr : enumerate(invertedMap.getResults())) {
if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
if (dims[dimExpr.getPosition()] == 1 &&
Expand All @@ -183,7 +212,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
ArrayAttr newIndexingMapAttr =
replaceUnitDims(unitDims, indexingMaps, context);
if (!newIndexingMapAttr)
return genericOp.emitError("unable to compute modified indexing_maps");
return op.emitError("unable to compute modified indexing_maps");

// Compute the iterator types of the modified op by dropping the one-trip
// count loops.
Expand All @@ -193,10 +222,11 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
newIteratorTypes.push_back(attr.value());
}

rewriter.startRootUpdate(genericOp);
genericOp.indexing_mapsAttr(newIndexingMapAttr);
genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
rewriter.finalizeRootUpdate(genericOp);
rewriter.startRootUpdate(op);
op.indexing_mapsAttr(newIndexingMapAttr);
op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
rewriter.finalizeRootUpdate(op);
return success();
}
};
Expand Down Expand Up @@ -263,25 +293,27 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
namespace {

/// Pattern to replace tensors operands/results that are unit extents.
struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
template <typename GenericOpTy>
struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
// TODO: support init_tensors and reductions.
if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty())
if (!op.hasTensorSemantics() || !op.init_tensors().empty())
return failure();

MLIRContext *context = rewriter.getContext();
Location loc = genericOp.getLoc();
Location loc = op.getLoc();

SmallVector<AffineMap, 4> newIndexingMaps;
SmallVector<ArrayAttr, 4> reassociationMaps;
SmallVector<ShapedType, 4> newInputOutputTypes;
bool doCanonicalization = false;
for (auto it : llvm::zip(genericOp.getIndexingMaps(),
genericOp.getInputOutputShapedTypes())) {
for (auto it :
llvm::zip(op.getIndexingMaps(), op.getInputOutputShapedTypes())) {
auto replacementInfo = replaceUnitExtents(
std::get<0>(it), std::get<1>(it).cast<RankedTensorType>(), context);
std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
context);
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);
Expand Down Expand Up @@ -313,41 +345,40 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
return res;
};

SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
SmallVector<Value, 4> newOutputBuffers =
insertReshapes(genericOp.output_buffers());
SmallVector<Value, 4> newInitTensors =
insertReshapes(genericOp.init_tensors());
insertReshapes(op.output_buffers());
SmallVector<Value, 4> newInitTensors = insertReshapes(op.init_tensors());

// If any result type change, insert a reshape to convert from the original
// type to the new type.
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(genericOp.getNumResults());
for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
GenericOp replacementOp = rewriter.create<GenericOp>(
resultTypes.reserve(op.getNumResults());
for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
newIndexingMaps,
llvm::to_vector<4>(
genericOp.iterator_types().getAsValueRange<StringAttr>()));
rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
op.iterator_types().template getAsValueRange<StringAttr>()));
rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
replacementOp.region().begin());

// If any result tensor has a modified shape, then add reshape to recover
// the original shape.
SmallVector<Value, 4> resultReplacements;
for (auto result : llvm::enumerate(replacementOp.getResults())) {
unsigned index = result.index() + replacementOp.getNumOperands();
RankedTensorType origResultType = genericOp.getResult(result.index())
RankedTensorType origResultType = op.getResult(result.index())
.getType()
.cast<RankedTensorType>();
.template cast<RankedTensorType>();
if (origResultType != result.value().getType())
resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
loc, origResultType, result.value(), reassociationMaps[index]));
else
resultReplacements.push_back(result.value());
}
rewriter.replaceOp(genericOp, resultReplacements);
rewriter.replaceOp(op, resultReplacements);
return success();
}
};
Expand Down Expand Up @@ -467,7 +498,10 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
/// broadcasting.
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
patterns
.insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
ReplaceUnitExtentTensors<GenericOp>,
ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldReshapeOpWithUnitExtent>(context);
}
Expand All @@ -481,7 +515,8 @@ struct LinalgFoldUnitExtentDimsPass
FuncOp funcOp = getFunction();
MLIRContext *context = funcOp.getContext();
if (foldOneTripLoopsOnly)
patterns.insert<FoldUnitDimLoops>(context);
patterns.insert<FoldUnitDimLoops<GenericOp>,
FoldUnitDimLoops<IndexedGenericOp>>(context);
else
populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
applyPatternsAndFoldGreedily(funcOp.getBody(), patterns);
Expand Down
20 changes: 13 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Expand Up @@ -109,13 +109,19 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
// consumer's operand.
// If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
// generic op. In this case, there are no indices in block arguments.
unsigned numProducerIndices =
isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
unsigned numConsumerIndices =
isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation())
? producer.getNumLoops()
: 0;
unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation())
? consumer.getNumLoops()
: 0;
unsigned numFusedOpIndices =
(isa<IndexedGenericOp>(producer.getOperation()) ||
isa<IndexedGenericOp>(consumer.getOperation()))
? std::max(producer.getNumLoops(), consumer.getNumLoops())
: 0;
// Firstly, add all the indices to the block arguments.
for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
i < e; ++i)
for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
fusedBlock->addArgument(rewriter.getIndexType());
// Map the arguments for the unmodified args from the consumer.
for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
Expand All @@ -129,7 +135,7 @@ static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
auto newIndex = rewriter.create<mlir::AffineApplyOp>(
producer.getLoc(),
consumerToProducerLoopsMap.getSubMap(producerArg.index()),
fusedBlock->getArguments().take_front(nloops));
fusedBlock->getArguments().take_front(numFusedOpIndices));
mapper.map(producerArg.value(), newIndex);
} else {
mapper.map(producerArg.value(),
Expand Down
55 changes: 55 additions & 0 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Expand Up @@ -43,6 +43,34 @@ func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf3

// -----

// -----

func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
-> tensor<f32> {
%0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
tensor<1x1x1xf32> into tensor<1xf32>
%1 = linalg.tensor_reshape %0 [] : tensor<1xf32> into tensor<f32>
return %1 : tensor<f32>
}
// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
// CHECK: linalg.tensor_reshape %{{.*}} []
// CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>

// -----

func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
-> memref<f32> {
%0 = linalg.reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
memref<1x1x1xf32> into memref<1xf32>
%1 = linalg.reshape %0 [] : memref<1xf32> into memref<f32>
return %1 : memref<f32>
}
// CHECK-LABEL: collapsing_memref_reshapes_to_zero
// CHECK: linalg.reshape %{{.*}} []
// CHECK-SAME: memref<1x1x1xf32> into memref<f32>

// -----

func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x?x?x?x?xf32>
{
%0 = linalg.tensor_reshape %arg0
Expand Down Expand Up @@ -106,6 +134,33 @@ func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x?x?x?x?xf32

// -----

func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
-> tensor<1x1x1xf32> {
%0 = linalg.tensor_reshape %arg0 [] : tensor<f32> into tensor<1xf32>
%1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
tensor<1xf32> into tensor<1x1x1xf32>
return %1 : tensor<1x1x1xf32>
}
// CHECK-LABEL: expanding_tensor_reshapes_to_zero
// CHECK: linalg.tensor_reshape %{{.*}} []
// CHECK-SAME: tensor<f32> into tensor<1x1x1xf32>

// -----

func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
-> memref<1x1x1xf32> {
%0 = linalg.reshape %arg0 [] : memref<f32> into memref<1xf32>
%1 = linalg.reshape %0
[affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
memref<1xf32> into memref<1x1x1xf32>
return %1 : memref<1x1x1xf32>
}
// CHECK-LABEL: expanding_memref_reshapes_to_zero
// CHECK: linalg.reshape %{{.*}} []
// CHECK-SAME: memref<f32> into memref<1x1x1xf32>

// -----

func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
{
%0 = linalg.tensor_reshape %arg0
Expand Down

0 comments on commit 78f37b7

Please sign in to comment.