331 changes: 195 additions & 136 deletions mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,16 +376,20 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
// 3. Create the tiled loops.
LinalgOp res = op;
SmallVector<Value, 4> ivs(loopRanges.size());
GenericLoopNestRangeBuilder<LoopTy>(ivs, loopRanges)([&] {
SmallVector<Attribute, 4> iteratorTypes =
llvm::to_vector<4>(op.iterator_types().cast<ArrayAttr>().getValue());
if (!options.interchangeVector.empty())
applyPermutationToVector(iteratorTypes, options.interchangeVector);
GenerateLoopNest<LoopTy>::doit(ivs, loopRanges, iteratorTypes, [&] {
auto &b = ScopedContext::getBuilderRef();
auto loc = ScopedContext::getLocation();
SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());

// If we have to apply a permutation to the tiled loop nest, we have to
// reorder the induction variables This permutation is the right one
// assuming that loopRanges have previously been permuted by
// (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of
// that one: (d0,d1,d2)->(d2,d0,d1)
// (i,j,k)->(k,i,j) So this permutation should be the inversePermutation
// of that one: (d0,d1,d2)->(d2,d0,d1)
if (!options.interchangeVector.empty())
ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues);

Expand Down
19 changes: 13 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,19 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
return failure();
if (failed(promoteSubviewsPrecondition(op, options)))
return failure();
rewriter.updateRootInPlace(op, [&]() {
auto promotedOp = promoteSubViews(rewriter, op, options);
(void)promotedOp;
assert(promotedOp && "Unexpected pattern failure");
marker.replaceLinalgMarker(rewriter, op);
});

// TODO: We cannot use root update here. This pattern is creating other ops,
// so if the promotion fails, those need to be cleaned up, which doesnt seem
// to be happening here. So to fail properly, we should be cloning the op and
// deleting the previous op. This needs more investigation.
rewriter.startRootUpdate(op);
Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
if (!promotedOp) {
rewriter.cancelRootUpdate(op);
return op->emitError("subview promotion failed");
}
rewriter.finalizeRootUpdate(op);
marker.replaceLinalgMarker(rewriter, op);
return success();
}

Expand Down
89 changes: 89 additions & 0 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/SCF/EDSC/Builders.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -101,3 +102,91 @@ mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
}
return res;
}

bool mlir::linalg::isParallelIteratorType(Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>()) {
return strAttr.getValue() == getParallelIteratorTypeName();
}
return false;
}

bool mlir::linalg::isReductionIteratorType(Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>()) {
return strAttr.getValue() == getReductionIteratorTypeName();
}
return false;
}

bool mlir::linalg::isWindowIteratorType(Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>()) {
return strAttr.getValue() == getWindowIteratorTypeName();
}
return false;
}

/// Explicit instantiation of loop nest generator for different loop types.
template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
template struct mlir::linalg::GenerateLoopNest<AffineForOp>;

/// Specialization of loop nest generator for scf.parallel loops to handle
/// iterator types that are not parallel. These are generated as sequential
/// loops.
template <>
void mlir::linalg::GenerateLoopNest<scf::ForOp>::doit(
MutableArrayRef<Value> allIvs, ArrayRef<SubViewOp::Range> loopRanges,
ArrayRef<Attribute> iteratorTypes, std::function<void(void)> fun) {
edsc::GenericLoopNestRangeBuilder<scf::ForOp>(allIvs, loopRanges)(fun);
}

template <>
void mlir::linalg::GenerateLoopNest<AffineForOp>::doit(
MutableArrayRef<Value> allIvs, ArrayRef<SubViewOp::Range> loopRanges,
ArrayRef<Attribute> iteratorTypes, std::function<void(void)> fun) {
edsc::GenericLoopNestRangeBuilder<AffineForOp>(allIvs, loopRanges)(fun);
}

template <>
void mlir::linalg::GenerateLoopNest<scf::ParallelOp>::doit(
MutableArrayRef<Value> allIvs, ArrayRef<SubViewOp::Range> loopRanges,
ArrayRef<Attribute> iteratorTypes, std::function<void(void)> fun) {
// Check if there is nothing to do here. This is also the recursion
// termination.
if (loopRanges.empty())
return;
size_t nOuterPar = iteratorTypes.take_front(loopRanges.size())
.take_while(isParallelIteratorType)
.size();
if (nOuterPar == 0 && loopRanges.size() == 1)
// Generate the sequential for loop for the remaining non-parallel loop.
return GenerateLoopNest<scf::ForOp>::doit(allIvs, loopRanges, iteratorTypes,
fun);
if (nOuterPar == 0) {
// The immediate outer loop is not parallel. Generate a scf.for op for this
// loop, but there might be subsequent loops that are parallel. Use
// recursion to find those.
auto nestedFn = [&]() {
GenerateLoopNest<scf::ParallelOp>::doit(allIvs.drop_front(),
loopRanges.drop_front(),
iteratorTypes.drop_front(), fun);
};
return GenerateLoopNest<scf::ForOp>::doit(allIvs[0], loopRanges[0],
iteratorTypes[0], nestedFn);
}
if (nOuterPar == loopRanges.size()) {
// All loops are parallel, so generate the scf.parallel op.
return edsc::GenericLoopNestRangeBuilder<scf::ParallelOp>(allIvs,
loopRanges)(fun);
}
// Generate scf.parallel for the outer parallel loops. The next inner loop is
// sequential, but there might be more parallel loops after that. So recurse
// into the same method.
auto nestedFn = [&]() {
GenerateLoopNest<scf::ParallelOp>::doit(
allIvs.drop_front(nOuterPar), loopRanges.drop_front(nOuterPar),
iteratorTypes.drop_front(nOuterPar), fun);
};
return GenerateLoopNest<scf::ParallelOp>::doit(
allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar),
iteratorTypes.take_front(nOuterPar), nestedFn);
}
38 changes: 37 additions & 1 deletion mlir/test/Dialect/Linalg/parallel_loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,42 @@ func @lower_outer_parallel(%A: memref<?x?x?x?xf32>, %B: memref<?x?x?xf32>) {
// CHECK-DAG: %[[D3:.*]] = dim %{{.*}}, 3
// CHECK: scf.parallel (%[[IV0:.*]], %[[IV1:.*]]) = (%[[C0]], %[[C0]]) to (%[[D0]], %[[D1]]) step (%[[C1]], %[[C1]])
// CHECK: scf.for %[[IV2:.*]] = %[[C0]] to %[[D2]] step %[[C1]]
// CHECK: scf.for %[[IV3:.*]] = %[[C0]] to %[[D3]] step %[[C1]]
// CHECK: scf.parallel (%[[IV3:.*]]) = (%[[C0]]) to (%[[D3]]) step (%[[C1]])
// CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK: store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV1]], %[[IV3]]]

// -----

#accesses = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
]
#trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"],
indexing_maps = #accesses
}

func @lower_mixed_parallel(%A: memref<?x?x?x?x?x?xf32>, %B: memref<?x?x?x?xf32>) {
linalg.generic #trait %A, %B {
^bb0(%a: f32, %b: f32):
linalg.yield %a: f32
} : memref<?x?x?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
// CHECK-LABEL: @lower_mixed_parallel
// CHECK-DAG: %[[C0:.*]] = constant 0
// CHECK-DAG: %[[C1:.*]] = constant 1
// CHECK-DAG: %[[D0:.*]] = dim %{{.*}}, 0
// CHECK-DAG: %[[D1:.*]] = dim %{{.*}}, 1
// CHECK-DAG: %[[D2:.*]] = dim %{{.*}}, 2
// CHECK-DAG: %[[D3:.*]] = dim %{{.*}}, 3
// CHECK-DAG: %[[D4:.*]] = dim %{{.*}}, 4
// CHECK-DAG: %[[D5:.*]] = dim %{{.*}}, 5
// CHECK: scf.parallel (%[[IV0:.*]], %[[IV1:.*]]) = (%[[C0]], %[[C0]]) to (%[[D0]], %[[D1]]) step (%[[C1]], %[[C1]])
// CHECK: scf.for %[[IV2:.*]] = %[[C0]] to %[[D2]] step %[[C1]]
// CHECK: scf.parallel (%[[IV3:.*]], %[[IV4:.*]]) = (%[[C0]], %[[C0]]) to (%[[D3]], %[[D4]]) step (%[[C1]], %[[C1]])
// CHECK: scf.for %[[IV5:.*]] = %[[C0]] to %[[D5]] step %[[C1]]
// CHECK: load %{{.*}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]]]
// CHECK: store %{{.*}}, %{{.*}}[%[[IV0]], %[[IV2]], %[[IV4]], %[[IV5]]]
33 changes: 33 additions & 0 deletions mlir/test/Dialect/Linalg/promotion_options.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-promotion-options -split-input-file | FileCheck %s

func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
{
linalg.matmul(%a, %b, %c) {__internal_linalg_transform__ = "START"}
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
return
}

// CHECK: func @gemm
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-DAG: %[[C42:.+]] = constant 4.200000e+01 : f32
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: %[[T7:.+]] = subview %[[ARG0]]
// CHECK: %[[T12:.+]] = subview %[[ARG1]]
// CHECK: %[[T17:.+]] = subview %[[ARG2]]
// CHECK: %[[T18:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32, 3>
// CHECK: %[[T19:.+]] = subview %[[T18]]
// CHECK: %[[T20:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32, 3>
// CHECK: %[[T21:.+]] = subview %[[T20]]
// CHECK: linalg.fill(%[[T19]], %[[C42]])
// CHECK: linalg.copy(%[[T7]], %[[T19]])
// CHECK: linalg.fill(%[[T21]], %[[C42]])
// CHECK: linalg.copy(%[[T17]], %[[T21]])
// CHECK: linalg.matmul(%[[T19]], %[[T12]], %[[T21]])
// CHECK-NOT: linalg.fill
// CHECK: linalg.copy(%[[T21]], %[[T17]])
// CHECK: dealloc %[[T18]]
// CHECK: dealloc %[[T20]]
108 changes: 108 additions & 0 deletions mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2,4,8" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2" -split-input-file | FileCheck %s -check-prefix=TILE1
// RUN: mlir-opt %s -linalg-tile-to-parallel-loops="linalg-tile-sizes=2,4" -split-input-file | FileCheck %s -check-prefix=TILE2

func @gemm(%arg0 : memref<?x?xf32>,
%arg1 : memref<?x?xf32>,
%arg2 : memref<?x?xf32>)
{
linalg.matmul(%arg0, %arg1, %arg2)
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
return
}
// CHECK-LABEL: func @gemm
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C4:.*]] = constant 4 : index
// CHECK-DAG: %[[C8:.*]] = constant 8 : index
// CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) =
// CHECK-SAME: step (%[[C2]], %[[C4]])
// CHECK: scf.for %[[ARG5:.*]] =
// CHECK-SAME: step %[[C8]]
// CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]]
// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
// CHECK: linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]])

// TILE1-LABEL: func @gemm
// TILE1-DAG: %[[C2:.*]] = constant 2 : index
// TILE1: scf.parallel (%[[ARG3:.*]]) =
// TILE1-SAME: step (%[[C2]])
// TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE1: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE1-NOT: subview
// TILE1: linalg.matmul(%[[SV1]], %{{.*}}, %[[SV3]])

// TILE2-LABEL: func @gemm
// TILE2-DAG: %[[C2:.*]] = constant 2 : index
// TILE2-DAG: %[[C4:.*]] = constant 4 : index
// TILE2: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) =
// TILE2-SAME: step (%[[C2]], %[[C4]])
// TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE2: %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]]
// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
// TILE2: linalg.matmul(%[[SV1]], %[[SV2]], %[[SV3]])

// -----

#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1)>
#accesses = [#map0, #map1, #map2]
#trait = {
args_in = 2 : i64,
args_out = 1 : i64,
iterator_types = ["reduction", "parallel", "reduction"],
indexing_maps = #accesses
}

func @reduction(%arg0 : memref<?x?x?xf32>,
%arg1 : memref<?x?xf32>,
%arg2 : memref<?xf32>)
{
linalg.generic #trait %arg0, %arg1, %arg2 {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%0 = addf %arg3, %arg4 : f32
%1 = addf %0, %arg5 : f32
linalg.yield %1 : f32
} : memref<?x?x?xf32>, memref<?x?xf32>, memref<?xf32>
return
}

// CHECK-LABEL: func @reduction
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C4:.*]] = constant 4 : index
// CHECK-DAG: %[[C8:.*]] = constant 8 : index
// CHECK: scf.for %[[ARG3:.*]] =
// CHECK-SAME: step %[[C2]]
// CHECK: scf.parallel (%[[ARG4:.*]]) =
// CHECK-SAME: step (%[[C4]])
// CHECK: scf.for %[[ARG5:.*]] =
// CHECK-SAME: step %[[C8]]
// CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]], %[[ARG5]]]
// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]]
// CHECK: linalg.generic
// CHECK-SAME: %[[SV1]], %[[SV2]], %[[SV3]]

// TILE1-LABEL: func @reduction
// TILE1-DAG: %[[C2:.*]] = constant 2 : index
// TILE1: scf.for %[[ARG3:.*]] =
// TILE1-SAME: step %[[C2]]
// TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0, 0]
// TILE1: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE1-NOT: subview
// TILE1: linalg.generic
// TILE1-SAME: %[[SV1]], %[[SV2]], %{{.*}}

// TILE2-LABEL: func @reduction
// TILE2-DAG: %[[C2:.*]] = constant 2 : index
// TILE2-DAG: %[[C4:.*]] = constant 4 : index
// TILE2: scf.for %[[ARG3:.*]] =
// TILE2-SAME: step %[[C2]]
// TILE2: scf.parallel (%[[ARG4:.*]]) =
// TILE2-SAME: step (%[[C4]])
// TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]], 0]
// TILE2: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]]
// TILE2: linalg.generic
// TILE2-SAME: %[[SV1]], %[[SV2]], %[[SV3]]
25 changes: 24 additions & 1 deletion mlir/test/Dialect/Linalg/transform-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK-DAG: %[[c0:.*]] = constant 0 : index
// CHECK-DAG: %[[c5:.*]] = constant 5 : index
// CHECK-DAG: %[[c6:.*]] = constant 6 : index
// CHECK: scf.parallel {{.*}} step (%[[c5]], %[[c6]])
// CHECK: scf.parallel {{.*}} step (%[[c5]])
// CHECK: scf.for {{.*}} step %[[c6]]
// CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>

func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
Expand Down Expand Up @@ -364,3 +365,25 @@ func @aligned_promote_fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
// CHECK: linalg.fill(%[[v0]], {{%.*}}) : memref<?x?xf32>, f32
// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK: linalg.fill(%[[v0]], %[[cf]]) : memref<?x?xf32>, f32

func @tile_permute_parallel_loop(%arg0: memref<?x?xf32>,
%arg1: memref<?x?xf32>,
%arg2: memref<?x?xf32>) {
linalg.matmul(%arg0, %arg1, %arg2) {__internal_linalg_transform__ = "par__with_perm__"}
: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
return
}
// CHECK-LABEL: func @tile_permute_parallel_loop
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-DAG: %[[C16:.*]] = constant 16 : index
// CHECK-DAG: %[[C8:.*]] = constant 8 : index
// CHECK-DAG: %[[C4:.*]] = constant 4 : index
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[D0:.*]] = dim %[[ARG0]], 0
// CHECK-DAG: %[[D1:.*]] = dim %[[ARG0]], 1
// CHECK-DAG: %[[D2:.*]] = dim %[[ARG1]], 1
// CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D2]]) step (%[[C8]])
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[D1]] step %[[C4]]
// CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[D0]]) step (%[[C16]])
78 changes: 78 additions & 0 deletions mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ struct TestLinalgTransforms
"Test a fused pass that applies patterns from matmul to vectors via "
"2-d tiling"),
llvm::cl::init(false)};
Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
llvm::cl::desc("Test promotion options"),
llvm::cl::init(false)};
};
} // end anonymous namespace

Expand Down Expand Up @@ -101,6 +104,14 @@ static void applyPatterns(FuncOp funcOp) {
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
LinalgMarker({"__with_perm__"}, "L1__with_perm__"));

patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({16, 8, 4})
.setInterchange({1, 2, 0})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
LinalgMarker({"par__with_perm__"}, "after_par__with_perm__"));

//===--------------------------------------------------------------------===//
// Linalg to loops patterns.
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -189,10 +200,77 @@ static void fillL1TilingAndMatmulToVectorPatterns(
LinalgVectorizationPattern<CopyOp>>(context);
}

//===----------------------------------------------------------------------===//
// Test promotion callbacks
//===----------------------------------------------------------------------===//

// Allocation call back
static Optional<Value> allocCallBackFn(OpBuilder &b, SubViewOp subView,
ArrayRef<Value> boundingSubViewSize,
OperationFolder *folder) {
SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
return b
.create<AllocOp>(subView.getLoc(),
MemRefType::get(shape,
subView.getType().getElementType(),
/*affineMapComposition =*/{}, 3),
boundingSubViewSize)
.getResult();
}

// Deallocation callback
static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
b.create<DeallocOp>(buffer.getLoc(), buffer);
return success();
}

// Copy in call back
static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
bool isOutput) {
auto floatType = src.getType().cast<MemRefType>().getElementType();
if (!floatType.isa<FloatType>())
return failure();
if (!isOutput)
b.create<FillOp>(
src.getLoc(), dst,
b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)));
b.create<CopyOp>(src.getLoc(), src, dst);
return success();
}

void fillPromotionCallBackPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns.insert<LinalgTilingPattern<MatmulOp>>(
context, LinalgTilingOptions().setTileSizes({16, 16, 16}),
LinalgMarker({"START"}, "PROMOTE"));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
context,
LinalgPromotionOptions()
.setOperandsToPromote({0, 2})
.setUseFullTileBuffers({false, false})
.setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
.setCopyInOutFns(
[](OpBuilder &b, Value src, Value dst) -> LogicalResult {
copyCallBackFn(b, src, dst, false);
return success();
},
[](OpBuilder &b, Value src, Value dst) -> LogicalResult {
copyCallBackFn(b, src, dst, true);
return success();
}),
LinalgMarker({"PROMOTE"}));
}

/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnFunction() {
if (testPatterns) {
applyPatterns(getFunction());
return;
}
if (testPromotionOptions) {
OwningRewritePatternList patterns;
fillPromotionCallBackPatterns(&getContext(), patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
} else {
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
if (testMatmulToVectorPatterns1dTiling) {
Expand Down