diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 91b2ecf8922a3..da061b269daf7 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -82,6 +82,7 @@ #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h" #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" +#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" #include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 2058aba7f9e37..323af3e97e2d4 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1555,4 +1555,16 @@ def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> { let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// XeGPUToXeVM +//===----------------------------------------------------------------------===// + +def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> { + let summary = "Convert XeGPU to XeVM dialect"; + let dependentDialects = ["xevm::XeVMDialect", "vector::VectorDialect", + "memref::MemRefDialect", "arith::ArithDialect", + "LLVM::LLVMDialect", "index::IndexDialect", + "gpu::GPUDialect", "scf::SCFDialect"]; +} + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h new file mode 100644 index 0000000000000..ddaaae82e03be --- /dev/null +++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h @@ -0,0 +1,27 @@ +//===-- XeGPUToXeVM.h - Convert XeGPU to XeVM dialect ---------_--*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_ +#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_ + +#include + +namespace mlir { +class DialectRegistry; +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_CONVERTXEGPUTOXEVMPASS +#include "mlir/Conversion/Passes.h.inc" + +void populateXeGPUToXeVMConversionPatterns( + const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_ diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 171f7169fd41d..134fe8e14ca38 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -76,3 +76,4 @@ add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) add_subdirectory(VectorToXeGPU) add_subdirectory(XeVMToLLVM) +add_subdirectory(XeGPUToXeVM) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt new file mode 100644 index 0000000000000..ed54b0bb5ee81 --- /dev/null +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -0,0 +1,25 @@ +add_mlir_conversion_library(MLIRXeGPUToXeVM + XeGPUToXeVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeGPUToXeVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRGPUDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRXeVMDialect + MLIRVectorDialect + MLIRArithDialect + MLIRIndexDialect + MLIRXeGPUDialect + MLIRPass + MLIRTransforms +) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp new file mode 100644 index 0000000000000..d8dd09a6280c0 --- /dev/null +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -0,0 +1,1021 @@ +//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/FormatVariadic.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +#include "llvm/ADT/TypeSwitch.h" + +#include + +namespace mlir { +#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +// TODO: Below are uArch dependent values, should move away from hardcoding +static constexpr int32_t systolicDepth{8}; +static constexpr int32_t executionSize{16}; + +// Offsets to individual fields of the 8xi32 layout nd tensor descriptor. +enum class NdTdescOffset : uint32_t { + BasePtr = 0, // Base pointer (i64) + BaseShapeW = 2, // Base shape width (i32) + BaseShapeH = 3, // Base shape height (i32) + TensorOffsetW = 4, // Tensor offset W (i32) + TensorOffsetH = 5 // Tensor offset H (i32) +}; + +static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { + switch (xeGpuMemspace) { + case xegpu::MemorySpace::Global: + return static_cast(xevm::AddrSpace::GLOBAL); + case xegpu::MemorySpace::SLM: + return static_cast(xevm::AddrSpace::SHARED); + } +} + +// Get same bitwidth flat vector type of new element type. +static VectorType encodeVectorTypeTo(VectorType currentVecType, + Type toElemType) { + auto elemType = currentVecType.getElementType(); + auto currentBitWidth = elemType.getIntOrFloatBitWidth(); + auto newBitWidth = toElemType.getIntOrFloatBitWidth(); + const int size = + currentVecType.getNumElements() * currentBitWidth / newBitWidth; + return VectorType::get(size, toElemType); +} + +static xevm::LoadCacheControl +translateLoadXeGPUCacheHint(std::optional L1hint, + std::optional L3hint) { + auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); + auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); + switch (L1hintVal) { + case xegpu::CachePolicy::CACHED: + if (L3hintVal == xegpu::CachePolicy::CACHED) + return xevm::LoadCacheControl::L1C_L2UC_L3C; + else if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::LoadCacheControl::L1C_L2UC_L3UC; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::UNCACHED: + if (L3hintVal == xegpu::CachePolicy::CACHED) + return xevm::LoadCacheControl::L1UC_L2UC_L3C; + else if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::LoadCacheControl::L1UC_L2UC_L3UC; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::STREAMING: + if (L3hintVal == xegpu::CachePolicy::CACHED) + return xevm::LoadCacheControl::L1S_L2UC_L3C; + else if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::LoadCacheControl::L1S_L2UC_L3UC; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::READ_INVALIDATE: + return xevm::LoadCacheControl::INVALIDATE_READ; + default: + llvm_unreachable("Unsupported cache control."); + } +} + +static xevm::StoreCacheControl +translateStoreXeGPUCacheHint(std::optional L1hint, + std::optional L3hint) { + auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); + auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); + switch (L1hintVal) { + case xegpu::CachePolicy::UNCACHED: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1UC_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1UC_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::STREAMING: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1S_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1S_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::WRITE_BACK: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1WB_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1WB_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::WRITE_THROUGH: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1WT_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1WT_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + default: + llvm_unreachable("Unsupported cache control."); + } +} + +class CreateNdDescToXeVMPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp op, + xegpu::CreateNdDescOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto source = op.getSource(); + // Op is lowered to a code sequence that populates payload. + // Payload is a 8xi32 vector. Offset to individual fields are defined in + // NdTdescOffset enum. + Type payloadElemTy = rewriter.getI32Type(); + VectorType payloadTy = VectorType::get(8, payloadElemTy); + Type i64Ty = rewriter.getI64Type(); + // 4xi64 view is used for inserting the base pointer. + VectorType payloadI64Ty = VectorType::get(4, i64Ty); + // Initialize payload to zero. + Value payload = arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0))); + + Value baseAddr; + Value baseShapeW; + Value baseShapeH; + Value offsetW; + Value offsetH; + + // Source can be a memref or a pointer (ui64, ui32, i64 or i32). + SmallVector mixedSizes = op.getMixedSizes(); + SmallVector mixedOffsets = op.getMixedOffsets(); + // Descriptor shape is expected to be 2D. + int64_t rank = mixedSizes.size(); + if (rank != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D shape."); + auto sourceTy = source.getType(); + auto sourceMemrefTy = dyn_cast(sourceTy); + // If source is a memref, we need to extract the aligned pointer as index. + // Pointer type is passed as i32 or i64 by type converter. + if (sourceMemrefTy) { + if (!sourceMemrefTy.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "Expected static memref shape."); + } + baseAddr = + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); + } else { + baseAddr = adaptor.getSource(); + } + // Utility for creating offset values from op fold result. + auto createOffset = [&](SmallVector &ofrVec, + unsigned idx) -> Value { + Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]); + val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); + return val; + }; + // Offsets can be either 2D or not provided (0 is used). + if (mixedOffsets.size() == 2) { + offsetW = createOffset(mixedOffsets, 1); + offsetH = createOffset(mixedOffsets, 0); + } else if (mixedOffsets.size() == 0) { + offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); + offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); + } else { + return rewriter.notifyMatchFailure(op, + "Expected 2D offsets or no offsets."); + } + // Get shape values from op fold results. + baseShapeW = createOffset(mixedSizes, 1); + baseShapeH = createOffset(mixedSizes, 0); + if (sourceMemrefTy) { + // Cast index to i64. + baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); + } else if (baseAddr.getType() != i64Ty) { + // Pointer type may be i32. Cast to i64 if needed. + baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); + } + // Populate payload. + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); + payLoadAsI64 = + vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64, + static_cast(NdTdescOffset::BasePtr)); + payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64); + payload = + vector::InsertOp::create(rewriter, loc, baseShapeW, payload, + static_cast(NdTdescOffset::BaseShapeW)); + payload = + vector::InsertOp::create(rewriter, loc, baseShapeH, payload, + static_cast(NdTdescOffset::BaseShapeH)); + payload = vector::InsertOp::create( + rewriter, loc, offsetW, payload, + static_cast(NdTdescOffset::TensorOffsetW)); + payload = vector::InsertOp::create( + rewriter, loc, offsetH, payload, + static_cast(NdTdescOffset::TensorOffsetH)); + rewriter.replaceOp(op, payload); + return success(); + } +}; + +class UpdateNdOffsetToXeVMPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::UpdateNdOffsetOp op, + xegpu::UpdateNdOffsetOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto mixedOffsets = op.getMixedOffsets(); + // Only 2D offsets are supported for now. + if (mixedOffsets.size() != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); + auto tdesc = adaptor.getTensorDesc(); + // Utility for updating payload offset values from op fold result. + auto updateOffset = [&](unsigned idx, int payloadPos) -> Value { + Value offset = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]); + offset = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offset); + Value oldOffset = + vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos); + Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset); + return vector::InsertOp::create(rewriter, loc, newOffset, tdesc, + payloadPos); + }; + // Update offsets in the payload. + auto val = updateOffset(0, static_cast(NdTdescOffset::TensorOffsetH)); + val = updateOffset(1, static_cast(NdTdescOffset::TensorOffsetW)); + rewriter.replaceOp(op, val); + return success(); + } +}; + +template < + typename OpType, + typename = std::enable_if_t::value>> +class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + + auto tdesc = adaptor.getTensorDesc(); + auto tdescTy = op.getTensorDescType(); + if (tdescTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); + auto elemType = tdescTy.getElementType(); + auto elemBitSize = elemType.getIntOrFloatBitWidth(); + if (elemBitSize % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + + VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); + Value basePtr = vector::ExtractOp::create( + rewriter, loc, payLoadAsI64, static_cast(NdTdescOffset::BasePtr)); + Value baseShapeW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeW)); + Value baseShapeH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeH)); + // Offsets provided in two ways: + // 1. Offsets are extracted from the tensor descriptor. + // 2. (Mixed) offsets which are provided by the op. + Value offsetW; + Value offsetH; + auto mixedOffsets = op.getMixedOffsets(); + int64_t opOffsetsSize = mixedOffsets.size(); + if (opOffsetsSize != 0 && opOffsetsSize != 2) + return rewriter.notifyMatchFailure(op, + "Expected 2D offsets or no offsets."); + if (opOffsetsSize) { + // If mixed offsets are provided by the op convert them to i32. + offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); + } else { + // If offsets are not available, we need to extract them from the tensor + // descriptor. + offsetW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast(NdTdescOffset::TensorOffsetW)); + offsetH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast(NdTdescOffset::TensorOffsetH)); + } + // Get address space from tensor descriptor memory space. + auto ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // Compute element byte size and surface width in bytes. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); + Value surfaceW = + arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + + // Get tile sizes and vblocks from the tensor descriptor type. + auto tileW = tdescTy.getDimSize(1); + auto tileH = tdescTy.getDimSize(0); + int32_t vblocks = tdescTy.getArrayLength(); + if constexpr (std::is_same_v) { + VectorType srcVecTy = dyn_cast(adaptor.getValue().getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + Value src = adaptor.getValue(); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + xevm::BlockStore2dOp::create( + rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, + offsetH, elemBitSize, tileW, tileH, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); + rewriter.eraseOp(op); + } else { + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + if constexpr (std::is_same_v) { + xevm::BlockPrefetch2dOp::create( + rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, + offsetH, elemBitSize, tileW, tileH, vblocks, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter.eraseOp(op); + } else { + VectorType dstVecTy = cast(op.getValue().getType()); + const bool vnni = op.getPacked().value_or(false); + auto transposeValue = op.getTranspose(); + bool transpose = + transposeValue.has_value() && transposeValue.value()[0] == 1; + VectorType loadedTy = encodeVectorTypeTo( + dstVecTy, vnni ? rewriter.getI32Type() + : rewriter.getIntegerType(elemBitSize)); + + Value resultFlatVec = xevm::BlockLoad2dOp::create( + rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, + surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, + transpose, vnni, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + resultFlatVec = vector::BitCastOp::create( + rewriter, loc, + encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), + resultFlatVec); + rewriter.replaceOp(op, resultFlatVec); + } + } + return success(); + } +}; + +// Add a builder that creates +// offset * elemByteSize + baseAddr +static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, + Value baseAddr, Value offset, int64_t elemByteSize) { + Value byteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI64Type(), elemByteSize); + Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); + Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); + return newAddr; +} + +class CreateDescToXeVMPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto eTy = op.getTensorDescType().getElementType(); + auto eBw = eTy.getIntOrFloatBitWidth(); + if (eBw % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + auto loc = op.getLoc(); + // Offsets are provided as scalar i64 by type converter. + auto offsets = adaptor.getOffsets(); + // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32). + // But type converter will convert them to integer types. + Value addr = adaptor.getSource(); + // ui32 or i32 are passed as i32 so they need to be casted to i64. + if (addr.getType() != rewriter.getI64Type()) + addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr); + auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8); + rewriter.replaceOp(op, laneAddr); + return success(); + } +}; + +class UpdateOffsetToXeVMPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::UpdateOffsetOp op, + xegpu::UpdateOffsetOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto eTy = op.getTensorDescType().getElementType(); + auto eBw = eTy.getIntOrFloatBitWidth(); + if (eBw % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + auto loc = op.getLoc(); + // Scatter descriptor is provided as scalar i64 by type converter. + // Offsets are provided as scalar i64 by type converter. + Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(), + adaptor.getOffsets(), eBw / 8); + rewriter.replaceOp(op, newOffset); + return success(); + } +}; + +template ::value>> +class LoadStoreToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto tdescTy = op.getTensorDescType(); + Value basePtrI64; + // Load result or Store valye Type can be vector or scalar. + Type valOrResTy; + if constexpr (std::is_same_v) + valOrResTy = op.getResult().getType(); + else + valOrResTy = adaptor.getValue().getType(); + VectorType valOrResVecTy = dyn_cast(valOrResTy); + bool hasScalarVal = !valOrResVecTy; + int64_t elemBitWidth = + hasScalarVal ? valOrResTy.getIntOrFloatBitWidth() + : valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + // Element type must be multiple of 8 bits. + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + int64_t elemByteSize = elemBitWidth / 8; + // Default memory space is global. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); + // If tensor descriptor is available, we use its memory space. + if (tdescTy) + ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + // Base pointer can come from source (load) or dest (store). + // If they are memrefs, we use their memory space. + if constexpr (std::is_same_v) { + basePtrI64 = adaptor.getSource(); + if (auto memRefTy = dyn_cast(op.getSource().getType())) { + auto addrSpace = memRefTy.getMemorySpaceAsInt(); + if (addrSpace != 0) + ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); + } + } else { + basePtrI64 = adaptor.getDest(); + if (auto memRefTy = dyn_cast(op.getDest().getType())) { + auto addrSpace = memRefTy.getMemorySpaceAsInt(); + if (addrSpace != 0) + ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); + } + } + // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. + if (basePtrI64.getType() != rewriter.getI64Type()) { + basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), + basePtrI64); + } + Value offsets = adaptor.getOffsets(); + Value mask = adaptor.getMask(); + if (offsets) { + if (dyn_cast(offsets.getType())) { + // Offset needs be scalar. Single element vector is converted to scalar + // by type converter. + return rewriter.notifyMatchFailure(op, + "Expected offsets to be a scalar."); + } else { + // If offsets are provided, we add them to the base pointer. + // Offsets are in number of elements, we need to multiply by + // element byte size. + basePtrI64 = + addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + } + } + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + + Value maskForLane; + VectorType maskVecTy = dyn_cast(mask.getType()); + if (maskVecTy) { + // Mask needs be scalar. Single element vector is converted to scalar by + // type converter. + return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar."); + } else + maskForLane = mask; + if constexpr (std::is_same_v) { + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy}, + maskForLane, true, true); + // If mask is true,- then clause - load from memory and yield. + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + if (!hasScalarVal) + valOrResTy = VectorType::get({valOrResVecTy.getNumElements()}, + valOrResVecTy.getElementType()); + Value loaded = + LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM); + // Set cache control attribute on the load operation. + loaded.getDefiningOp()->setAttr( + "cache_control", xevm::LoadCacheControlAttr::get( + ctxt, translateLoadXeGPUCacheHint( + op.getL1Hint(), op.getL3Hint()))); + scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + // If mask is false - else clause -yield a vector of zeros. + auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType(); + TypedAttr eVal; + if (eTy.isFloat()) + eVal = FloatAttr::get(eTy, 0.0); + else + eVal = IntegerAttr::get(eTy, 0); + if (hasScalarVal) + loaded = arith::ConstantOp::create(rewriter, loc, eVal); + else + loaded = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal)); + scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); + rewriter.replaceOp(op, ifOp.getResult(0)); + } else { + // If mask is true, perform the store. + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false); + auto body = ifOp.getBody(); + rewriter.setInsertionPointToStart(body); + auto storeOp = + LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM); + // Set cache control attribute on the store operation. + storeOp.getOperation()->setAttr( + "cache_control", xevm::StoreCacheControlAttr::get( + ctxt, translateStoreXeGPUCacheHint( + op.getL1Hint(), op.getL3Hint()))); + rewriter.eraseOp(op); + } + return success(); + } +}; + +class PrefetchToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto tdescTy = op.getTensorDescType(); + Value basePtrI64 = adaptor.getSource(); + // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. + if (basePtrI64.getType() != rewriter.getI64Type()) + basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), + basePtrI64); + Value offsets = adaptor.getOffsets(); + if (offsets) { + VectorType offsetsVecTy = dyn_cast(offsets.getType()); + if (offsetsVecTy) { + // Offset needs be scalar. + return rewriter.notifyMatchFailure(op, + "Expected offsets to be a scalar."); + } else { + int64_t elemBitWidth{0}; + int64_t elemByteSize; + // Element byte size can come from three sources: + if (tdescTy) { + // If tensor descriptor is available, we use its element type to + // determine element byte size. + elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); + } else if (auto memRefTy = dyn_cast(op.getSourceType())) { + // If memref is available, we use its element type to + // determine element byte size. + elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth(); + } else { + // Otherwise, we use the provided offset byte alignment. + elemByteSize = *op.getOffsetAlignByte(); + } + if (elemBitWidth != 0) { + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + elemByteSize = elemBitWidth / 8; + } + basePtrI64 = + addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + } + } + // Default memory space is global. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); + // If tensor descriptor is available, we use its memory space. + if (tdescTy) + ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + // If source is a memref, we use its memory space. + if (auto memRefTy = dyn_cast(op.getSource().getType())) { + auto addrSpace = memRefTy.getMemorySpaceAsInt(); + if (addrSpace != 0) + ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); + } + // Convert base pointer (i64) to LLVM pointer type. + Value ptrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + // Create the prefetch op with cache control attribute. + xevm::PrefetchOp::create( + rewriter, loc, ptrLLVM, + xevm::LoadCacheControlAttr::get( + ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()))); + rewriter.eraseOp(op); + return success(); + } +}; + +class FenceToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + xevm::MemScope memScope{xevm::MemScope::WORKGROUP}; + switch (op.getFenceScope()) { + case xegpu::FenceScope::Workgroup: + memScope = xevm::MemScope::WORKGROUP; + break; + case xegpu::FenceScope::GPU: + memScope = xevm::MemScope::DEVICE; + break; + } + xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL}; + switch (op.getMemoryKind()) { + case xegpu::MemorySpace::Global: + addrSpace = xevm::AddrSpace::GLOBAL; + break; + case xegpu::MemorySpace::SLM: + addrSpace = xevm::AddrSpace::SHARED; + break; + } + xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace); + rewriter.eraseOp(op); + return success(); + } +}; + +class DpasToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto aTy = cast(op.getLhs().getType()); + auto bTy = cast(op.getRhs().getType()); + auto resultType = cast(op.getResultType()); + + auto encodePrecision = [&](Type type) -> xevm::ElemType { + if (type == rewriter.getBF16Type()) + return xevm::ElemType::BF16; + else if (type == rewriter.getF16Type()) + return xevm::ElemType::F16; + else if (type == rewriter.getTF32Type()) + return xevm::ElemType::TF32; + else if (type.isInteger(8)) { + if (type.isUnsignedInteger()) + return xevm::ElemType::U8; + return xevm::ElemType::S8; + } else if (type == rewriter.getF32Type()) + return xevm::ElemType::F32; + else if (type.isInteger(32)) + return xevm::ElemType::S32; + llvm_unreachable("add more support for ElemType"); + }; + xevm::ElemType precATy = encodePrecision(aTy.getElementType()); + xevm::ElemType precBTy = encodePrecision(bTy.getElementType()); + Value c = op.getAcc(); + if (!c) { + auto elementTy = resultType.getElementType(); + Attribute initValueAttr; + if (isa(elementTy)) + initValueAttr = FloatAttr::get(elementTy, 0.0); + else + initValueAttr = IntegerAttr::get(elementTy, 0); + c = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr)); + } + + Value aVec = op.getLhs(); + Value bVec = op.getRhs(); + auto cvecty = cast(c.getType()); + xevm::ElemType precCTy = encodePrecision(cvecty.getElementType()); + xevm::ElemType precDTy = encodePrecision(resultType.getElementType()); + VectorType cNty = + VectorType::get(cvecty.getNumElements(), cvecty.getElementType()); + if (cvecty != cNty) + c = vector::ShapeCastOp::create(rewriter, loc, cNty, c); + Value dpasRes = xevm::MMAOp::create( + rewriter, loc, cNty, aVec, bVec, c, + xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize, + systolicDepth * + getNumOperandsPerDword(precATy)), + xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy)); + if (cvecty != cNty) + dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes); + rewriter.replaceOp(op, dpasRes); + return success(); + } + +private: + static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { + switch (pTy) { + case xevm::ElemType::TF32: + return 1; + case xevm::ElemType::BF16: + case xevm::ElemType::F16: + return 2; + case xevm::ElemType::U8: + case xevm::ElemType::S8: + return 4; + default: + llvm_unreachable("unsupported xevm::ElemType"); + } + } +}; + +static std::optional +matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) { + switch (arithKind) { + case arith::AtomicRMWKind::addf: + return LLVM::AtomicBinOp::fadd; + case arith::AtomicRMWKind::addi: + return LLVM::AtomicBinOp::add; + case arith::AtomicRMWKind::assign: + return LLVM::AtomicBinOp::xchg; + case arith::AtomicRMWKind::maximumf: + return LLVM::AtomicBinOp::fmax; + case arith::AtomicRMWKind::maxs: + return LLVM::AtomicBinOp::max; + case arith::AtomicRMWKind::maxu: + return LLVM::AtomicBinOp::umax; + case arith::AtomicRMWKind::minimumf: + return LLVM::AtomicBinOp::fmin; + case arith::AtomicRMWKind::mins: + return LLVM::AtomicBinOp::min; + case arith::AtomicRMWKind::minu: + return LLVM::AtomicBinOp::umin; + case arith::AtomicRMWKind::ori: + return LLVM::AtomicBinOp::_or; + case arith::AtomicRMWKind::andi: + return LLVM::AtomicBinOp::_and; + default: + return std::nullopt; + } +} + +class AtomicRMWToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto tdesc = op.getTensorDesc().getType(); + auto ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace())); + Value basePtrI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc()); + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + VectorType srcOrDstVecTy = cast(op.getValue().getType()); + VectorType srcOrDstFlatVecTy = VectorType::get( + srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType()); + Value srcFlatVec = vector::ShapeCastOp::create( + rewriter, loc, srcOrDstFlatVecTy, op.getValue()); + auto atomicKind = matchSimpleAtomicOp(op.getKind()); + assert(atomicKind.has_value()); + Value resVec = srcFlatVec; + for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) { + auto val = vector::ExtractOp::create(rewriter, loc, resVec, i); + Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + rewriter.getIndexAttr(i)); + Value currPtr = + LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM, + srcOrDstVecTy.getElementType(), basePtrLLVM, idx); + Value newVal = + LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr, + val, LLVM::AtomicOrdering::seq_cst); + resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i); + } + rewriter.replaceOp(op, resVec); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct ConvertXeGPUToXeVMPass + : public impl::ConvertXeGPUToXeVMPassBase { + using Base::Base; + + void runOnOperation() override { + LLVMTypeConverter typeConverter(&getContext()); + typeConverter.addConversion([&](VectorType type) -> Type { + unsigned rank = type.getRank(); + auto elemType = type.getElementType(); + // If the element type is index, convert it to i64. + if (llvm::isa(elemType)) + elemType = IntegerType::get(&getContext(), 64); + // If the vector is a scalar or has a single element, return the element + if (rank < 1 || type.getNumElements() == 1) + return elemType; + // Otherwise, convert the vector to a flat vector type. + int64_t sum = + std::accumulate(type.getShape().begin(), type.getShape().end(), + int64_t{1}, std::multiplies()); + return VectorType::get(sum, elemType); + }); + typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { + if (type.isScattered()) + return IntegerType::get(&getContext(), 64); + auto i32Type = IntegerType::get(&getContext(), 32); + return VectorType::get(8, i32Type); + }); + typeConverter.addConversion([&](MemRefType type) -> Type { + // Convert MemRefType to i64 type. + return IntegerType::get(&getContext(), 64); + }); + + // LLVM type converter puts unrealized casts for the following cases: + // add materialization casts to handle them. + + // Materialization to convert memref to i64 + auto memrefMaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (auto memrefTy = dyn_cast(input.getType())) { + + Value addr = + memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input); + return arith::IndexCastUIOp::create(builder, loc, type, addr) + .getResult(); + } + return {}; + }; + + // Materialization to convert ui64 to i64 + auto ui64MaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType() == builder.getIntegerType(64, false)) { + Value cast = + index::CastUOp::create(builder, loc, builder.getIndexType(), input) + .getResult(); + return arith::IndexCastUIOp::create(builder, loc, type, cast) + .getResult(); + } + return {}; + }; + + // Materialization to convert ui32 to i32 + auto ui32MaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType() == builder.getIntegerType(32, false)) { + Value cast = + index::CastUOp::create(builder, loc, builder.getIndexType(), input) + .getResult(); + return arith::IndexCastUIOp::create(builder, loc, type, cast) + .getResult(); + } + return {}; + }; + + // Materialization to convert + // - single element 1D vector to scalar + // - bitcast vector of same rank + // - shape vector of different rank but same element type + auto vectorMaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (auto vecTy = dyn_cast(input.getType())) { + if (vecTy.getNumElements() == 1) { + // If the vector has a single element, return the element type. + Value cast = + vector::ExtractOp::create(builder, loc, input, 0).getResult(); + if (vecTy.getElementType() == builder.getIndexType()) + cast = arith::IndexCastUIOp::create(builder, loc, type, cast) + .getResult(); + return cast; + } else if (auto targetVecTy = dyn_cast(type)) { + // If the target type is a vector of same rank, + // bitcast to the target type. + if (targetVecTy.getRank() == vecTy.getRank()) + return vector::BitCastOp::create(builder, loc, targetVecTy, input) + .getResult(); + else if (targetVecTy.getElementType() == vecTy.getElementType()) { + // If the target type is a vector of different rank but same element + // type, reshape to the target type. + return vector::ShapeCastOp::create(builder, loc, targetVecTy, input) + .getResult(); + } + } + } + return {}; + }; + typeConverter.addSourceMaterialization(memrefMaterializationCast); + typeConverter.addSourceMaterialization(ui64MaterializationCast); + typeConverter.addSourceMaterialization(ui32MaterializationCast); + typeConverter.addSourceMaterialization(vectorMaterializationCast); + typeConverter.addTargetMaterialization(memrefMaterializationCast); + typeConverter.addTargetMaterialization(ui32MaterializationCast); + typeConverter.addTargetMaterialization(ui64MaterializationCast); + typeConverter.addTargetMaterialization(vectorMaterializationCast); + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(&getContext()); + populateXeGPUToXeVMConversionPatterns(typeConverter, patterns); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, + patterns, target); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// +void mlir::populateXeGPUToXeVMConversionPatterns( + const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { + patterns.add, + LoadStorePrefetchNdToXeVMPattern, + LoadStorePrefetchNdToXeVMPattern>( + typeConverter, patterns.getContext()); + patterns.add, + LoadStoreToXeVMPattern>( + typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir new file mode 100644 index 0000000000000..4ff95b40fe68c --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +gpu.module @create_nd_tdesc { + // CHECK-LABEL: gpu.func @create_nd_tdesc + // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64, + // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index + gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, + %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index + // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32 + // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32 + // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32 + // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32 + // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32> + // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32> + %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2] + : ui64 -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32> + %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> + + // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index + // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32 + // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32 + // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 + // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32 + // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64 + // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32 + // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> + // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index + // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32 + // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32 + // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64 + // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32 + // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64 + // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32 + // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64 + // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32> + // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32> + %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir new file mode 100644 index 0000000000000..e6f22f0a9acbb --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +#sg_map_a_f16 = #xegpu.layout +#sg_map_b_f16 = #xegpu.layout +#sg_map_c_f32 = #xegpu.layout + +gpu.module @load_store_check { + // CHECK-LABEL: func.func @dpas( + // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32> + func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { + // Loads are checked in a separate test. + // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = , types = } + // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> + %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} + : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + return %d : vector<8xf32> + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/fence.mlir b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir new file mode 100644 index 0000000000000..cedfcace398a6 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +gpu.module @fence_check { + gpu.func @fence(%dst: memref<8x16xf32, 1>) kernel { + %tid_x = gpu.thread_id x + %tid_x_i32 = arith.index_cast %tid_x : index to i32 + %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 + + // CHECK: xevm.memfence <{addrspace = #xevm.addr_space, scope = #xevm.mem_scope}> + xegpu.fence memory_kind = global, fence_scope = workgroup + %c0 = arith.constant 0 : index + memref.store %tid_x_f32, %dst[%c0, %c0] : memref<8x16xf32, 1> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir new file mode 100644 index 0000000000000..4c6bbf25b4728 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +gpu.module @load_store_check { + gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> + %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> + + // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 + // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> + // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> + // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + + + //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> + //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> + //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> + //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> + //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64 + //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32 + //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64 + //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32 + //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1> + //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 + //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32 + //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]], + //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]] + //CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, + //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, + //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> + //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32> + + %tid_x = gpu.thread_id x + %tid_x_i32 = arith.index_cast %tid_x : index to i32 + %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 + //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32> + %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> + + // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 + // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> + // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> + // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32> + // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32> + // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32> + // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32> + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + + //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64> + //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64> + //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32> + //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32> + //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64 + //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32 + //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64 + //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32 + //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1> + //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 + //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32 + //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32> + //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]], + //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]] + //CHECK-SAME: <{cache_control = #xevm.store_cache_control, elem_size_in_bits = 32 : i32, + //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir new file mode 100644 index 0000000000000..0f67dc290689b --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -0,0 +1,261 @@ +// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s + +gpu.module @test { +// CHECK-LABEL: @load_gather_ui64_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: ui64 +gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 + // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64 + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) { + // CHECK: %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control} + // CHECK-SAME: : !llvm.ptr<1> -> vector<2xf32> + // CHECK: scf.yield %[[VAR9]] : vector<2xf32> + // CHECK: } else { + // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> + // CHECK: scf.yield %[[CST_1]] : vector<2xf32> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr>, vector<1xi1> -> vector<2xf32> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @load_gather_memref_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> +gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) { + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 + %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) { + // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control} + // CHECK-SAME: : !llvm.ptr<1> -> vector<1xf32> + // CHECK: %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32> + // CHECK: scf.yield %[[VAR9]] : f32 + // CHECK: } else { + // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> + // CHECK: %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32> + // CHECK: scf.yield %[[VAR8]] : f32 + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @load_gather_memref_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex> +gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 + %2 = xegpu.create_tdesc %src, %offset : memref<256xf16>, vector<1xindex> + -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (vector<8xf16>) { + // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control} + // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16> + // CHECK: scf.yield %[[VAR8]] : vector<8xf16> + // CHECK: } else { + // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16> + // CHECK: scf.yield %[[CST_0]] : vector<8xf16> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr>, vector<1xi1> -> vector<8xf16> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @store_scatter_ui64_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: ui64 +gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32> + %2 = arith.constant dense<2.9>: vector<2xf32> + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 + // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64 + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1> + // CHECK: scf.if %[[VAR4]] { + // CHECK: llvm.store %[[CST_1]], %[[VAR7]] {cache_control = #xevm.store_cache_control} + // CHECK-SAME: : vector<2xf32>, !llvm.ptr<1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<2xf32>, !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr>, vector<1xi1> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @store_scatter_memref_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> +gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) { + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16> + %2 = arith.constant dense<2.9>: vector<2xf16> + // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 + %3 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: scf.if %[[VAR2]] { + // CHECK: llvm.store %[[CST_1]], %[[VAR6]] {cache_control = #xevm.store_cache_control} + // CHECK-SAME: : vector<2xf16>, !llvm.ptr<1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<2xf16>, !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr>, vector<1xi1> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @store_scatter_memref_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex> +gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> + // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32> + %2 = arith.constant dense<2.9>: vector<1xf32> + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 + %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: scf.if %[[VAR2]] { + // CHECK: llvm.store %[[VAR7]], %[[VAR6]] {cache_control = #xevm.store_cache_control} + // CHECK-SAME: : f32, !llvm.ptr<1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @prefetch_ui64_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: ui64 +gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR1]], %[[VAR4]] : i64 + %1 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: xevm.prefetch %[[VAR6]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @prefetch_memref_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> +gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) { + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 + %1 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> + // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @prefetch_memref_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex> +gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 + %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> + // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + gpu.return +} +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir new file mode 100644 index 0000000000000..b28a8c2ccf843 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm --split-input-file %s | FileCheck %s + +// This file contains tests for materalization patterns added to handle custom type conversions +// added on top of LLVM type converter. + +gpu.module @materializecast { + // CHECK-LABEL: gpu.func @materialize_memref + // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> + gpu.func @materialize_memref(%src: memref<128xf32>) kernel { + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index + // CHECK: %[[CASTED:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + %offset = arith.constant dense<0> : vector<1xindex> + %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + gpu.return + } +} + +// ----- +gpu.module @materializecast { + // CHECK-LABEL: gpu.func @materialize_ui64 + // CHECK-SAME: %[[ARG0:.*]]: ui64 + gpu.func @materialize_ui64(%src: ui64) kernel { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + %offset = arith.constant dense<0> : vector<1xindex> + %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + gpu.return + } +} + +// ----- +gpu.module @materializecast { + // CHECK-LABEL: gpu.func @materialize_ui32 + // CHECK-SAME: %[[ARG0:.*]]: ui32 + gpu.func @materialize_ui32(%src: ui32) kernel { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui32 to index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i32 + %offset = arith.constant dense<0> : vector<1xindex> + %src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + gpu.return + } +} + +// ----- +gpu.module @materializecast { + // CHECK-LABEL: gpu.func @materialize_single_index_vector + // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> + gpu.func @materialize_single_index_vector(%src: memref<128xf32>) kernel { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[VAR1]] : index to i64 + %offset = arith.constant dense<0> : vector<1xindex> + %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + gpu.return + } +} + +// ----- +gpu.module @materializecast { + // CHECK-LABEL: gpu.func @materialize_single_elem_vector + // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> + gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel { + // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> + %mask = arith.constant dense<1>: vector<1xi1> + %offset = arith.constant dense<0> : vector<1xindex> + %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir new file mode 100644 index 0000000000000..873478aed57e3 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s + +gpu.module @fence_check { + gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> + %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> + + // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 + // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> + // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> + // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, + #xegpu.block_tdesc_attr, #xegpu.layout> + + //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> + //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> + //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> + //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> + //CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64 + //CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32 + //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64 + //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32 + //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1> + //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32 + //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32 + //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]], + //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]] + //CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, + //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> + //CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + xegpu.prefetch_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr, + #xegpu.layout> + + gpu.return + } +} + diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir new file mode 100644 index 0000000000000..6e59414c62582 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +gpu.module @update_offset { + // CHECK-LABEL: gpu.func @update_offset + // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> + gpu.func @update_offset(%src: memref<128xf32>) kernel { + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + %offset = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 + %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64 + // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64 + %new_tdesc = xegpu.update_offset %src_tdesc, %offset : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + , vector<1xindex> + gpu.return + } +}