Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,25 @@ using namespace mlir::linalg;
#define DEBUG_TYPE "linalg-fold-into-elementwise"

namespace {
struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
template <typename ProducerOpTy>
struct ElementwiseOpFolder {
static bool fold(OpOperand *operand, AffineMap consumerMap,
SmallVectorImpl<Value> &newIns,
SmallVectorImpl<AffineMap> &newMaps) {
auto producerOp = operand->get().getDefiningOp<ProducerOpTy>();
if (!producerOp)
return false;
newIns.push_back(producerOp.getInput());
// push in composed affine map
newMaps.push_back(
producerOp.getMatchingIndexingMap(producerOp.getDpsInputOperand(0))
.compose(consumerMap));
return true;
}
};

template <typename... ProducerOps>
struct FoldIntoElementwisePattern : public OpRewritePattern<ElementwiseOp> {
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ElementwiseOp op,
Expand All @@ -38,20 +56,17 @@ struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
SmallVector<Value> newIns;
SmallVector<AffineMap> newMaps;
for (OpOperand *operand : op.getDpsInputOperands()) {
AffineMap map = op.getMatchingIndexingMap(operand);
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();

if (!map.isIdentity() || !transposeOp) {
AffineMap consumerMap = op.getMatchingIndexingMap(operand);
const bool folded = (ElementwiseOpFolder<ProducerOps>::fold(
operand, consumerMap, newIns, newMaps) ||
...);
if (folded) {
changed = true;
} else {
// push in original operand and its map.
newIns.push_back(operand->get());
newMaps.push_back(map);
continue;
newMaps.push_back(consumerMap);
}
newIns.push_back(transposeOp.getInput());
// push in transposeOp's inverse permutation map.
newMaps.push_back(transposeOp.getMatchingIndexingMap(
transposeOp.getDpsInputOperand(0)));
changed = true;
}
if (!changed)
return failure();
Expand Down Expand Up @@ -83,5 +98,6 @@ struct LinalgFoldIntoElementwisePass

void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
RewritePatternSet &patterns) {
patterns.add<FoldTransposePattern>(patterns.getContext());
patterns.add<FoldIntoElementwisePattern<TransposeOp, BroadcastOp>>(
patterns.getContext());
}
162 changes: 155 additions & 7 deletions mlir/test/Dialect/Linalg/elementwise/fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
//
func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
func.func @unary_transpose(%A: tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%empty = tensor.empty() : tensor<8x16x32xf32>
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %result : tensor<8x16x32xf32>
}

Expand All @@ -28,16 +28,164 @@ func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
//
func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
func.func @binary_transposed(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>

%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
//
// CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
// CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
//
func.func @unary_broadcasted(%A: tensor<8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%empty = tensor.empty() : tensor<8x16x32xf32>
%broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty : tensor<8x16x32xf32>) dimensions = [1]
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %result : tensor<8x16x32xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
//
// CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
//
func.func @binary_broadcasted(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>

%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
%broadcasted_B = linalg.broadcast ins(%B : tensor<?xf32>) outs(%empty : tensor<?x?xf32>) dimensions = [1]
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%A, %broadcasted_B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
//
// CHECK: func.func @fold_broadcast_after_transpose_fold(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]] : tensor<16xf32>) outs(%[[B]] : tensor<16x32xf32>) -> tensor<16x32xf32>
// CHECK-NEXT: return %[[RES]] : tensor<16x32xf32>
//
func.func @fold_broadcast_after_transpose_fold(%A: tensor<16xf32>, %B: tensor<16x32xf32>) -> tensor<16x32xf32> {
%empty_b = tensor.empty() : tensor<32x16xf32>
%broadcasted_A = linalg.broadcast ins(%A : tensor<16xf32>) outs(%empty_b : tensor<32x16xf32>) dimensions = [0]

%empty_t = tensor.empty() : tensor<16x32xf32>
%transposed_A = linalg.transpose ins(%broadcasted_A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]

%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
ins(%transposed_A : tensor<16x32xf32>) outs(%B : tensor<16x32xf32>) -> tensor<16x32xf32>
return %result : tensor<16x32xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
//
// CHECK: func.func @fold_transpose_after_broadcast_fold(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]] : tensor<32x16xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
//
func.func @fold_transpose_after_broadcast_fold(%A: tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%empty_t = tensor.empty() : tensor<16x32xf32>
%transposed_A = linalg.transpose ins(%A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]

%empty_b = tensor.empty() : tensor<8x16x32xf32>
%broadcasted_A = linalg.broadcast ins(%transposed_A : tensor<16x32xf32>) outs(%empty_b : tensor<8x16x32xf32>) dimensions = [0]

%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %result : tensor<8x16x32xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
//
// CHECK: func.func @fold_broadcast_after_transpose_fold_binary(%[[A:.+]]: tensor<?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
//
func.func @fold_broadcast_after_transpose_fold_binary(%A: tensor<?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %B, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %B, %c1 : tensor<?x?xf32>

%empty_b = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
%broadcasted_A = linalg.broadcast ins(%A : tensor<?xf32>) outs(%empty_b : tensor<?x?xf32>) dimensions = [0]

%empty_t = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
%transposed_A = linalg.transpose ins(%broadcasted_A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]

%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%transposed_A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
//
// CHECK: func.func @fold_transpose_after_broadcast_fold_binary(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?x?xf32>, %[[C:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?x?xf32>) outs(%[[C]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NEXT: return %[[RES]] : tensor<?x?x?xf32>
//
func.func @fold_transpose_after_broadcast_fold_binary(%A: tensor<?x?xf32>, %B: tensor<?x?x?xf32>, %C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim0 = tensor.dim %B, %c0 : tensor<?x?x?xf32>
%dim1 = tensor.dim %B, %c1 : tensor<?x?x?xf32>
%dim2 = tensor.dim %B, %c2 : tensor<?x?x?xf32>

%empty_t = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
%transposed_A = linalg.transpose ins(%A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]

%empty_b = tensor.empty(%dim0, %dim1, %dim2) : tensor<?x?x?xf32>
%broadcasted_A = linalg.broadcast ins(%transposed_A : tensor<?x?xf32>) outs(%empty_b : tensor<?x?x?xf32>) dimensions = [0]

%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%broadcasted_A, %B : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%C : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}