Skip to content

Commit

Permalink
[mlir] Update flipped accessors (NFC)
Browse files Browse the repository at this point in the history
Follow up with memref flipped and flipping any intermediate changes
made.
  • Loading branch information
jpienaar committed Jun 28, 2022
1 parent 08d651d commit 04235d0
Show file tree
Hide file tree
Showing 39 changed files with 299 additions and 290 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
Expand Up @@ -65,7 +65,7 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
if (!resultType)
return failure();
auto newOp =
rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
mixedOffsets, mixedSizes, mixedStrides);
CastOpFunc func;
func(rewriter, op, newOp);
Expand Down
43 changes: 22 additions & 21 deletions mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
Expand Up @@ -86,9 +86,9 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
// Fold producer-consumer reshape ops that where the operand type of the
// producer is same as the return type of the consumer.
auto reshapeSrcOp =
reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
return reshapeSrcOp.src();
return reshapeSrcOp.getSrc();
// Reshape of a constant can be replaced with a new constant.
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
return elements.reshape(
Expand Down Expand Up @@ -122,10 +122,10 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
"extent dimensions to zero-rank tensor/memref");
return success();
}
if (collapsedRank != op.reassociation().size())
if (collapsedRank != op.getReassociation().size())
return op.emitOpError("expected rank of the collapsed type(")
<< collapsedRank << ") to be the number of reassociation maps("
<< op.reassociation().size() << ")";
<< op.getReassociation().size() << ")";
auto maps = op.getReassociationMaps();
for (auto it : llvm::enumerate(maps))
if (it.value().getNumDims() != expandedRank)
Expand Down Expand Up @@ -172,15 +172,16 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
PatternRewriter &rewriter) const override {
auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
auto srcReshapeOp =
reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
if (!srcReshapeOp)
return failure();

ShapedType resultType = reshapeOp.getResultType();

if (hasNonIdentityLayout(srcReshapeOp.src().getType()) ||
hasNonIdentityLayout(reshapeOp.src().getType()) ||
hasNonIdentityLayout(reshapeOp.result().getType()))
if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
hasNonIdentityLayout(reshapeOp.getResult().getType()))
return failure();

Optional<SmallVector<ReassociationIndices>> reassociationIndices =
Expand All @@ -190,7 +191,7 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
if (!reassociationIndices)
return failure();
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
return success();
}
};
Expand Down Expand Up @@ -228,16 +229,16 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
PatternRewriter &rewriter) const override {
auto expandOp = collapseOp.src().template getDefiningOp<ExpandOpTy>();
auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
if (!expandOp)
return failure();

ShapedType srcType = expandOp.getSrcType();
ShapedType resultType = collapseOp.getResultType();

if (hasNonIdentityLayout(collapseOp.src().getType()) ||
hasNonIdentityLayout(expandOp.src().getType()) ||
hasNonIdentityLayout(expandOp.result().getType()))
if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(expandOp.getResult().getType()))
return failure();

int64_t srcRank = srcType.getRank();
Expand Down Expand Up @@ -274,10 +275,10 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
}
if (isResultCollapsed)
rewriter.replaceOpWithNewOp<CollapseOpTy>(
collapseOp, resultType, expandOp.src(), composedReassociation);
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
else
rewriter.replaceOpWithNewOp<ExpandOpTy>(
collapseOp, resultType, expandOp.src(), composedReassociation);
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
return success();
}
};
Expand All @@ -287,16 +288,16 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandOpTy expandOp,
PatternRewriter &rewriter) const override {
auto collapseOp = expandOp.src().template getDefiningOp<CollapseOpTy>();
auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
if (!collapseOp)
return failure();

ShapedType srcType = collapseOp.getSrcType();
ShapedType resultType = expandOp.getResultType();

if (hasNonIdentityLayout(expandOp.src().getType()) ||
hasNonIdentityLayout(collapseOp.src().getType()) ||
hasNonIdentityLayout(collapseOp.result().getType()))
if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getResult().getType()))
return failure();

int64_t srcRank = srcType.getRank();
Expand All @@ -314,7 +315,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
return failure();

rewriter.replaceOpWithNewOp<CollapseOpTy>(
expandOp, resultType, collapseOp.src(), *composedReassociation);
expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
return success();
}
auto composedReassociation =
Expand All @@ -324,7 +325,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
return failure();

rewriter.replaceOpWithNewOp<ExpandOpTy>(
expandOp, resultType, collapseOp.src(), *composedReassociation);
expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
return success();
}

Expand Down

0 comments on commit 04235d0

Please sign in to comment.