85 changes: 46 additions & 39 deletions mlir/test/Transforms/buffer-placement-preparation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ func @func_signature_conversion(%arg0: tensor<4x8xf32>) {
// CHECK-LABEL: func @memref_in_function_results
func @memref_in_function_results(%arg0: tensor<5xf32>, %arg1: memref<10xf32>) -> (tensor<5xf32>, memref<10xf32>, memref<15xf32>) {
%0 = alloc() : memref<15xf32>
%1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
%1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%arg0 : tensor<5xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: tensor<5xf32> -> tensor<5xf32>
} -> tensor<5xf32>
return %1, %arg1, %0 : tensor<5xf32>, memref<10xf32>, memref<15xf32>
}
// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[RESULT:.*]]: memref<5xf32>)
Expand Down Expand Up @@ -97,23 +98,25 @@ func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %ar

// CHECK-LABEL: func @compute_allocs_position_simple
func @compute_allocs_position_simple(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{
%0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
%0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%arg0 : tensor<2xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: tensor<2xf32> -> tensor<2xf32>
%1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 {
} -> tensor<2xf32>
%1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%0 : tensor<2xf32>) {
^bb0(%gen2_arg0: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
}: tensor<2xf32> -> tensor<2xf32>
} -> tensor<2xf32>
return %1 : tensor<2xf32>
}
// CHECK: (%{{.*}}: {{.*}}, %[[ARG0:.*]]: memref<2xf32>,
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[FIRST_ALLOC]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[FIRST_ALLOC]]
// CHECK: %[[SECOND_ALLOC:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[FIRST_ALLOC]]{{.*}} outs(%[[SECOND_ALLOC]]

// -----

Expand All @@ -123,78 +126,86 @@ func @compute_allocs_position_simple(%cond: i1, %arg0: tensor<2xf32>) -> tensor<

// CHECK-LABEL: func @compute_allocs_position
func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{
%0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
%0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%arg0 : tensor<2xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: tensor<2xf32> -> tensor<2xf32>
%1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 {
} -> tensor<2xf32>
%1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%0 : tensor<2xf32>) {
^bb0(%gen2_arg0: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
}: tensor<2xf32> -> tensor<2xf32>
} -> tensor<2xf32>
cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>),
^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>)
^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>):
%2 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
%2 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%arg0 : tensor<2xf32>) {
^bb0(%gen3_arg0: f32):
%tmp3 = exp %gen3_arg0 : f32
linalg.yield %tmp3 : f32
}: tensor<2xf32> -> tensor<2xf32>
%3 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %2 {
} -> tensor<2xf32>
%3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%2 : tensor<2xf32>) {
^bb0(%gen4_arg0: f32):
%tmp4 = exp %gen4_arg0 : f32
linalg.yield %tmp4 : f32
}: tensor<2xf32> -> tensor<2xf32>
} -> tensor<2xf32>
br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>)
^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>):
%4 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
%4 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%arg0 : tensor<2xf32>) {
^bb0(%gen5_arg0: f32):
%tmp5 = exp %gen5_arg0 : f32
linalg.yield %tmp5 : f32
}: tensor<2xf32> -> tensor<2xf32>
%5 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %4 {
} -> tensor<2xf32>
%5 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%4 : tensor<2xf32>) {
^bb0(%gen6_arg0: f32):
%tmp6 = exp %gen6_arg0 : f32
linalg.yield %tmp6 : f32
}: tensor<2xf32> -> tensor<2xf32>
} -> tensor<2xf32>
br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>)
^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>):
%6 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
%6 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%arg0 : tensor<2xf32>) {
^bb0(%gen7_arg0: f32):
%tmp7 = exp %gen7_arg0 : f32
linalg.yield %tmp7 : f32
}: tensor<2xf32> -> tensor<2xf32>
%7 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %6 {
} -> tensor<2xf32>
%7 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%6 : tensor<2xf32>) {
^bb0(%gen8_arg0: f32):
%tmp8 = exp %gen8_arg0 : f32
linalg.yield %tmp8 : f32
}: tensor<2xf32> -> tensor<2xf32>
} -> tensor<2xf32>
return %7 : tensor<2xf32>
}
// CHECK: (%{{.*}}: {{.*}}, %[[ARG0:.*]]: memref<2xf32>,
// CHECK-NEXT: %[[ALLOC0:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC0]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC0]]
// CHECK: %[[ALLOC1:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC0]], %[[ALLOC1]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC0]]{{.*}} outs(%[[ALLOC1]]
// CHECK: cond_br %{{.*}}, ^[[BB0:.*]]({{.*}}), ^[[BB1:.*]](
// CHECK-NEXT: ^[[BB0]]
// CHECK-NEXT: %[[ALLOC2:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC2]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC2]]
// CHECK: %[[ALLOC3:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC2]], %[[ALLOC3]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC2]]{{.*}} outs(%[[ALLOC3]]
// CHECK: br ^[[EXIT:.*]]({{.*}})
// CHECK-NEXT: ^[[BB1]]
// CHECK-NEXT: %[[ALLOC4:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC4]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC4]]
// CHECK: %[[ALLOC5:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC4]], %[[ALLOC5]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC4]]{{.*}} outs(%[[ALLOC5]]
// CHECK: br ^[[EXIT]]
// CHECK-NEXT: ^[[EXIT]]
// CHECK-NEXT: %[[ALLOC6:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC6]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ARG0]]{{.*}} outs(%[[ALLOC6]]
// CHECK: %[[ALLOC7:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC6]], %[[ALLOC7]]
// CHECK-NEXT: linalg.generic {{.*}} ins(%[[ALLOC6]]{{.*}} outs(%[[ALLOC7]]

// -----

Expand All @@ -211,16 +222,12 @@ func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{

// CHECK-LABEL: func @callee
func @callee(%arg1: tensor<5xf32>) -> tensor<5xf32> {
%0 = linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]
} %arg1 {
%0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%arg1 : tensor<5xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: tensor<5xf32> -> tensor<5xf32>
} -> tensor<5xf32>
return %0 : tensor<5xf32>
}
// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>)
Expand Down
218 changes: 114 additions & 104 deletions mlir/test/Transforms/buffer-placement.mlir

Large diffs are not rendered by default.

34 changes: 17 additions & 17 deletions mlir/test/Transforms/copy-removal.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ func @test_with_temp_usage_after_copy() -> memref<5xf32> {
%temp = alloc() : memref<5xf32>
linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]} %temp, %res {
iterator_types = ["parallel"]}
ins(%temp : memref<5xf32>)
outs(%res : memref<5xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: memref<5xf32>, memref<5xf32>
}
dealloc %ret : memref<5xf32>
return %temp : memref<5xf32>
}
Expand Down Expand Up @@ -231,18 +231,18 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>)
// CHECK-NOT: %{{.*}} = alloc
%temp = alloc() : memref<2xf32>
// CHECK-NEXT: linalg.generic
// CHECK-SAME: %[[ARG0]], %[[RES]]
// CHECK-SAME: ins(%[[ARG0]]{{.*}}outs(%[[RES]]
// CHECK-NOT: linalg.copy(%{{.*}}, %[[RES]])
// CHECK-NOT: dealloc %{{.*}}
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]} %arg0, %temp {
iterator_types = ["parallel"]}
ins(%arg0 : memref<2xf32>)
outs(%temp : memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
}: memref<2xf32>, memref<2xf32>
}
"linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> ()
dealloc %temp : memref<2xf32>
// CHECK: return
Expand All @@ -261,23 +261,23 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
%to = alloc() : memref<2xf32>
%temp = alloc() : memref<2xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]} %arg0, %temp {
iterator_types = ["parallel"]}
ins(%arg0 : memref<2xf32>)
outs(%temp : memref<2xf32>) {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: memref<2xf32>, memref<2xf32>
}
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]} %arg0, %to {
iterator_types = ["parallel"]}
ins(%arg0 : memref<2xf32>)
outs(%to : memref<2xf32>) {
^bb0(%gen2_arg0: f32, %gen2_arg1: f32):
%tmp2 = exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
}: memref<2xf32>, memref<2xf32>
}
// CHECK: linalg.copy
"linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> ()
dealloc %temp : memref<2xf32>
Expand Down
73 changes: 46 additions & 27 deletions mlir/test/lib/Transforms/TestBufferPlacement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ struct TestBufferPlacementPreparationPass

/// Converts tensor-type generic linalg operations to memref ones using
/// buffer assignment.
/// TODO: Avoid the copy-pasta by exposing the pattern from BufferPlacement.h
/// This is limited by not wanting BufferPlacement to depend on Linalg. Fixing
/// this probably requires an OpConversionPattern over generic Operation*. For
/// now only RewritePattern but not ConversionPattern allow this.

class GenericOpConverter
: public BufferAssignmentOpConversionPattern<linalg::GenericOp> {
public:
Expand All @@ -48,58 +53,72 @@ struct TestBufferPlacementPreparationPass
LogicalResult
matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
linalg::GenericOpAdaptor adaptor(operands,
op.getOperation()->getAttrDictionary());

// TODO: support ops with reduction.
if (!op.init_tensors().empty())
return failure();

// All inputs need to be turned into buffers first. Until then, bail out.
if (llvm::any_of(adaptor.inputs(), [](Value in) {
return !in.getType().isa<MemRefType>();
}))
return failure();

Location loc = op.getLoc();
ResultRange results = op.getOperation()->getResults();
SmallVector<Value, 2> newArgs, newResults;
newArgs.reserve(operands.size() + results.size());
newArgs.append(operands.begin(), operands.end());
newResults.reserve(results.size());
SmallVector<Value, 2> outputBuffers, newOutputBuffers;
outputBuffers.assign(adaptor.output_buffers().begin(),
adaptor.output_buffers().end());
newOutputBuffers.reserve(op.getNumOutputs());
newOutputBuffers.append(adaptor.output_buffers().begin(),
adaptor.output_buffers().end());

// Update all types to memref types.
for (auto result : results) {
ShapedType type = result.getType().cast<ShapedType>();
assert(type && "Generic operations with non-shaped typed results are "
"not currently supported.");
for (Type t : op.getResultTypes()) {
auto type = t.cast<ShapedType>();
if (!type.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "dynamic shapes not currently supported");
auto memrefType =
MemRefType::get(type.getShape(), type.getElementType());
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
newArgs.push_back(alloc);
newResults.push_back(alloc);
newOutputBuffers.push_back(alloc);
}

// Generate a new linalg operation that works on buffers.
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
op.iterator_types(), op.docAttr(), op.library_callAttr(),
op.symbol_sourceAttr());
loc,
/*resultTensorTypes=*/ArrayRef<Type>{},
/*inputs=*/adaptor.inputs(),
/*outputBuffers=*/newOutputBuffers,
/*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(),
op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr());

// Create a new block in the region of the new Generic Op.
Block &oldBlock = op.getRegion().front();
Region &newRegion = linalgOp.region();
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
oldBlock.getArgumentTypes());

// Map the old block arguments to the new ones.
BlockAndValueMapping mapping;
mapping.map(oldBlock.getArguments(), newBlock->getArguments());

// Add the result arguments to the new block.
for (auto result : newResults)
newBlock->addArgument(
result.getType().cast<ShapedType>().getElementType());
for (Value v : newOutputBuffers)
newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());

// Clone the body of the old block to the new block.
BlockAndValueMapping mapping;
for (unsigned i = 0; i < oldBlock.getNumArguments(); i++)
mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i));

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(newBlock);
for (auto &op : oldBlock.getOperations())
rewriter.clone(op, mapping);
for (auto &op : oldBlock.getOperations()) {
Operation *clonedOp = rewriter.clone(op, mapping);
mapping.map(op.getResults(), clonedOp->getResults());
}

// Replace the results of the old Generic Op with the results of the new
// one.
rewriter.replaceOp(op, newResults);
// Replace the results of the old op with the new output buffers.
rewriter.replaceOp(op, newOutputBuffers);
return success();
}
};
Expand Down
2 changes: 1 addition & 1 deletion mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyMemRef>:$output_buffers,
Variadic<AnyRankedTensor>:$init_tensors);
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let builders = [ OpBuilder<
Expand Down