Skip to content

Commit

Permalink
[mlir][TilingInterface] Enabling tiling tensor.pad using `TilingInt…
Browse files Browse the repository at this point in the history
…erface`.

Update the implementation of `TilingInterface` for `tensor.pad`
operations to allow tiling the op using the existing patterns for the
interface. Verify that tests that pass with existing pad tiling
patterns producer the same results through TilingInterface patterns.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D132720
  • Loading branch information
Mahesh Ravishankar committed Aug 26, 2022
1 parent e117137 commit a235562
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 5 deletions.
Expand Up @@ -54,7 +54,7 @@ Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
/// ops on `affine.apply` and Affine dialect already depends on TensorOps. In
/// order to break the cyclic dependency (TensorOps->AffineOps->TensorOps) the
/// implementation is moved to a separate library.
void registerTilingOpInterfaceExternalModels(mlir::DialectRegistry &registry);
void registerTilingInterfaceExternalModels(mlir::DialectRegistry &registry);

} // namespace tensor
} // namespace mlir
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/InitAllDialects.h
Expand Up @@ -125,7 +125,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
tensor::registerTilingOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
}

Expand Down
15 changes: 14 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
Expand Up @@ -22,6 +22,8 @@ namespace {
struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {

SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
ReifiedRankedShapedTypeDims reifiedShapes;
ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
Expand Down Expand Up @@ -69,6 +71,17 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
return {};
return {result};
}

LogicalResult
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) const {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
}
};

} // namespace
Expand Down Expand Up @@ -281,7 +294,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
return createPadOfExtractSlice();
}

void mlir::tensor::registerTilingOpInterfaceExternalModels(
void mlir::tensor::registerTilingInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
Expand Down
140 changes: 140 additions & 0 deletions mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
@@ -0,0 +1,140 @@
// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -resolve-shaped-type-result-dims -cse -split-input-file %s | FileCheck %s

// 2D tiling of dynamic 2D pad tensor op.
func.func @dynamic_2d_pad_tensor(%input_tensor: tensor<?x?xf32>,
%pad_value: f32) -> tensor<?x?xf32> {
%0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : f32
} {__internal_linalg_transform__ = "pad_2dtiling"}: tensor<?x?xf32> to tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 8)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 7)>
// CHECK: func @dynamic_2d_pad_tensor(
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[DIM_IN0:.+]] = tensor.dim %[[IN]], %[[C0]]
// CHECK: %[[DIM0:.+]] = affine.apply #[[MAP0]]()[%[[DIM_IN0]]]
// CHECK: %[[DIM_IN1:.+]] = tensor.dim %[[IN]], %[[C1]]
// CHECK: %[[DIM1:.+]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]]
// CHECK: %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[DIM0]] step %[[C2]]
// CHECK: scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
// CHECK: %[[SWAP_RESULT:.*]] = scf.if
// CHECK: tensor.generate
// CHECK: else
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SLICE]]
// CHECK: tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
// CHECK: return %[[RESULT]]

// -----

func.func @dynamic_2d_pad_tensor_inner_tiling(%input_tensor: tensor<?x?xf32>,
%pad_value: f32) -> tensor<?x?xf32> {
%0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : f32
} {__internal_linalg_transform__ = "pad_inner_tiling"}: tensor<?x?xf32> to tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>
// CHECK: func @dynamic_2d_pad_tensor_inner_tiling(
// CHECK-SAME: %[[IN:.*]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM_IN0:.*]] = tensor.dim %[[IN]], %[[C0]]
// CHECK: %[[DIM0:.*]] = affine.apply #[[MAP0]]()[%[[DIM_IN0]]]
// CHECK: %[[DIM_IN1:.*]] = tensor.dim %[[IN]], %[[C1]]
// CHECK: %[[DIM1:.*]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]]
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
// CHECK: %[[SWAP_RESULT:.*]] = scf.if
// CHECK: tensor.generate
// CHECK: else
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[{{.*}}, {{.*}}]
// CHECK: tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][%[[C0]], {{.*}}] [%[[DIM0]], {{.*}}] [1, 1]
// CHECK: return %[[RESULT]]

// -----

func.func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
%pad_value: f32) -> tensor<15x16xf32> {
%0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : f32
} {__internal_linalg_transform__ = "pad_2dtiling"} : tensor<7x9xf32> to tensor<15x16xf32>
return %0 : tensor<15x16xf32>
}
// CHECK-LABEL: func @static_pad_tensor(
// CHECK-SAME: %[[IN:.*]]: tensor<7x9xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]]
// CHECK: scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
// CHECK: %[[SWAP_RESULT:.*]] = scf.if
// CHECK: tensor.generate
// CHECK: else
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SLICE]]
// CHECK: tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
// CHECK: return %[[RESULT]]

// -----

func.func @static_pad_tensor_inner_tiling(%input_tensor: tensor<7x9xf32>,
%pad_value: f32) -> tensor<15x16xf32> {
%0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : f32
} {__internal_linalg_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32>
return %0 : tensor<15x16xf32>
}
// CHECK-LABEL: func @static_pad_tensor_inner_tiling(
// CHECK-SAME: %[[IN:.*]]: tensor<7x9xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
// CHECK: %[[SWAP_RESULT:.*]] = scf.if
// CHECK: tensor.generate
// CHECK: else
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[IN]][0, {{.*}}] [7, {{.*}}] [1, 1]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[5, {{.*}}]
// CHECK: tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][%[[C0]], {{.*}}] [%[[C15]], {{.*}}] [1, 1]
// CHECK: return %[[RESULT]]

/// Rest of the tests only check that they dont fail.

// -----

func.func @dynamic_2d_pad_tensor_outer_tiling(%input_tensor: tensor<?x?xf32>,
%pad_value: f32) -> tensor<?x?xf32> {
%0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : f32
} {__internal_linalg_transform__ = "pad_outer_tiling"}: tensor<?x?xf32> to tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: func @dynamic_2d_pad_tensor_outer_tiling

// -----

func.func @static_pad_tensor_outer_tiling(%input_tensor: tensor<7x9xf32>,
%pad_value: f32) -> tensor<15x16xf32> {
%0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : f32
} {__internal_linalg_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32>
return %0 : tensor<15x16xf32>
}
// CHECK-LABEL: func @static_pad_tensor_outer_tiling
1 change: 1 addition & 0 deletions mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt
Expand Up @@ -12,4 +12,5 @@ add_mlir_library(MLIRTilingInterfaceTestPasses
MLIRSCFDialect
MLIRSCFTransforms
MLIRTensorDialect
MLIRTensorTilingInterfaceImpl
)
17 changes: 15 additions & 2 deletions mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Expand Up @@ -13,12 +13,14 @@

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -117,9 +119,10 @@ struct TestTilingInterfacePass
TestTilingInterfacePass(const TestTilingInterfacePass &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
tensor::TensorDialect>();
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect, tensor::TensorDialect>();
linalg::registerTilingInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
}
StringRef getArgument() const final { return "test-tiling-interface"; }
StringRef getDescription() const final {
Expand Down Expand Up @@ -184,6 +187,16 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
// 6. Tiling + interchange of an operation
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0});
// 7. Tiling for 2D pad tensor operations.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, patterns, "pad_2dtiling", {2, 3});
// 8. Tiling inner dimension of 2d pad tensor operations.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, patterns, "pad_inner_tiling", {0, 3});
// 9. Tiling inner dimension of 2d pad tensor operations.
addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
context, patterns, "pad_outer_tiling", {2, 3});

return;
}
if (testTileConsumerAndFuseProducer) {
Expand Down

0 comments on commit a235562

Please sign in to comment.