diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp index b1c0c3b161b20..e7674ecd7101e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp @@ -41,16 +41,54 @@ struct FoldTransposePattern : public OpRewritePattern { AffineMap map = op.getMatchingIndexingMap(operand); auto transposeOp = operand->get().getDefiningOp(); - if (!map.isIdentity() || !transposeOp) { + if (!transposeOp) { // push in original operand and its map. newIns.push_back(operand->get()); newMaps.push_back(map); continue; } newIns.push_back(transposeOp.getInput()); - // push in transposeOp's inverse permutation map. - newMaps.push_back(transposeOp.getMatchingIndexingMap( - transposeOp.getDpsInputOperand(0))); + // push in composed affine map. + newMaps.push_back( + transposeOp.getMatchingIndexingMap(transposeOp.getDpsInputOperand(0)) + .compose(map)); + changed = true; + } + if (!changed) + return failure(); + newMaps.push_back(op.getIndexingMapsArray().back()); + + rewriter.replaceOpWithNewOp( + op, newIns, op.getDpsInits()[0], op.getKindAttr(), + rewriter.getAffineMapArrayAttr(newMaps)); + return success(); + } +}; + +struct FoldBroadcastPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ElementwiseOp op, + PatternRewriter &rewriter) const override { + bool changed = false; + SmallVector newIns; + SmallVector newMaps; + for (OpOperand *operand : op.getDpsInputOperands()) { + AffineMap map = op.getMatchingIndexingMap(operand); + auto broadcastOp = operand->get().getDefiningOp(); + + if (!broadcastOp) { + // push in original operand and its map. + newIns.push_back(operand->get()); + newMaps.push_back(map); + continue; + } + + newIns.push_back(broadcastOp.getInput()); + // push in composed affine map. + newMaps.push_back( + broadcastOp.getMatchingIndexingMap(broadcastOp.getDpsInputOperand(0)) + .compose(map)); changed = true; } if (!changed) @@ -84,4 +122,5 @@ struct LinalgFoldIntoElementwisePass void mlir::linalg::populateLinalgFoldIntoElementwisePatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/elementwise/fold.mlir b/mlir/test/Dialect/Linalg/elementwise/fold.mlir index e83c32fb6a2cf..732b8a90f51d2 100644 --- a/mlir/test/Dialect/Linalg/elementwise/fold.mlir +++ b/mlir/test/Dialect/Linalg/elementwise/fold.mlir @@ -9,11 +9,11 @@ // CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> // CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32> // -func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { +func.func @unary_transpose(%A: tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { %empty = tensor.empty() : tensor<8x16x32xf32> - %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2] + %transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2] %result = linalg.elementwise kind=#linalg.elementwise_kind - ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> + ins(%transposed_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> return %result : tensor<8x16x32xf32> } @@ -28,16 +28,164 @@ func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> // CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor) outs(%[[C]] : tensor) -> tensor // CHECK-NEXT: return %[[RES]] : tensor // -func.func @binary_transposed(%A : tensor, %B: tensor, %C: tensor) -> tensor { +func.func @binary_transposed(%A: tensor, %B: tensor, %C: tensor) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %dim0 = tensor.dim %A, %c0 : tensor %dim1 = tensor.dim %A, %c1 : tensor %empty = tensor.empty(%dim1, %dim0) : tensor - %transposed_B = linalg.transpose ins(%B : tensor) outs(%empty : tensor) permutation = [1, 0] + %transposed_B = linalg.transpose ins(%B : tensor) outs(%empty : tensor) permutation = [1, 0] %result = linalg.elementwise kind=#linalg.elementwise_kind - ins(%A, %transposed_B : tensor, tensor) - outs(%C: tensor) -> tensor + ins(%A, %transposed_B : tensor, tensor) + outs(%C : tensor) -> tensor return %result : tensor } + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// +// CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { +// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind +// CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]] +// CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> +// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32> +// +func.func @unary_broadcasted(%A: tensor<8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %empty = tensor.empty() : tensor<8x16x32xf32> + %broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty : tensor<8x16x32xf32>) dimensions = [1] + %result = linalg.elementwise kind=#linalg.elementwise_kind + ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> + return %result : tensor<8x16x32xf32> +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)> +// +// CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor, %[[B:.+]]: tensor, %[[C:.+]]: tensor) -> tensor { +// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind +// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]] +// CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor) outs(%[[C]] : tensor) -> tensor +// CHECK-NEXT: return %[[RES]] : tensor +// +func.func @binary_broadcasted(%A: tensor, %B: tensor, %C: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %A, %c0 : tensor + %dim1 = tensor.dim %A, %c1 : tensor + + %empty = tensor.empty(%dim1, %dim0) : tensor + %broadcasted_B = linalg.broadcast ins(%B : tensor) outs(%empty : tensor) dimensions = [1] + %result = linalg.elementwise kind=#linalg.elementwise_kind + ins(%A, %broadcasted_B : tensor, tensor) + outs(%C : tensor) -> tensor + return %result : tensor +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// +// CHECK: func.func @fold_broadcast_after_transpose_fold(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32> { +// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind +// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]] +// CHECK-SAME: ins(%[[A]] : tensor<16xf32>) outs(%[[B]] : tensor<16x32xf32>) -> tensor<16x32xf32> +// CHECK-NEXT: return %[[RES]] : tensor<16x32xf32> +// +func.func @fold_broadcast_after_transpose_fold(%A: tensor<16xf32>, %B: tensor<16x32xf32>) -> tensor<16x32xf32> { + %empty_b = tensor.empty() : tensor<32x16xf32> + %broadcasted_A = linalg.broadcast ins(%A : tensor<16xf32>) outs(%empty_b : tensor<32x16xf32>) dimensions = [0] + + %empty_t = tensor.empty() : tensor<16x32xf32> + %transposed_A = linalg.transpose ins(%broadcasted_A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0] + + %result = linalg.elementwise kind=#linalg.elementwise_kind + ins(%transposed_A : tensor<16x32xf32>) outs(%B : tensor<16x32xf32>) -> tensor<16x32xf32> + return %result : tensor<16x32xf32> +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// +// CHECK: func.func @fold_transpose_after_broadcast_fold(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { +// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind +// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]] +// CHECK-SAME: ins(%[[A]] : tensor<32x16xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> +// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32> +// +func.func @fold_transpose_after_broadcast_fold(%A: tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> { + %empty_t = tensor.empty() : tensor<16x32xf32> + %transposed_A = linalg.transpose ins(%A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0] + + %empty_b = tensor.empty() : tensor<8x16x32xf32> + %broadcasted_A = linalg.broadcast ins(%transposed_A : tensor<16x32xf32>) outs(%empty_b : tensor<8x16x32xf32>) dimensions = [0] + + %result = linalg.elementwise kind=#linalg.elementwise_kind + ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> + return %result : tensor<8x16x32xf32> +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// +// CHECK: func.func @fold_broadcast_after_transpose_fold_binary(%[[A:.+]]: tensor, %[[B:.+]]: tensor, %[[C:.+]]: tensor) -> tensor { +// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind +// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]] +// CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor) outs(%[[C]] : tensor) -> tensor +// CHECK-NEXT: return %[[RES]] : tensor +// +func.func @fold_broadcast_after_transpose_fold_binary(%A: tensor, %B: tensor, %C: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %B, %c0 : tensor + %dim1 = tensor.dim %B, %c1 : tensor + + %empty_b = tensor.empty(%dim1, %dim0) : tensor + %broadcasted_A = linalg.broadcast ins(%A : tensor) outs(%empty_b : tensor) dimensions = [0] + + %empty_t = tensor.empty(%dim0, %dim1) : tensor + %transposed_A = linalg.transpose ins(%broadcasted_A : tensor) outs(%empty_t : tensor) permutation = [1, 0] + + %result = linalg.elementwise kind=#linalg.elementwise_kind + ins(%transposed_A, %B : tensor, tensor) outs(%C : tensor) -> tensor + return %result : tensor +} + +// ----- + +// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// +// CHECK: func.func @fold_transpose_after_broadcast_fold_binary(%[[A:.+]]: tensor, %[[B:.+]]: tensor, %[[C:.+]]: tensor) -> tensor { +// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind +// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]] +// CHECK-SAME: ins(%[[A]], %[[B]] : tensor, tensor) outs(%[[C]] : tensor) -> tensor +// CHECK-NEXT: return %[[RES]] : tensor +// +func.func @fold_transpose_after_broadcast_fold_binary(%A: tensor, %B: tensor, %C: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %dim0 = tensor.dim %B, %c0 : tensor + %dim1 = tensor.dim %B, %c1 : tensor + %dim2 = tensor.dim %B, %c2 : tensor + + %empty_t = tensor.empty(%dim1, %dim2) : tensor + %transposed_A = linalg.transpose ins(%A : tensor) outs(%empty_t : tensor) permutation = [1, 0] + + %empty_b = tensor.empty(%dim0, %dim1, %dim2) : tensor + %broadcasted_A = linalg.broadcast ins(%transposed_A : tensor) outs(%empty_b : tensor) dimensions = [0] + + %result = linalg.elementwise kind=#linalg.elementwise_kind + ins(%broadcasted_A, %B : tensor, tensor) outs(%C : tensor) -> tensor + return %result : tensor +}