Skip to content

Commit

Permalink
[mlir] Add support for parallel dim *after* reduction dim in split re…
Browse files Browse the repository at this point in the history
…duction

Previously, splitReduction transformation added the split parallel dimension
*before* the reduction dimension, leading to tiling for reduction. This
commit creates an option to create the parallel dimension *after* the
reduction dimension, allowing us to transform the op into vertical reduction
with SIMD parallelism.

Reviewed By: ThomasRaoux, dcaballe

Differential Revision: https://reviews.llvm.org/D134764
  • Loading branch information
vmurali authored and dcaballe committed Sep 29, 2022
1 parent 5bdf22e commit 146c3ea
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
remaining reduction after splitting).
- insert_split_dimension: the dimension in the temporary tensor into
which the new parallel dimension is inserted.
- inner_parallel: specifies whether the parallel dimension is before or
after the reduction dimension in the splitting op.
- use_scaling_algorithm: whether to use a scaling based formulation that
does not create an ExpandShapeOp (default: do not use scaling)
- use_alloc: whether to use an alloc op to allocate the temporary
Expand Down Expand Up @@ -587,6 +589,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<I64Attr, "{}">:$split_factor,
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension,
UnitAttr:$inner_parallel,
UnitAttr:$use_scaling_algorithm,
UnitAttr:$use_alloc);
let results = (outs PDL_Operation:$init_or_alloc_op,
Expand Down
20 changes: 14 additions & 6 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1363,14 +1363,22 @@ class TilingPatterns<OpTy, OpTypes...> {
}
};

/// Function signature to control reduction splitting. This returns a pair
/// containing a ratio and a dimension index. The ratio is used to split the
/// reduction dimension. The dimension index is used to control where the extra
/// dimension is added to the intermediate tensor shape. If the ratio value is
/// less or equal to 1 then nothing will be done.
/// Split Reduction options.
struct SplitReductionOptions {
// Ratio used to split the reduction dimension. If the ratio is <= 1, nothing
// will be done.
int64_t ratio = 0;
// Index where the extra dimension is added to the intermediate tensor shape.
unsigned index = 0;
// If the inner dimension after splitting is parallel or reduction.
bool innerParallel = false;
};

/// Function signature to control reduction splitting. This returns
/// `SplitReductionOptions`.
// TODO: don't use unsigned unless doing bit manipulation.
using ControlSplitReductionFn =
std::function<std::pair<int64_t, unsigned>(LinalgOp op)>;
std::function<SplitReductionOptions(LinalgOp op)>;

/// Patterns to apply `splitReduction` below.
void populateSplitReductionPattern(
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,8 +1001,9 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
ControlSplitReductionFn splitFn = [&](LinalgOp) {
return std::pair<int64_t, unsigned>(getSplitFactor(),
getInsertSplitDimension());
return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
unsigned(getInsertSplitDimension()),
/*innerParallel=*/false};
};
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
Expand Down
69 changes: 49 additions & 20 deletions mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);

std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
int64_t ratio = control.first;
unsigned insertSplitDimension = control.second;
SplitReductionOptions control = controlSplitReductionFn(op);
int64_t ratio = control.ratio;
unsigned insertSplitDimension = control.index;
if (ratio <= 1)
return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");

Expand Down Expand Up @@ -125,17 +125,32 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
unsigned dim = map.getDimPosition(idx);
if (reductionDim == dim) {
newShape.push_back(ratio);
newShape.push_back(op.getShape(operand)[idx] / ratio);
if (control.innerParallel) {
newShape.push_back(op.getShape(operand)[idx] / ratio);
newShape.push_back(ratio);
} else {
newShape.push_back(ratio);
newShape.push_back(op.getShape(operand)[idx] / ratio);
}
reassociation.push_back({index++, index++});
exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
exprs.push_back(
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
if (control.innerParallel) {
exprs.push_back(b.getAffineDimExpr(reductionDim));
exprs.push_back(b.getAffineDimExpr(reductionDim + 1));
} else {
exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
exprs.push_back(
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
}
continue;
}
newShape.push_back(op.getShape(operand)[idx]);
exprs.push_back(
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
if (control.innerParallel) {
exprs.push_back(
b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1));
} else {
exprs.push_back(
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
}
reassociation.push_back({index++});
}
newMaps.push_back(
Expand Down Expand Up @@ -163,14 +178,23 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
if (idx == insertSplitDimension) {
newOutputShape.push_back(ratio);
outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
if (control.innerParallel) {
outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1));
} else {
outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
}
continue;
}
unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1;
newOutputShape.push_back(oldShape[oldDim]);
unsigned dim = oldOutputMap.getDimPosition(oldDim);
outputExpr.push_back(
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
unsigned oldIdx = idx < insertSplitDimension ? idx : idx - 1;
newOutputShape.push_back(oldShape[oldIdx]);
unsigned dim = oldOutputMap.getDimPosition(oldIdx);
if (control.innerParallel) {
outputExpr.push_back(
b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1));
} else {
outputExpr.push_back(
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
}
}
Value initOrAllocTensor;
if (useAlloc) {
Expand All @@ -192,9 +216,11 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
op.getContext()));
SmallVector<StringRef> newIteratorTypes;
for (auto &it : llvm::enumerate(op.iterator_types())) {
if (insertSplitDimension == it.index())
if (insertSplitDimension == it.index() && !control.innerParallel)
newIteratorTypes.push_back(getParallelIteratorTypeName());
newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
if (insertSplitDimension == it.index() && control.innerParallel)
newIteratorTypes.push_back(getParallelIteratorTypeName());
}
// Create the new op matching the original op with an extra parallel
// dimension.
Expand Down Expand Up @@ -275,9 +301,12 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
b.setInsertionPoint(op);

// Matcher part, enforce preconditions.
std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
int64_t splitFactor = control.first;
unsigned insertSplitDimension = control.second;
SplitReductionOptions control = controlSplitReductionFn(op);
if (control.innerParallel)
return b.notifyMatchFailure(op, "innerParallel not supported");

int64_t splitFactor = control.ratio;
unsigned insertSplitDimension = control.index;
if (splitFactor <= 1)
return b.notifyMatchFailure(op, "split factor needs to be greater than 1");

Expand Down
74 changes: 74 additions & 0 deletions mlir/test/Dialect/Linalg/split_reduction.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction-inner-parallel -split-input-file | FileCheck %s --check-prefix=INNERPARALLELCHECK

func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
Expand Down Expand Up @@ -31,6 +32,31 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten
// CHECK: } -> tensor<16x32xf32>
// CHECK: return %[[R]] : tensor<16x32xf32>

// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// INNERPARALLELCHECK-LABEL: @matmul_split
// INNERPARALLELCHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
// INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x64x4xf32>
// INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<64x4x32xf32>
// INNERPARALLELCHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32>
// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// INNERPARALLELCHECK-SAME: , iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
// INNERPARALLELCHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x64x4xf32>, tensor<64x4x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) {
// INNERPARALLELCHECK: arith.mulf
// INNERPARALLELCHECK: arith.addf
// INNERPARALLELCHECK: linalg.yield
// INNERPARALLELCHECK: } -> tensor<16x32x4xf32>
// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]],
// INNERPARALLELCHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) {
// INNERPARALLELCHECK: arith.addf
// INNERPARALLELCHECK: linalg.yield %{{.*}} : f32
// INNERPARALLELCHECK: } -> tensor<16x32xf32>
// INNERPARALLELCHECK: return %[[R]] : tensor<16x32xf32>

// -----

func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: tensor<f32>) -> tensor<f32> {
Expand Down Expand Up @@ -73,6 +99,30 @@ func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: ten
// CHECK: } -> tensor<f32>
// CHECK: return %[[R]] : tensor<f32>

// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
//INNERPARALLELCHECK-LABEL: @generic_split_1d
// INNERPARALLELCHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
// INNERPARALLELCHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<8x4xf32>
// INNERPARALLELCHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32>
// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic
// INNERPARALLELCHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
// INNERPARALLELCHECK: iterator_types = ["reduction", "parallel"]} ins(%[[I1]], %{{.*}} : tensor<8x4xf32>, tensor<f32>) outs(%[[F]] : tensor<4xf32>) {
// INNERPARALLELCHECK: arith.subf
// INNERPARALLELCHECK: math.exp
// INNERPARALLELCHECK: arith.mulf
// INNERPARALLELCHECK: linalg.yield
// INNERPARALLELCHECK: } -> tensor<4xf32>
// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
// INNERPARALLELCHECK: arith.mulf
// INNERPARALLELCHECK: linalg.yield
// INNERPARALLELCHECK: } -> tensor<f32>
// INNERPARALLELCHECK: return %[[R]] : tensor<f32>

// -----

func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
Expand Down Expand Up @@ -117,3 +167,27 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
// CHECK: linalg.yield
// CHECK: } -> tensor<5x2xf32>
// CHECK: return %[[R]] : tensor<5x2xf32>

// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)>
// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>
// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// INNERPARALLELCHECK-LABEL: func @generic_split_3d
// INNERPARALLELCHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32
// INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32>
// INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32>
// INNERPARALLELCHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32>
// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
// INNERPARALLELCHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
// INNERPARALLELCHECK: arith.addf
// INNERPARALLELCHECK: arith.maxf
// INNERPARALLELCHECK: linalg.yield
// INNERPARALLELCHECK: } -> tensor<5x2x4xf32>
// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
// INNERPARALLELCHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
// INNERPARALLELCHECK: arith.maxf
// INNERPARALLELCHECK: linalg.yield
// INNERPARALLELCHECK: } -> tensor<5x2xf32>
// INNERPARALLELCHECK: return %[[R]] : tensor<5x2xf32>
22 changes: 21 additions & 1 deletion mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ struct TestLinalgTransforms
*this, "test-split-reduction",
llvm::cl::desc("Test split reduction transformation"),
llvm::cl::init(false)};
Option<bool> testSplitReductionInnerParallel{
*this, "test-split-reduction-inner-parallel",
llvm::cl::desc("Test split reduction with inner parallel transformation"),
llvm::cl::init(false)};
ListOption<int64_t> peeledLoops{
*this, "peeled-loops",
llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
Expand Down Expand Up @@ -499,7 +503,21 @@ static void applySplitReduction(func::FuncOp funcOp) {
patterns,
[](LinalgOp op) {
unsigned insertDimIndex = op.getNumLoops() - 1;
return std::make_pair(4, insertDimIndex);
return SplitReductionOptions{4, insertDimIndex, false};
},
LinalgTransformationFilter(
ArrayRef<StringAttr>{},
StringAttr::get(funcOp.getContext(), "SPLIT")));
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

static void applySplitReductionInnerParallel(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
linalg::populateSplitReductionPattern(
patterns,
[](LinalgOp op) {
unsigned insertDimIndex = op.getNumLoops() - 1;
return SplitReductionOptions{4, insertDimIndex, true};
},
LinalgTransformationFilter(
ArrayRef<StringAttr>{},
Expand Down Expand Up @@ -560,6 +578,8 @@ void TestLinalgTransforms::runOnOperation() {
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
if (testSplitReduction)
return applySplitReduction(getOperation());
if (testSplitReductionInnerParallel)
return applySplitReductionInnerParallel(getOperation());
if (testBubbleUpExtractSliceOpPattern)
return applyBubbleUpExtractSliceOpPattern(getOperation());
if (testSwapExtractSliceWithFill)
Expand Down

0 comments on commit 146c3ea

Please sign in to comment.