diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 19a52317956d2..5695d5d515d7f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -712,14 +712,10 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { return getAttrs().contains(name); } - ArrayAttr getStrideAttr() { + ArrayAttr getStrides() { return getAttrs().getAs("stride"); } - ArrayAttr getBlockAttr() { - return getAttrs().getAs("block"); - } - }]; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 426377fcf598f..73f9061f5debe 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, } def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, - AllElementTypesMatch<["mem_desc", "res"]>]> { + AllElementTypesMatch<["mem_desc", "res"]>, + AllRanksMatch<["mem_desc", "res"]>]> { let arguments = (ins XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); - let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res); + let results = (outs XeGPU_ValueType:$res); let assemblyFormat = [{ $mem_desc `` custom($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands) `->` type(results) @@ -1319,9 +1319,6 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, Arguments: - `mem_desc`: the memory descriptor identifying the SLM region. - `offsets`: the coordinates within the matrix to read from. - - `subgroup_block_io`: [optional] An attribute indicating that the operation can be - lowered to a subgroup block load. When this attribute is present, - the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1339,10 +1336,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } ArrayRef getDataShape() { - auto resTy = getRes().getType(); - if (auto vecTy = llvm::dyn_cast(resTy)) - return vecTy.getShape(); - return {}; + return getRes().getType().getShape(); } }]; @@ -1350,13 +1344,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, } def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - AllElementTypesMatch<["mem_desc", "data"]>]> { + AllElementTypesMatch<["mem_desc", "data"]>, + AllRanksMatch<["mem_desc", "data"]>]> { let arguments = (ins - AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data, + XeGPU_ValueType:$data, XeGPU_MemDesc:$mem_desc, Variadic: $offsets, DenseI64ArrayAttr: $const_offsets, - OptionalAttr:$subgroup_block_io, OptionalAttr:$layout ); let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) @@ -1370,9 +1364,6 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, - `mem_desc`: the memory descriptor specifying the SLM region. - `offsets`: the coordinates within the matrix where the data will be written. - `data`: the values to be stored in the matrix. - - `subgroup_block_io`: [optional] An attribute indicating that the operation can be - lowered to a subgroup block store. When this attribute is present, - the offsets are subgroup-uniform across all lanes. - `layout`: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. @@ -1387,10 +1378,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, } ArrayRef getDataShape() { - auto DataTy = getData().getType(); - if (auto vecTy = llvm::dyn_cast(DataTy)) - return vecTy.getShape(); - return {}; + return getData().getType().getShape(); } }]; @@ -1398,4 +1386,41 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, let hasVerifier = 1; } +def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview", + [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> { + let description = [{ + Creates a subview of a memory descriptor. The resulting memory descriptor can have + a lower rank than the source; in this case, the result dimensions correspond to the + higher-order dimensions of the source memory descriptor. + + Arguments: + - `src` : a memory descriptor. + - `offsets` : the coordinates within the matrix the subview will be created from. + + Results: + - `res` : a memory descriptor with smaller size. + + }]; + let arguments = (ins XeGPU_MemDesc:$src, + Variadic:$offsets, + DenseI64ArrayAttr:$const_offsets); + let results = (outs XeGPU_MemDesc:$res); + let assemblyFormat = [{$src `` custom($offsets, $const_offsets) prop-dict + attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}]; + let builders = [ + OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef": $offsets)> + ]; + + let extraClassDeclaration = [{ + mlir::Value getViewSource() { return getSrc(); } + + SmallVector getMixedOffsets() { + return getMixedValues(getConstOffsets(), getOffsets(), getContext()); + } + }]; + + let hasVerifier = 1; +} + + #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index b1196fbe9c66a..84902b2039643 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -237,11 +237,12 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); } - ArrayAttr getStrideAttr() { + ArrayAttr getStrides() { auto layout = getMemLayout(); if (layout && layout.hasAttr("stride")) { - return layout.getStrideAttr(); + return layout.getStrides(); } + // derive and return default strides SmallVector defaultStrides; llvm::append_range(defaultStrides, getShape().drop_front()); @@ -249,63 +250,6 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m Builder builder(getContext()); return builder.getI64ArrayAttr(defaultStrides); } - - ArrayAttr getBlockAttr() { - auto layout = getMemLayout(); - if (layout && layout.hasAttr("block")) { - return layout.getBlockAttr(); - } - Builder builder(getContext()); - return builder.getI64ArrayAttr({}); - } - - /// Heuristic to determine if the MemDesc uses column-major layout, - /// based on the rank and the value of the first stride dimension. - bool isColMajor() { - auto dim0 = dyn_cast(getStrideAttr()[0]); - return getRank() == 2 && dim0.getInt() == 1; - } - - // Get the Blocking shape for a MemDescType, Which is represented - // as an attribute in MemDescType. By default it is the shape - // of the mdescTy - SmallVector getBlockShape() { - SmallVector size(getShape()); - ArrayAttr blockAttr = getBlockAttr(); - if (!blockAttr.empty()) { - size.clear(); - for (auto attr : blockAttr.getValue()) { - size.push_back(cast(attr).getInt()); - } - } - return size; - } - - // Get strides as vector of integer. - // If it contains block attribute, the strides are blocked strides. - // - // The blocking is applied to the base matrix shape derived from the - // memory descriptor's stride information. If the matrix described by - // the memory descriptor is not contiguous, it is assumed that the base - // matrix is contiguous and follows the same memory layout. - // - // It first computes the original matrix shape using the stride info, - // then computes the number of blocks in each dimension of original shape, - // then compute the outer block shape and stride, - // then combines the inner and outer block shape and stride - // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>` - // its memory layout tuple is ([2,32,16,8],[128,256,1,16]) - // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1] - // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) - SmallVector getStrideShape(); - - /// Generates instructions to compute the linearize offset - // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout - // the strides of memory descriptor is always considered regardless of blocked or not - Value getLinearOffsets(OpBuilder &builder, - Location loc, ArrayRef offsets); - - }]; let hasCustomAssemblyFormat = true; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt index dd9edc43a1657..84b25809f1ed0 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -21,7 +21,6 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM MLIRIndexDialect MLIRSCFDialect MLIRXeGPUDialect - MLIRXeGPUUtils MLIRPass MLIRTransforms MLIRSCFTransforms diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index fcbf66dbe9e45..ddcbc44f2652a 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" @@ -64,7 +63,6 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { case xegpu::MemorySpace::SLM: return static_cast(xevm::AddrSpace::SHARED); } - llvm_unreachable("Unknown XeGPU memory space"); } // Get same bitwidth flat vector type of new element type. @@ -188,7 +186,6 @@ class CreateNdDescToXeVMPattern int64_t rank = mixedSizes.size(); if (rank != 2) return rewriter.notifyMatchFailure(op, "Expected 2D shape."); - auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -367,11 +364,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { // Add a builder that creates // offset * elemByteSize + baseAddr -static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter, - Location loc, Value baseAddr, Value offset, - int64_t elemByteSize) { +static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, + Value baseAddr, Value offset, int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( - rewriter, loc, baseAddr.getType(), elemByteSize); + rewriter, loc, rewriter.getI64Type(), elemByteSize); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; @@ -447,8 +443,7 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { // If offset is provided, we add them to the base pointer. // Offset is in number of elements, we need to multiply by // element byte size. - basePtrI64 = - addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize); + basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize); } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -511,147 +506,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { } }; -// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions -// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than -// 32 bits will be converted to 32 bits. -class CreateMemDescOpPattern final - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - auto resTy = op.getMemDesc(); - - // Create the result MemRefType with the same shape, element type, and - // memory space - auto newResTy = getTypeConverter()->convertType(resTy); - - Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); - auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, - op.getSource(), zero, ValueRange()); - rewriter.replaceOp(op, viewOp); - return success(); - } -}; - -template ::value>> -class LoadStoreMatrixToXeVMPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - SmallVector offsets = op.getMixedOffsets(); - if (offsets.empty()) - return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); - - auto loc = op.getLoc(); - auto ctxt = rewriter.getContext(); - Value basePtrStruct = adaptor.getMemDesc(); - Value mdescVal = op.getMemDesc(); - // Load result or Store value Type can be vector or scalar. - Value data; - if constexpr (std::is_same_v) - data = op.getResult(); - else - data = adaptor.getData(); - VectorType valOrResVecTy = dyn_cast(data.getType()); - if (!valOrResVecTy) - valOrResVecTy = VectorType::get(1, data.getType()); - - int64_t elemBitWidth = - valOrResVecTy.getElementType().getIntOrFloatBitWidth(); - // Element type must be multiple of 8 bits. - if (elemBitWidth % 8 != 0) - return rewriter.notifyMatchFailure( - op, "Expected element type bit width to be multiple of 8."); - int64_t elemByteSize = elemBitWidth / 8; - - // Default memory space is SLM. - LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( - ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); - - auto mdescTy = cast(mdescVal.getType()); - - Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, loc, basePtrStruct); - - // Convert base pointer (ptr) to i32 - Value basePtrI32 = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI32Type(), basePtrLLVM); - - Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); - linearOffset = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI32Type(), linearOffset); - basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, - elemByteSize); - - // convert base pointer (i32) to LLVM pointer type - basePtrLLVM = - LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); - - if (op.getSubgroupBlockIoAttr()) { - // if the attribute 'subgroup_block_io' is set to true, it lowers to - // xevm.blockload - - Type intElemTy = rewriter.getIntegerType(elemBitWidth); - VectorType intVecTy = - VectorType::get(valOrResVecTy.getShape(), intElemTy); - - if constexpr (std::is_same_v) { - Value loadOp = - xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM); - if (intVecTy != valOrResVecTy) { - loadOp = - vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp); - } - rewriter.replaceOp(op, loadOp); - } else { - Value dataToStore = adaptor.getData(); - if (valOrResVecTy != intVecTy) { - dataToStore = - vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore); - } - xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore, - nullptr); - rewriter.eraseOp(op); - } - return success(); - } - - if (valOrResVecTy.getNumElements() >= 1) { - auto chipOpt = xegpu::getChipStr(op); - if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) { - // the lowering for chunk load only works for pvc and bmg - return rewriter.notifyMatchFailure( - op, "The lowering is specific to pvc or bmg."); - } - } - - if constexpr (std::is_same_v) { - // if the size of valOrResVecTy is 1, it lowers to a scalar load/store - // operation. LLVM load/store does not support vector of size 1, so we - // need to handle this case separately. - auto scalarTy = valOrResVecTy.getElementType(); - LLVM::LoadOp loadOp; - if (valOrResVecTy.getNumElements() == 1) - loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM); - else - loadOp = - LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); - rewriter.replaceOp(op, loadOp); - } else { - LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); - rewriter.eraseOp(op); - } - return success(); - } -}; - class PrefetchToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -694,8 +548,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern { op, "Expected element type bit width to be multiple of 8."); elemByteSize = elemBitWidth / 8; } - basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets, - elemByteSize); + basePtrI64 = + addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); } } // Default memory space is global. @@ -932,13 +786,6 @@ struct ConvertXeGPUToXeVMPass auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); - // Convert MemDescType into flattened MemRefType for SLM - typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { - Type elemTy = type.getElementType(); - int numElems = type.getNumElements(); - return MemRefType::get(numElems, elemTy, AffineMap(), 3); - }); - typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. return IntegerType::get(&getContext(), 64); @@ -1093,9 +940,6 @@ void mlir::populateXeGPUToXeVMConversionPatterns( LoadStoreToXeVMPattern, LoadStoreToXeVMPattern>( typeConverter, patterns.getContext()); - patterns.add, - LoadStoreMatrixToXeVMPattern, - CreateMemDescOpPattern>(typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 1cfae28f31188..9beb22d517473 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -727,152 +727,6 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { } printer << ">"; } -// a helper utility to perform binary operation on OpFoldResult. -// If both a and b are attributes, it will simply return the result. -// Otherwise, the corresponding arith op will be generated, and an -// contant op will be created if one of them is an attribute. -template -OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, - OpBuilder &builder) { - auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); - auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); - return builder.create(loc, aVal, bVal).getResult(); -} - -// a helper utility to perform division operation on OpFoldResult and int64_t. -#define div(a, b) \ - genBinOp(a, builder.getIndexAttr(b), loc, builder) - -// a helper utility to perform reminder operation on OpFoldResult and int64_t. -#define rem(a, b) \ - genBinOp(a, builder.getIndexAttr(b), loc, builder) - -// a helper utility to perform multiply operation on OpFoldResult and int64_t. -#define mul(a, b) \ - genBinOp(a, builder.getIndexAttr(b), loc, builder) - -// a helper utility to perform addition operation on two OpFoldResult. -#define add(a, b) genBinOp(a, b, loc, builder) - -// block the given offsets according to the block shape -// say the original offset is [y, x], and the block shape is [By, Bx], -// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] -SmallVector getBlockedOffsets(OpBuilder &builder, Location loc, - ArrayRef offsets, - ArrayRef blockShape) { - - assert(offsets.size() == blockShape.size() && - "offsets and blockShape must have the same size"); - SmallVector blockedOffsets; - SmallVector divs, rems; - - for (auto [offset, block] : llvm::zip(offsets, blockShape)) { - divs.push_back(div(offset, block)); - rems.push_back(rem(offset, block)); - } - blockedOffsets.append(divs.begin(), divs.end()); - blockedOffsets.append(rems.begin(), rems.end()); - - return blockedOffsets; -} - -// Get strides as vector of integer for MemDesc. -SmallVector MemDescType::getStrideShape() { - - SmallVector matrixShape(getShape().begin(), getShape().end()); - - ArrayAttr strideAttr = getStrideAttr(); - SmallVector strides; - for (Attribute attr : strideAttr.getValue()) { - strides.push_back(cast(attr).getInt()); - } - - SmallVector innerBlkShape = getBlockShape(); - - // get perm from FCD to LCD - // perm[i] = the dim with i-th smallest stride - SmallVector perm = - llvm::to_vector<4>(llvm::seq(0, strides.size())); - llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); - - assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); - - SmallVector innerBlkStride(innerBlkShape.size()); - innerBlkStride[perm[0]] = 1; - for (size_t i = 1; i < perm.size(); ++i) - innerBlkStride[perm[i]] = - innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; - - // compute the original matrix shape using the stride info - // and compute the number of blocks in each dimension - // The shape of highest dim can't be derived from stride info, - // but doesn't impact the stride computation for blocked layout. - SmallVector matrixShapeOrig(matrixShape.size()); - SmallVector BlkShapeOrig(matrixShape.size()); - for (size_t i = 0; i < perm.size() - 1; ++i) { - matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; - BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; - } - - int64_t innerBlkSize = 1; - for (auto s : innerBlkShape) - innerBlkSize *= s; - - SmallVector outerBlkStride(matrixShape.size()); - outerBlkStride[perm[0]] = innerBlkSize; - for (size_t i = 0; i < perm.size() - 1; ++i) { - outerBlkStride[perm[i + 1]] = - outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; - } - - // combine the inner and outer strides - SmallVector blockedStrides; - blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); - blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); - - return blockedStrides; -} - -// Calculate the linear offset using the blocked offsets and stride -Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, - ArrayRef offsets) { - - SmallVector matrixShape(getShape().begin(), getShape().end()); - SmallVector blockShape = getBlockShape(); - SmallVector strides = getStrideShape(); - - // blockshape equal to matrixshape means no blocking - if (llvm::equal(blockShape, matrixShape)) { - // remove the outer dims from strides - strides.erase(strides.begin(), strides.begin() + matrixShape.size()); - } else { - assert(offsets.size() == blockShape.size() && - "offsets and blockShape must have the same size"); - // say the original offset is [y, x], and the block shape is [By, Bx], - // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] - SmallVector blockedOffsets; - SmallVector divs, rems; - - for (auto [offset, block] : llvm::zip(offsets, blockShape)) { - divs.push_back(div(offset, block)); - rems.push_back(rem(offset, block)); - } - blockedOffsets.append(divs.begin(), divs.end()); - blockedOffsets.append(rems.begin(), rems.end()); - - offsets = blockedOffsets; - } - - // Start with initial value as matrix descriptor's base offset. - Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); - for (size_t i = 0; i < offsets.size(); ++i) { - OpFoldResult mulResult = mul(offsets[i], strides[i]); - Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); - linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); - } - - return linearOffset; -} } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index abd12e2e69ac0..e0a8ac40648e0 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -173,49 +173,6 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, return success(); } -LogicalResult -IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, - UnitAttr subgroup_block_io, - function_ref emitError) { - - if (!dataTy) { - if (subgroup_block_io) - return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; - else - return success(); - } - - if (mdescTy.getRank() != 2) - return emitError() << "mem_desc must be 2D."; - - ArrayRef dataShape = dataTy.getShape(); - ArrayRef mdescShape = mdescTy.getShape(); - - if (dataShape.size() == 2) { - if (subgroup_block_io) - return emitError() << "subgroup_block_io " - "are only allowed when result is a 1D VectorType."; - if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), - [](auto p) { return std::get<0>(p) > std::get<1>(p); })) - return emitError() << "data shape must not exceed mem_desc shape."; - } else { - SmallVector blockShape = mdescTy.getBlockShape(); - // if the subgroup_block_io attribute is set, mdescTy must have block - // attribute - if (subgroup_block_io && !blockShape.size()) - return emitError() << "mem_desc must have block attribute when " - "subgroup_block_io is set."; - // if the subgroup_block_io attribute is set, the memdesc should be row - // major - if (subgroup_block_io && mdescTy.isColMajor()) - return emitError() << "mem_desc should be row major when " - "subgroup_block_io is set."; - } - - return success(); -} - //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -1092,20 +1049,23 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, llvm::SmallVector staticOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); - // Call the generated builder with all parameters (including optional ones as - // nullptr/empty) build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, - /*subgroup_block_io=*/nullptr, layout); + layout); } LogicalResult LoadMatrixOp::verify() { - - auto resTy = dyn_cast(getRes().getType()); - UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + VectorType resTy = getRes().getType(); MemDescType mdescTy = getMemDesc().getType(); - return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + if (mdescTy.getRank() != 2) + return emitOpError("mem_desc must be 2D."); + + ArrayRef valueShape = resTy.getShape(); + ArrayRef mdescShape = mdescTy.getShape(); + if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed mem_desc shape."); + return success(); } //===----------------------------------------------------------------------===// @@ -1120,16 +1080,57 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, - /*subgroup_block_io=*/nullptr, layout); + layout); } LogicalResult StoreMatrixOp::verify() { - - auto dataTy = dyn_cast(getData().getType()); - UnitAttr subgroup_block_io = getSubgroupBlockIoAttr(); + VectorType dataTy = getData().getType(); MemDescType mdescTy = getMemDesc().getType(); - return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, - [&]() { return emitError(); }); + + if (mdescTy.getRank() != 2) + return emitOpError("mem_desc must be 2D."); + + ArrayRef dataShape = dataTy.getShape(); + ArrayRef mdescShape = mdescTy.getShape(); + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("data shape must not exceed mem_desc shape."); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_MemDescSubviewOp +//===----------------------------------------------------------------------===// + +void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, + Type resTy, Value src, + llvm::ArrayRef offsets) { + llvm::SmallVector dynamicOffsets; + llvm::SmallVector staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); +} + +LogicalResult MemDescSubviewOp::verify() { + MemDescType srcTy = getSrc().getType(); + MemDescType resTy = getRes().getType(); + ArrayRef srcShape = srcTy.getShape(); + ArrayRef resShape = resTy.getShape(); + + if (srcTy.getRank() < resTy.getRank()) + return emitOpError("result rank must not exceed source rank."); + + if (llvm::any_of( + llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed source shape."); + + if (srcTy.getStrides() != resTy.getStrides()) + return emitOpError("result must inherit the source strides."); + + return success(); } namespace mlir { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index aafa1b7deb84b..a178d0fe4b0b0 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -941,9 +941,7 @@ struct UnrollLoadMatrixOp : public UnrollPattern { LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - VectorType valueTy = llvm::dyn_cast(op.getType()); - assert(valueTy && "the value type must be vector type!"); - + VectorType valueTy = op.getType(); std::optional> targetShape = getTargetShape(op); if (!targetShape || targetShape->size() != (size_t)valueTy.getRank()) return failure(); @@ -986,8 +984,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern { return failure(); Location loc = op.getLoc(); - VectorType valueTy = llvm::dyn_cast(op.getData().getType()); - assert(valueTy && "the value type must be vector type!"); + VectorType valueTy = op.getData().getType(); ArrayRef shape = valueTy.getShape(); auto layout = dyn_cast(op.getLayoutAttr()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 31a967dcd04c7..c28d2fc6c2b63 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -991,8 +991,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern { return failure(); ArrayRef wgShape = op.getDataShape(); - VectorType valueTy = llvm::dyn_cast(op.getRes().getType()); - assert(valueTy && "the value type must be vector type!"); + VectorType valueTy = op.getRes().getType(); Type elemTy = valueTy.getElementType(); xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir index a9ab0be00722c..e6f22f0a9acbb 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -1,13 +1,17 @@ // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s -gpu.module @test_kernel { +#sg_map_a_f16 = #xegpu.layout +#sg_map_b_f16 = #xegpu.layout +#sg_map_c_f32 = #xegpu.layout + +gpu.module @load_store_check { // CHECK-LABEL: func.func @dpas( // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32> func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { // Loads are checked in a separate test. // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = , types = } // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded + %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> return %d : vector<8xf32> } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir deleted file mode 100644 index d4cb493271d0d..0000000000000 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir +++ /dev/null @@ -1,201 +0,0 @@ -// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s - -gpu.module @test_kernel [#xevm.target] { - - // e.g. for mem_desc<32x32xf16, @strides=[1, 16]> - // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1]) - //CHECK-LABEL: load_store_matrix_1 - gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 { - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32> - - //CHECK: %[[TID:.*]] = gpu.thread_id x - //CHECK: %[[C1:.*]] = arith.constant 1 : index - //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index - //CHECK: %[[C4:.*]] = arith.constant 4 : i32 - //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32 - //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32 - - %tid_x = gpu.thread_id x - %c0 = arith.constant 0 : index - %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32 - - //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3> - - xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index - - gpu.return %1: f32 - } - -// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]> - // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) - //CHECK-LABEL: load_store_matrix_2 - gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 { - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[tid_x:.*]] = gpu.thread_id x - //CHECK: %[[c13:.*]] = arith.constant 13 : index - //CHECK: %[[c16:.*]] = arith.constant 16 : index - //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index - //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index - //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - - //CHECK: %[[c256:.*]] = arith.constant 256 : index - //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index - //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index - //CHECK: %[[c512:.*]] = arith.constant 512 : index - //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index - //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index - //CHECK: %[[c1:.*]] = arith.constant 1 : index - //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index - //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index - //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index - //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index - - //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16 - - - %tid_x = gpu.thread_id x - %c13 = arith.constant 13 : index - %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 - - //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> - - xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index - gpu.return %1: f16 - } - - - // e.g. for mem_desc<32x64xf16, @block=[16, 16]> - // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) - //CHECK-LABEL: load_store_matrix_3 - gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - - //CHECK: %[[tid_x:.*]] = gpu.thread_id x - //CHECK: %[[c19:.*]] = arith.constant 19 : index - %tid_x = gpu.thread_id x - %c19 = arith.constant 19: index - - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 - //CHECK: %[[c16:.*]] = arith.constant 16 : index - //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index - //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index - //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[c1024:.*]] = arith.constant 1024 : index - //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index - //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index - //CHECK: %[[c256:.*]] = arith.constant 256 : index - //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index - //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index - //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index - //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index - //CHECK: %[[c1:.*]] = arith.constant 1 : index - //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index - //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index - - //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16 - %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> f16 - - //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3> - xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index - - //CHECK: gpu.return %[[loaded]] : f16 - gpu.return %1: f16 - } - - // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]> - // its memory layout tuple is ([2,4,16,16],[256,512,1,16]) - //CHECK-LABEL: load_store_matrix_4 - gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[tid_x:.*]] = gpu.thread_id x - - //CHECK: %[[c16:.*]] = arith.constant 16 : index - //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index - //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index - //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index - //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index - - //CHECK: %[[c256:.*]] = arith.constant 256 : index - //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index - //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index - //CHECK: %[[c512:.*]] = arith.constant 512 : index - //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index - //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index - //CHECK: %[[c1:.*]] = arith.constant 1 : index - //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index - //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index - //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index - //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index - - //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16> - - %tid_x = gpu.thread_id x - %c16 = arith.constant 16 : index - %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> - - //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3> - xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index - - gpu.return %1: vector<8xf16> - } - - - // e.g. for mem_desc<32x64xf16, @block=[16, 16]> - // its memory layout tuple is ([2,4,16,16],[1024,256,16,1]) - //CHECK-LABEL: load_store_matrix_5 - gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { - //CHECK: %[[c0:.*]] = arith.constant 0 : index - //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3> - - %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout> - - //CHECK: %[[c16:.*]] = arith.constant 16 : index - //CHECK: %[[c48:.*]] = arith.constant 48 : index - - %c16 = arith.constant 16 : index - %c48 = arith.constant 48 : index - - //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index - //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 - //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index - //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index - //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index - //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index - //CHECK: %[[c1024:.*]] = arith.constant 1024 : index - //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index - //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index - //CHECK: %[[c256:.*]] = arith.constant 256 : index - //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index - //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index - //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index - //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index - //CHECK: %[[c1:.*]] = arith.constant 1 : index - //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index - //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index - //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32 - //CHECK: %[[c2:.*]] = arith.constant 2 : i32 - //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32 - //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32 - //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3> - //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16> - //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16> - - %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index -> vector<8xf16> - - //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16> - //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) - - xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout>, index, index - - gpu.return %1: vector<8xf16> - } - -} diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index ebbe3ce0ec0d0..228ef69d9a478 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16> // ----- func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{data shape must not exceed mem_desc shape}} + // expected-error@+1 {{result shape must not exceed mem_desc shape}} %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16> return } @@ -870,14 +870,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { return } -// ----- -func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) { - // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} - %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16> - return -} - - // ----- func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) { // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}} @@ -900,16 +892,30 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve } // ----- -func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { - // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} - xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> +func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{result shape must not exceed source shape}} + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16> + return +} + +// ----- +func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // expected-error@+1 {{result must inherit the source strides}} + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16> + return +} + +// ----- +func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{failed to verify that all of {src, res} have same element type}} + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout> return } // ----- -func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) { - // expected-error@+1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}} - xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> +func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{result rank must not exceed source rank}} + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16> return } diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 0a10f6814ae96..bb379024a34d7 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -825,73 +825,53 @@ gpu.func @create_mem_desc_with_stride() { gpu.return } -// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) -gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { +// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> gpu.return } -// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) -gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { +// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> gpu.return } -// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) -gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) { - // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16> - %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16> - gpu.return -} - -// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) -gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { - // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> - %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> - gpu.return -} - -// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) -gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { - // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> - %data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8xf16> - gpu.return -} -// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> gpu.return } -// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, [[ARG1:%.+]]: vector<16x16xf16>) -gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<16x16xf16>) { +// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<16x16xf16>) { // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> gpu.return } -// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { -gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { - // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16> - xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16> +// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) { + //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> gpu.return } -// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) -gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { - // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> - xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> +// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) { + //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout> + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout> gpu.return } -// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { -gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<8xf16>) { - // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> - xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> +// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> gpu.return }