Skip to content

Commit 40e85fc

Browse files
authored
[MLIR][XeGPU][VectorToXeGPU] Fix transfer_read/write cases with non-contiguous memrefs (#158126)
This PR fixes a case where a source memref in `vector.transfer_read/write` is not contiguous, which violates the `memref.collapse_shape` semantic that is used in the lowering. <details><summary>An example of a failing test</summary> ```mlir gpu.module @xevm_module { gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> { %c0 = arith.constant 0.0 : f16 %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> %0 = vector.transfer_read %subview[%off2, %off2], %c0 {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16> gpu.return %0 : vector<8xf16> } } ``` Fails with: ``` /home/user/llvm/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir:404:8: error: 'memref.collapse_shape' op invalid source layout map or collapsing non-contiguous dims %0 = vector.transfer_read %subview[%off2, %off2], %c0 ^ /home/user/llvm/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir:404:8: note: see current operation: %8 = "memref.collapse_shape"(%2) <{reassociation = [[0, 1]]}> : (memref<256x256xf16, strided<[4096, 1], offset: ?>>) -> memref<65536xf16> ``` </details> A suggestion was to replace `memref.collapse_shape` with `memref.extract_aligned_pointer_as_index` which is done in this PR. Since `extract_aligned_pointer` applied to a subview returns an original pointer without subview offsets, this PR also adds a logic to use an offset obtained from `memref.extract_strided_metadata` in `baseOffset` calculation in `computeOffsets`. --------- Signed-off-by: dchigarev <dmitry.chigarev@intel.com>
1 parent 2740e4b commit 40e85fc

File tree

3 files changed

+174
-86
lines changed

3 files changed

+174
-86
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 52 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,31 @@ static void adjustStridesForPermutation(AffineMap permMap,
180180
strides = applyPermutation(strides, perms64);
181181
}
182182

183-
// Computes memory strides for vector transfer operations, handling both
184-
// static and dynamic memrefs while applying permutation transformations
185-
// for XeGPU lowering.
186-
static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
187-
PatternRewriter &rewriter) {
183+
// Computes memory strides and a memref offset for vector transfer operations,
184+
// handling both static and dynamic memrefs while applying permutation
185+
// transformations for XeGPU lowering.
186+
static std::pair<SmallVector<Value>, Value>
187+
computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
188188
SmallVector<Value> strides;
189189
Value baseMemref = xferOp.getBase();
190190
AffineMap permMap = xferOp.getPermutationMap();
191191
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
192192

193193
Location loc = xferOp.getLoc();
194+
Value offsetVal = nullptr;
194195
if (memrefType.hasStaticShape()) {
195196
int64_t offset;
196197
SmallVector<int64_t> intStrides;
197198
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
198-
return {};
199+
return {{}, offsetVal};
199200
// Wrap static strides as MLIR values
200201
for (int64_t s : intStrides)
201202
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
202-
} else {
203+
if (!ShapedType::isDynamic(offset))
204+
offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
205+
}
206+
207+
if (strides.empty() || !offsetVal) {
203208
// For dynamic shape memref, use memref.extract_strided_metadata to get
204209
// stride values
205210
unsigned rank = memrefType.getRank();
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
220225

221226
auto meta = memref::ExtractStridedMetadataOp::create(
222227
rewriter, loc, resultTypes, baseMemref);
223-
strides.append(meta.getStrides().begin(), meta.getStrides().end());
228+
229+
if (strides.empty())
230+
strides.append(meta.getStrides().begin(), meta.getStrides().end());
231+
232+
if (!offsetVal)
233+
offsetVal = meta.getOffset();
224234
}
225235
// Adjust strides according to the permutation map (e.g., for transpose)
226236
adjustStridesForPermutation(permMap, strides);
227-
return strides;
237+
return {strides, offsetVal};
228238
}
229239

230240
// This function compute the vectors of localOffsets for scattered load/stores.
@@ -254,10 +264,10 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
254264
// %23 = arith.add %20, %21
255265
// %local_offsets = arith.add %22, %23
256266
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
257-
// %offsets = orig_offset + local_offsets
267+
// %offsets = memref_offset + orig_offset + local_offsets
258268
static Value computeOffsets(VectorTransferOpInterface xferOp,
259-
PatternRewriter &rewriter,
260-
ArrayRef<Value> strides) {
269+
PatternRewriter &rewriter, ArrayRef<Value> strides,
270+
Value baseOffset) {
261271
Location loc = xferOp.getLoc();
262272
VectorType vectorType = xferOp.getVectorType();
263273
SmallVector<Value> indices(xferOp.getIndices().begin(),
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
315325
arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
316326

317327
// Compute base offset from transfer read indices
318-
Value baseOffset = nullptr;
319-
if (!indices.empty()) {
320-
baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
321-
for (size_t i = 0; i < indices.size(); ++i) {
322-
Value strideVal = strides[i];
323-
Value offsetContrib =
324-
arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
325-
baseOffset =
326-
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
327-
}
328-
// Broadcast base offset to match vector shape
329-
Value bcastBase = vector::BroadcastOp::create(
330-
rewriter, loc, fullIndexVectorType, baseOffset);
331-
localOffsets =
332-
arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
328+
for (size_t i = 0; i < indices.size(); ++i) {
329+
Value strideVal = strides[i];
330+
Value offsetContrib =
331+
arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
332+
baseOffset =
333+
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
333334
}
335+
// Broadcast base offset to match vector shape
336+
Value bcastBase = vector::BroadcastOp::create(
337+
rewriter, loc, fullIndexVectorType, baseOffset);
338+
localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
334339
return localOffsets;
335340
}
336341

337-
// Collapse memref shape to 1D
338-
static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
339-
PatternRewriter &rewriter) {
342+
// Convert memref to i64 base pointer
343+
static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
344+
PatternRewriter &rewriter) {
340345
Location loc = xferOp.getLoc();
341-
342-
Value baseMemref = xferOp.getBase();
343-
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
344-
Type elementType = memrefType.getElementType();
345-
346-
// Compute the total number of elements in the memref
347-
MemRefType flatMemrefType;
348-
if (memrefType.hasStaticShape()) {
349-
auto totalElements = memrefType.getNumElements();
350-
flatMemrefType = MemRefType::get({totalElements}, elementType);
351-
} else {
352-
flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
353-
}
354-
355-
SmallVector<ReassociationIndices> reassociation;
356-
ReassociationIndices allDims =
357-
llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
358-
reassociation.push_back(allDims);
359-
360-
auto collapseOp = memref::CollapseShapeOp::create(
361-
rewriter, loc, flatMemrefType, baseMemref, reassociation);
362-
return collapseOp;
346+
auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
347+
rewriter, loc, xferOp.getBase())
348+
.getResult();
349+
return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
350+
indexPtr)
351+
.getResult();
363352
}
364353

365354
static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
372361
if (!memrefType)
373362
return rewriter.notifyMatchFailure(readOp, "Expected memref source");
374363

375-
SmallVector<Value> strides = computeStrides(readOp, rewriter);
376-
if (strides.empty())
364+
auto meta = computeMemrefMeta(readOp, rewriter);
365+
if (meta.first.empty())
377366
return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
378367

379-
Value localOffsets = computeOffsets(readOp, rewriter, strides);
368+
Value localOffsets =
369+
computeOffsets(readOp, rewriter, meta.first, meta.second);
380370

381-
Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
371+
Value flatMemref = memrefToIndexPtr(readOp, rewriter);
382372

383373
Value mask = vector::ConstantMaskOp::create(
384374
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
405395
if (!memrefType)
406396
return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
407397

408-
SmallVector<Value> strides = computeStrides(writeOp, rewriter);
398+
auto meta = computeMemrefMeta(writeOp, rewriter);
399+
if (meta.first.empty())
400+
return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
409401

410-
Value localOffsets = computeOffsets(writeOp, rewriter, strides);
402+
Value localOffsets =
403+
computeOffsets(writeOp, rewriter, meta.first, meta.second);
411404

412-
Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
405+
Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
413406

414407
Value mask = vector::ConstantMaskOp::create(
415408
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),

mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
2727
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
2828
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
2929
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
30-
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
31-
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
30+
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
31+
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
32+
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
3233

3334
}
3435

@@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
6263
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
6364
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
6465
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
65-
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
66-
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
66+
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
67+
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
68+
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
6769

6870
}
6971

@@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
124126
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
125127
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
126128
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
127-
// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
128-
// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
129+
// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
130+
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
131+
// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
129132

130133
}
131134

@@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
164167
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
165168
// LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
166169
// LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
167-
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
168-
// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
170+
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
171+
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
172+
// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
169173
// LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32>
170174
}
171175

@@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
195199
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
196200
// LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
197201
// LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
198-
// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
199-
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
202+
// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
203+
// LOAD-GATHER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
204+
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
200205

201206
}
202207

@@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
224229
// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
225230
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
226231
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
227-
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
228-
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
232+
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
233+
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
234+
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
229235
// LOAD-GATHER: return %[[VEC]]
230236
}
231237

@@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
254260
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
255261
// LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
256262
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
257-
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
258-
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
263+
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
264+
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
265+
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
259266

260267
}
261268

@@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
283290
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
284291
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
285292
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
286-
// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
287-
// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
293+
// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
294+
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
295+
// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
288296
}
289297

290298
// -----
@@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
396404
// LOAD-GATHER: vector.transfer_read
397405
}
398406

407+
// -----
408+
gpu.module @xevm_module {
409+
gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
410+
%c0 = arith.constant 0.0 : f16
411+
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
412+
%0 = vector.transfer_read %subview[%off2, %off2], %c0
413+
{in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
414+
gpu.return %0 : vector<8xf16>
415+
}
416+
417+
// LOAD-ND-LABEL: @load_from_subview(
418+
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
419+
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
420+
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
421+
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
422+
// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
423+
// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
424+
// LOAD-ND-SAME: boundary_check = false
425+
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
426+
// LOAD-ND: return %[[VEC]]
427+
428+
// LOAD-GATHER-LABEL: @load_from_subview(
429+
// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
430+
// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
431+
// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
432+
// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
433+
// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
434+
// LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
435+
// LOAD-GATHER: arith.muli {{.*}} : index
436+
// LOAD-GATHER: arith.addi %[[OFFSET]]{{.*}} : index
437+
// LOAD-GATHER: arith.addi {{.*}} : index
438+
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
439+
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
440+
// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
441+
// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
442+
// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
443+
}

0 commit comments

Comments
 (0)