Navigation Menu

Skip to content

Commit

Permalink
[mlir] Rename getTied* methods to getMatching* in LinalgInterface.
Browse files Browse the repository at this point in the history
Summary:
As mentioned in the comment to https://reviews.llvm.org/D134444, the term `tied`
is a misnomer in this context and `matching` sounds much better.

Differential Revision: https://reviews.llvm.org/D134534
  • Loading branch information
olegshyshkov committed Sep 30, 2022
1 parent 661403b commit 1227b8a
Show file tree
Hide file tree
Showing 20 changed files with 91 additions and 85 deletions.
Expand Up @@ -82,8 +82,8 @@ class LinalgDependenceGraph {
if (!owner)
return llvm::None;
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
return owner.getTiedIndexingMap(operand);
return owner.getTiedIndexingMap(owner.getOutputOperand(
return owner.getMatchingIndexingMap(operand);
return owner.getMatchingIndexingMap(owner.getOutputOperand(
opView.get<Value>().cast<OpResult>().getResultNumber()));
}
// Return the operand number if the `opView` is an OpOperand *. Otherwise
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Expand Up @@ -377,7 +377,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Return the block argument for an `opOperand`.
}],
/*retTy=*/"BlockArgument",
/*methodName=*/"getTiedBlockArgument",
/*methodName=*/"getMatchingBlockArgument",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Expand All @@ -390,7 +390,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Return the operand for a `blockArgument`.
}],
/*retTy=*/"OpOperand *",
/*methodName=*/"getTiedOpOperand",
/*methodName=*/"getMatchingOpOperand",
/*args=*/(ins "BlockArgument":$blockArgument),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Expand All @@ -404,7 +404,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Return the input or output indexing map for `opOperand`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getTiedIndexingMap",
/*methodName=*/"getMatchingIndexingMap",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Expand All @@ -419,7 +419,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Return the indexing map for a `result`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getTiedIndexingMapForResult",
/*methodName=*/"getIndexingMapMatchingResult",
/*args=*/(ins "OpResult":$result),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Expand All @@ -442,7 +442,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
`opOperand`.
}],
/*retTy=*/"OpOperand *",
/*methodName=*/"getTiedYieldValue",
/*methodName=*/"getMatchingYieldValue",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Expand Up @@ -34,7 +34,7 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
for (auto *opOperand : linalgOp.getInputAndOutputOperands()) {
if (llvm::is_contained(droppedOperands, opOperand))
continue;
indexingMaps.push_back(linalgOp.getTiedIndexingMap(opOperand));
indexingMaps.push_back(linalgOp.getMatchingIndexingMap(opOperand));
}
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}
Expand Down Expand Up @@ -658,7 +658,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
<< linalgOp.getNumInputsAndOutputs() << ")";

for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);

// Symbols disallowed.
if (indexingMap.getNumSymbols() != 0)
Expand Down Expand Up @@ -696,7 +696,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
for (int64_t &range : endLoopRangeValues)
range -= 1;
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
SmallVector<int64_t, 4> startIndices =
indexingMap.compose(startLoopRangeValues);
SmallVector<int64_t, 4> endIndices =
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -945,7 +945,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults

// Check if this operand is a duplicate.
AffineMap indexingMap =
genericOp.getTiedIndexingMap(inputOpOperand.value());
genericOp.getMatchingIndexingMap(inputOpOperand.value());
auto it = dedupedInputs.find(
std::make_pair(inputOpOperand.value()->get(), indexingMap));
if (it != dedupedInputs.end()) {
Expand Down Expand Up @@ -984,7 +984,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
newOutputOperands.push_back(outputOpOperand.value()->get());
newIndexingMaps.push_back(
genericOp.getTiedIndexingMap(outputOpOperand.value()));
genericOp.getMatchingIndexingMap(outputOpOperand.value()));
}
} else {
// Output argument can be dropped if the result has
Expand All @@ -997,7 +997,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
llvm::enumerate(genericOp.getOutputOperands())) {
Value result = genericOp.getResult(outputOpOperand.index());
AffineMap indexingMap =
genericOp.getTiedIndexingMap(outputOpOperand.value());
genericOp.getMatchingIndexingMap(outputOpOperand.value());
auto key =
std::make_tuple(outputOpOperand.value()->get(), indexingMap,
yieldOp->getOperand(outputOpOperand.index()));
Expand Down Expand Up @@ -1033,7 +1033,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
dedupedOutpts[key] = newOutputOperands.size();
newOutputOperands.push_back(outputOpOperand.value()->get());
newIndexingMaps.push_back(
genericOp.getTiedIndexingMap(outputOpOperand.value()));
genericOp.getMatchingIndexingMap(outputOpOperand.value()));
}
}

Expand Down Expand Up @@ -1957,7 +1957,7 @@ static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
continue;
Value src = opOperand->get();
auto sourceType = src.getType().cast<RankedTensorType>();
auto sourceMap = linalgOp.getTiedIndexingMap(opOperand);
auto sourceMap = linalgOp.getMatchingIndexingMap(opOperand);

// Get the `sourceShape` of the `sourceType`. If the operand is a result of
// `tensor.cast` operation and source of the cast operation has a static
Expand Down Expand Up @@ -2005,7 +2005,7 @@ static void createNewOperandWithStaticSizes(
return;
}
ArrayRef<int64_t> sourceShape = sourceType.getShape();
AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand);
AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
SmallVector<int64_t> newShape;
// If operand is updated with new shape, `newOperandNeeded` will be
// true.
Expand Down
Expand Up @@ -81,7 +81,7 @@ struct BubbleUpExtractSliceOpPattern
}

OpOperand *outOperand = linalgOp.getOutputOperand(0);
AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
if (!indexingMap.isProjectedPermutation()) {
return rewriter.notifyMatchFailure(
sliceOp, "expected a projected permutation for output");
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
Expand Up @@ -180,7 +180,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
OpResult result = genericOp.getResult(*resultNumber).cast<OpResult>();
newResultTypes.push_back(result.getType());
peeledGenericOpIndexingMaps.push_back(
genericOp.getTiedIndexingMapForResult(result));
genericOp.getIndexingMapMatchingResult(result));
continue;
}

Expand Down Expand Up @@ -227,15 +227,16 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
/// as those used for the new results of the peeledGenericOp.
auto indexingMaps = llvm::to_vector(
llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) {
return genericOp.getTiedIndexingMap(operand);
return genericOp.getMatchingIndexingMap(operand);
}));
for (auto resultNum :
llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result));
indexingMaps.push_back(
peeledGenericOp.getIndexingMapMatchingResult(result));
}
for (OpOperand *outOperand : genericOp.getOutputOperands())
indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand));
indexingMaps.push_back(genericOp.getMatchingIndexingMap(outOperand));

auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
return rewriter.create<GenericOp>(
Expand Down Expand Up @@ -263,7 +264,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
}

if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
return !genericOp.getTiedIndexingMap(outOperand).isPermutation();
return !genericOp.getMatchingIndexingMap(outOperand).isPermutation();
})) {
return rewriter.notifyMatchFailure(
genericOp, "unhandled decomposition of generic op with out operand not "
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Expand Up @@ -245,7 +245,7 @@ struct UnitExtentReplacementInfo {
static llvm::Optional<UnitExtentReplacementInfo>
replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
MLIRContext *context) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
SmallVector<AffineExpr> reassociations;
Expand Down Expand Up @@ -390,7 +390,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
// type, indexing map, and create a set of mappings representing an
// identity matrix.
newInputOutputTypes.push_back(opOperand->get().getType());
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(opOperand));
int64_t origRank = genericOp.getRank(opOperand);
auto maps = llvm::to_vector<8>(llvm::map_range(
llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
Expand Down
43 changes: 23 additions & 20 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Expand Up @@ -59,7 +59,7 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(

LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
// argMap is a map from producer loop -> producer arg tensor index.
AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);

// Compose argMap with invProducerResultIndexMap to get a map from
// producer result tensor index -> producer arg tensor index.
Expand Down Expand Up @@ -95,14 +95,14 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {

// Get the consumer index map. The number of results of the consumer index
// map must match the number of loops of the producer.
AffineMap consumerIndexMap = consumer.getTiedIndexingMap(fusedOperand);
AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
return false;

// Finally the index_map for the result must be invertible. For now just
// verify it is a permutation.
AffineMap producerResultIndexMap =
producer.getTiedIndexingMap(producer.getOutputOperand(0));
producer.getMatchingIndexingMap(producer.getOutputOperand(0));
if (!producerResultIndexMap.isPermutation())
return false;

Expand Down Expand Up @@ -288,41 +288,41 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(it != consumerInputs.end() && "expected to find the consumer operand");
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
fusedInputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
}
// 4. Splice in producer's input operands/maps.
AffineMap producerResultIndexMap =
producer.getTiedIndexingMapForResult(producerResult);
producer.getIndexingMapMatchingResult(producerResult);
for (OpOperand *opOperand : producer.getInputOperands()) {
fusedInputOperands.push_back(opOperand->get());
// Compute indexing maps for the producer args in the fused operation.
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
opOperand, producerResultIndexMap,
consumer.getTiedIndexingMap(fusedOperand));
consumer.getMatchingIndexingMap(fusedOperand));
fusedIndexMaps.push_back(map);
}
// 5. Remaining consumer's input operands/maps (drop past index
// `consumerIdx`).
for (OpOperand *opOperand :
llvm::make_range(std::next(it), consumerInputs.end())) {
fusedInputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
}

// 6. Collect all of the producer outputs.
for (OpOperand *opOperand : producer.getOutputOperands()) {
fusedOutputOperands.push_back(opOperand->get());
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
opOperand, producerResultIndexMap,
consumer.getTiedIndexingMap(fusedOperand));
consumer.getMatchingIndexingMap(fusedOperand));
fusedIndexMaps.push_back(map);
fusedResultTypes.push_back(opOperand->get().getType());
}

// 7. All of consumer's output operands (skip operands: added by the builder).
for (OpOperand *opOperand : consumer.getOutputOperands()) {
fusedOutputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
fusedResultTypes.push_back(opOperand->get().getType());
}

Expand All @@ -344,7 +344,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,

// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
AffineMap consumerResultIndexMap = consumer.getTiedIndexingMap(fusedOperand);
AffineMap consumerResultIndexMap =
consumer.getMatchingIndexingMap(fusedOperand);
// tensor index -> producer loop
AffineMap invProducerResultIndexMap =
inversePermutation(producerResultIndexMap);
Expand Down Expand Up @@ -466,7 +467,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
.getValue()
.isProjectedPermutation();
}) &&
genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > 0 &&
llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) {
return it == getParallelIteratorTypeName();
});
Expand Down Expand Up @@ -517,7 +518,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);

SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
Expand Down Expand Up @@ -727,7 +728,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
continue;
}
if (genericOp.isInputTensor(opOperand)) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOperandType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
Expand Down Expand Up @@ -755,7 +756,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
Location loc = genericOp.getLoc();
SmallVector<Value> outputs;
for (OpOperand *opOperand : genericOp.getOutputOperands()) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOutputType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
Expand Down Expand Up @@ -802,7 +803,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
if (resultTypes[resultNumber] != opResult.getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(
genericOp.getTiedIndexingMap(
genericOp.getMatchingIndexingMap(
genericOp.getOutputOperand(resultNumber)),
expansionInfo);
resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
Expand Down Expand Up @@ -1063,7 +1064,7 @@ getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
}

llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
auto iteratorTypes = genericOp.getIteratorTypes().getValue();
SmallVector<ReassociationIndices> iterationSpaceReassociation;
for (ReassociationIndicesRef foldedRangeDims : reassociation) {
Expand Down Expand Up @@ -1312,7 +1313,7 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
OpOperand *opOperand,
const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
getOperandReassociation(indexingMap, collapsingInfo);

Expand Down Expand Up @@ -1470,7 +1471,7 @@ static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
genericOp.getTiedIndexingMapForResult(originalResult.value());
genericOp.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
Value result = rewriter.create<tensor::ExpandShapeOp>(
Expand Down Expand Up @@ -1594,12 +1595,14 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
if (inputOperand == opOperand)
continue;
Value inputValue = inputOperand->get();
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
fusedIndexMaps.push_back(
genericOp.getMatchingIndexingMap(inputOperand));
fusedOperands.push_back(inputValue);
fusedLocs.push_back(inputValue.getLoc());
}
for (OpOperand *outputOperand : genericOp.getOutputOperands())
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
fusedIndexMaps.push_back(
genericOp.getMatchingIndexingMap(outputOperand));

// Check if the operation shapes to loops map is computable.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Expand Up @@ -80,7 +80,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
opOperand->get().getDefiningOp()))
continue;

AffineMap map = op.getTiedIndexingMap(opOperand);
AffineMap map = op.getMatchingIndexingMap(opOperand);
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
<< opOperand->getOperandNumber() << "\n");
LLVM_DEBUG(llvm::dbgs()
Expand Down Expand Up @@ -442,7 +442,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpOperand *opOperand =
producerOp.getOutputOperand(producerOpResult.getResultNumber());
LinalgOp fusedProducer =
fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand),
consumerOpOperand);

// Replace use.
Expand Down

0 comments on commit 1227b8a

Please sign in to comment.