Skip to content

Commit 628f5c9

Browse files
committed
[mlir] Add a roundtrip test for 'linalg.tiled_loop' on buffers.
https://llvm.discourse.group/t/rfc-add-linalg-tileop/2833 Differential Revision: https://reviews.llvm.org/D98900
1 parent 74ffe8d commit 628f5c9

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1863,7 +1863,9 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
18631863
if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
18641864
result.operands))
18651865
return failure();
1866-
result.addTypes(outputTypes);
1866+
for (Type outputType : outputTypes)
1867+
if (outputType.isa<RankedTensorType>())
1868+
result.addTypes(outputType);
18671869
}
18681870

18691871
// Parse attributes.

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
// Test that we can lower all the way to LLVM without crashing, don't check results here.
77
// DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
88

9+
// CHECK-DAG: #[[$id_2d:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
10+
// CHECK-DAG: #[[$id_1d:.*]] = affine_map<(d0, d1, d2) -> (d1)>
911
// CHECK-DAG: #[[$permute_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
1012
// CHECK-DAG: #[[$permute_1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
1113
// CHECK-DAG: #[[$reshape5D01:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
@@ -881,3 +883,61 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
881883
}
882884
// CHECK-LABEL: func @tiled_loop_reduction
883885
// CHECK: iterators[
886+
887+
// -----
888+
889+
#trait_6 = {
890+
indexing_maps = [
891+
#id_3d,
892+
#id_2d,
893+
#id_1d,
894+
#id_1d
895+
],
896+
iterator_types = ["reduction", "parallel", "reduction"]
897+
}
898+
#map_1 = affine_map<(d0, d1, d2)[s0] -> (d0 * 768 + s0 + d1 * 32 + d2)>
899+
#map_2 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
900+
#map_3 = affine_map<(d0)[s0] -> (d0 + s0)>
901+
902+
func @tiled_loop_on_buffers(%input_3d: memref<16x24x32xf32>,
903+
%input_2d: memref<16x32xf32>,
904+
%input_1d: memref<24xf32>,
905+
%output: memref<24xf32>) {
906+
%c0 = constant 0 : index
907+
%c1 = constant 1 : index
908+
%c2 = constant 2 : index
909+
%c4 = constant 4 : index
910+
%c8 = constant 8 : index
911+
%X = memref.dim %input_3d, %c0 : memref<16x24x32xf32>
912+
%Y = memref.dim %input_3d, %c1 : memref<16x24x32xf32>
913+
%Z = memref.dim %input_3d, %c2 : memref<16x24x32xf32>
914+
linalg.tiled_loop (%i, %j, %k) = (%c0, %c0, %c0)
915+
to (%X, %Y, %Z) step (%c2, %c4, %c8)
916+
ins(%input_3d, %input_2d: memref<16x24x32xf32>, memref<16x32xf32>)
917+
outs( %output: memref<24xf32>)
918+
iterators["reduction", "parallel", "reduction"] {
919+
%sub_3d = memref.subview %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1]
920+
: memref<16x24x32xf32> to memref<2x4x8xf32, #map_1>
921+
%sub_2d = memref.subview %input_2d[%i, %k][2, 8][1, 1]
922+
: memref<16x32xf32> to memref<2x8xf32, #map_2>
923+
%sub_1d = memref.subview %input_1d[%j] [4] [1]
924+
: memref<24xf32> to memref<4xf32, #map_3>
925+
%sub_out = memref.subview %output[%j] [4] [1]
926+
: memref<24xf32> to memref<4xf32, #map_3>
927+
linalg.generic #trait_6
928+
ins(%sub_3d, %sub_2d, %sub_1d
929+
: memref<2x4x8xf32, #map_1>,
930+
memref<2x8xf32, #map_2>,
931+
memref<4xf32, #map_3>)
932+
outs(%sub_out : memref<4xf32, #map_3>) {
933+
^bb0(%i3d: f32, %i2d: f32, %i1d: f32, %o: f32):
934+
%0 = addf %i3d, %i2d : f32
935+
%1 = addf %0, %i1d : f32
936+
linalg.yield %1 : f32
937+
}
938+
linalg.yield
939+
}
940+
return
941+
}
942+
// CHECK-LABEL: func @tiled_loop_on_buffers
943+
// CHECK: iterators[

0 commit comments

Comments
 (0)