Skip to content

Commit

Permalink
[mlir][linalg][bufferize][NFC] Rename functions in BufferizationState
Browse files Browse the repository at this point in the history
The old function names (e.g., `replaceOp`) could have been confusing to users because they sound similar to rewriter functions, but have slightly different semantics.

Differential Revision: https://reviews.llvm.org/D116449
  • Loading branch information
matthias-springer committed Jan 6, 2022
1 parent 670de10 commit bf9d8d9
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,22 +380,6 @@ class BufferizationState {
/// Creates a memcpy between two given buffers.
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;

/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOp(RewriterBase &rewriter, Operation *op,
ValueRange values) const;

/// Replace an op with a new op. Tensor OpResults must be replaced with memref
/// values.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
Args &&...args) const {
Operation *newOp =
rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOp(rewriter, op, newOp->getResults());
return cast<OpTy>(newOp);
}

/// Lookup the memref buffer that is associated to the given tensor value.
/// Asserts if no buffer is associated.
Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
Expand Down Expand Up @@ -443,6 +427,21 @@ class BufferizationState {
const BufferizationOptions &options;
};

/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
ValueRange values);

/// Replace an op with a new op. Tensor OpResults must be replaced with memref
/// values.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
Args &&...args) {
auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
return newOp;
}

/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct ConstantOpInterface

GlobalCreator globalCreator(moduleOp);
auto globalMemref = globalCreator.getGlobalFor(constantOp);
state.replaceOpWithNewOp<memref::GetGlobalOp>(
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
rewriter, op, globalMemref.type(), globalMemref.getName());
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
return operandBuffer;
}

void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
RewriterBase &rewriter, Operation *op, ValueRange values) const {
void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues(
RewriterBase &rewriter, Operation *op, ValueRange values) {
OpBuilder::InsertionGuard g(rewriter);

// Replace all OpResults with the given values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
op.clone(rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands);

// Replace the results of the old op with the new output buffers.
state.replaceOp(rewriter, op, newOutputBuffers);
replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);

return success();
}
Expand Down Expand Up @@ -201,7 +201,7 @@ struct InitTensorOpInterface

Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(),
initTensorOp.result());
state.replaceOp(rewriter, op, alloc);
replaceOpWithBufferizedValues(rewriter, op, alloc);
return success();
}
};
Expand Down Expand Up @@ -342,7 +342,7 @@ struct TiledLoopOpInterface
rewriter.eraseOp(oldTerminator);

// Replace results and delete old op.
state.replaceOp(rewriter, op, newResults);
replaceOpWithBufferizedValues(rewriter, op, newResults);

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ struct CallOpInterface
}

// 5. Replace the old op with the new op.
state.replaceOp(rewriter, callOp, replacementValues);
replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ struct IfOpInterface
}

// Replace op results.
state.replaceOp(rewriter, op, newIfOp->getResults());
replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());

return success();
}
Expand Down Expand Up @@ -326,7 +326,7 @@ struct ForOpInterface
yieldOp.getResultsMutable().assign(yieldValues);

// Replace loop results.
state.replaceOp(rewriter, op, newForOp->getResults());
replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ struct CastOpInterface
: MemRefLayoutAttrInterface();
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), layout, memorySpace);
state.replaceOpWithNewOp<memref::CastOp>(rewriter, op, memRefType,
resultBuffer);
replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, memRefType,
resultBuffer);
return success();
}
};
Expand Down Expand Up @@ -98,7 +98,7 @@ struct DimOpInterface
if (!dimOp.source().getType().isa<RankedTensorType>())
return dimOp.emitError("unranked tensor not supported");
Value v = state.lookupBuffer(rewriter, dimOp.source());
state.replaceOpWithNewOp<memref::DimOp>(rewriter, op, v, dimOp.index());
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
return success();
}
};
Expand Down Expand Up @@ -164,7 +164,7 @@ struct ExtractSliceOpInterface
subView = alloc;
}

state.replaceOp(rewriter, op, subView);
replaceOpWithBufferizedValues(rewriter, op, subView);
return success();
}
};
Expand All @@ -191,8 +191,8 @@ struct ExtractOpInterface
const BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
extractOp.indices());
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
extractOp.indices());
return success();
}
};
Expand Down Expand Up @@ -231,7 +231,7 @@ struct InsertOpInterface
state.getResultBuffer(rewriter, insertOp->getOpResult(0));
rewriter.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
insertOp.indices());
state.replaceOp(rewriter, op, destMemref);
replaceOpWithBufferizedValues(rewriter, op, destMemref);
return success();
}

Expand Down Expand Up @@ -413,7 +413,7 @@ struct InsertSliceOpInterface
Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);

state.replaceOp(rewriter, op, dstMemref);
replaceOpWithBufferizedValues(rewriter, op, dstMemref);
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ struct TransferReadOpInterface

// TransferReadOp always reads from the bufferized op.source().
Value buffer = state.lookupBuffer(rewriter, readOp.source());
Value read = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), readOp.getVectorType(), buffer, readOp.indices(),
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(),
readOp.permutation_map(), readOp.padding(), readOp.mask(),
readOp.in_boundsAttr());
state.replaceOp(rewriter, op, read);
return success();
}
};
Expand Down Expand Up @@ -101,7 +100,7 @@ struct TransferWriteOpInterface
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
state.replaceOp(rewriter, op, resultBuffer);
replaceOpWithBufferizedValues(rewriter, op, resultBuffer);

return success();
}
Expand Down

0 comments on commit bf9d8d9

Please sign in to comment.