diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 33e8f2ed1f6ed..bc360c1a6929b 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -264,8 +264,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { auto tdesc = adaptor.getTensorDesc(); auto tdescTy = op.getTensorDescType(); - if (tdescTy.getRank() != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); if (elemBitSize % 8 != 0) @@ -294,71 +292,138 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); - // Convert base pointer (i64) to LLVM pointer type. - Value basePtrLLVM = - LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); - // Compute element byte size and surface width in bytes. + // Compute element byte size. Value elemByteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); - Value surfaceW = - arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); - - // Get tile sizes and vblocks from the tensor descriptor type. - auto tileW = tdescTy.getDimSize(1); - auto tileH = tdescTy.getDimSize(0); - int32_t vblocks = tdescTy.getArrayLength(); - if constexpr (std::is_same_v) { - Value src = adaptor.getValue(); - // If store value is a scalar, get value from op instead of adaptor. - // Adaptor might have optimized away single element vector - if (src.getType().isIntOrFloat()) { - src = op.getValue(); - } - VectorType srcVecTy = dyn_cast(src.getType()); - if (!srcVecTy) - return rewriter.notifyMatchFailure( - op, "Expected store value to be a vector type."); - // Get flat vector type of integer type with matching element bit size. - VectorType newSrcVecTy = - encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); - if (srcVecTy != newSrcVecTy) - src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); - auto storeCacheControl = - translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - xevm::BlockStore2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, src, - xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); - rewriter.eraseOp(op); - } else { - auto loadCacheControl = - translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - if constexpr (std::is_same_v) { - xevm::BlockPrefetch2dOp::create( + auto tileRank = tdescTy.getRank(); + // Get tile width from the tensor descriptor type. + auto tileW = tdescTy.getDimSize(tileRank - 1); + if (tileRank == 2) { + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // Compute width in bytes. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); + Value surfaceW = + arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + + // Get tile height from the tensor descriptor type. + auto tileH = tdescTy.getDimSize(0); + // Get vblocks from the tensor descriptor type. + int32_t vblocks = tdescTy.getArrayLength(); + if constexpr (std::is_same_v) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + xevm::BlockStore2dOp::create( rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, vblocks, - xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + offsetH, elemBitSize, tileW, tileH, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { - VectorType dstVecTy = cast(op.getValue().getType()); - const bool vnni = op.getPacked().value_or(false); - auto transposeValue = op.getTranspose(); - bool transpose = - transposeValue.has_value() && transposeValue.value()[0] == 1; - VectorType loadedTy = encodeVectorTypeTo( - dstVecTy, vnni ? rewriter.getI32Type() - : rewriter.getIntegerType(elemBitSize)); - - Value resultFlatVec = xevm::BlockLoad2dOp::create( - rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, - surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, - transpose, vnni, + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + if constexpr (std::is_same_v) { + xevm::BlockPrefetch2dOp::create( + rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, + offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter.eraseOp(op); + } else { + VectorType dstVecTy = cast(op.getValue().getType()); + const bool vnni = op.getPacked().value_or(false); + auto transposeValue = op.getTranspose(); + bool transpose = + transposeValue.has_value() && transposeValue.value()[0] == 1; + VectorType loadedTy = encodeVectorTypeTo( + dstVecTy, vnni ? rewriter.getI32Type() + : rewriter.getIntegerType(elemBitSize)); + + Value resultFlatVec = xevm::BlockLoad2dOp::create( + rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, + surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, + transpose, vnni, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + resultFlatVec = vector::BitCastOp::create( + rewriter, loc, + encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), + resultFlatVec); + rewriter.replaceOp(op, resultFlatVec); + } + } + } else { + // Get address from base address and offsets. + // Offset in number of elements, need to multiply by element byte size. + // Compute linear offset. + // linearOffset = offsetH * baseShapeW + offsetW + Value offsetHInElems = + rewriter.createOrFold(loc, offsetH, baseShapeW); + Value linearOffset = + rewriter.createOrFold(loc, offsetHInElems, offsetW); + // Then compute byte offset by multiplying with element byte size. + // byteOffset = linearOffset * elemByteSize + Value byteOffset = + rewriter.createOrFold(loc, linearOffset, elemByteSize); + // Final address = basePtr + byteOffset + Value finalAddrI64 = rewriter.createOrFold( + loc, basePtr, + getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(), + byteOffset)); + // Convert base pointer (i64) to LLVM pointer type. + Value finalPtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64); + if constexpr (std::is_same_v) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + rewriter.replaceOpWithNewOp( + op, finalPtrLLVM, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); + } else if constexpr (std::is_same_v) { + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + VectorType resTy = cast(op.getValue().getType()); + VectorType loadedTy = + encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize)); + Value load = xevm::BlockLoadOp::create( + rewriter, loc, loadedTy, finalPtrLLVM, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); - resultFlatVec = vector::BitCastOp::create( - rewriter, loc, - encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), - resultFlatVec); - rewriter.replaceOp(op, resultFlatVec); + if (loadedTy != resTy) + load = vector::BitCastOp::create(rewriter, loc, resTy, load); + rewriter.replaceOp(op, load); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported operation: xegpu.prefetch_nd with tensor " + "descriptor rank == 1"); } } return success(); diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir new file mode 100644 index 0000000000000..54804ee6cd033 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: @load_store( + // CHECK-SAME: %[[SRC:.*]]: memref<8x64xf32, 1>, %[[DST:.*]]: memref<8x32xf32, 1> + gpu.func @load_store(%src: memref<8x64xf32, 1>, %dst: memref<8x32xf32, 1>) kernel { + // CHECK: %[[C512:.*]] = arith.constant 512 : i64 + // CHECK: %[[C32:.*]] = arith.constant 32 : i32 + // CHECK: %[[C384:.*]] = arith.constant 384 : i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64> + // CHECK: %[[C8:.*]] = arith.constant 8 : i32 + // CHECK: %[[C64:.*]] = arith.constant 64 : i32 + // CHECK: %[[C0:.*]] = arith.constant 0 : i32 + + // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<8x64xf32, 1> to memref<8x64xf32> + %srcce = memref.memory_space_cast %src : memref<8x64xf32, 1> to memref<8x64xf32> + // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<8x32xf32, 1> to memref<8x32xf32> + %dstte = memref.memory_space_cast %dst : memref<8x32xf32, 1> to memref<8x32xf32> + + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<8x64xf32> -> index + // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[VEC1:.*]] = vector.insert %[[INTPTR_I64]], %[[CST]] [0] : i64 into vector<4xi64> + // CHECK: %[[VEC2:.*]] = vector.bitcast %[[VEC1]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VEC3:.*]] = vector.insert %[[C64]], %[[VEC2]] [2] : i32 into vector<8xi32> + // CHECK: %[[VEC4:.*]] = vector.insert %[[C8]], %[[VEC3]] [3] : i32 into vector<8xi32> + // CHECK: %[[VEC5:.*]] = vector.insert %[[C0]], %[[VEC4]] [4] : i32 into vector<8xi32> + // CHECK: %[[VEC6:.*]] = vector.insert %[[C0]], %[[VEC5]] [5] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x64xf32> -> !xegpu.tensor_desc<32xf32> + // CHECK: %[[VEC7:.*]] = vector.bitcast %[[VEC6]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[EXTR:.*]] = vector.extract %[[VEC7]][0] : i64 from vector<4xi64> + // CHECK: %[[ADDR:.*]] = arith.addi %[[EXTR]], %[[C384]] : i64 + // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control}> + // CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32> + %loaded = xegpu.load_nd %src_tdesc[1, 32] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<32xf32> -> vector<2xf32> + + // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<8x32xf32> -> index + // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64 + // CHECK: %[[VEC1_1:.*]] = vector.insert %[[INTPTR1_I64]], %[[CST]] [0] : i64 into vector<4xi64> + // CHECK: %[[VEC2_1:.*]] = vector.bitcast %[[VEC1_1]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VEC3_1:.*]] = vector.insert %[[C32]], %[[VEC2_1]] [2] : i32 into vector<8xi32> + // CHECK: %[[VEC4_1:.*]] = vector.insert %[[C8]], %[[VEC3_1]] [3] : i32 into vector<8xi32> + // CHECK: %[[VEC5_1:.*]] = vector.insert %[[C0]], %[[VEC4_1]] [4] : i32 into vector<8xi32> + // CHECK: %[[VEC6_1:.*]] = vector.insert %[[C0]], %[[VEC5_1]] [5] : i32 into vector<8xi32> + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr> + // CHECK: %[[VEC7_1:.*]] = vector.bitcast %[[VEC6_1]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[EXTR1:.*]] = vector.extract %[[VEC7_1]][0] : i64 from vector<4xi64> + // CHECK: %[[ADDR1:.*]] = arith.addi %[[EXTR1]], %[[C512]] : i64 + // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1> + // CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control}> + // CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>) + xegpu.store_nd %loaded, %dst_tdesc[4, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr> + gpu.return + } +}