Skip to content

Conversation

@silee2
Copy link
Contributor

@silee2 silee2 commented Nov 18, 2025

adaptor already lowers memref to base address.
Conversion patterns should use it instead of generating code to get base address from memref.

@llvmbot
Copy link
Member

llvmbot commented Nov 18, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Sang Ik Lee (silee2)

Changes

adaptor already lowers memref to base address.
Conversion patterns should use it instead of generating code to get base address from memref.


Full diff: https://github.com/llvm/llvm-project/pull/168610.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+3-4)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+3-3)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir (+4-4)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir (+9-3)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 705298f497d20..f16a8e239ee74 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -194,10 +194,9 @@ class CreateNdDescToXeVMPattern
       if (!sourceMemrefTy.hasRank()) {
         return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
       }
-      baseAddr =
-          memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
-      // Cast index to i64.
-      baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+      // Access adaptor after failure check to avoid rolling back generated code
+      // for materialization cast.
+      baseAddr = adaptor.getSource();
     } else {
       baseAddr = adaptor.getSource();
       if (baseAddr.getType() != i64Ty) {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 109312218afae..8b87b791c9fd3 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -7,6 +7,8 @@ gpu.module @create_nd_tdesc {
   // CHECK-SAME: %[[DYN:.*]]: memref<?x?xf16>) kernel {
   gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
   %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
+        // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
+        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -27,9 +29,9 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
         %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
 
-        // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
         // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
         // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+        // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
         // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
         // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
         // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
@@ -52,8 +54,6 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[C16:.*]] = arith.constant 16 : index
         %BLOCK_DMODEL = arith.constant 16 : index
         // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
-        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
         // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
         // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
         // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
index 7b4ad9ec2df03..aebec7fc27c78 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
@@ -9,11 +9,13 @@ gpu.module @load_store_check {
 
         // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32>
         %srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32>
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index
+        // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32>
         %dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32>
+        // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index
+        // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
 
-        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index
-        // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32>
         // CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64
         // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1>
@@ -22,8 +24,6 @@ gpu.module @load_store_check {
         %loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
             : !xegpu.tensor_desc<32xf32> -> vector<2xf32>
 
-        // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index
-        // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
         %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
         // CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64
         // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index 95774ca67c4f2..afeae8be24b72 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -1,11 +1,18 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
 
 gpu.module @load_store_check {
+    // CHECK-LABEL: gpu.func @load_store(
+    // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel {
     gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
         %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
         %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
 
-        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+        // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32>
+        // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
+        // CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
         // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
         // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
         // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
@@ -42,9 +49,8 @@ gpu.module @load_store_check {
         //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
         %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
 
-        // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
         // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
         // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
         // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>

@github-actions
Copy link

🐧 Linux x64 Test Results

  • 7083 tests passed
  • 594 tests skipped

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@silee2 silee2 merged commit 7de59f0 into llvm:main Nov 20, 2025
13 checks passed
aadeshps-mcw pushed a commit to aadeshps-mcw/llvm-project that referenced this pull request Nov 26, 2025
…s from memref. (llvm#168610)

adaptor already lowers memref to base address.
Conversion patterns should use it instead of generating code to get base
address from memref.
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Nov 26, 2025
…s from memref. (llvm#168610)

adaptor already lowers memref to base address.
Conversion patterns should use it instead of generating code to get base
address from memref.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants