diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 5695d5d515d7f..601e966b49890 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -716,8 +716,30 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { return getAttrs().getAs("stride"); } + ArrayAttr getBlockAttr() { + return getAttrs().getAs("block"); + } + }]; } +def RowOriented : I32EnumAttrCase<"ROW", 0, "row">; +def ColOriented : I32EnumAttrCase<"COL", 1, "col">; +def MatrixAccessDirection : + I32EnumAttr<"MatrixAccessDirection", + "Matrix elements/vectors can have row or column direction", [ + RowOriented, ColOriented +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::xegpu"; +} +def MatrixAccessDirectionAttr : + EnumAttr{ + let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}]; + let assemblyFormat = "`<` $value `>`"; +} + #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 73f9061f5debe..044a8ef22d891 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1298,14 +1298,16 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, } def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, - AllElementTypesMatch<["mem_desc", "res"]>, - AllRanksMatch<["mem_desc", "res"]>]> { + AllElementTypesMatch<["mem_desc", "res"]>]> { let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr:$vec_length, + OptionalAttr:$vec_direction, + OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); - let results = (outs XeGPU_ValueType:$res); + let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res); let assemblyFormat = [{ $mem_desc `` custom($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands) `->` type(results) @@ -1336,7 +1338,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } ArrayRef getDataShape() { - return getRes().getType().getShape(); + auto resTy = getRes().getType(); + if (auto vecTy = llvm::dyn_cast(resTy)) + return vecTy.getShape(); + return {}; } }]; @@ -1344,13 +1349,15 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - AllElementTypesMatch<["mem_desc", "data"]>, - AllRanksMatch<["mem_desc", "data"]>]> { + AllElementTypesMatch<["mem_desc", "data"]>]> { let arguments = (ins - XeGPU_ValueType:$data, + AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data, XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, + OptionalAttr:$vec_length, + OptionalAttr:$vec_direction, + OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) @@ -1378,7 +1385,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, } ArrayRef getDataShape() { - return getData().getType().getShape(); + auto DataTy = getData().getType(); + if (auto vecTy = llvm::dyn_cast(DataTy)) + return vecTy.getShape(); + return {}; } }]; diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 84902b2039643..c261fbb576642 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -237,7 +237,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); } - ArrayAttr getStrides() { + ArrayAttr getStridesAttr() { auto layout = getMemLayout(); if (layout && layout.hasAttr("stride")) { return layout.getStrides(); @@ -250,6 +250,54 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m Builder builder(getContext()); return builder.getI64ArrayAttr(defaultStrides); } + + /// Heuristic to determine if the MemDesc uses column-major layout, + /// based on the rank and the value of the first stride dimension. + bool isColMajor() { + auto dim0 = dyn_cast(getStridesAttr()[0]); + return getRank() == 2 && dim0 && dim0.getInt() == 1; + } + + // get the Blocking shape for a MemDescType, Which is represented + // as an attribute in MemDescType. By default it is the shape + // of the mdescTy + SmallVector getBlockSize() { + SmallVector size(getShape()); + MemLayoutAttr layout = getMemLayout(); + if (layout && layout.hasAttr("block")) { + ArrayAttr attr = layout.getBlockAttr(); + size.clear(); + llvm::for_each(attr, [&](Attribute elem) { + if (auto intElem = dyn_cast(elem)) + size.push_back(intElem.getInt()); + }); + } + return size; + } + + // Get strides as vector of integer. + // If it contains block attribute, the strides are blocked strides. + // + // The blocking is applied against the original matrix shape + // so that the linear offset is not impacted by the subview. + // + // It first computes the original matrix shape using the stride info, + // then computes the number of blocks in each dimension of original shape, + // then compute the outer block shape and stride, + // then combines the inner and outer block shape and stride + // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]> + // its memory layout tuple is ([2,32,16,8],[128,256,1,16]) + // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1] + // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) + SmallVector getStrides(); + + /// Generates instructions to compute the linearize offset + // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout + // the strides of memory descriptor is always considered regardless of blocked or not + Value getLinearOffsets(OpBuilder &builder, + Location loc, ArrayRef offsets); + + }]; let hasCustomAssemblyFormat = true; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt index 84b25809f1ed0..dd9edc43a1657 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM MLIRIndexDialect MLIRSCFDialect MLIRXeGPUDialect + MLIRXeGPUUtils MLIRPass MLIRTransforms MLIRSCFTransforms diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 9ead1d89069d6..05f26354e5a2a 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/FormatVariadic.h" @@ -61,6 +62,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { case xegpu::MemorySpace::SLM: return static_cast(xevm::AddrSpace::SHARED); } + llvm_unreachable("Unknown XeGPU memory space"); } // Get same bitwidth flat vector type of new element type. @@ -182,8 +184,9 @@ class CreateNdDescToXeVMPattern SmallVector mixedSizes = op.getMixedSizes(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); - if (rank != 2) + if (rank != 2) { return rewriter.notifyMatchFailure(op, "Expected 2D shape."); + } auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -503,6 +506,187 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { } }; +// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions +// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than +// 32 bits will be converted to 32 bits. +class CreateMemDescOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TypedValue src = op.getSource(); + auto resTy = cast(op.getResult().getType()); + + // Create the result MemRefType with the same shape, element type, and + // memory space + auto newResTy = getTypeConverter()->convertType(resTy); + + Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); + auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, + Value(src), zero, ValueRange()); + rewriter.replaceOp(op, viewOp); + return success(); + } +}; + +class MemDescSubviewOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriter.notifyMatchFailure( + op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture."); + } +}; + +template ::value>> +class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector offsets = op.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); + + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + Value basePtrStruct = adaptor.getMemDesc(); + Value mdescVal = op.getMemDesc(); + // Load result or Store value Type can be vector or scalar. + Value data; + if constexpr (std::is_same_v) + data = op.getResult(); + else + data = adaptor.getData(); + VectorType valOrResVecTy = dyn_cast(data.getType()); + if (!valOrResVecTy) + valOrResVecTy = VectorType::get(1, data.getType()); + + int64_t elemBitWidth = + valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + // Element type must be multiple of 8 bits. + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + int64_t elemByteSize = elemBitWidth / 8; + + // Default memory space is SLM. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); + + auto mdescTy = cast(mdescVal.getType()); + + Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, basePtrStruct); + + // Convert base pointer (ptr) to i64 + Value basePtrI64 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI64Type(), basePtrLLVM); + + Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); + linearOffset = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI64Type(), linearOffset); + basePtrI64 = + addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize); + + // convert base pointer (i64) to LLVM pointer type + basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + + // if the size of valOrResVecTy is 1, it lowers to a scalar load/store + // operation. LLVM load/store does not support vector of size 1, so we need + // to handle this case separately. + if (valOrResVecTy.getNumElements() == 1) { + Type scalarTy = valOrResVecTy.getElementType(); + if constexpr (std::is_same_v) { + Value loadOp = + LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); + rewriter.eraseOp(op); + } + return success(); + } else { + // if the attribute 'subgroup_block_io' is set to true, it lowers to + // xevm.blockload + auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr(); + bool subgroup_block_io = static_cast(subgroupBlockIoAttr); + + // BlockLoadOp only supports integer types, so we need to bitcast + // Get integer type with matching bit width + Type elemTy = valOrResVecTy.getElementType(); + int64_t bitWidth = elemTy.getIntOrFloatBitWidth(); + Type intElemTy = rewriter.getIntegerType(bitWidth); + VectorType intVecTy = + VectorType::get(valOrResVecTy.getShape(), intElemTy); + + if (subgroup_block_io) { + if constexpr (std::is_same_v) { + Value loadOp = + xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); + if (intVecTy != valOrResVecTy) { + loadOp = + vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); + } + rewriter.replaceOp(op, loadOp); + } else { + Value dataToStore = adaptor.getData(); + if (valOrResVecTy != intVecTy) { + dataToStore = + vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); + } + xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, + nullptr); + rewriter.eraseOp(op); + } + } else { + // if the result is 1D vector, if the vector direction is Column, then + // the + // memory descriptor should be treated as column major + auto chipOpt = xegpu::getChipStr(op); + if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { + // the lowering only works for pvc and bmg + return rewriter.notifyMatchFailure( + op, "The lowering is specific to pvc or bmg."); + } + xegpu::MatrixAccessDirectionAttr vecDirection = + op.getVecDirectionAttr(); + if (vecDirection && + vecDirection.getValue() == xegpu::MatrixAccessDirection::COL && + !mdescTy.isColMajor()) + return rewriter.notifyMatchFailure( + op, "mem_desc should be column major when " + "vec_direction is COLUMN for 1D result."); + if (vecDirection && + vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW && + mdescTy.isColMajor()) + return rewriter.notifyMatchFailure( + op, "mem_desc should be row major when " + "vec_direction is ROW for 1D result."); + + if constexpr (std::is_same_v) { + Value loadOp = + LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); + rewriter.replaceOp(op, loadOp); + } else { + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); + rewriter.eraseOp(op); + } + } + } + return success(); + } +}; + class PrefetchToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -785,6 +969,13 @@ struct ConvertXeGPUToXeVMPass auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); + // Convert MemDescType into flattened MemRefType for SLM + typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { + Type elemTy = type.getElementType(); + int numElems = type.getNumElements(); + return MemRefType::get(numElems, elemTy, AffineMap(), 3); + }); + typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. return IntegerType::get(&getContext(), 64); @@ -919,6 +1110,10 @@ void mlir::populateXeGPUToXeVMConversionPatterns( LoadStoreToXeVMPattern, LoadStoreToXeVMPattern>( typeConverter, patterns.getContext()); + patterns.add, + LoadStoreMatrixToXeVMPattern, + CreateMemDescOpPattern, MemDescSubviewOpPattern>( + typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 94c5509fd7c29..cccc8fab4adbc 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -726,6 +726,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { } printer << ">"; } +// a helper utility to perform binary operation on OpFoldResult. +// If both a and b are attributes, it will simply return the result. +// Otherwise, the corresponding arith op will be generated, and an +// contant op will be created if one of them is an attribute. +template +OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, + OpBuilder &builder) { + auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); + auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); + return builder.create(loc, aVal, bVal).getResult(); +} + +// a helper utility to perform division operation on OpFoldResult and int64_t. +#define div(a, b) \ + genBinOp(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform reminder operation on OpFoldResult and int64_t. +#define rem(a, b) \ + genBinOp(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform multiply operation on OpFoldResult and int64_t. +#define mul(a, b) \ + genBinOp(a, builder.getIndexAttr(b), loc, builder) + +// a helper utility to perform addition operation on two OpFoldResult. +#define add(a, b) genBinOp(a, b, loc, builder) + +// block the given offsets according to the block shape +// say the original offset is [y, x], and the block shape is [By, Bx], +// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] +SmallVector getBlockedOffsets(OpBuilder &builder, Location loc, + ArrayRef offsets, + ArrayRef blockShape) { + + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + SmallVector blockedOffsets; + SmallVector divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + return blockedOffsets; +} + +// Get strides as vector of integer for MemDesc. +SmallVector MemDescType::getStrides() { + + SmallVector matrixShape(getShape().begin(), getShape().end()); + + ArrayAttr strideAttr = getStridesAttr(); + SmallVector strides; + for (Attribute attr : strideAttr.getValue()) { + strides.push_back(cast(attr).getInt()); + } + + SmallVector innerBlkShape = getBlockSize(); + + // get perm from FCD to LCD + // perm[i] = the dim with i-th smallest stride + SmallVector perm = + llvm::to_vector<4>(llvm::seq(0, strides.size())); + llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); + + assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); + + SmallVector innerBlkStride(innerBlkShape.size()); + innerBlkStride[perm[0]] = 1; + for (size_t i = 1; i < perm.size(); ++i) + innerBlkStride[perm[i]] = + innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; + + // compute the original matrix shape using the stride info + // and compute the number of blocks in each dimension + // The shape of highest dim can't be derived from stride info, + // but doesn't impact the stride computation for blocked layout. + SmallVector matrixShapeOrig(matrixShape.size()); + SmallVector BlkShapeOrig(matrixShape.size()); + for (size_t i = 0; i < perm.size() - 1; ++i) { + matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; + BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; + } + + int64_t innerBlkSize = 1; + for (auto s : innerBlkShape) + innerBlkSize *= s; + + SmallVector outerBlkStride(matrixShape.size()); + outerBlkStride[perm[0]] = innerBlkSize; + for (size_t i = 0; i < perm.size() - 1; ++i) { + outerBlkStride[perm[i + 1]] = + outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; + } + + // combine the inner and outer strides + SmallVector blockedStrides; + blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); + blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); + + return blockedStrides; +} + +// Calculate the linear offset using the blocked offsets and stride +Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, + ArrayRef offsets) { + + SmallVector matrixShape(getShape().begin(), getShape().end()); + SmallVector blockShape = getBlockSize(); + SmallVector strides = getStrides(); + + // blockshape equal to matrixshape means no blocking + if (llvm::equal(blockShape, matrixShape)) { + // remove the outer dims from strides + strides.erase(strides.begin(), strides.begin() + matrixShape.size()); + } else { + assert(offsets.size() == blockShape.size() && + "offsets and blockShape must have the same size"); + // say the original offset is [y, x], and the block shape is [By, Bx], + // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] + SmallVector blockedOffsets; + SmallVector divs, rems; + + for (auto [offset, block] : llvm::zip(offsets, blockShape)) { + divs.push_back(div(offset, block)); + rems.push_back(rem(offset, block)); + } + blockedOffsets.append(divs.begin(), divs.end()); + blockedOffsets.append(rems.begin(), rems.end()); + + offsets = blockedOffsets; + } + + // Start with initial value as matrix descriptor's base offset. + Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); + for (size_t i = 0; i < offsets.size(); ++i) { + OpFoldResult mulResult = mul(offsets[i], strides[i]); + Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); + linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); + } + + return linearOffset; +} } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788d0b9b4..0bc7b3f06ec53 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -173,6 +173,51 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, return success(); } +LogicalResult IsValidStoreMatrixParams( + VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, + MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength, + function_ref emitError) { + + if (!dataTy) + if (subgroup_block_io || vecDirection || vecLength) + return emitError() << "vec_length, vec_direction and subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + else + return success(); + + if (mdescTy.getRank() != 2) + return emitError() << "mem_desc must be 2D."; + + ArrayRef dataShape = dataTy.getShape(); + ArrayRef mdescShape = mdescTy.getShape(); + + if (dataShape.size() == 2) { + if (subgroup_block_io || vecDirection || vecLength) + return emitError() << "vec_length, vec_direction and subgroup_block_io " + "are only allowed when result is a 1D VectorType."; + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitError() << "data shape must not exceed mem_desc shape."; + } else if (dataShape.size() == 1) { + + SmallVector blockSize = mdescTy.getBlockSize(); + // if the subgroup_block_io attribute is set, mdescTy must have block + // attribute + if (subgroup_block_io && !blockSize.size()) + return emitError() << "mem_desc must have block attribute when " + "subgroup_block_io is set."; + // if the subgroup_block_io attribute is set, the memdesc should be row + // major + if (subgroup_block_io && mdescTy.isColMajor()) + return emitError() << "mem_desc should be row major when " + "subgroup_block_io is set."; + } else if (dataShape.size() == 0) { + return emitError() << "result shape must not be empty."; + } + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -1049,23 +1094,24 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + // Call the generated builder with all parameters (including optional ones as + // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*vec_length=*/nullptr, /*vec_direction=*/nullptr, + /*subgroup_block_io=*/nullptr, layout); } LogicalResult LoadMatrixOp::verify() { - VectorType resTy = getRes().getType(); - MemDescType mdescTy = getMemDesc().getType(); - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); + auto resTy = dyn_cast(getRes().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr(); + IntegerAttr vecLength = getVecLengthAttr(); + MemDescType mdescTy = getMemDesc().getType(); - ArrayRef valueShape = resTy.getShape(); - ArrayRef mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("result shape must not exceed mem_desc shape."); - return success(); + return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io, + vecDirection, vecLength, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1080,23 +1126,20 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, - layout); + /*vec_length=*/nullptr, /*vec_direction=*/nullptr, + /*subgroup_block_io=*/nullptr, layout); } LogicalResult StoreMatrixOp::verify() { - VectorType dataTy = getData().getType(); - MemDescType mdescTy = getMemDesc().getType(); - if (mdescTy.getRank() != 2) - return emitOpError("mem_desc must be 2D."); - - ArrayRef dataShape = dataTy.getShape(); - ArrayRef mdescShape = mdescTy.getShape(); - if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitOpError("data shape must not exceed mem_desc shape."); - - return success(); + auto dataTy = dyn_cast(getData().getType()); + UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr(); + IntegerAttr vecLength = getVecLengthAttr(); + MemDescType mdescTy = getMemDesc().getType(); + return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io, + vecDirection, vecLength, + [&]() { return emitError(); }); } //===----------------------------------------------------------------------===// @@ -1127,7 +1170,7 @@ LogicalResult MemDescSubviewOp::verify() { [](auto p) { return std::get<0>(p) > std::get<1>(p); })) return emitOpError("result shape must not exceed source shape."); - if (srcTy.getStrides() != resTy.getStrides()) + if (srcTy.getStridesAttr() != resTy.getStridesAttr()) return emitOpError("result must inherit the source strides."); return success(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a178d0fe4b0b0..6d17b27849a43 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -941,7 +941,7 @@ struct UnrollLoadMatrixOp : public UnrollPattern { LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - VectorType valueTy = op.getType(); + VectorType valueTy = llvm::dyn_cast(op.getType()); std::optional> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); @@ -984,7 +984,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern { return failure(); Location loc = op.getLoc(); - VectorType valueTy = op.getData().getType(); + VectorType valueTy = llvm::dyn_cast(op.getData().getType()); ArrayRef shape = valueTy.getShape(); auto layout = dyn_cast(op.getLayoutAttr()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index c28d2fc6c2b63..baee57c512ddf 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -991,7 +991,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern { return failure(); ArrayRef wgShape = op.getDataShape(); - VectorType valueTy = op.getRes().getType(); + VectorType valueTy = llvm::dyn_cast(op.getRes().getType()); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir index e6f22f0a9acbb..a9ab0be00722c 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -1,17 +1,13 @@ // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s -#sg_map_a_f16 = #xegpu.layout -#sg_map_b_f16 = #xegpu.layout -#sg_map_c_f32 = #xegpu.layout - -gpu.module @load_store_check { +gpu.module @test_kernel { // CHECK-LABEL: func.func @dpas( // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32> func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { // Loads are checked in a separate test. // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = , types = } // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} + %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> return %d : vector<8xf32> } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir new file mode 100644 index 0000000000000..6302758195e51 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir @@ -0,0 +1,201 @@ +// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s + +gpu.module @test_kernel [#xevm.target] { + + // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> + // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) + //CHECK-LABEL: load_store_matrix_1 + gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> + + //CHECK: %[[TID:.*]] = gpu.thread_id x + //CHECK: %[[C1:.*]] = arith.constant 1 : index + //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index + //CHECK: %[[C4:.*]] = arith.constant 4 : i64 + //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64 + //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 + + %tid_x = gpu.thread_id x + %c0 = arith.constant 0 : index + %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 + + //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index + + gpu.return %1: f32 + } + +// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) + //CHECK-LABEL: load_store_matrix_2 + gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c13:.*]] = arith.constant 13 : index + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 + + + %tid_x = gpu.thread_id x + %c13 = arith.constant 13 : index + %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + + xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + gpu.return %1: f16 + } + + + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) + //CHECK-LABEL: load_store_matrix_3 + gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + //CHECK: %[[c19:.*]] = arith.constant 19 : index + %tid_x = gpu.thread_id x + %c19 = arith.constant 19: index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64 + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 + %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 + + //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + + //CHECK: gpu.return %[[loaded]] : f16 + gpu.return %1: f16 + } + + // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> + // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) + //CHECK-LABEL: load_store_matrix_4 + gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[tid_x:.*]] = gpu.thread_id x + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index + //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index + + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c512:.*]] = arith.constant 512 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index + //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index + + //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> + + %tid_x = gpu.thread_id x + %c16 = arith.constant 16 : index + %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> + + //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3> + xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + + gpu.return %1: vector<8xf16> + } + + + // e.g. for mem_desc<32x64xf16, @block=[16, 16]> + // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) + //CHECK-LABEL: load_store_matrix_5 + gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { + //CHECK: %[[c0:.*]] = arith.constant 0 : index + //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> + + %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> + + //CHECK: %[[c16:.*]] = arith.constant 16 : index + //CHECK: %[[c48:.*]] = arith.constant 48 : index + + %c16 = arith.constant 16 : index + %c48 = arith.constant 48 : index + + //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index + //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64 + //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index + //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index + //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index + //CHECK: %[[c1024:.*]] = arith.constant 1024 : index + //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index + //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index + //CHECK: %[[c256:.*]] = arith.constant 256 : index + //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index + //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index + //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index + //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index + //CHECK: %[[c1:.*]] = arith.constant 1 : index + //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index + //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index + //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64 + //CHECK: %[[c2:.*]] = arith.constant 2 : i64 + //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64 + //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64 + //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3> + //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> + //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> + + %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> + + //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16> + //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) + + xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index + + gpu.return %1: vector<8xf16> + } + +} \ No newline at end of file diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 228ef69d9a478..fee3136195e1d 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16> // ----- func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{result shape must not exceed mem_desc shape}} + // expected-error@+1 {{data shape must not exceed mem_desc shape}} %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16> return } @@ -870,6 +870,21 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { return } +// ----- +func.func @load_mem_desc_invalid_attr1(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + %data1 = xegpu.load_matrix %arg0[8, 8]<{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> + return +} + +// ----- +func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> + return +} + + // ----- func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) { // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}} @@ -891,6 +906,20 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve return } +// ----- +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> + return +} + +// ----- +func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { + // expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}} + xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> + return +} + // ----- func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { // expected-error@+1 {{result shape must not exceed source shape}} diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index bb379024a34d7..eb5d653be8b9c 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -825,35 +825,76 @@ gpu.func @create_mem_desc_with_stride() { gpu.return } -// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) { +// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> gpu.return } -// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) -gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { +// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> gpu.return } +// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) +gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16> + gpu.return +} + +// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + gpu.return +} -// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + %data = xegpu.load_matrix %arg0[8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> + gpu.return +} + +// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> gpu.return } -// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> gpu.return } +// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { +gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16> + xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16> + gpu.return +} + +// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) +gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + gpu.return +} + +// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { +gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction, vec_length = 8 : i32}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[8, 8] <{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + gpu.return +} + // CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) { //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout>