diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index d4dbb5aeb7c27f..f65a0fa1772d45 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -701,24 +701,27 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, })); SmallVector expandedOpOperands; + expandedOpOperands.reserve(genericOp.getNumInputs()); for (OpOperand *opOperand : genericOp.getInputOperands()) { if (opOperand == fusableOpOperand) { expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() : collapsingReshapeOp.src()); continue; } - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); - RankedTensorType expandedOperandType = - getExpandedType(opOperand->get().getType().cast(), - indexingMap, expansionInfo); - if (expandedOperandType != opOperand->get().getType()) { - // Reshape the operand to get the right type. - SmallVector reassociation = - getReassociationForExpansion(indexingMap, expansionInfo); - expandedOpOperands.push_back(rewriter.create( - genericOp.getLoc(), expandedOperandType, opOperand->get(), - reassociation)); - continue; + if (genericOp.isInputTensor(opOperand)) { + AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + RankedTensorType expandedOperandType = + getExpandedType(opOperand->get().getType().cast(), + indexingMap, expansionInfo); + if (expandedOperandType != opOperand->get().getType()) { + // Reshape the operand to get the right type. + SmallVector reassociation = + getReassociationForExpansion(indexingMap, expansionInfo); + expandedOpOperands.push_back(rewriter.create( + genericOp.getLoc(), expandedOperandType, opOperand->get(), + reassociation)); + continue; + } } expandedOpOperands.push_back(opOperand->get()); } @@ -1035,7 +1038,7 @@ class FoldWithProducerReshapeOpByExpansion LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : genericOp.getInputOperands()) { + for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { TensorCollapseShapeOp reshapeOp = opOperand->get().getDefiningOp(); if (!reshapeOp)