Skip to content

Commit 470c798

Browse files
committed
Address suggestion from Diego, extend to linalg.unpack
1 parent 9006029 commit 470c798

File tree

3 files changed

+36
-38
lines changed

3 files changed

+36
-38
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type);
228228
/// Note: all read offsets are set to 0.
229229
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
230230
ArrayRef<int64_t> inputVectorSizes,
231-
std::optional<Value> padValue,
231+
std::optional<Value> padValue = std::nullopt,
232232
bool useInBoundsInsteadOfMasking = false,
233233
ArrayRef<bool> inputScalableVecDims = {});
234234

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,7 +1770,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
17701770
rewriter.setInsertionPoint(packOp);
17711771

17721772
Location loc = packOp.getLoc();
1773-
auto padValue = packOp.getPaddingValue();
1773+
std::optional<Value> padValue = packOp.getPaddingValue()
1774+
? std::optional(packOp.getPaddingValue())
1775+
: std::nullopt;
17741776

17751777
// If the input vector sizes are not provided, then the vector sizes are
17761778
// determined by the result tensor shape. In case the vector sizes aren't
@@ -1793,8 +1795,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
17931795
for (auto [idx, size] : enumerate(innerTiles))
17941796
inputShape[innerDimsPos[idx]] *= size;
17951797
auto maskedRead = vector::createReadOrMaskedRead(
1796-
rewriter, loc, packOp.getSource(), inputShape,
1797-
padValue ? std::optional<Value>(padValue) : std::nullopt,
1798+
rewriter, loc, packOp.getSource(), inputShape, padValue,
17981799
useInBoundsInsteadOfMasking,
17991800
/*inputScalableVecSizes=*/{});
18001801

@@ -1932,11 +1933,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19321933
}
19331934

19341935
// -- Generate the read operation --
1935-
auto padValue = arith::ConstantOp::create(
1936-
rewriter, loc,
1937-
rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
19381936
Value readResult = vector::createReadOrMaskedRead(
1939-
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1937+
rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
19401938
useInBoundsInsteadOfMasking, readScalableVectorFlags);
19411939

19421940
// -- Generate the transpose operation --

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,16 +1068,16 @@ module attributes {transform.with_named_sequence} {
10681068
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
10691069
// CHECK-SAME: %[[SRC:.*]]: tensor<?x?x16x2xf32>
10701070
func.func @test_vectorize_dynamic_shapes_unpack_scalable_vec(%dest: tensor<?x?xf32>, %src: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
1071-
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
1072-
// CHECK: %[[C01:.*]] = arith.constant 0
1073-
// CHECK: %[[C02:.*]] = arith.constant 0
1071+
// CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
1072+
// CHECK-DAG: %[[C01:.*]] = arith.constant 0
1073+
// CHECK-DAG: %[[C02:.*]] = arith.constant 0
10741074
// CHECK: %[[DIM4:.*]] = tensor.dim %[[SRC]], %[[C02]] : tensor<?x?x16x2xf32>
10751075
// CHECK: %[[CNST14:.*]] = arith.constant 1
10761076
// CHECK: %[[DIM6:.*]] = tensor.dim %[[SRC]], %[[CNST14]] : tensor<?x?x16x2xf32>
10771077
// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
10781078
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
10791079
// CHECK: %[[MASK_READ:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x[16]x2xi1>
1080-
// CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x[16]x2xf32> } : vector<2x1x[16]x2xi1> -> vector<2x1x[16]x2xf32>
1080+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {{.*}} vector.transfer_read %{{.*}} %[[PAD]] {{.*}}: tensor<?x?x16x2xf32>, vector<2x1x[16]x2xf32> } : vector<2x1x[16]x2xi1> -> vector<2x1x[16]x2xf32>
10811081
// CHECK: %[[TR:.*]] = vector.transpose %[[READ]], [0, 3, 1, 2] : vector<2x1x[16]x2xf32> to vector<2x2x1x[16]xf32>
10821082
// CHECK: %[[SC:.*]] = vector.shape_cast %[[TR]] : vector<2x2x1x[16]xf32> to vector<4x[16]xf32>
10831083
// CHECK: %[[MASK_WRITE:.*]] = vector.create_mask {{.*}} : vector<4x[16]xi1>
@@ -1100,17 +1100,17 @@ module attributes {transform.with_named_sequence} {
11001100
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
11011101
// CHECK-SAME: %[[SRC:.*]]: tensor<?x?x?x2xf32>
11021102
func.func @test_vectorize_dynamic_shapes_unpack_scalable_vec_and_tile_size(%dest: tensor<?x?xf32>, %src: tensor<?x?x?x2xf32>) -> tensor<?x?xf32> {
1103-
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
1104-
// CHECK: %[[C01:.*]] = arith.constant 0
1105-
// CHECK: %[[C02:.*]] = arith.constant 0
1103+
// CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
1104+
// CHECK-DAG: %[[C01:.*]] = arith.constant 0
1105+
// CHECK-DAG: %[[C02:.*]] = arith.constant 0
11061106
// CHECK: %[[DIM4:.*]] = tensor.dim %[[SRC]], %[[C02]] : tensor<?x?x?x2xf32>
11071107
// CHECK: %[[C1_2:.*]] = arith.constant 1
11081108
// CHECK: %[[DIM6:.*]] = tensor.dim %[[SRC]], %[[C1_2]] : tensor<?x?x?x2xf32>
11091109
// CHECK: %[[C2:.*]] = arith.constant 2 : index
11101110
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[SRC]], %[[C2]] : tensor<?x?x?x2xf32>
11111111
// CHECK: %[[C2_1:.*]] = arith.constant 2 : index
11121112
// CHECK: %[[MASK_READ:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[DIM_2]], %[[C2_1]] : vector<2x1x[16]x2xi1>
1113-
// CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x?x2xf32>, vector<2x1x[16]x2xf32> } : vector<2x1x[16]x2xi1> -> vector<2x1x[16]x2xf32>
1113+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {{.*}} vector.transfer_read %{{.*}} %[[PAD]] {{.*}}: tensor<?x?x?x2xf32>, vector<2x1x[16]x2xf32> } : vector<2x1x[16]x2xi1> -> vector<2x1x[16]x2xf32>
11141114
// CHECK: %[[TR:.*]] = vector.transpose %[[READ]], [0, 3, 1, 2] : vector<2x1x[16]x2xf32> to vector<2x2x1x[16]xf32>
11151115
// CHECK: %[[SC:.*]] = vector.shape_cast %[[TR]] : vector<2x2x1x[16]xf32> to vector<4x[16]xf32>
11161116
// CHECK: %[[MASK_WRITE:.*]] = vector.create_mask {{.*}} : vector<4x[16]xi1>
@@ -1138,14 +1138,14 @@ module attributes {transform.with_named_sequence} {
11381138
// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
11391139
// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
11401140
func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
1141-
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1142-
// CHECK: %[[C0:.*]]= arith.constant 0 : index
1143-
// CHECK: %[[C8:.*]] = arith.constant 8 : index
1144-
// CHECK: %[[C80:.*]] = arith.constant 8 : index
1145-
// CHECK: %[[C32:.*]] = arith.constant 32 : index
1146-
// CHECK: %[[C16:.*]] = arith.constant 16 : index
1141+
// CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
1142+
// CHECK-DAG: %[[C0:.*]]= arith.constant 0 : index
1143+
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
1144+
// CHECK-DAG: %[[C80:.*]] = arith.constant 8 : index
1145+
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
1146+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
11471147
// CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<16x8x32x16xi1>
1148-
// CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] { vector.transfer_read %[[SRC]]{{.*}}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
1148+
// CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] { vector.transfer_read %[[SRC]]{{.*}} %[[PAD]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
11491149
// CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
11501150
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
11511151
// CHECK: %[[C01:.*]] = arith.constant 0 : index
@@ -1171,9 +1171,9 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
11711171
// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
11721172
// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
11731173
func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
1174-
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1175-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1176-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
1174+
// CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
1175+
// CHECK-AD: %[[C0:.*]] = arith.constant 0 : index
1176+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}} %[[PAD]] {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
11771177
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
11781178
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
11791179
// CHECK: %[[C00:.*]] = arith.constant 0 : index
@@ -1196,9 +1196,9 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
11961196
// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
11971197
// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
11981198
func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
1199-
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1200-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1201-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
1199+
// CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
1200+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1201+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}} %[[PAD]] {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
12021202
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
12031203
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
12041204
// CHECK: %[[C00:.*]] = arith.constant 0 : index
@@ -1221,9 +1221,9 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
12211221
// CHECK-SAME: %[[SRC:.*]]: tensor<8x8x32x16xf32>
12221222
// CHECK-SAME: %[[DEST:.*]]: tensor<256x128xf32>
12231223
func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
1224-
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1225-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1226-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
1224+
// CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
1225+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1226+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}} %[[PAD]] {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
12271227
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
12281228
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
12291229
// CHECK: %[[C00:.*]] = arith.constant 0 : index
@@ -1246,9 +1246,9 @@ func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>,
12461246
// CHECK-SAME: %[[SRC:.*]]: tensor<8x4x16x16xf32>
12471247
// CHECK-SAME: %[[DEST:.*]]: tensor<64x127xf32>
12481248
func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> {
1249-
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1250-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1251-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
1249+
// CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
1250+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1251+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}} %[[PAD]] {{.*}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
12521252
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
12531253
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x16x8x16xf32> to vector<64x128xf32>
12541254
// CHECK: %[[C00:.*]] = arith.constant 0 : index
@@ -1275,9 +1275,9 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
12751275
%0 = linalg.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32>
12761276
return %0 : tensor<7x16xf32>
12771277
}
1278-
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1279-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1280-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
1278+
// CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
1279+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1280+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}} %[[PAD]] {{.*}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
12811281
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 0, 2] : vector<4x7x4xf32> to vector<7x4x4xf32>
12821282
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<7x4x4xf32> to vector<7x16xf32>
12831283
// CHECK: %[[C00:.*]] = arith.constant 0 : index

0 commit comments

Comments
 (0)