diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index 87deef9ca7466..3e4da94bd714e 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -142,6 +142,32 @@ FailureOr buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, FailureOr buildIndependentOp(OpBuilder &b, tensor::EmptyOp emptyOp, ValueRange independencies); +/// Computes the offsets, sizes, and strides needed to build a collapsed +/// `sliceOp`. The dimensions to collapse are specified by `reassociation`. +/// +/// This fails when the specified collapse cannot be represented by a valid +/// ExtractSliceOp. +LogicalResult +getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, + ArrayRef reassociation, + SmallVectorImpl &collapsedOffsets, + SmallVectorImpl &collapsedSizes, + SmallVectorImpl &collapsedStrides); + +/// Computes the offsets, sizes, and strides needed to build an expanded +/// `sliceOp`. The dimensions to expand are specified by `reassociation` and +/// `expandedShape`. +/// +/// This fails when the specified expansion cannot be represented by a valid +/// ExtractSliceOp. +LogicalResult +getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, + ArrayRef reassociation, + ArrayRef expandedShape, + SmallVectorImpl &expandedOffsets, + SmallVectorImpl &expandedSizes, + SmallVectorImpl &expandedStrides); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 2ec23e1fb35ce..dfce835a1954b 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice PatternRewriter &rewriter) const override { auto expandShapeOp = sliceOp.getSource().getDefiningOp(); + if (!expandShapeOp) { + return rewriter.notifyMatchFailure( + sliceOp, "tensor.extract_slice source not produced by expand_shape"); + } + SmallVector reassociation = + expandShapeOp.getReassociationIndices(); - if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp, - rewriter) - .failed()) + SmallVector offsets, sizes, strides; + if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation, + offsets, sizes, strides))) return failure(); - // The tensor.extract_slice before applying the pattern works on the result - // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) - // referring to the state before applying the pattern are named with the - // prefix "expanded", and ones referring to the state after applying the - // pattern are named with the prefix "collapsed". - SmallVector expandedOffsets = sliceOp.getMixedOffsets(); - SmallVector expandedSizes = sliceOp.getMixedSizes(); - SmallVector expandedShape = - getMixedValues(expandShapeOp.getStaticOutputShape(), - expandShapeOp.getOutputShape(), rewriter); - - // Helper variables and function for accumulating the size values. - Location loc = expandShapeOp->getLoc(); - AffineExpr d0, d1, d2; - bindDims(rewriter.getContext(), d0, d1, d2); - // Multiply two integers. - auto mul = [&](OpFoldResult v1, OpFoldResult v2) { - auto mulMap = AffineMap::get(2, 0, {d0 * d1}); - return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, - {v1, v2}); - }; - - // Compute new offsets, sizes, and strides for tensor.extract_slice. - // The new tensor.extract_slice will work on a tensor that has has a rank of - // ReassociationIndices.size(). In the loop a single offset, size, and - // stride value is computed per reassociation group. - SmallVector collapsedOffsets, collapsedSizes, - collapsedStrides; - for (const ReassociationIndices &indices : - expandShapeOp.getReassociationIndices()) { - // collapsedSize will hold the size of the single dim that represents the - // reassociation group in the non expanded tensor. - OpFoldResult collapsedSize = rewriter.getIndexAttr(1); - // The reassocGroupSizes and reassocGroupOffsets are used to create an - // affine.linearize_index op to linearize the single offset value required - // for this reassociation group. - SmallVector reassocGroupSizes, reassocGroupOffsets; - - for (long expandedDim : indices) { - // reassocGroupSizes and reassocGroupOffsets can be obtained directly - // from the expanded state, but the collapsed size requires calculation - // as it did not previously exist. - reassocGroupSizes.push_back(expandedShape[expandedDim]); - reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); - collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); - } - - SmallVector offsetVals = - llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { - return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); - }); - OpFoldResult collapsedOffset = - affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals, - reassocGroupSizes, - /*disjoint=*/true) - .getResult(); - collapsedOffsets.push_back(collapsedOffset); - collapsedSizes.push_back(collapsedSize); - - // Only unit stride is supported. - collapsedStrides.push_back(rewriter.getIndexAttr(1)); - } - // The shape of the result can be obtained from the sizes passed in. - SmallVector dynDims; - SmallVector shape; - dispatchIndexOpFoldResults(expandedSizes, dynDims, shape); - RankedTensorType resultType = RankedTensorType::get( - shape, expandShapeOp.getResultType().getElementType()); + SmallVector expandedSizes = sliceOp.getMixedSizes(); + RankedTensorType resultType = sliceOp.getResultType(); // Create a new ExtractSliceOp and ExpandShapeOp. + Location loc = sliceOp.getLoc(); Value newSliceOp = tensor::ExtractSliceOp::create( - rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes, - collapsedStrides); + rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides); rewriter.replaceOpWithNewOp( sliceOp, resultType, newSliceOp, expandShapeOp.getReassociationIndices(), expandedSizes); return success(); } - - // Helper function to check if all the required conditions for the - // tensor.extract_slice to be bubbled up through the tensor.expand_shape are - // met. - LogicalResult - checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp, - tensor::ExpandShapeOp expandShapeOp, - PatternRewriter &rewriter) const { - - if (!expandShapeOp) { - return rewriter.notifyMatchFailure( - sliceOp, "tensor.extract_slice source not produced by expand_shape"); - } - - if (!sliceOp.hasUnitStride()) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: non-unit stride. Only contiguous slices can " - "be supported in this transformation."); - } - - SmallVector offsets = sliceOp.getMixedOffsets(); - SmallVector sizes = sliceOp.getMixedSizes(); - - if (static_cast(sliceOp.getResultType().getRank()) != - sizes.size()) { - return rewriter.notifyMatchFailure(sliceOp, - "unimplemented: rank reducing slice"); - } - - SmallVector outputShape = - getMixedValues(expandShapeOp.getStaticOutputShape(), - expandShapeOp.getOutputShape(), rewriter); - - std::function - isZeroOffsetAndFullSize = - [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { - if (!isZeroInteger(offset)) - return false; - FailureOr maybeEqual = - ValueBoundsConstraintSet::areEqual(sliceSize, size); - return llvm::succeeded(maybeEqual) && maybeEqual.value(); - }; - - // Check that the slice is contiguous within each reassociation group. - // The slice is contiguous only if after the first dimension where a non - // unit slice is taken, the slice size on all subsequent dimensions of the - // group is equal to the entire size of the dimension. - // Examples of contiguous slices: - // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] - // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] - // Examples of non contiguous slices: - // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] - // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] - for (const ReassociationIndices &indices : - expandShapeOp.getReassociationIndices()) { - int64_t i = 0; - int64_t e = indices.size(); - // Find the first expanded dim after the first dim with non-unit extracted - // size. - for (; i < e; ++i) { - if (!isOneInteger(sizes[indices[i]])) { - // +1 to skip the first non-unit size dim. - i++; - break; - } - } - - // Verify that all subsequent dimensions extract the full size of the - // source tensor. - for (; i < e; ++i) { - int64_t expandedDim = indices[i]; - if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], - outputShape[expandedDim])) { - return rewriter.notifyMatchFailure( - sliceOp, "Not a contiguous slice of the expanded tensor."); - } - } - } - - return success(); - } }; /// Converts `tensor.extract_slice(tensor.collapse_shape)` to @@ -582,170 +441,281 @@ struct BubbleUpCollapseShapeThroughExtractSlice "tensor.extract_slice source not produced by tensor.collapse_shape"); } - if (!sliceOp.hasUnitStride()) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: non-unit stride. Only contiguous slices can " - "be supported in this transformation."); - } + SmallVector offsets, sizes, strides; + if (failed(getExpandedExtractSliceInfo( + rewriter, sliceOp, collapseShapeOp.getReassociationIndices(), + collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides))) + return failure(); - // The tensor.extract_slice before applying the pattern works on the result - // of the tensor.collapse_shape, so variables (i.e. inputs for - // ExtractSliceOp) referring to the state before applying the pattern are - // named with the prefix "collapsed", and ones referring to the state after - // applying the pattern are named with the prefix "expanded". - SmallVector collapsedOffsets = sliceOp.getMixedOffsets(); - SmallVector collapsedSizes = sliceOp.getMixedSizes(); - - if (static_cast(sliceOp.getResultType().getRank()) != - collapsedSizes.size()) { - return rewriter.notifyMatchFailure(sliceOp, - "unimplemented: rank reducing slice"); - } + Value newSliceOp = tensor::ExtractSliceOp::create( + rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets, + sizes, strides); + rewriter.replaceOpWithNewOp( + sliceOp, sliceOp.getResultType(), newSliceOp, + collapseShapeOp.getReassociationIndices()); - ArrayRef srcShape = collapseShapeOp.getSrcType().getShape(); - SmallVector reassociationIndices = - collapseShapeOp.getReassociationIndices(); - - // Compute new offsets, sizes, and strides for tensor.extract_slice. - // The new tensor.extract_slice will work on a tensor that has has a rank - // equal to the rank of the src of the collapse_shape. In each iteration of - // the loop, the offsets and sizes will be computed per reassociation group. - SmallVector expandedOffsets, expandedSizes; - SmallVector expandedStrides(srcShape.size(), - rewriter.getIndexAttr(1)); - - for (auto [collapsedSize, collapsedOffset, reassocIndices] : - llvm::zip_equal(collapsedSizes, collapsedOffsets, - collapseShapeOp.getReassociationIndices())) { - // CASE #1 - size and/or offset are dynamic. - // In this case, the slice can be represented as a contiguous slice only - // if there is a single dimension in the reassociation group that has a - // size not equal to 1. - if (isa(collapsedSize) || isa(collapsedOffset)) { - int nonUnitSizeCount = 0; - for (int64_t expandedShapeIdx : reassocIndices) { - if (srcShape[expandedShapeIdx] != 1) { - nonUnitSizeCount++; - expandedSizes.push_back(collapsedSize); - expandedOffsets.push_back(collapsedOffset); - continue; - } - - expandedSizes.push_back(rewriter.getIndexAttr(1)); - expandedOffsets.push_back(rewriter.getIndexAttr(0)); - } + return success(); + } +}; - if (nonUnitSizeCount != 1) { - return rewriter.notifyMatchFailure( - sliceOp, - "unsupported: slice cannot be verified to be contiguous"); - } - continue; - } +} // namespace - // CASE #2 = size and offset are static. - // Verify that the slice can be represented as a contiguous slice of the - // src of the collapse_shape. - // Checking this is done on order of most internal dimensions first, - // so traversal is done in reverse order of the reassociation group. - // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, - // ...,An] then we first find the size and offset for n...k+1 then for k - // and then for k-1...0. - - // currentCollapsedsize and currentCollapsedOffset are initialized with - // the original collapsed size and offset and divided by the expanded - // shape size in each dimension as we go along the reassociation group. - // In essence we are spreading the original collapsed size and offset over - // the various expanded slice dimensions. - // The variables are used both to check the validity of the slice and to - // compute the expanded sizes and offsets. - int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); - int64_t currentCollapsedOffset = - getConstantIntValue(collapsedOffset).value(); - - SmallVector groupExpandedSizes, groupExpandedOffsets; - - ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), - reassocIndices.rend()); - int64_t idx = 0; - int64_t reassocGroupSize = reassocIndices.size(); - - // First handle the trailing dimensions where the slice size should be - // equal to the tensor shape and the offset should be 0 (n...k+1). - for (; idx < reassocGroupSize; ++idx) { - int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - - if (currentCollapsedsize < expandedShapeSize) - break; - - // We need to make sure that the slice size can be set to the shape size - // and the offset to 0. - if ((currentCollapsedsize % expandedShapeSize) != 0 || - (currentCollapsedOffset % expandedShapeSize) != 0) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: cannot be extracted as a contiguous slice " - "of the src of the collapse_shape"); - } +LogicalResult mlir::tensor::getCollapsedExtractSliceInfo( + OpBuilder &b, tensor::ExtractSliceOp sliceOp, + ArrayRef reassociation, + SmallVectorImpl &collapsedOffsets, + SmallVectorImpl &collapsedSizes, + SmallVectorImpl &collapsedStrides) { + if (!sliceOp.hasUnitStride()) { + return failure(); + } + + SmallVector offsets = sliceOp.getMixedOffsets(); + SmallVector sizes = sliceOp.getMixedSizes(); - groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); - groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); + if (static_cast(sliceOp.getResultType().getRank()) != sizes.size()) { + return failure(); + } - currentCollapsedsize /= expandedShapeSize; - currentCollapsedOffset /= expandedShapeSize; + auto isZeroOffsetAndFullSize = [&](OpFoldResult offset, + OpFoldResult sliceSize, int64_t inputDim) { + if (!isZeroInteger(offset)) + return false; + ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim); + FailureOr maybeEqual = + ValueBoundsConstraintSet::areEqual(sliceSize, inputSize); + return llvm::succeeded(maybeEqual) && maybeEqual.value(); + }; + + // Check that the slice is contiguous within each reassociation group. + // The slice is contiguous only if after the first dimension where a non + // unit slice is taken, the slice size on all subsequent dimensions of the + // group is equal to the entire size of the dimension. + // Examples of contiguous slices: + // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] + // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] + // Examples of non contiguous slices: + // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] + // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] + for (const ReassociationIndices &indices : reassociation) { + int64_t i = 0; + int64_t e = indices.size(); + // Find the first expanded dim after the first dim with non-unit extracted + // size. + for (; i < e; ++i) { + if (!isOneInteger(sizes[indices[i]])) { + // +1 to skip the first non-unit size dim. + i++; + break; } + } + + // Verify that all subsequent dimensions extract the full size of the + // source tensor. + for (; i < e; ++i) { + int64_t expandedDim = indices[i]; + if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], + expandedDim)) { + return failure(); + } + } + } + + // The tensor.extract_slice before applying the pattern works on the result + // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) + // referring to the state before applying the pattern are named with the + // prefix "expanded", and ones referring to the state after applying the + // pattern are named with the prefix "collapsed". + Location loc = sliceOp.getLoc(); + SmallVector expandedOffsets = sliceOp.getMixedOffsets(); + SmallVector expandedSizes = sliceOp.getMixedSizes(); + SmallVector expandedShape = + getMixedSizes(b, loc, sliceOp.getSource()); + + // Helper variables and function for accumulating the size values. + AffineExpr d0, d1, d2; + bindDims(b.getContext(), d0, d1, d2); + // Multiply two integers. + auto mul = [&](OpFoldResult v1, OpFoldResult v2) { + auto mulMap = AffineMap::get(2, 0, {d0 * d1}); + return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2}); + }; + + // Compute new offsets, sizes, and strides for tensor.extract_slice. + // The new tensor.extract_slice will work on a tensor that has has a rank of + // ReassociationIndices.size(). In the loop a single offset, size, and + // stride value is computed per reassociation group. + for (const ReassociationIndices &indices : reassociation) { + // collapsedSize will hold the size of the single dim that represents the + // reassociation group in the non expanded tensor. + OpFoldResult collapsedSize = b.getIndexAttr(1); + // The reassocGroupSizes and reassocGroupOffsets are used to create an + // affine.linearize_index op to linearize the single offset value required + // for this reassociation group. + SmallVector reassocGroupSizes, reassocGroupOffsets; + + for (long expandedDim : indices) { + // reassocGroupSizes and reassocGroupOffsets can be obtained directly + // from the expanded state, but the collapsed size requires calculation + // as it did not previously exist. + reassocGroupSizes.push_back(expandedShape[expandedDim]); + reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); + collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); + } + + SmallVector offsetVals = + llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { + return getValueOrCreateConstantIndexOp(b, loc, ofr); + }); + OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create( + b, loc, offsetVals, reassocGroupSizes, + /*disjoint=*/true) + .getResult(); + collapsedOffsets.push_back(collapsedOffset); + collapsedSizes.push_back(collapsedSize); + + // Only unit stride is supported. + collapsedStrides.push_back(b.getIndexAttr(1)); + } + return success(); +} + +LogicalResult mlir::tensor::getExpandedExtractSliceInfo( + OpBuilder &b, tensor::ExtractSliceOp sliceOp, + ArrayRef reassociation, + ArrayRef expandedShape, + SmallVectorImpl &expandedOffsets, + SmallVectorImpl &expandedSizes, + SmallVectorImpl &expandedStrides) { + if (!sliceOp.hasUnitStride()) { + return failure(); + } + + // The tensor.extract_slice before applying the pattern works on the result + // of the tensor.collapse_shape, so variables (i.e. inputs for + // ExtractSliceOp) referring to the state before applying the pattern are + // named with the prefix "collapsed", and ones referring to the state after + // applying the pattern are named with the prefix "expanded". + SmallVector collapsedOffsets = sliceOp.getMixedOffsets(); + SmallVector collapsedSizes = sliceOp.getMixedSizes(); + if (static_cast(sliceOp.getResultType().getRank()) != + collapsedSizes.size()) { + return failure(); + } - // Now handle the first dim where slicing occurs on (k). - if (idx < reassocGroupSize) { - int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; - // We need to make sure that the slice size in this dim + offset will - // not exceed the shape size. - if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: slice cannot be extracted as a contiguous " - "slice of the src of the collapse_shape"); + // Compute new offsets, sizes, and strides for tensor.extract_slice. + // The new tensor.extract_slice will work on a tensor that has has a rank + // equal to the rank of the src of the collapse_shape. In each iteration of + // the loop, the offsets and sizes will be computed per reassociation group. + expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1)); + for (auto [collapsedSize, collapsedOffset, reassocIndices] : + llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) { + // CASE #1 - size and/or offset are dynamic. + // In this case, the slice can be represented as a contiguous slice only + // if there is a single dimension in the reassociation group that has a + // size not equal to 1. + if (isa(collapsedSize) || isa(collapsedOffset)) { + int nonUnitSizeCount = 0; + for (int64_t expandedShapeIdx : reassocIndices) { + if (expandedShape[expandedShapeIdx] != 1) { + nonUnitSizeCount++; + expandedSizes.push_back(collapsedSize); + expandedOffsets.push_back(collapsedOffset); + continue; } - groupExpandedSizes.push_back( - rewriter.getIndexAttr(currentCollapsedsize)); - groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); + expandedSizes.push_back(b.getIndexAttr(1)); + expandedOffsets.push_back(b.getIndexAttr(0)); + } - currentCollapsedOffset /= expandedShapeSize; + if (nonUnitSizeCount != 1) { + return failure(); } + continue; + } - // Now handle the leading dimensions where the slice size is equal to 1 - // (k-1...0). - // The size for these dimensions must be 1 because of how we constructed - // the slice size of the expanded shape. We spread the original collapsed - // size over the expanded shape sizes until we reached dimension k where - // the remaining size was smaller than the expanded shape size, and spread - // the remaining size on it. So, now we are left with only 1s. - for (idx++; idx < reassocGroupSize; ++idx) { - int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; - groupExpandedSizes.push_back(rewriter.getIndexAttr(1)); - groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); - currentCollapsedOffset /= expandedShapeSize; + // CASE #2 = size and offset are static. + // Verify that the slice can be represented as a contiguous slice of the + // src of the collapse_shape. + // Checking this is done on order of most internal dimensions first, + // so traversal is done in reverse order of the reassociation group. + // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, + // ...,An] then we first find the size and offset for n...k+1 then for k + // and then for k-1...0. + + // currentCollapsedsize and currentCollapsedOffset are initialized with + // the original collapsed size and offset and divided by the expanded + // shape size in each dimension as we go along the reassociation group. + // In essence we are spreading the original collapsed size and offset over + // the various expanded slice dimensions. + // The variables are used both to check the validity of the slice and to + // compute the expanded sizes and offsets. + int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); + int64_t currentCollapsedOffset = + getConstantIntValue(collapsedOffset).value(); + SmallVector groupExpandedSizes, groupExpandedOffsets; + ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), + reassocIndices.rend()); + int64_t idx = 0; + int64_t reassocGroupSize = reassocIndices.size(); + + // First handle the trailing dimensions where the slice size should be + // equal to the tensor shape and the offset should be 0 (n...k+1). + for (; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; + + if (currentCollapsedsize < expandedShapeSize) + break; + + // We need to make sure that the slice size can be set to the shape size + // and the offset to 0. + if ((currentCollapsedsize % expandedShapeSize) != 0 || + (currentCollapsedOffset % expandedShapeSize) != 0) { + return failure(); } - expandedSizes.append(groupExpandedSizes.rbegin(), - groupExpandedSizes.rend()); - expandedOffsets.append(groupExpandedOffsets.rbegin(), - groupExpandedOffsets.rend()); + groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize)); + groupExpandedOffsets.push_back(b.getIndexAttr(0)); + + currentCollapsedsize /= expandedShapeSize; + currentCollapsedOffset /= expandedShapeSize; } - Value newSliceOp = tensor::ExtractSliceOp::create( - rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), - expandedOffsets, expandedSizes, expandedStrides); - rewriter.replaceOpWithNewOp( - sliceOp, sliceOp.getResultType(), newSliceOp, - collapseShapeOp.getReassociationIndices()); + // Now handle the first dim where slicing occurs on (k). + if (idx < reassocGroupSize) { + int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; + // We need to make sure that the slice size in this dim + offset will + // not exceed the shape size. + if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { + return failure(); + } + groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize)); + groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); + currentCollapsedOffset /= expandedShapeSize; + } - return success(); + // Now handle the leading dimensions where the slice size is equal to 1 + // (k-1...0). + // The size for these dimensions must be 1 because of how we constructed + // the slice size of the expanded shape. We spread the original collapsed + // size over the expanded shape sizes until we reached dimension k where + // the remaining size was smaller than the expanded shape size, and spread + // the remaining size on it. So, now we are left with only 1s. + for (idx++; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; + groupExpandedSizes.push_back(b.getIndexAttr(1)); + groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); + currentCollapsedOffset /= expandedShapeSize; + } + expandedSizes.append(groupExpandedSizes.rbegin(), + groupExpandedSizes.rend()); + expandedOffsets.append(groupExpandedOffsets.rbegin(), + groupExpandedOffsets.rend()); } -}; - -} // namespace + return success(); +} void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) {