From d9a53760818ecb943e3f1d006a52ceb3cbc1703c Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 11 Nov 2025 05:48:52 +0000 Subject: [PATCH 1/6] inital implementation --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 2 +- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 45 +++++++++-------- .../XeGPUToXeVM/loadstore_matrix.mlir | 48 +++++++++++++++---- mlir/test/Dialect/XeGPU/ops.mlir | 21 ++++++++ 4 files changed, 86 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 689ebd0d1179a..59dc0eade0744 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1308,7 +1308,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, Results: - `mem_desc` : the memory descriptor. }]; - let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source); + let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[I8]>, ConfinedType, [HasStaticShapePred, isSharedPred]>]>:$source); let results = (outs XeGPU_MemDesc:$mem_desc); let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))"; } diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 705298f497d20..02b60ce5f6a1d 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -590,16 +590,22 @@ class CreateMemDescOpPattern final matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resTy = op.getMemDesc(); +// auto resTy = op.getMemDesc(); // Create the result MemRefType with the same shape, element type, and // memory space - auto newResTy = getTypeConverter()->convertType(resTy); +// auto newResTy = getTypeConverter()->convertType(resTy); - Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); - auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, - op.getSource(), zero, ValueRange()); - rewriter.replaceOp(op, viewOp); +// Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); +// auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, +// op.getSource(), zero, ValueRange()); + + Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, op.getLoc(), op.getSource()); + auto baseAddr32 = arith::IndexCastUIOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr); + + rewriter.replaceOp(op, baseAddr32); return success(); } }; @@ -619,7 +625,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); - Value basePtrStruct = adaptor.getMemDesc(); + Value baseAddr32 = adaptor.getMemDesc(); Value mdescVal = op.getMemDesc(); // Load result or Store value Type can be vector or scalar. Value data; @@ -647,21 +653,21 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { auto mdescTy = cast(mdescVal.getType()); - Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, loc, basePtrStruct); + // Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( + // rewriter, loc, basePtrStruct); // Convert base pointer (ptr) to i32 - Value basePtrI32 = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI32Type(), basePtrLLVM); + //Value basePtrI32 = arith::IndexCastUIOp::create( + // rewriter, loc, rewriter.getI32Type(), baseAddr); Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); linearOffset = arith::IndexCastUIOp::create( rewriter, loc, rewriter.getI32Type(), linearOffset); - basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, + Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, linearOffset, elemByteSize); // convert base pointer (i32) to LLVM pointer type - basePtrLLVM = + Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); if (op.getSubgroupBlockIoAttr()) { @@ -1005,12 +1011,13 @@ struct ConvertXeGPUToXeVMPass auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); - // Convert MemDescType into flattened MemRefType for SLM - typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { - Type elemTy = type.getElementType(); - int numElems = type.getNumElements(); - return MemRefType::get(numElems, elemTy, AffineMap(), 3); - }); + // Convert MemDescType into i32 for SLM + typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { + // Type elemTy = type.getElementType(); + // int numElems = type.getNumElements(); + // return MemRefType::get(numElems, elemTy, AffineMap(), 3); + return IntegerType::get(&getContext(), 32); + }); typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index d4cb493271d0d..2b5c308f9ab86 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target] { // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) - //CHECK-LABEL: load_store_matrix_1 - gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { + //CHECK-LABEL: load_store_matrix_plain + gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> //CHECK: %[[TID:.*]] = gpu.thread_id x @@ -26,10 +26,38 @@ gpu.module @test_kernel [#xevm.target] { gpu.return %1: f32 } + //CHECK-LABEL: load_store_matrix_plain_2d_input + gpu.func @load_store_matrix_plain(%arg0: memref<8192xi8, 3>) -> f32 { + + %view = memref.view %arg0[0][]: memref<8192xi8, 3> to memref<64x32xf32, 3> + + %subview = memref.subview %view[64, 0] [64, 128] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> + + %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32> + + //CHECK: %[[TID:.*]] = gpu.thread_id x + //CHECK: %[[C1:.*]] = arith.constant 1 : index + //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index + //CHECK: %[[C4:.*]] = arith.constant 4 : i32 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32 + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 + + %tid_x = gpu.thread_id x + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 + + //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index + + gpu.return %1: f32 + } + + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) - //CHECK-LABEL: load_store_matrix_2 - gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { + //CHECK-LABEL: load_store_matrix_blocked_strided + gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[tid_x:.*]] = gpu.thread_id x @@ -68,8 +96,8 @@ gpu.module @test_kernel [#xevm.target] { // e.g. for mem_desc<32x64xf16, @block=[16, 16]> // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) - //CHECK-LABEL: load_store_matrix_3 - gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { + //CHECK-LABEL: load_store_matrix_blocked_nostride + gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 { //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> @@ -110,8 +138,8 @@ gpu.module @test_kernel [#xevm.target] { // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) - //CHECK-LABEL: load_store_matrix_4 - gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + //CHECK-LABEL: load_store_matrix_blocked_strided_return_vector + gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> //CHECK: %[[c0:.*]] = arith.constant 0 : index @@ -150,8 +178,8 @@ gpu.module @test_kernel [#xevm.target] { // e.g. for mem_desc<32x64xf16, @block=[16, 16]> // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) - //CHECK-LABEL: load_store_matrix_5 - gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio + gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 9b3829664108d..1e9738f44bb66 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -834,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() { gpu.return } + +// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) { +gpu.func @create_mem_desc_from_2d_memref() { + //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3> + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16> + %m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3> + %mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16> + gpu.return +} + +// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) { +gpu.func @create_mem_desc_with_stride_from_2d_memref() { + //CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3> + //CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> + //CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + %m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3> + %m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> + %mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + gpu.return +} + // CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> From 18f8cb976d61b26e37e9b00ad6b9e770c288e73f Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Wed, 12 Nov 2025 21:41:09 +0000 Subject: [PATCH 2/6] adding tests --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 36 ++++------------ .../XeGPUToXeVM/loadstore_matrix.mlir | 42 +++++++++---------- 2 files changed, 27 insertions(+), 51 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 02b60ce5f6a1d..c01078ca6938e 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -590,20 +590,10 @@ class CreateMemDescOpPattern final matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { -// auto resTy = op.getMemDesc(); - - // Create the result MemRefType with the same shape, element type, and - // memory space -// auto newResTy = getTypeConverter()->convertType(resTy); - -// Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); -// auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, -// op.getSource(), zero, ValueRange()); - Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, op.getLoc(), op.getSource()); - auto baseAddr32 = arith::IndexCastUIOp::create( - rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr); + rewriter, op.getLoc(), op.getSource()); + auto baseAddr32 = arith::IndexCastUIOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr); rewriter.replaceOp(op, baseAddr32); return success(); @@ -653,18 +643,11 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { auto mdescTy = cast(mdescVal.getType()); - // Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( - // rewriter, loc, basePtrStruct); - - // Convert base pointer (ptr) to i32 - //Value basePtrI32 = arith::IndexCastUIOp::create( - // rewriter, loc, rewriter.getI32Type(), baseAddr); - Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); linearOffset = arith::IndexCastUIOp::create( rewriter, loc, rewriter.getI32Type(), linearOffset); - Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, linearOffset, - elemByteSize); + Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, + linearOffset, elemByteSize); // convert base pointer (i32) to LLVM pointer type Value basePtrLLVM = @@ -1012,12 +995,9 @@ struct ConvertXeGPUToXeVMPass return VectorType::get(8, i32Type); }); // Convert MemDescType into i32 for SLM - typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { - // Type elemTy = type.getElementType(); - // int numElems = type.getNumElements(); - // return MemRefType::get(numElems, elemTy, AffineMap(), 3); - return IntegerType::get(&getContext(), 32); - }); + typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { + return IntegerType::get(&getContext(), 32); + }); typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index 2b5c308f9ab86..5b3776ba0039c 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -27,11 +27,11 @@ gpu.module @test_kernel [#xevm.target] { } //CHECK-LABEL: load_store_matrix_plain_2d_input - gpu.func @load_store_matrix_plain(%arg0: memref<8192xi8, 3>) -> f32 { - - %view = memref.view %arg0[0][]: memref<8192xi8, 3> to memref<64x32xf32, 3> + gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 { + %c0 = arith.constant 0 : index + %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3> - %subview = memref.subview %view[64, 0] [64, 128] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> + %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32> @@ -43,7 +43,7 @@ gpu.module @test_kernel [#xevm.target] { //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 %tid_x = gpu.thread_id x - %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> @@ -59,7 +59,7 @@ gpu.module @test_kernel [#xevm.target] { //CHECK-LABEL: load_store_matrix_blocked_strided gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x //CHECK: %[[c13:.*]] = arith.constant 13 : index //CHECK: %[[c16:.*]] = arith.constant 16 : index @@ -67,7 +67,7 @@ gpu.module @test_kernel [#xevm.target] { //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c256:.*]] = arith.constant 256 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -98,8 +98,9 @@ gpu.module @test_kernel [#xevm.target] { // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) //CHECK-LABEL: load_store_matrix_blocked_nostride gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> //CHECK: %[[tid_x:.*]] = gpu.thread_id x @@ -107,13 +108,12 @@ gpu.module @test_kernel [#xevm.target] { %tid_x = gpu.thread_id x %c19 = arith.constant 19: index - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c1024:.*]] = arith.constant 1024 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -125,7 +125,6 @@ gpu.module @test_kernel [#xevm.target] { //CHECK: %[[c1:.*]] = arith.constant 1 : index //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index - //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 @@ -142,15 +141,13 @@ gpu.module @test_kernel [#xevm.target] { gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[tid_x:.*]] = gpu.thread_id x - //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c256:.*]] = arith.constant 256 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index @@ -180,23 +177,22 @@ gpu.module @test_kernel [#xevm.target] { // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> - - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + + //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[c48:.*]] = arith.constant 48 : index - %c16 = arith.constant 16 : index %c48 = arith.constant 48 : index - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index + //CHECK: %[[c0:.*]] = arith.constant 0 : index //CHECK: %[[c1024:.*]] = arith.constant 1024 : index //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index From c0ba80e50f65af0fe7daa9e91afdfc271326665e Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Wed, 12 Nov 2025 22:05:35 +0000 Subject: [PATCH 3/6] fixing documentation --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 59dc0eade0744..cd5824b607aba 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1304,7 +1304,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, as the underlying shared local memory. Arguments: - - `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer. + - `source` : 1D or 2D statically shape memref, representing the raw SLM buffer. When the source is provided as 1D memref, its type must be i8. Results: - `mem_desc` : the memory descriptor. }]; From 281393e862ecea749044b111d61589f2f02b3d4c Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 18 Nov 2025 19:07:17 +0000 Subject: [PATCH 4/6] address feedback --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 8 +------- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 11 +++++++++++ mlir/test/Dialect/XeGPU/invalid.mlir | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 7ac5f495a9789..ed166260fb689 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1282,12 +1282,6 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou let hasCanonicalizer = 1; } -def isSharedPred : CPred<"isSharedMemory(llvm::cast($_self))">; -class StaticShared1DMemRefOf allowedTypes> : - ConfinedType, [HasStaticShapePred, isSharedPred], - "statically shaped " # MemRefOf.summary # " for shared memory", - "mlir::MemRefType">; - class SizeInBits : StrFunc<"llvm::cast($" # name # ".getType()).getNumElements()" "*llvm::cast($" # name # ".getType()).getElementTypeBitWidth()">; @@ -1308,7 +1302,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, Results: - `mem_desc` : the memory descriptor. }]; - let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[I8]>, ConfinedType, [HasStaticShapePred, isSharedPred]>]>:$source); + let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[XeGPU_ScalarType]>, StaticShared2DMemRefOf<[XeGPU_ScalarType]>]>:$source); let results = (outs XeGPU_MemDesc:$mem_desc); let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))"; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index b1196fbe9c66a..7f7f7d065c50e 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -35,6 +35,17 @@ class XeGPUTypeDef traits = [], let mnemonic = typeMnemonic; } +def isSharedPred : CPred<"isSharedMemory(llvm::cast($_self))">; +class StaticShared1DMemRefOf allowedTypes> : + ConfinedType, [HasStaticShapePred, isSharedPred], + "reside in share memory and statically 1d shaped " # MemRefOf.summary # " ", + "mlir::MemRefType">; + +class StaticShared2DMemRefOf allowedTypes>: + ConfinedType, [HasStaticShapePred, isSharedPred], + "reside in share memory and statically 2d shaped " # MemRefOf.summary # " ", + "mlir::MemRefType">; + def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", [ShapedTypeInterface], "::mlir::TensorType"> { let summary = "TensorDesc describing regions of interested data."; diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 92f353717ac59..67faa60f2835e 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -836,7 +836,7 @@ func.func @slice_attr_repeat_dim() { // ----- func.func @create_mem_desc_non_slm() { %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1> - // expected-error@+1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}} + // expected-error@+1 {{operand #0 must be reside in share memory and statically 1d shaped memref }} %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16> return } From 8138f3ad15272058b79830e8e4fb98781537b765 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 18 Nov 2025 19:14:15 +0000 Subject: [PATCH 5/6] change documentation --- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index ed166260fb689..1ae1feed177ae 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1298,7 +1298,8 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, as the underlying shared local memory. Arguments: - - `source` : 1D or 2D statically shape memref, representing the raw SLM buffer. When the source is provided as 1D memref, its type must be i8. + - `source` : 1D or 2D statically shape memref, representing the raw SLM buffer. + The provided memref must be contiguous. Results: - `mem_desc` : the memory descriptor. }]; From 164c72f6744d89eb79b56525f211765dbd471d29 Mon Sep 17 00:00:00 2001 From: Jianhui Li Date: Tue, 18 Nov 2025 19:51:08 +0000 Subject: [PATCH 6/6] using implict type converter for memref input --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 13 +++---------- .../Conversion/XeGPUToXeVM/loadstore_matrix.mlir | 5 ++--- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index c01078ca6938e..4588cf46a9d46 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -579,9 +579,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { } }; -// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions -// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than -// 32 bits will be converted to 32 bits. class CreateMemDescOpPattern final : public OpConversionPattern { public: @@ -590,12 +587,7 @@ class CreateMemDescOpPattern final matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, op.getLoc(), op.getSource()); - auto baseAddr32 = arith::IndexCastUIOp::create( - rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr); - - rewriter.replaceOp(op, baseAddr32); + rewriter.replaceOp(op, adaptor.getSource()); return success(); } }; @@ -1000,7 +992,8 @@ struct ConvertXeGPUToXeVMPass }); typeConverter.addConversion([&](MemRefType type) -> Type { - // Convert MemRefType to i64 type. + if (type.getMemorySpaceAsInt() == 3) + return IntegerType::get(&getContext(), 32); return IntegerType::get(&getContext(), 64); }); diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir index 5b3776ba0039c..ac95a1a5707ea 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -179,9 +179,8 @@ gpu.module @test_kernel [#xevm.target] { gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 + //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32 %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: %[[c16:.*]] = arith.constant 16 : index //CHECK: %[[c48:.*]] = arith.constant 48 : index @@ -207,7 +206,7 @@ gpu.module @test_kernel [#xevm.target] { //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 //CHECK: %[[c2:.*]] = arith.constant 2 : i32 //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 - //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32 //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>