diff --git a/include/gc/Transforms/Utils/ValueUtils.h b/include/gc/Transforms/Utils/ValueUtils.h index 409f563b..1709078e 100644 --- a/include/gc/Transforms/Utils/ValueUtils.h +++ b/include/gc/Transforms/Utils/ValueUtils.h @@ -53,6 +53,28 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref); // Return true if the memref has shared memory space. bool hasSharedMemSpace(mlir::Value memref); +// Go through all parent 'memref.subview' ops for the given `memref` +// and return the folded offsets of all subviews and the root memref. +void computeSubviewOffsets(PatternRewriter &rewriter, Location loc, + Value memref, SmallVector &resultOffsets, + Value &resultRootMemref); + +// Return the strides of the memref +SmallVector getMemrefStrides(PatternRewriter &rewriter, + Location loc, Value memref); + +// Squeeze the leading dimensions of a given memref up to 'maxDims'. +FailureOr reduceMemrefDims(PatternRewriter &rewriter, Location loc, + Value memref, size_t maxDims = 2); + +// Squeeze the leading dimensions of memref operands of a given 'linalgOp'. +LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, + linalg::LinalgOp linalgOp, size_t maxDims = 2); + +// Return if a memref with the given shape can be squeezed to the shape of +// 'maxDims'. Only leading dimensions are considered squeezable. +bool canSqueezeDims(llvm::ArrayRef shape, size_t maxDims = 2); + } // namespace utils } // namespace mlir diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 344261d4..994e445e 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -62,28 +62,6 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc, return res.getResult(); } -// Extracts the offsets from a subview operation as values. -// The differense from mlir::getMixedOffsets is that this function -// returns the offsets as mlir::Value that can already be used as an argument -// for other mlir::Operations. -static SmallVector extractOffsetsAsValues(PatternRewriter &rewriter, - Location loc, - memref::SubViewOp subview) { - SmallVector offsetValues; - auto staticOffsets = subview.getStaticOffsets(); - auto dynamicOffsets = subview.getOffsets(); - size_t dynIdx = 0; - for (size_t i = 0; i < staticOffsets.size(); i++) { - if (staticOffsets[i] == ShapedType::kDynamic) - offsetValues.push_back(dynamicOffsets[dynIdx++]); - else - offsetValues.push_back( - rewriter.create(loc, staticOffsets[i])); - } - - return offsetValues; -} - // Max number of elements to load/store from SLM constexpr int64_t maxSLMTileSize = 32; @@ -214,7 +192,8 @@ static LogicalResult isValidMemrefOperand(linalg::LinalgOp linalgOp, linalgOp, "Expect memref operand for XeGPU lowering"); } - if (type.getShape().size() > maxDims) { + if (type.getShape().size() > maxDims && + !utils::canSqueezeDims(type.getShape(), maxDims)) { return rewriter.notifyMatchFailure( linalgOp, "Too high dimensionality for XeGPU operations"); } @@ -856,43 +835,33 @@ static SmallVector createSLMDescTiles(PatternRewriter &rewriter, auto srcType = cast(src.getType()); assert(srcType.getRank() == 2 && "Expected a 2D memref"); - SmallVector memrefStrides; - Value blockOffset; - + SmallVector offsets; + Value rootMemref; // 'imex::ConvertGPUXToSPIRVPass' doesn't allow 'memref.subview' ops in the // GPU kernel. We have to merge the subview offsets into the descriptor // offset. - if (auto subView = dyn_cast(src.getDefiningOp())) { - auto offsets = extractOffsetsAsValues(rewriter, loc, subView); - assert(offsets.size() == 2 && "Expected 2D subview offsets"); - - auto xIntOffs = offsets[0]; - auto yIntOffs = offsets[1]; - - // compute 'blockOffset' (beginning of the subview block in the original - // flat memref) - auto rowStride = - cast(subView.getOperand(0).getType()).getShape()[1]; - auto rowStrideValue = - rewriter.create(loc, rowStride); - - auto rowBlockOffset = - rewriter.create(loc, xIntOffs, rowStrideValue) - .getResult(); - blockOffset = rewriter.create(loc, rowBlockOffset, yIntOffs) - .getResult(); + utils::computeSubviewOffsets(rewriter, loc, src, offsets, rootMemref); + auto rootStridesFold = utils::getMemrefStrides(rewriter, loc, rootMemref); + auto rootStrides = + getValueOrCreateConstantIndexOp(rewriter, loc, rootStridesFold); - memrefStrides = {rowStride, 1}; - src = subView.getOperand(0); - } else { - // If the source is not a subview, then the blockOffset is 0 - blockOffset = rewriter.create(loc, 0); - memrefStrides = {srcType.getShape()[1], 1}; + assert(rootStrides.size() == offsets.size() && + "Expected same number of strides and offsets"); + + // blockOffset = sum(rootStrides[i] * offsets[i]) + Value blockOffset = rewriter.create(loc, 0); + for (size_t i = 0; i < rootStrides.size(); i++) { + auto mul = rewriter.create(loc, rootStrides[i], offsets[i]); + blockOffset = rewriter.create(loc, blockOffset, mul); } - // Scatter descriptors only work with 1D memrefs - src = utils::flattenMemref(rewriter, loc, src); + auto memrefStridesFold = utils::getMemrefStrides(rewriter, loc, src); + auto [memrefStrides, memrefStridesDynamic] = + decomposeMixedValues(memrefStridesFold); + assert(memrefStridesDynamic.size() == 0 && + "Expected all values to be resolved"); + src = utils::flattenMemref(rewriter, loc, rootMemref); return createScatterDescriptorTiles( rewriter, loc, /*flatMemref=*/src, /*loadShape2D=*/loadShape, /*tileSize2D=*/descTile, /*memrefStrides=*/memrefStrides, @@ -1839,6 +1808,11 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern { if (failed(isOutputValid)) return isOutputValid; + if (failed(mlir::utils::maybeSqueezeDims(rewriter, gemmLikeOp))) { + return rewriter.notifyMatchFailure( + gemmLikeOp, "Failed to squeeze dimensions of GEMM-like operation"); + } + // Ensure that reduction dimension tiling also works for smaller // workloads. auto aType = cast(gemmLikeOp.getDpsInputs()[0].getType()); @@ -1894,6 +1868,12 @@ struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern { if (failed(isOutputValid)) return isOutputValid; + if (failed(utils::maybeSqueezeDims(rewriter, eltwiseOp))) { + return rewriter.notifyMatchFailure( + eltwiseOp, + "Could not squeeze dimensions of the elementwise operation"); + } + return createEltwiseKernel(eltwiseOp, rewriter); } @@ -1988,6 +1968,12 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern { if (failed(isOutputValid)) return isOutputValid; + if (failed(utils::maybeSqueezeDims(rewriter, linalgOp))) { + return rewriter.notifyMatchFailure( + linalgOp, + "Could not squeeze dimensions of the memory fill operation"); + } + return createMemoryFillKernel(linalgOp, rewriter); } diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index 8e3de421..c6285df1 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -6,10 +6,14 @@ // //===----------------------------------------------------------------------===// +#include + +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Matchers.h" @@ -155,9 +159,10 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref) { auto srcType = cast(srcMemref.getType()); assert(srcType && "Expected a memref type"); - assert(srcType.getRank() == 2 && "Expected a 2D memref"); - int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1]; + auto shapeNd = srcType.getShape(); + int64_t flatSize = + std::accumulate(shapeNd.begin(), shapeNd.end(), 1, std::multiplies<>()); Value offset = rewriter.create(loc, 0); Value size = rewriter.create(loc, flatSize); @@ -193,5 +198,134 @@ bool hasSharedMemSpace(mlir::Value memref) { return false; } +void computeSubviewOffsets(PatternRewriter &rewriter, Location loc, + Value memref, SmallVector &resultOffsets, + Value &resultRootMemref) { + auto fillVal = rewriter.create(loc, 0); + auto type = dyn_cast(memref.getType()); + assert(type && "Expected a memref type"); + + auto origShape = type.getShape(); + + resultOffsets.clear(); + resultOffsets.append(origShape.size(), fillVal); + resultRootMemref = memref; + + while (auto subViewOp = resultRootMemref.getDefiningOp()) { + auto currentOffsets = getAsOpFoldResult(resultOffsets); + resultOffsets.clear(); + + affine::resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, resultRootMemref.getLoc(), subViewOp.getMixedOffsets(), + subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), currentOffsets, + resultOffsets); + resultRootMemref = subViewOp.getOperand(0); + } +} + +SmallVector getMemrefStrides(PatternRewriter &rewriter, + Location loc, Value memref) { + auto type = dyn_cast(memref.getType()); + + auto stridedLayout = dyn_cast(type.getLayout()); + if (stridedLayout) { + auto strides = stridedLayout.getStrides(); + return getMixedValues(strides, {}, rewriter); + } + + auto sizes = getMixedValues(type.getShape(), {}, rewriter); + auto strides = memref::computeStridesIRBlock(loc, rewriter, sizes); + return strides; +} + +FailureOr reduceMemrefDims(PatternRewriter &rewriter, Location loc, + Value memref, size_t maxDims = 2) { + auto type = dyn_cast(memref.getType()); + auto shape = type.getShape(); + + if (shape.size() <= maxDims) + return memref; + + for (size_t i = 0; i < shape.size() - maxDims; i++) + if (shape[i] != 1) + return failure(); + + auto offsets = + getMixedValues(SmallVector(shape.size(), 0), {}, rewriter); + auto sizes = getMixedValues(shape, {}, rewriter); + auto staticStrides = utils::getStaticStrides(memref).value(); + auto strides = + getMixedValues(SmallVector(shape.size(), 1), {}, rewriter); + + SmallVector newShape(shape.begin() + shape.size() - maxDims, + shape.end()); + SmallVector newStrides( + staticStrides.begin() + shape.size() - maxDims, staticStrides.end()); + + int64_t newOffset = 0; + if (auto memrefLayout = dyn_cast(type.getLayout())) + newOffset = memrefLayout.getOffset(); + + auto newLayout = StridedLayoutAttr::get( + rewriter.getContext(), /*offset=*/newOffset, /*strides=*/newStrides); + MemRefType newMemRefType = MemRefType::get(newShape, type.getElementType(), + newLayout, type.getMemorySpace()); + + auto squeezedSubview = + rewriter + .create(loc, newMemRefType, memref, offsets, sizes, + strides) + .getResult(); + return squeezedSubview; +} + +LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, + linalg::LinalgOp linalgOp, size_t maxDims) { + SmallVector> newOperands; + auto operands = linalgOp->getOperands(); + auto loc = linalgOp.getLoc(); + + for (size_t i = 0; i < operands.size(); i++) { + auto operand = operands[i]; + auto type = dyn_cast(operand.getType()); + if (!type) { + // Skip non-memref operands + continue; + } + + if (type.getShape().size() <= maxDims) + continue; + + auto res = reduceMemrefDims(rewriter, loc, operand, maxDims); + if (failed(res)) { + return rewriter.notifyMatchFailure( + linalgOp, "Can't squeeze memref to the desired number of dimensions"); + } + + auto flatSubview = res.value(); + newOperands.emplace_back(i, flatSubview); + } + + if (newOperands.empty()) + return success(); + + rewriter.modifyOpInPlace(linalgOp, [&] { + for (auto [i, operand] : newOperands) + linalgOp->setOperand(i, operand); + }); + return success(); +} + +bool canSqueezeDims(llvm::ArrayRef shape, size_t maxDims) { + if (shape.size() <= maxDims) + return true; + + for (size_t i = 0; i < shape.size() - maxDims; i++) + if (shape[i] != 1) + return false; + + return true; +} + } // namespace utils } // namespace mlir diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-squeeze.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-squeeze.mlir new file mode 100644 index 00000000..362867fb --- /dev/null +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-squeeze.mlir @@ -0,0 +1,61 @@ +// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file -cse | FileCheck %s + +!input_type = memref<2x4x8x16xf16> +!chunk_type = memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> +!slm_chunk = memref<1x1x8x16xf16, strided<[128, 128, 16, 1], offset: ?>, 3> + +// The map that computes an offset for SLM +// CHECK: #map = affine_map<(d0, d1) -> (d0 * 4 + d1)> +#map = affine_map<(xi, yi) -> (xi * 4 + yi)> + +func.func @entry(%arg0: !input_type, %arg1: !input_type, %arg2: !input_type) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c2, %arg13 = %c4, %arg14 = %c1) { + // CHECK: %[[ARG0_SB:.+]] = memref.subview %arg0[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] + %arg0_sb = memref.subview %arg0[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type + // CHECK: %[[ARG1_SB:.+]] = memref.subview %arg1[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] + %arg1_sb = memref.subview %arg1[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type + // CHECK: %[[ARG2_SB:.+]] = memref.subview %arg2[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] + %arg2_sb = memref.subview %arg2[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type + + // CHECK: %[[SLM_BUFF:.+]] = memref.alloc() : memref<8x1x8x16xf16, 3> + %slm_root = memref.alloc() : memref<8x1x8x16xf16, 3> + + %slm_idx = affine.apply #map(%arg6, %arg7) + %slm = memref.subview %slm_root[%slm_idx, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<8x1x8x16xf16, 3> to !slm_chunk + + // Squeezing the arguments of 'linalg.mul' + // CHECK: %[[ARG0_SQUEEZ:.+]] = memref.subview %[[ARG0_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : + // CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>> + + // CHECK: %[[ARG1_SQUEEZ:.+]] = memref.subview %[[ARG1_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : + // CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>> + + // Verify that tensor descriptors are created from the squeezed memrefs + // CHECK: xegpu.create_nd_tdesc %[[ARG0_SQUEEZ]] + // CHECK: xegpu.create_nd_tdesc %[[ARG1_SQUEEZ]] + + // Verify that the SLM output of linalg.mul is squeezed correctly + // CHECK-NOT: .* = memref.subview %[[SLM_BUFF]] .* + // CHECK: %[[SLM_THREAD_OFF:.+]] = affine.apply #map(%arg6, %arg7) + // CHECK: %[[SLM_OFF:.+]] = arith.muli %[[SLM_THREAD_OFF]], %c128 : index + // CHECK: %[[FLAT_SLM:.+]] = memref.reinterpret_cast %[[SLM_BUFF]] to offset: [%c0], sizes: [%c1024], strides: [%c1] : memref<8x1x8x16xf16, 3> to memref<1024xf16, 3> + // CHECK: xegpu.create_tdesc %[[FLAT_SLM]] + linalg.mul ins(%arg0_sb, %arg1_sb : !chunk_type, !chunk_type) outs(%slm : !slm_chunk) + + // Squeezing the result buffer of 'linalg.add' + // CHECK: %[[ARG2_SQUEEZ:.+]] = memref.subview %[[ARG2_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : + // CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>> + + // Verify that tensor descriptors are created from the squeezed memrefs + // CHECK: xegpu.create_nd_tdesc %[[ARG2_SQUEEZ]] + linalg.add ins(%arg0_sb, %slm : !chunk_type, !slm_chunk) outs(%arg2_sb : !chunk_type) + + gpu.terminator + } {SCFToGPU_visited} + + return +}