From 9dd135edb6bb4f87eeb5191e443a706fdceacf09 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Thu, 28 Nov 2024 12:39:41 +0000 Subject: [PATCH 1/7] [LinalgToXeGPU] Support squeezable any-D memrefs Signed-off-by: dchigarev --- include/gc/Transforms/Utils/ValueUtils.h | 21 ++++ lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 90 +++++++--------- lib/gc/Transforms/Utils/ValueUtils.cpp | 132 ++++++++++++++++++++++- 3 files changed, 188 insertions(+), 55 deletions(-) diff --git a/include/gc/Transforms/Utils/ValueUtils.h b/include/gc/Transforms/Utils/ValueUtils.h index 409f563b..a512d2e7 100644 --- a/include/gc/Transforms/Utils/ValueUtils.h +++ b/include/gc/Transforms/Utils/ValueUtils.h @@ -53,6 +53,27 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref); // Return true if the memref has shared memory space. bool hasSharedMemSpace(mlir::Value memref); +// Go through all parent 'memref.subview' ops for the given `memref` +// and return the folded offsets of all subviews and the root memref. +std::tuple, Value> +computeSubviewOffsets(PatternRewriter &rewriter, Location loc, Value memref); + +// Return the strides of the memref +SmallVector getMemrefStrides(PatternRewriter &rewriter, + Location loc, Value memref); + +// Squeeze the leading dimensions of a given memref up to 'maxDims'. +FailureOr squeezeMemref(PatternRewriter &rewriter, Location loc, + Value memref, size_t maxDims = 2); + +// Squeeze the leading dimensions of memref operands of a given 'linalgOp'. +LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, + linalg::LinalgOp linalgOp, size_t maxDims = 2); + +// Return if a memref with the given shape can be squeezed to the shape of +// 'maxDims'. Only leading dimensions are considered squeezable. +bool canSqueezeDims(llvm::ArrayRef shape, size_t maxDims = 2); + } // namespace utils } // namespace mlir diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 344261d4..2b6841ec 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -62,28 +62,6 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc, return res.getResult(); } -// Extracts the offsets from a subview operation as values. -// The differense from mlir::getMixedOffsets is that this function -// returns the offsets as mlir::Value that can already be used as an argument -// for other mlir::Operations. -static SmallVector extractOffsetsAsValues(PatternRewriter &rewriter, - Location loc, - memref::SubViewOp subview) { - SmallVector offsetValues; - auto staticOffsets = subview.getStaticOffsets(); - auto dynamicOffsets = subview.getOffsets(); - size_t dynIdx = 0; - for (size_t i = 0; i < staticOffsets.size(); i++) { - if (staticOffsets[i] == ShapedType::kDynamic) - offsetValues.push_back(dynamicOffsets[dynIdx++]); - else - offsetValues.push_back( - rewriter.create(loc, staticOffsets[i])); - } - - return offsetValues; -} - // Max number of elements to load/store from SLM constexpr int64_t maxSLMTileSize = 32; @@ -214,7 +192,8 @@ static LogicalResult isValidMemrefOperand(linalg::LinalgOp linalgOp, linalgOp, "Expect memref operand for XeGPU lowering"); } - if (type.getShape().size() > maxDims) { + if (type.getShape().size() > maxDims && + !utils::canSqueezeDims(type.getShape(), maxDims)) { return rewriter.notifyMatchFailure( linalgOp, "Too high dimensionality for XeGPU operations"); } @@ -856,43 +835,31 @@ static SmallVector createSLMDescTiles(PatternRewriter &rewriter, auto srcType = cast(src.getType()); assert(srcType.getRank() == 2 && "Expected a 2D memref"); - SmallVector memrefStrides; - Value blockOffset; - // 'imex::ConvertGPUXToSPIRVPass' doesn't allow 'memref.subview' ops in the // GPU kernel. We have to merge the subview offsets into the descriptor // offset. - if (auto subView = dyn_cast(src.getDefiningOp())) { - auto offsets = extractOffsetsAsValues(rewriter, loc, subView); - assert(offsets.size() == 2 && "Expected 2D subview offsets"); - - auto xIntOffs = offsets[0]; - auto yIntOffs = offsets[1]; - - // compute 'blockOffset' (beginning of the subview block in the original - // flat memref) - auto rowStride = - cast(subView.getOperand(0).getType()).getShape()[1]; - auto rowStrideValue = - rewriter.create(loc, rowStride); - - auto rowBlockOffset = - rewriter.create(loc, xIntOffs, rowStrideValue) - .getResult(); - blockOffset = rewriter.create(loc, rowBlockOffset, yIntOffs) - .getResult(); + auto [offsets, rootMemref] = utils::computeSubviewOffsets(rewriter, loc, src); + auto rootStridesFold = utils::getMemrefStrides(rewriter, loc, rootMemref); + auto rootStrides = + getValueOrCreateConstantIndexOp(rewriter, loc, rootStridesFold); - memrefStrides = {rowStride, 1}; - src = subView.getOperand(0); - } else { - // If the source is not a subview, then the blockOffset is 0 - blockOffset = rewriter.create(loc, 0); - memrefStrides = {srcType.getShape()[1], 1}; + assert(rootStrides.size() == offsets.size() && + "Expected same number of strides and offsets"); + + // blockOffset = sum(rootStrides[i] * offsets[i]) + Value blockOffset = rewriter.create(loc, 0); + for (size_t i = 0; i < rootStrides.size(); i++) { + auto mul = rewriter.create(loc, rootStrides[i], offsets[i]); + blockOffset = rewriter.create(loc, blockOffset, mul); } - // Scatter descriptors only work with 1D memrefs - src = utils::flattenMemref(rewriter, loc, src); + auto memrefStridesFold = utils::getMemrefStrides(rewriter, loc, src); + auto [memrefStrides, memrefStridesDynamic] = + decomposeMixedValues(memrefStridesFold); + assert(memrefStridesDynamic.size() == 0 && + "Expected all values to be resolved"); + src = utils::flattenMemref(rewriter, loc, rootMemref); return createScatterDescriptorTiles( rewriter, loc, /*flatMemref=*/src, /*loadShape2D=*/loadShape, /*tileSize2D=*/descTile, /*memrefStrides=*/memrefStrides, @@ -1839,6 +1806,11 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern { if (failed(isOutputValid)) return isOutputValid; + if (failed(mlir::utils::maybeSqueezeDims(rewriter, gemmLikeOp))) { + return rewriter.notifyMatchFailure( + gemmLikeOp, "Failed to squeeze dimensions of GEMM-like operation"); + } + // Ensure that reduction dimension tiling also works for smaller // workloads. auto aType = cast(gemmLikeOp.getDpsInputs()[0].getType()); @@ -1894,6 +1866,12 @@ struct ConvertNamedEltwiseToXeGPU : public OpRewritePattern { if (failed(isOutputValid)) return isOutputValid; + if (failed(utils::maybeSqueezeDims(rewriter, eltwiseOp))) { + return rewriter.notifyMatchFailure( + eltwiseOp, + "Could not squeeze dimensions of the elementwise operation"); + } + return createEltwiseKernel(eltwiseOp, rewriter); } @@ -1988,6 +1966,12 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern { if (failed(isOutputValid)) return isOutputValid; + if (failed(utils::maybeSqueezeDims(rewriter, linalgOp))) { + return rewriter.notifyMatchFailure( + linalgOp, + "Could not squeeze dimensions of the memory fill operation"); + } + return createMemoryFillKernel(linalgOp, rewriter); } diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index 8e3de421..fa50b1fe 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -6,10 +6,14 @@ // //===----------------------------------------------------------------------===// +#include + +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Matchers.h" @@ -155,9 +159,10 @@ Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref) { auto srcType = cast(srcMemref.getType()); assert(srcType && "Expected a memref type"); - assert(srcType.getRank() == 2 && "Expected a 2D memref"); - int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1]; + auto shapeNd = srcType.getShape(); + int64_t flatSize = + std::accumulate(shapeNd.begin(), shapeNd.end(), 1, std::multiplies<>()); Value offset = rewriter.create(loc, 0); Value size = rewriter.create(loc, flatSize); @@ -193,5 +198,128 @@ bool hasSharedMemSpace(mlir::Value memref) { return false; } +std::tuple, Value> +computeSubviewOffsets(PatternRewriter &rewriter, Location loc, Value memref) { + auto fillVal = rewriter.create(loc, 0); + auto origShape = dyn_cast(memref.getType()).getShape(); + + SmallVector resolvedOffsets(origShape.size(), fillVal); + + while (auto subViewOp = memref.getDefiningOp()) { + auto currentOffsets = getAsOpFoldResult(resolvedOffsets); + resolvedOffsets.clear(); + + affine::resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, memref.getLoc(), subViewOp.getMixedOffsets(), + subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), currentOffsets, + resolvedOffsets); + memref = subViewOp.getOperand(0); + } + + return std::make_tuple(resolvedOffsets, memref); +} + +SmallVector getMemrefStrides(PatternRewriter &rewriter, + Location loc, Value memref) { + auto type = dyn_cast(memref.getType()); + + auto stridedLayout = dyn_cast(type.getLayout()); + if (stridedLayout) { + auto strides = stridedLayout.getStrides(); + return getMixedValues(strides, {}, rewriter); + } + + auto sizes = getMixedValues(type.getShape(), {}, rewriter); + auto strides = memref::computeStridesIRBlock(loc, rewriter, sizes); + return strides; +} + +FailureOr squeezeMemref(PatternRewriter &rewriter, Location loc, + Value memref, size_t maxDims = 2) { + auto type = dyn_cast(memref.getType()); + auto shape = type.getShape(); + + if (shape.size() <= maxDims) + return memref; + + for (size_t i = 0; i < shape.size() - maxDims; i++) + if (shape[i] != 1) + return failure(); + + auto offsets = + getMixedValues(SmallVector(shape.size(), 0), {}, rewriter); + auto sizes = getMixedValues(shape, {}, rewriter); + auto staticStrides = utils::getStaticStrides(memref).value(); + auto strides = + getMixedValues(SmallVector(shape.size(), 1), {}, rewriter); + + SmallVector newShape(shape.begin() + shape.size() - maxDims, + shape.end()); + SmallVector newStrides( + staticStrides.begin() + shape.size() - maxDims, staticStrides.end()); + + int64_t newOffset = 0; + if (auto memrefLayout = dyn_cast(type.getLayout())) + newOffset = memrefLayout.getOffset(); + + auto newLayout = StridedLayoutAttr::get( + rewriter.getContext(), /*offset=*/newOffset, /*strides=*/newStrides); + MemRefType newMemRefType = MemRefType::get(newShape, type.getElementType(), + newLayout, type.getMemorySpace()); + + auto squeezedSubview = + rewriter + .create(loc, newMemRefType, memref, offsets, sizes, + strides) + .getResult(); + return squeezedSubview; +} + +LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, + linalg::LinalgOp linalgOp, size_t maxDims) { + SmallVector> newOperands; + auto operands = linalgOp->getOperands(); + auto loc = linalgOp.getLoc(); + + for (size_t i = 0; i < operands.size(); i++) { + auto operand = operands[i]; + auto type = dyn_cast(operand.getType()); + if (!type) { + // maybe should 'continue' here instead and skip non-memref operands? + // TODO: replace this with 'continue' if such case would appear someday + return rewriter.notifyMatchFailure( + linalgOp, "Expect memref operand for XeGPU lowering"); + } + + if (type.getShape().size() <= maxDims) + continue; + + auto res = squeezeMemref(rewriter, loc, operand, maxDims); + if (failed(res)) { + return rewriter.notifyMatchFailure( + linalgOp, "Can't squeeze memref to the desired number of dimensions"); + } + + auto flatSubview = res.value(); + newOperands.emplace_back(i, flatSubview); + } + + for (auto [i, operand] : newOperands) + linalgOp->setOperand(i, operand); + + return success(); +} + +bool canSqueezeDims(llvm::ArrayRef shape, size_t maxDims) { + if (shape.size() <= maxDims) + return true; + + for (size_t i = 0; i < shape.size() - maxDims; i++) + if (shape[i] != 1) + return false; + + return true; +} + } // namespace utils } // namespace mlir From 324168f4d182f35cebe8a56442fb7ebe5a17cfde Mon Sep 17 00:00:00 2001 From: dchigarev Date: Fri, 29 Nov 2024 14:06:56 +0000 Subject: [PATCH 2/7] Fix linalg.fill case Signed-off-by: dchigarev --- lib/gc/Transforms/Utils/ValueUtils.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index fa50b1fe..2b1c83a0 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -285,10 +285,8 @@ LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, auto operand = operands[i]; auto type = dyn_cast(operand.getType()); if (!type) { - // maybe should 'continue' here instead and skip non-memref operands? - // TODO: replace this with 'continue' if such case would appear someday - return rewriter.notifyMatchFailure( - linalgOp, "Expect memref operand for XeGPU lowering"); + // Skip non-memref operands + continue; } if (type.getShape().size() <= maxDims) From d8bf832def043a213c4e8d312936d58c2b59a967 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Fri, 29 Nov 2024 14:56:29 +0000 Subject: [PATCH 3/7] Add squeeze tests Signed-off-by: dchigarev --- .../GPU/linalg-to-xegpu-squeeze.mlir | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-squeeze.mlir diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-squeeze.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-squeeze.mlir new file mode 100644 index 00000000..362867fb --- /dev/null +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-squeeze.mlir @@ -0,0 +1,61 @@ +// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file -cse | FileCheck %s + +!input_type = memref<2x4x8x16xf16> +!chunk_type = memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> +!slm_chunk = memref<1x1x8x16xf16, strided<[128, 128, 16, 1], offset: ?>, 3> + +// The map that computes an offset for SLM +// CHECK: #map = affine_map<(d0, d1) -> (d0 * 4 + d1)> +#map = affine_map<(xi, yi) -> (xi * 4 + yi)> + +func.func @entry(%arg0: !input_type, %arg1: !input_type, %arg2: !input_type) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c2, %arg13 = %c4, %arg14 = %c1) { + // CHECK: %[[ARG0_SB:.+]] = memref.subview %arg0[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] + %arg0_sb = memref.subview %arg0[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type + // CHECK: %[[ARG1_SB:.+]] = memref.subview %arg1[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] + %arg1_sb = memref.subview %arg1[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type + // CHECK: %[[ARG2_SB:.+]] = memref.subview %arg2[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] + %arg2_sb = memref.subview %arg2[%arg6, %arg7, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : !input_type to !chunk_type + + // CHECK: %[[SLM_BUFF:.+]] = memref.alloc() : memref<8x1x8x16xf16, 3> + %slm_root = memref.alloc() : memref<8x1x8x16xf16, 3> + + %slm_idx = affine.apply #map(%arg6, %arg7) + %slm = memref.subview %slm_root[%slm_idx, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : memref<8x1x8x16xf16, 3> to !slm_chunk + + // Squeezing the arguments of 'linalg.mul' + // CHECK: %[[ARG0_SQUEEZ:.+]] = memref.subview %[[ARG0_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : + // CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>> + + // CHECK: %[[ARG1_SQUEEZ:.+]] = memref.subview %[[ARG1_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : + // CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>> + + // Verify that tensor descriptors are created from the squeezed memrefs + // CHECK: xegpu.create_nd_tdesc %[[ARG0_SQUEEZ]] + // CHECK: xegpu.create_nd_tdesc %[[ARG1_SQUEEZ]] + + // Verify that the SLM output of linalg.mul is squeezed correctly + // CHECK-NOT: .* = memref.subview %[[SLM_BUFF]] .* + // CHECK: %[[SLM_THREAD_OFF:.+]] = affine.apply #map(%arg6, %arg7) + // CHECK: %[[SLM_OFF:.+]] = arith.muli %[[SLM_THREAD_OFF]], %c128 : index + // CHECK: %[[FLAT_SLM:.+]] = memref.reinterpret_cast %[[SLM_BUFF]] to offset: [%c0], sizes: [%c1024], strides: [%c1] : memref<8x1x8x16xf16, 3> to memref<1024xf16, 3> + // CHECK: xegpu.create_tdesc %[[FLAT_SLM]] + linalg.mul ins(%arg0_sb, %arg1_sb : !chunk_type, !chunk_type) outs(%slm : !slm_chunk) + + // Squeezing the result buffer of 'linalg.add' + // CHECK: %[[ARG2_SQUEEZ:.+]] = memref.subview %[[ARG2_SB]][0, 0, 0, 0] [1, 1, 8, 16] [1, 1, 1, 1] : + // CHECK-SAME: memref<1x1x8x16xf16, strided<[512, 128, 16, 1], offset: ?>> to memref<8x16xf16, strided<[16, 1], offset: ?>> + + // Verify that tensor descriptors are created from the squeezed memrefs + // CHECK: xegpu.create_nd_tdesc %[[ARG2_SQUEEZ]] + linalg.add ins(%arg0_sb, %slm : !chunk_type, !slm_chunk) outs(%arg2_sb : !chunk_type) + + gpu.terminator + } {SCFToGPU_visited} + + return +} From 7dbe4337c793b0efd25999046fdad698be2b8bc1 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Fri, 29 Nov 2024 16:23:09 +0000 Subject: [PATCH 4/7] avoid copy Signed-off-by: dchigarev --- lib/gc/Transforms/Utils/ValueUtils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index 2b1c83a0..9040b85f 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -216,7 +216,7 @@ computeSubviewOffsets(PatternRewriter &rewriter, Location loc, Value memref) { memref = subViewOp.getOperand(0); } - return std::make_tuple(resolvedOffsets, memref); + return std::make_tuple(std::move(resolvedOffsets), memref); } SmallVector getMemrefStrides(PatternRewriter &rewriter, From 8e793a98c51ae50c863686da8795cd6a68c21096 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Mon, 2 Dec 2024 09:57:25 +0000 Subject: [PATCH 5/7] do not return smallVector Signed-off-by: dchigarev --- include/gc/Transforms/Utils/ValueUtils.h | 5 +++-- lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 4 +++- lib/gc/Transforms/Utils/ValueUtils.cpp | 23 ++++++++++++----------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/include/gc/Transforms/Utils/ValueUtils.h b/include/gc/Transforms/Utils/ValueUtils.h index a512d2e7..c0d105df 100644 --- a/include/gc/Transforms/Utils/ValueUtils.h +++ b/include/gc/Transforms/Utils/ValueUtils.h @@ -55,8 +55,9 @@ bool hasSharedMemSpace(mlir::Value memref); // Go through all parent 'memref.subview' ops for the given `memref` // and return the folded offsets of all subviews and the root memref. -std::tuple, Value> -computeSubviewOffsets(PatternRewriter &rewriter, Location loc, Value memref); +void computeSubviewOffsets(PatternRewriter &rewriter, Location loc, + Value memref, SmallVector &resultOffsets, + Value &resultRootMemref); // Return the strides of the memref SmallVector getMemrefStrides(PatternRewriter &rewriter, diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 2b6841ec..994e445e 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -835,10 +835,12 @@ static SmallVector createSLMDescTiles(PatternRewriter &rewriter, auto srcType = cast(src.getType()); assert(srcType.getRank() == 2 && "Expected a 2D memref"); + SmallVector offsets; + Value rootMemref; // 'imex::ConvertGPUXToSPIRVPass' doesn't allow 'memref.subview' ops in the // GPU kernel. We have to merge the subview offsets into the descriptor // offset. - auto [offsets, rootMemref] = utils::computeSubviewOffsets(rewriter, loc, src); + utils::computeSubviewOffsets(rewriter, loc, src, offsets, rootMemref); auto rootStridesFold = utils::getMemrefStrides(rewriter, loc, rootMemref); auto rootStrides = getValueOrCreateConstantIndexOp(rewriter, loc, rootStridesFold); diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index 9040b85f..b69ab91e 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -198,25 +198,26 @@ bool hasSharedMemSpace(mlir::Value memref) { return false; } -std::tuple, Value> -computeSubviewOffsets(PatternRewriter &rewriter, Location loc, Value memref) { +void computeSubviewOffsets(PatternRewriter &rewriter, Location loc, + Value memref, SmallVector &resultOffsets, + Value &resultRootMemref) { auto fillVal = rewriter.create(loc, 0); auto origShape = dyn_cast(memref.getType()).getShape(); - SmallVector resolvedOffsets(origShape.size(), fillVal); + resultOffsets.clear(); + resultOffsets.append(origShape.size(), fillVal); + resultRootMemref = memref; - while (auto subViewOp = memref.getDefiningOp()) { - auto currentOffsets = getAsOpFoldResult(resolvedOffsets); - resolvedOffsets.clear(); + while (auto subViewOp = resultRootMemref.getDefiningOp()) { + auto currentOffsets = getAsOpFoldResult(resultOffsets); + resultOffsets.clear(); affine::resolveIndicesIntoOpWithOffsetsAndStrides( - rewriter, memref.getLoc(), subViewOp.getMixedOffsets(), + rewriter, resultRootMemref.getLoc(), subViewOp.getMixedOffsets(), subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), currentOffsets, - resolvedOffsets); - memref = subViewOp.getOperand(0); + resultOffsets); + resultRootMemref = subViewOp.getOperand(0); } - - return std::make_tuple(std::move(resolvedOffsets), memref); } SmallVector getMemrefStrides(PatternRewriter &rewriter, From a05f541183f52c5d45e8927f58ea550709d0d234 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Mon, 2 Dec 2024 15:21:50 +0000 Subject: [PATCH 6/7] address review comments Signed-off-by: dchigarev --- include/gc/Transforms/Utils/ValueUtils.h | 4 ++-- lib/gc/Transforms/Utils/ValueUtils.cpp | 18 +++++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/include/gc/Transforms/Utils/ValueUtils.h b/include/gc/Transforms/Utils/ValueUtils.h index c0d105df..1709078e 100644 --- a/include/gc/Transforms/Utils/ValueUtils.h +++ b/include/gc/Transforms/Utils/ValueUtils.h @@ -64,8 +64,8 @@ SmallVector getMemrefStrides(PatternRewriter &rewriter, Location loc, Value memref); // Squeeze the leading dimensions of a given memref up to 'maxDims'. -FailureOr squeezeMemref(PatternRewriter &rewriter, Location loc, - Value memref, size_t maxDims = 2); +FailureOr reduceMemrefDims(PatternRewriter &rewriter, Location loc, + Value memref, size_t maxDims = 2); // Squeeze the leading dimensions of memref operands of a given 'linalgOp'. LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index b69ab91e..8d20533e 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -202,7 +202,10 @@ void computeSubviewOffsets(PatternRewriter &rewriter, Location loc, Value memref, SmallVector &resultOffsets, Value &resultRootMemref) { auto fillVal = rewriter.create(loc, 0); - auto origShape = dyn_cast(memref.getType()).getShape(); + auto type = dyn_cast(memref.getType()); + assert(type && "Expected a memref type"); + + auto origShape = type.getShape(); resultOffsets.clear(); resultOffsets.append(origShape.size(), fillVal); @@ -235,8 +238,8 @@ SmallVector getMemrefStrides(PatternRewriter &rewriter, return strides; } -FailureOr squeezeMemref(PatternRewriter &rewriter, Location loc, - Value memref, size_t maxDims = 2) { +FailureOr reduceMemrefDims(PatternRewriter &rewriter, Location loc, + Value memref, size_t maxDims = 2) { auto type = dyn_cast(memref.getType()); auto shape = type.getShape(); @@ -293,7 +296,7 @@ LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, if (type.getShape().size() <= maxDims) continue; - auto res = squeezeMemref(rewriter, loc, operand, maxDims); + auto res = reduceMemrefDims(rewriter, loc, operand, maxDims); if (failed(res)) { return rewriter.notifyMatchFailure( linalgOp, "Can't squeeze memref to the desired number of dimensions"); @@ -303,9 +306,10 @@ LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, newOperands.emplace_back(i, flatSubview); } - for (auto [i, operand] : newOperands) - linalgOp->setOperand(i, operand); - + rewriter.modifyOpInPlace(linalgOp, [&] { + for (auto [i, operand] : newOperands) + linalgOp->setOperand(i, operand); + }); return success(); } From 21ee00b8ec7a1d54abb3c2ef8db6e10723f73917 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Tue, 3 Dec 2024 11:25:53 +0000 Subject: [PATCH 7/7] fix endless loop Signed-off-by: dchigarev --- lib/gc/Transforms/Utils/ValueUtils.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/gc/Transforms/Utils/ValueUtils.cpp b/lib/gc/Transforms/Utils/ValueUtils.cpp index 8d20533e..c6285df1 100644 --- a/lib/gc/Transforms/Utils/ValueUtils.cpp +++ b/lib/gc/Transforms/Utils/ValueUtils.cpp @@ -306,6 +306,9 @@ LogicalResult maybeSqueezeDims(PatternRewriter &rewriter, newOperands.emplace_back(i, flatSubview); } + if (newOperands.empty()) + return success(); + rewriter.modifyOpInPlace(linalgOp, [&] { for (auto [i, operand] : newOperands) linalgOp->setOperand(i, operand);