11// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
22
3+ ///----------------------------------------------------------------------------------------
4+ /// Tests for linalg.dot
5+ ///----------------------------------------------------------------------------------------
6+
37// CHECK-LABEL: contraction_dot
48func.func @contraction_dot (%A: memref <1584 xf32 >, %B: memref <1584 xf32 >, %C: memref <f32 >) {
59
@@ -20,6 +24,10 @@ module attributes {transform.with_named_sequence} {
2024
2125// -----
2226
27+ ///----------------------------------------------------------------------------------------
28+ /// Tests for linalg.matvec
29+ ///----------------------------------------------------------------------------------------
30+
2331// CHECK-LABEL: contraction_matvec
2432func.func @contraction_matvec (%A: memref <1584 x1584 xf32 >, %B: memref <1584 xf32 >, %C: memref <1584 xf32 >) {
2533
@@ -41,6 +49,10 @@ module attributes {transform.with_named_sequence} {
4149
4250// -----
4351
52+ ///----------------------------------------------------------------------------------------
53+ /// Tests for linalg.matmul
54+ ///----------------------------------------------------------------------------------------
55+
4456// CHECK-LABEL: contraction_matmul
4557func.func @contraction_matmul (%A: memref <1584 x1584 xf32 >, %B: memref <1584 x1584 xf32 >, %C: memref <1584 x1584 xf32 >) {
4658// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584xf32>
@@ -138,6 +150,10 @@ module attributes {transform.with_named_sequence} {
138150
139151// -----
140152
153+ ///----------------------------------------------------------------------------------------
154+ /// Tests for linalg.batch_matmul
155+ ///----------------------------------------------------------------------------------------
156+
141157// CHECK-LABEL: contraction_batch_matmul
142158func.func @contraction_batch_matmul (%A: memref <1584 x1584 x1584 xf32 >, %B: memref <1584 x1584 x1584 xf32 >, %C: memref <1584 x1584 x1584 xf32 >) {
143159// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32>
@@ -159,6 +175,10 @@ module attributes {transform.with_named_sequence} {
159175
160176// -----
161177
178+ ///----------------------------------------------------------------------------------------
179+ /// Tests for linalg.cantract
180+ ///----------------------------------------------------------------------------------------
181+
162182// CHECK-LABEL: @matmul_as_contract
163183// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32>
164184// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32>
@@ -220,6 +240,10 @@ module attributes {transform.with_named_sequence} {
220240
221241// -----
222242
243+ ///----------------------------------------------------------------------------------------
244+ /// Tests for linalg.fill
245+ ///----------------------------------------------------------------------------------------
246+
223247// CHECK-LABEL: func @test_vectorize_fill
224248func.func @test_vectorize_fill (%A : memref <8 x16 xf32 >, %arg0 : f32 ) {
225249 // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
@@ -259,70 +283,14 @@ module attributes {transform.with_named_sequence} {
259283
260284// -----
261285
262- // CHECK-LABEL: func @test_vectorize_copy
263- func.func @test_vectorize_copy (%A : memref <8 x16 xf32 >, %B : memref <8 x16 xf32 >) {
264- // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
265- // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
266- memref.copy %A , %B : memref <8 x16 xf32 > to memref <8 x16 xf32 >
267- return
268- }
286+ ///----------------------------------------------------------------------------------------
287+ /// Tests for linalg.pack
288+ ///----------------------------------------------------------------------------------------
269289
270- module attributes {transform.with_named_sequence } {
271- transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
272- %0 = transform.structured.match ops {[" memref.copy" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
273- %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
274- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op ) -> !transform.any_op
275- transform.yield
276- }
277- }
290+ // Note, see a similar test in:
291+ // * vectorization.mlir.
278292
279- // -----
280-
281- // CHECK-LABEL: func @test_vectorize_copy_0d
282- func.func @test_vectorize_copy_0d (%A : memref <f32 >, %B : memref <f32 >) {
283- // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
284- // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
285- // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
286- // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
287- // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
288- memref.copy %A , %B : memref <f32 > to memref <f32 >
289- return
290- }
291-
292- module attributes {transform.with_named_sequence } {
293- transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
294- %0 = transform.structured.match ops {[" memref.copy" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
295- %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
296- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op ) -> !transform.any_op
297- transform.yield
298- }
299- }
300-
301- // -----
302-
303- // CHECK-LABEL: func @test_vectorize_copy_complex
304- // CHECK-NOT: vector<
305- func.func @test_vectorize_copy_complex (%A : memref <8 x16 xcomplex <f32 >>, %B : memref <8 x16 xcomplex <f32 >>) {
306- memref.copy %A , %B : memref <8 x16 xcomplex <f32 >> to memref <8 x16 xcomplex <f32 >>
307- return
308- }
309-
310- module attributes {transform.with_named_sequence } {
311- transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
312- %0 = transform.structured.match ops {[" memref.copy" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
313- %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
314- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op ) -> !transform.any_op
315- transform.yield
316- }
317- }
318-
319- // -----
320-
321- // Input identical as the test in vectorization.mlir. Output is different -
322- // vector sizes are inferred (rather than user-specified) and hence _no_
323- // masking was used.
324-
325- func.func @test_vectorize_pack (%arg0: tensor <32 x8 x16 xf32 >, %arg1: tensor <4 x1 x32 x16 x2 xf32 >) -> tensor <4 x1 x32 x16 x2 xf32 > {
293+ func.func @pack_no_padding (%arg0: tensor <32 x8 x16 xf32 >, %arg1: tensor <4 x1 x32 x16 x2 xf32 >) -> tensor <4 x1 x32 x16 x2 xf32 > {
326294 %pack = linalg.pack %arg0 outer_dims_perm = [1 , 2 , 0 ] inner_dims_pos = [2 , 1 ] inner_tiles = [16 , 2 ] into %arg1 : tensor <32 x8 x16 xf32 > -> tensor <4 x1 x32 x16 x2 xf32 >
327295 return %pack : tensor <4 x1 x32 x16 x2 xf32 >
328296}
@@ -336,7 +304,7 @@ module attributes {transform.with_named_sequence} {
336304 }
337305}
338306
339- // CHECK-LABEL: func.func @test_vectorize_pack (
307+ // CHECK-LABEL: func.func @pack_no_padding (
340308// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>,
341309// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
342310// CHECK-DAG: %[[VAL_2:.*]] = ub.poison : f32
@@ -349,13 +317,16 @@ module attributes {transform.with_named_sequence} {
349317
350318// -----
351319
352- func.func @test_vectorize_padded_pack (%arg0: tensor <32 x7 x15 xf32 >, %arg1: tensor <32 x4 x1 x16 x2 xf32 >) -> tensor <32 x4 x1 x16 x2 xf32 > {
320+ // Note, see a similar test in:
321+ // * vectorization.mlir.
322+
323+ func.func @pack_with_padding (%arg0: tensor <32 x7 x15 xf32 >, %arg1: tensor <32 x4 x1 x16 x2 xf32 >) -> tensor <32 x4 x1 x16 x2 xf32 > {
353324 %pad = arith.constant 0.000000e+00 : f32
354325 %pack = linalg.pack %arg0 padding_value (%pad : f32 ) inner_dims_pos = [2 , 1 ] inner_tiles = [16 , 2 ] into %arg1 : tensor <32 x7 x15 xf32 > -> tensor <32 x4 x1 x16 x2 xf32 >
355326 return %pack : tensor <32 x4 x1 x16 x2 xf32 >
356327}
357328
358- // CHECK-LABEL: func.func @test_vectorize_padded_pack (
329+ // CHECK-LABEL: func.func @pack_with_padding (
359330// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x7x15xf32>,
360331// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
361332// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
@@ -377,6 +348,10 @@ module attributes {transform.with_named_sequence} {
377348
378349// -----
379350
351+ ///----------------------------------------------------------------------------------------
352+ /// Tests for linalg.map
353+ ///----------------------------------------------------------------------------------------
354+
380355func.func @vectorize_map (%arg0: memref <64 xf32 >,
381356 %arg1: memref <64 xf32 >, %arg2: memref <64 xf32 >) {
382357 linalg.map ins (%arg0 , %arg1 : memref <64 xf32 >, memref <64 xf32 >)
@@ -403,6 +378,10 @@ module attributes {transform.with_named_sequence} {
403378
404379// -----
405380
381+ ///----------------------------------------------------------------------------------------
382+ /// Tests for linalg.transpose
383+ ///----------------------------------------------------------------------------------------
384+
406385func.func @vectorize_transpose (%arg0: memref <16 x32 x64 xf32 >,
407386 %arg1: memref <32 x64 x16 xf32 >) {
408387 linalg.transpose ins (%arg0 : memref <16 x32 x64 xf32 >)
@@ -424,6 +403,10 @@ module attributes {transform.with_named_sequence} {
424403
425404// -----
426405
406+ ///----------------------------------------------------------------------------------------
407+ /// Tests for linalg.reduce
408+ ///----------------------------------------------------------------------------------------
409+
427410func.func @vectorize_reduce (%arg0: memref <16 x32 x64 xf32 >,
428411 %arg1: memref <16 x64 xf32 >) {
429412 linalg.reduce ins (%arg0 : memref <16 x32 x64 xf32 >)
@@ -449,6 +432,10 @@ module attributes {transform.with_named_sequence} {
449432
450433// -----
451434
435+ ///----------------------------------------------------------------------------------------
436+ /// Tests for linalg.generic
437+ ///----------------------------------------------------------------------------------------
438+
452439#matmul_trait = {
453440 indexing_maps = [
454441 affine_map <(m , n , k ) -> (m , k )>,
@@ -1446,6 +1433,8 @@ module attributes {transform.with_named_sequence} {
14461433
14471434// -----
14481435
1436+ // TODO: Two Linalg Ops in one tests - either split or document "why".
1437+
14491438// CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)>
14501439
14511440// CHECK-LABEL: func @fused_broadcast_red_2d
@@ -1896,3 +1885,65 @@ module attributes {transform.with_named_sequence} {
18961885 }
18971886}
18981887
1888+ // -----
1889+
1890+ ///----------------------------------------------------------------------------------------
1891+ /// Tests for memref.copy
1892+ ///----------------------------------------------------------------------------------------
1893+
1894+ // CHECK-LABEL: func @test_vectorize_copy
1895+ func.func @test_vectorize_copy (%A : memref <8 x16 xf32 >, %B : memref <8 x16 xf32 >) {
1896+ // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
1897+ // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
1898+ memref.copy %A , %B : memref <8 x16 xf32 > to memref <8 x16 xf32 >
1899+ return
1900+ }
1901+
1902+ module attributes {transform.with_named_sequence } {
1903+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1904+ %0 = transform.structured.match ops {[" memref.copy" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1905+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
1906+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op ) -> !transform.any_op
1907+ transform.yield
1908+ }
1909+ }
1910+
1911+ // -----
1912+
1913+ // CHECK-LABEL: func @test_vectorize_copy_0d
1914+ func.func @test_vectorize_copy_0d (%A : memref <f32 >, %B : memref <f32 >) {
1915+ // CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
1916+ // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
1917+ // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
1918+ // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
1919+ // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
1920+ memref.copy %A , %B : memref <f32 > to memref <f32 >
1921+ return
1922+ }
1923+
1924+ module attributes {transform.with_named_sequence } {
1925+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1926+ %0 = transform.structured.match ops {[" memref.copy" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1927+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
1928+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op ) -> !transform.any_op
1929+ transform.yield
1930+ }
1931+ }
1932+
1933+ // -----
1934+
1935+ // CHECK-LABEL: func @test_vectorize_copy_complex
1936+ // CHECK-NOT: vector<
1937+ func.func @test_vectorize_copy_complex (%A : memref <8 x16 xcomplex <f32 >>, %B : memref <8 x16 xcomplex <f32 >>) {
1938+ memref.copy %A , %B : memref <8 x16 xcomplex <f32 >> to memref <8 x16 xcomplex <f32 >>
1939+ return
1940+ }
1941+
1942+ module attributes {transform.with_named_sequence } {
1943+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1944+ %0 = transform.structured.match ops {[" memref.copy" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1945+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
1946+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op ) -> !transform.any_op
1947+ transform.yield
1948+ }
1949+ }
0 commit comments