diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 7f1ec17ce0ae8..9c99a24bea8cd 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16}; // Offsets to individual fields of the 8xi32 layout nd tensor descriptor. enum class NdTdescOffset : uint32_t { - BasePtr = 0, // Base pointer (i64) - BaseShapeW = 2, // Base shape width (i32) - BaseShapeH = 3, // Base shape height (i32) - TensorOffsetW = 4, // Tensor offset W (i32) - TensorOffsetH = 5 // Tensor offset H (i32) + BasePtr = 0, // Base pointer (i64) + BaseShapeW = 2, // Base shape width (i32) + BaseShapeH = 3, // Base shape height (i32) + BasePitch = 4, // Base pitch (i32) }; static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { @@ -179,11 +178,10 @@ class CreateNdDescToXeVMPattern Value baseAddr; Value baseShapeW; Value baseShapeH; - Value offsetW; - Value offsetH; // Source can be a memref or a pointer (ui64, ui32, i64 or i32). SmallVector mixedSizes = op.getMixedSizes(); + SmallVector mixedStrides = op.getMixedStrides(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); auto sourceTy = source.getType(); @@ -216,12 +214,11 @@ class CreateNdDescToXeVMPattern val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); return val; }; - // Offsets are not supported (0 is used). - offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); + // Get pitch value from op fold results. + Value basePitch = createOffset(mixedStrides, 0); // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); @@ -235,12 +232,9 @@ class CreateNdDescToXeVMPattern payload = vector::InsertOp::create(rewriter, loc, baseShapeH, payload, static_cast(NdTdescOffset::BaseShapeH)); - payload = vector::InsertOp::create( - rewriter, loc, offsetW, payload, - static_cast(NdTdescOffset::TensorOffsetW)); - payload = vector::InsertOp::create( - rewriter, loc, offsetH, payload, - static_cast(NdTdescOffset::TensorOffsetH)); + payload = + vector::InsertOp::create(rewriter, loc, basePitch, payload, + static_cast(NdTdescOffset::BasePitch)); rewriter.replaceOp(op, payload); return success(); } @@ -289,6 +283,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeW)); Value baseShapeH = vector::ExtractOp::create( rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeH)); + Value basePitch = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast(NdTdescOffset::BasePitch)); // Offsets are provided by the op. // convert them to i32. Value offsetW = @@ -303,8 +299,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); // Compute width in bytes. - Value surfaceW = + Value baseWidthByte = arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + // Compute pitch in bytes. + Value basePitchByte = + arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize); // Get tile width from the tensor descriptor type. auto tileW = tdescTy.getDimSize(tileRank - 1); @@ -331,8 +330,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { auto storeCacheControl = translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); xevm::BlockStore2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, src, + rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH, + basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src, xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { @@ -340,9 +339,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); if constexpr (std::is_same_v) { xevm::BlockPrefetch2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, - offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, - xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH, + basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, + vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); rewriter.eraseOp(op); } else { VectorType dstVecTy = cast(op.getValue().getType()); @@ -355,9 +354,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { : rewriter.getIntegerType(elemBitSize)); Value resultFlatVec = xevm::BlockLoad2dOp::create( - rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, - surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, - transpose, vnni, + rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH, + basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, + vblocks, transpose, vnni, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); resultFlatVec = vector::BitCastOp::create( rewriter, loc, diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index 8b87b791c9fd3..9a1e2cb3c7de0 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -8,21 +8,19 @@ gpu.module @create_nd_tdesc { gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref) kernel { // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref -> index - // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 + // CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32 - // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32 // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32 // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32 + // CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32 // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32> // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32> // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32> + // CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32> %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2] : ui64 -> !xegpu.tensor_desc<8x16xf32> @@ -32,19 +30,18 @@ gpu.module @create_nd_tdesc { // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32 - // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32 // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64 // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32 // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32 + // CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64 + // CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32 // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32> // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32> // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> - // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> + // CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -53,18 +50,16 @@ gpu.module @create_nd_tdesc { %size_x = arith.constant 64 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index %BLOCK_DMODEL = arith.constant 16 : index - // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32 - // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32 - // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32 - // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32 - // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64> - // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32> - // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32> + // CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32 + // CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32 + // CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32 + // CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32> %dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<16x16xf16> gpu.return } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir index afeae8be24b72..4c73c9c238b6e 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir @@ -1,78 +1,32 @@ -// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s gpu.module @load_store_check { // CHECK-LABEL: gpu.func @load_store( - // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel { gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32 + // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32 + // CHECK: %[[H:.*]] = arith.constant 8 : i32 %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index - // CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64 - // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - - //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32 - //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32 - //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32 - //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]], - //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]] + //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]] //CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> - //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32> %tid_x = gpu.thread_id x %tid_x_i32 = arith.index_cast %tid_x : index to i32 %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 - //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32> %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> - // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32> %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32 - //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32 - //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32 - //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32> - //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]], - //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]] - //CHECK-SAME: <{cache_control = #xevm.store_cache_control, elem_size_in_bits = 32 : i32, + //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{ + //CHECK-SAME: cache_control = #xevm.store_cache_control, elem_size_in_bits = 32 : i32, //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir index e4b303087ea9b..43df721fb77a0 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir @@ -3,27 +3,16 @@ gpu.module @prefetch_nd_check { // CHECK-LABEL: gpu.func @prefetch_nd gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { - // CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.constant 64 : i32 - // CHECK: %[[LD_CREATE_DESC_I64:.*]] = arith.constant dense<0> : vector<4xi64> - // CHECK: %[[PREF_BASE_H:.*]] = arith.constant 8 : i32 - // CHECK: %[[PREF_BASE_W:.*]] = arith.constant 16 : i32 + // CHECK: %[[BASE_WIDTH_PITCH_BYTES:.*]] = arith.constant 64 : i32 // CHECK: %[[OFFSET_ZERO:.*]] = arith.constant 0 : i32 + // CHECK: %[[BASE_H:.*]] = arith.constant 8 : i32 %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[LD_DESC_2:.*]] = vector.insert %[[PREF_BASE_W]], %[[LD_DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_3:.*]] = vector.insert %[[PREF_BASE_H]], %[[LD_DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_4:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_4]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr, #xegpu.layout> - //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]], - //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]] + //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1> + //CHECK: xevm.blockprefetch2d %[[LLVMPTR]], %[[BASE_WIDTH_PITCH_BYTES]], %[[BASE_H]], + //CHECK-SAME: %[[BASE_WIDTH_PITCH_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]] //CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> //CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32)