@@ -180,26 +180,31 @@ static void adjustStridesForPermutation(AffineMap permMap,
180
180
strides = applyPermutation (strides, perms64);
181
181
}
182
182
183
- // Computes memory strides for vector transfer operations, handling both
184
- // static and dynamic memrefs while applying permutation transformations
185
- // for XeGPU lowering.
186
- static SmallVector<Value> computeStrides (VectorTransferOpInterface xferOp,
187
- PatternRewriter &rewriter) {
183
+ // Computes memory strides and a memref offset for vector transfer operations,
184
+ // handling both static and dynamic memrefs while applying permutation
185
+ // transformations for XeGPU lowering.
186
+ static std::pair< SmallVector<Value>, Value>
187
+ computeMemrefMeta (VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
188
188
SmallVector<Value> strides;
189
189
Value baseMemref = xferOp.getBase ();
190
190
AffineMap permMap = xferOp.getPermutationMap ();
191
191
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType ());
192
192
193
193
Location loc = xferOp.getLoc ();
194
+ Value offsetVal = nullptr ;
194
195
if (memrefType.hasStaticShape ()) {
195
196
int64_t offset;
196
197
SmallVector<int64_t > intStrides;
197
198
if (failed (memrefType.getStridesAndOffset (intStrides, offset)))
198
- return {};
199
+ return {{}, offsetVal };
199
200
// Wrap static strides as MLIR values
200
201
for (int64_t s : intStrides)
201
202
strides.push_back (arith::ConstantIndexOp::create (rewriter, loc, s));
202
- } else {
203
+ if (!ShapedType::isDynamic (offset))
204
+ offsetVal = arith::ConstantIndexOp::create (rewriter, loc, offset);
205
+ }
206
+
207
+ if (strides.empty () || !offsetVal) {
203
208
// For dynamic shape memref, use memref.extract_strided_metadata to get
204
209
// stride values
205
210
unsigned rank = memrefType.getRank ();
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
220
225
221
226
auto meta = memref::ExtractStridedMetadataOp::create (
222
227
rewriter, loc, resultTypes, baseMemref);
223
- strides.append (meta.getStrides ().begin (), meta.getStrides ().end ());
228
+
229
+ if (strides.empty ())
230
+ strides.append (meta.getStrides ().begin (), meta.getStrides ().end ());
231
+
232
+ if (!offsetVal)
233
+ offsetVal = meta.getOffset ();
224
234
}
225
235
// Adjust strides according to the permutation map (e.g., for transpose)
226
236
adjustStridesForPermutation (permMap, strides);
227
- return strides;
237
+ return { strides, offsetVal} ;
228
238
}
229
239
230
240
// This function compute the vectors of localOffsets for scattered load/stores.
@@ -254,10 +264,10 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
254
264
// %23 = arith.add %20, %21
255
265
// %local_offsets = arith.add %22, %23
256
266
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
257
- // %offsets = orig_offset + local_offsets
267
+ // %offsets = memref_offset + orig_offset + local_offsets
258
268
static Value computeOffsets (VectorTransferOpInterface xferOp,
259
- PatternRewriter &rewriter,
260
- ArrayRef< Value> strides ) {
269
+ PatternRewriter &rewriter, ArrayRef<Value> strides,
270
+ Value baseOffset ) {
261
271
Location loc = xferOp.getLoc ();
262
272
VectorType vectorType = xferOp.getVectorType ();
263
273
SmallVector<Value> indices (xferOp.getIndices ().begin (),
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
315
325
arith::AddIOp::create (rewriter, loc, localOffsets, broadcasted[i]);
316
326
317
327
// Compute base offset from transfer read indices
318
- Value baseOffset = nullptr ;
319
- if (!indices.empty ()) {
320
- baseOffset = arith::ConstantIndexOp::create (rewriter, loc, 0 );
321
- for (size_t i = 0 ; i < indices.size (); ++i) {
322
- Value strideVal = strides[i];
323
- Value offsetContrib =
324
- arith::MulIOp::create (rewriter, loc, indices[i], strideVal);
325
- baseOffset =
326
- arith::AddIOp::create (rewriter, loc, baseOffset, offsetContrib);
327
- }
328
- // Broadcast base offset to match vector shape
329
- Value bcastBase = vector::BroadcastOp::create (
330
- rewriter, loc, fullIndexVectorType, baseOffset);
331
- localOffsets =
332
- arith::AddIOp::create (rewriter, loc, bcastBase, localOffsets);
328
+ for (size_t i = 0 ; i < indices.size (); ++i) {
329
+ Value strideVal = strides[i];
330
+ Value offsetContrib =
331
+ arith::MulIOp::create (rewriter, loc, indices[i], strideVal);
332
+ baseOffset =
333
+ arith::AddIOp::create (rewriter, loc, baseOffset, offsetContrib);
333
334
}
335
+ // Broadcast base offset to match vector shape
336
+ Value bcastBase = vector::BroadcastOp::create (
337
+ rewriter, loc, fullIndexVectorType, baseOffset);
338
+ localOffsets = arith::AddIOp::create (rewriter, loc, bcastBase, localOffsets);
334
339
return localOffsets;
335
340
}
336
341
337
- // Collapse memref shape to 1D
338
- static Value collapseMemrefTo1D (VectorTransferOpInterface xferOp,
339
- PatternRewriter &rewriter) {
342
+ // Convert memref to i64 base pointer
343
+ static Value memrefToIndexPtr (VectorTransferOpInterface xferOp,
344
+ PatternRewriter &rewriter) {
340
345
Location loc = xferOp.getLoc ();
341
-
342
- Value baseMemref = xferOp.getBase ();
343
- MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType ());
344
- Type elementType = memrefType.getElementType ();
345
-
346
- // Compute the total number of elements in the memref
347
- MemRefType flatMemrefType;
348
- if (memrefType.hasStaticShape ()) {
349
- auto totalElements = memrefType.getNumElements ();
350
- flatMemrefType = MemRefType::get ({totalElements}, elementType);
351
- } else {
352
- flatMemrefType = MemRefType::get ({ShapedType::kDynamic }, elementType);
353
- }
354
-
355
- SmallVector<ReassociationIndices> reassociation;
356
- ReassociationIndices allDims =
357
- llvm::to_vector (llvm::seq<int64_t >(0 , memrefType.getRank ()));
358
- reassociation.push_back (allDims);
359
-
360
- auto collapseOp = memref::CollapseShapeOp::create (
361
- rewriter, loc, flatMemrefType, baseMemref, reassociation);
362
- return collapseOp;
346
+ auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create (
347
+ rewriter, loc, xferOp.getBase ())
348
+ .getResult ();
349
+ return arith::IndexCastOp::create (rewriter, loc, rewriter.getI64Type (),
350
+ indexPtr)
351
+ .getResult ();
363
352
}
364
353
365
354
static LogicalResult lowerToScatteredLoadOp (vector::TransferReadOp readOp,
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
372
361
if (!memrefType)
373
362
return rewriter.notifyMatchFailure (readOp, " Expected memref source" );
374
363
375
- SmallVector<Value> strides = computeStrides (readOp, rewriter);
376
- if (strides .empty ())
364
+ auto meta = computeMemrefMeta (readOp, rewriter);
365
+ if (meta. first .empty ())
377
366
return rewriter.notifyMatchFailure (readOp, " Failed to compute strides" );
378
367
379
- Value localOffsets = computeOffsets (readOp, rewriter, strides);
368
+ Value localOffsets =
369
+ computeOffsets (readOp, rewriter, meta.first , meta.second );
380
370
381
- Value flatMemref = collapseMemrefTo1D (readOp, rewriter);
371
+ Value flatMemref = memrefToIndexPtr (readOp, rewriter);
382
372
383
373
Value mask = vector::ConstantMaskOp::create (
384
374
rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
405
395
if (!memrefType)
406
396
return rewriter.notifyMatchFailure (writeOp, " Expected memref source" );
407
397
408
- SmallVector<Value> strides = computeStrides (writeOp, rewriter);
398
+ auto meta = computeMemrefMeta (writeOp, rewriter);
399
+ if (meta.first .empty ())
400
+ return rewriter.notifyMatchFailure (writeOp, " Failed to compute strides" );
409
401
410
- Value localOffsets = computeOffsets (writeOp, rewriter, strides);
402
+ Value localOffsets =
403
+ computeOffsets (writeOp, rewriter, meta.first , meta.second );
411
404
412
- Value flatMemref = collapseMemrefTo1D (writeOp, rewriter);
405
+ Value flatMemref = memrefToIndexPtr (writeOp, rewriter);
413
406
414
407
Value mask = vector::ConstantMaskOp::create (
415
408
rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
0 commit comments