From 6a541298aa62f47ba760acae2649b83565750d50 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Tue, 18 Nov 2025 21:20:46 +0000 Subject: [PATCH] [MLIR][Conversion] XeGPU to XeVM: Use adaptor for getting base pointer from memref. --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 7 +++---- .../test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir | 6 +++--- mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir | 8 ++++---- mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir | 12 +++++++++--- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 705298f497d20..f16a8e239ee74 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -194,10 +194,9 @@ class CreateNdDescToXeVMPattern if (!sourceMemrefTy.hasRank()) { return rewriter.notifyMatchFailure(op, "Expected ranked Memref."); } - baseAddr = - memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); - // Cast index to i64. - baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); + // Access adaptor after failure check to avoid rolling back generated code + // for materialization cast. + baseAddr = adaptor.getSource(); } else { baseAddr = adaptor.getSource(); if (baseAddr.getType() != i64Ty) { diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index 109312218afae..8b87b791c9fd3 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -7,6 +7,8 @@ gpu.module @create_nd_tdesc { // CHECK-SAME: %[[DYN:.*]]: memref) kernel { gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref) kernel { + // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref -> index + // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> @@ -27,9 +29,9 @@ gpu.module @create_nd_tdesc { // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32> %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32> - // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32 // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32 // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64 @@ -52,8 +54,6 @@ gpu.module @create_nd_tdesc { // CHECK: %[[C16:.*]] = arith.constant 16 : index %BLOCK_DMODEL = arith.constant 16 : index // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref -> index - // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32 // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32 // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32 diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir index 7b4ad9ec2df03..aebec7fc27c78 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir @@ -9,11 +9,13 @@ gpu.module @load_store_check { // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32> %srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index + // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32> %dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32> + // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index + // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64 - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index - // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64 %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32> // CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64 // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1> @@ -22,8 +24,6 @@ gpu.module @load_store_check { %loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32xf32> -> vector<2xf32> - // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index - // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64 %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr> // CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64 // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1> diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir index 95774ca67c4f2..afeae8be24b72 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir @@ -1,11 +1,18 @@ // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store( + // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel { gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 + // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index + // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32> + // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index + // CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64 // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> @@ -42,9 +49,8 @@ gpu.module @load_store_check { //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32> %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> - // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> + // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32> // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32> // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>