@@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
27
27
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
28
28
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
29
29
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
30
- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
31
- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
30
+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
31
+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
32
+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
32
33
33
34
}
34
35
@@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
62
63
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
63
64
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
64
65
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
65
- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
66
- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
66
+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
67
+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
68
+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
67
69
68
70
}
69
71
@@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
124
126
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
125
127
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
126
128
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
127
- // LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
128
- // LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
129
+ // LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
130
+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
131
+ // LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
129
132
130
133
}
131
134
@@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
164
167
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
165
168
// LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
166
169
// LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
167
- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
168
- // LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
170
+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
171
+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
172
+ // LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
169
173
// LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32>
170
174
}
171
175
@@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
195
199
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
196
200
// LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
197
201
// LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
198
- // LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
199
- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
202
+ // LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
203
+ // LOAD-GATHER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
204
+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
200
205
201
206
}
202
207
@@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
224
229
// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
225
230
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
226
231
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
227
- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
228
- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
232
+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
233
+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
234
+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
229
235
// LOAD-GATHER: return %[[VEC]]
230
236
}
231
237
@@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
254
260
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
255
261
// LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
256
262
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
257
- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
258
- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
263
+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
264
+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
265
+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
259
266
260
267
}
261
268
@@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
283
290
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
284
291
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
285
292
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
286
- // LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
287
- // LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
293
+ // LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
294
+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
295
+ // LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
288
296
}
289
297
290
298
// -----
@@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
396
404
// LOAD-GATHER: vector.transfer_read
397
405
}
398
406
407
+ // -----
408
+ gpu.module @xevm_module {
409
+ gpu.func @load_from_subview (%source: memref <4096 x4096 xf16 >, %off1: index , %off2: index ) -> vector <8 xf16 > {
410
+ %c0 = arith.constant 0.0 : f16
411
+ %subview = memref.subview %source [%off1 , %off2 ] [256 , 256 ] [1 , 1 ] : memref <4096 x4096 xf16 > to memref <256 x256 xf16 , strided <[4096 , 1 ], offset : ?>>
412
+ %0 = vector.transfer_read %subview [%off2 , %off2 ], %c0
413
+ {in_bounds = [true ]} : memref <256 x256 xf16 , strided <[4096 , 1 ], offset : ?>>, vector <8 xf16 >
414
+ gpu.return %0 : vector <8 xf16 >
415
+ }
416
+
417
+ // LOAD-ND-LABEL: @load_from_subview(
418
+ // LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
419
+ // LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
420
+ // LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
421
+ // LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
422
+ // LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
423
+ // LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
424
+ // LOAD-ND-SAME: boundary_check = false
425
+ // LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
426
+ // LOAD-ND: return %[[VEC]]
427
+
428
+ // LOAD-GATHER-LABEL: @load_from_subview(
429
+ // LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
430
+ // LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
431
+ // LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
432
+ // LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
433
+ // LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
434
+ // LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
435
+ // LOAD-GATHER: arith.muli {{.*}} : index
436
+ // LOAD-GATHER: arith.addi %[[OFFSET]]{{.*}} : index
437
+ // LOAD-GATHER: arith.addi {{.*}} : index
438
+ // LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
439
+ // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
440
+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
441
+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
442
+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
443
+ }
0 commit comments