@@ -100,6 +100,18 @@ func.func @scalable_gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x[3]xin
100100 return %0 : vector <2 x[3 ]xf32 >
101101}
102102
103+ // CHECK-LABEL: @scalable_gather_memref_2d_with_alignment
104+ // CHECK: vector.gather
105+ // CHECK-SAME: {alignment = 8 : i64}
106+ // CHECK: vector.gather
107+ // CHECK-SAME: {alignment = 8 : i64}
108+ func.func @scalable_gather_memref_2d_with_alignment (%base: memref <?x?xf32 >, %v: vector <2 x[3 ]xindex >, %mask: vector <2 x[3 ]xi1 >, %pass_thru: vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 > {
109+ %c0 = arith.constant 0 : index
110+ %c1 = arith.constant 1 : index
111+ %0 = vector.gather %base [%c0 , %c1 ][%v ], %mask , %pass_thru {alignment = 8 } : memref <?x?xf32 >, vector <2 x[3 ]xindex >, vector <2 x[3 ]xi1 >, vector <2 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
112+ return %0 : vector <2 x[3 ]xf32 >
113+ }
114+
103115// CHECK-LABEL: @scalable_gather_cant_unroll
104116// CHECK-NOT: extract
105117// CHECK: vector.gather
@@ -234,7 +246,7 @@ func.func @strided_gather(%base : memref<100x3xf32>,
234246 %mask = arith.constant dense <true > : vector <4 xi1 >
235247 %pass_thru = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
236248 // Gather of a strided MemRef
237- %res = vector.gather %subview [%c0 ] [%idxs ], %mask , %pass_thru : memref <100 xf32 , strided <[3 ]>>, vector <4 xindex >, vector <4 xi1 >, vector <4 xf32 > into vector <4 xf32 >
249+ %res = vector.gather %subview [%c0 ] [%idxs ], %mask , %pass_thru { alignment = 8 } : memref <100 xf32 , strided <[3 ]>>, vector <4 xindex >, vector <4 xi1 >, vector <4 xf32 > into vector <4 xf32 >
238250 return %res : vector <4 xf32 >
239251}
240252// CHECK-LABEL: func.func @strided_gather(
@@ -250,22 +262,22 @@ func.func @strided_gather(%base : memref<100x3xf32>,
250262
251263// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
252264// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
253- // CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
265+ // CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] {alignment = 8 : i64} : memref<300xf32>, vector<1xf32>
254266// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
255267
256268// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
257269// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
258- // CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
270+ // CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] {alignment = 8 : i64} : memref<300xf32>, vector<1xf32>
259271// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
260272
261273// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
262274// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
263- // CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
275+ // CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] {alignment = 8 : i64} : memref<300xf32>, vector<1xf32>
264276// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
265277
266278// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
267279// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
268- // CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
280+ // CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] {alignment = 8 : i64} : memref<300xf32>, vector<1xf32>
269281// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
270282
271283// CHECK-LABEL: @scalable_gather_1d
0 commit comments