Expand Up
@@ -148,7 +148,8 @@ transform::BufferizeToAllocationOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
Attribute memorySpace =
getMemorySpace ().has_value () ? getMemorySpace ().value () : Attribute ();
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
auto transformed = llvm::to_vector (
llvm::map_range (state.getPayloadValues (getTarget ()), [&](Value v) {
return linalg::bufferizeToAllocation (rewriter, v, memorySpace);
Expand Down
Expand Up
@@ -207,7 +208,8 @@ transform::DecomposeOp::applyToOne(LinalgOp target,
// / Apply a tiling transformation to all payload ops and store both the
// / tiled operation as well as the created tile loops.
static LogicalResult applyTilingToAll (
Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops,
RewriterBase &rewriter, Operation *transformOp,
ArrayRef<Operation *> payloadOps, unsigned numLoops,
transform::TransformResults &transformResults,
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
applyFn) {
Expand All
@@ -221,7 +223,6 @@ static LogicalResult applyTilingToAll(
if (!tilingInterfaceOp)
return transformOp->emitError (" only TilingInterface ops are supported" );
IRRewriter rewriter (target->getContext ());
rewriter.setInsertionPoint (target);
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
applyFn (tilingInterfaceOp);
Expand Down
Expand Up
@@ -300,12 +301,13 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
tilingOptions = tilingOptions.setTileSizes (tileSizes);
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions = tilingOptions;
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
LogicalResult result = applyTilingToAll (
getOperation (), state.getPayloadOps (getTarget ()),
rewriter, getOperation (), state.getPayloadOps (getTarget ()),
tileSizes.size () - llvm::count (tileSizes, 0 ), transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
IRRewriter rewriter (getContext ());
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp (
rewriter, tilingInterfaceOp, tileAndFuseOptions);
});
Expand Down
Expand Up
@@ -620,7 +622,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
return failure ();
};
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
while (!remainingProducers.empty ()) {
auto nextProducer = getNextProducer ();
if (failed (nextProducer)) {
Expand Down
Expand Up
@@ -692,7 +695,8 @@ transform::GeneralizeOp::applyToOne(LinalgOp target,
results.push_back (target);
return DiagnosedSilenceableFailure::success ();
}
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
FailureOr<LinalgOp> generic = generalizeNamedOp (rewriter, target);
if (succeeded (generic)) {
Expand All
@@ -716,7 +720,8 @@ transform::InterchangeOp::applyToOne(GenericOp target,
results.push_back (target);
return DiagnosedSilenceableFailure::success ();
}
IRRewriter rewriter (target->getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
FailureOr<GenericOp> res =
interchangeGenericOp (rewriter, target,
SmallVector<unsigned >(interchangeVector.begin (),
Expand Down
Expand Up
@@ -866,7 +871,8 @@ static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne (
tensor::PackOp target, transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
IRRewriter rewriter (target->getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
FailureOr<LowerPackResult> res = lowerPack (rewriter, target);
if (failed (res)) {
Expand Down
Expand Up
@@ -995,7 +1001,8 @@ static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne (
tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
IRRewriter rewriter (target->getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
FailureOr<LowerUnPackOpResult> res = lowerUnPack (rewriter, target);
if (failed (res)) {
Expand All
@@ -1021,6 +1028,15 @@ void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
result.addTypes (pdl::OperationType::get (builder.getContext ()));
}
void transform::MatchOp::build (OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value target,
ArrayRef<StringRef> opNames) {
result.addOperands (target);
result.addAttribute (MatchOp::getOpsAttrName (result.name ),
builder.getStrArrayAttr (opNames));
result.addTypes (resultTypes);
}
DiagnosedSilenceableFailure
transform::MatchOp::apply (transform::TransformResults &results,
transform::TransformState &state) {
Expand Down
Expand Up
@@ -1241,7 +1257,8 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
DiagnosedSilenceableFailure status = unpackSingleIndexResultPDLOperations (
state, *this , packedSizes, getMixedPackedSizes ());
IRRewriter rewriter (linalgOp->getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (linalgOp);
FailureOr<PackResult> maybeResult = pack (rewriter, linalgOp, packedSizes);
if (failed (maybeResult))
Expand Down
Expand Up
@@ -1425,7 +1442,8 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
ArrayRef<Operation *> targetOps = state.getPayloadOps (getTarget ());
SmallVector<Operation *> results;
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
for (Operation *op : targetOps) {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
Expand Down
Expand Up
@@ -1598,7 +1616,8 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
assert (packOp && linalgOp && " unexpected null op" );
// Step 3. Actually transpose the ops.
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
FailureOr<PackTransposeResult> res = packTranspose (
rewriter, packOp, linalgOp, unPackOp, getOuterPerm (), getInnerPerm ());
// Preconditions have been checked, it is an error to fail here.
Expand Down
Expand Up
@@ -1671,7 +1690,8 @@ transform::PadOp::applyToOne(LinalgOp target,
transposePaddings.push_back (
extractFromI64ArrayAttr (transposeVector.cast <ArrayAttr>()));
IRRewriter rewriter (target->getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
LinalgOp paddedOp;
FailureOr<SmallVector<Value>> result = rewriteAsPaddedOp (
rewriter, target, extractFromI64ArrayAttr (getPaddingDimensions ()),
Expand Down
Expand Up
@@ -1744,7 +1764,8 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
if (!padOp || !loopOp)
return emitDefiniteFailure () << " requires exactly 2 non-null handles" ;
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
FailureOr<linalg::detail::PackingResult> result =
linalg::detail::buildPackingLoopNest (rewriter, padOp, loopOp,
getTranspose ());
Expand Down
Expand Up
@@ -1789,7 +1810,7 @@ transform::HoistPadOp::applyToOne(tensor::PadOp target,
tensor::PadOp hoistedPadOp;
SmallVector<GenericOp> transposeOps;
TrackingListener listener (state, *this );
IRRewriter rewriter (target-> getContext (), &listener);
IRRewriter rewriter (getContext (), &listener);
FailureOr<Value> result =
hoistPaddingOnTensors (rewriter, target, getNumLoops (), getTranspose (),
hoistedPadOp, transposeOps);
Expand Down
Expand Up
@@ -1872,7 +1893,8 @@ transform::PromoteOp::applyToOne(LinalgOp target,
if (failed (promoteSubviewsPrecondition (target, promotionOptions)))
return emitDefaultDefiniteFailure (target);
IRRewriter rewriter (target->getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
FailureOr<LinalgOp> res = promoteSubViews (rewriter, target, promotionOptions);
if (failed (res))
Expand Down
Expand Up
@@ -1901,7 +1923,8 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
}
// Clone and replace.
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
Operation *pattern = &getBodyRegion ().front ().front ();
SmallVector<Operation *> replacements;
for (Operation *target : payload) {
Expand Down
Expand Up
@@ -1957,7 +1980,8 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
AffineMap map = target.getShapesToLoopsMap ();
if (!map)
return tileSizes;
IRRewriter rewriter (b);
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
SmallVector<OpFoldResult> shapeSizes =
makeComposedFoldedMultiResultAffineApply (rewriter, loc, map,
allShapeSizes);
Expand All
@@ -1971,7 +1995,8 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
return tileSizes;
});
SmallVector<int64_t > emptyTileSizes;
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp (
rewriter, cast<TilingInterface>(target.getOperation ()), tilingOptions);
Expand All
@@ -1998,7 +2023,8 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
SmallVector<Operation *> res;
IRRewriter rewriter (target->getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
FailureOr<Operation *> maybeResult =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
Expand All
@@ -2020,7 +2046,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
TransformState &state) {
// Collect the dynamic split points if provided.
ArrayRef<Operation *> payload = state.getPayloadOps (getTarget ());
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
SmallVector<OpFoldResult> splitPoints;
splitPoints.reserve (payload.size ());
if (getDynamicSplitPoint ()) {
Expand Down
Expand Up
@@ -2227,7 +2254,8 @@ DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
unsigned (getInsertSplitDimension ()),
bool (getInnerParallel ())};
};
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
FailureOr<SplitReductionResult> splitResult =
(getUseScalingAlgorithm ())
Expand Down
Expand Up
@@ -2267,7 +2295,8 @@ void transform::TileReductionUsingScfOp::build(
DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne (
LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf (
rewriter, cast<PartialReductionOpInterface>(target.getOperation ()),
Expand Down
Expand Up
@@ -2310,7 +2339,8 @@ void transform::TileReductionUsingForallOp::build(
DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne (
LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
SmallVector<OpFoldResult> numThreads =
getAsOpFoldResult (rewriter.getI64ArrayAttr (getNumThreads ()));
Expand Down
Expand Up
@@ -2497,7 +2527,8 @@ transform::TileOp::apply(TransformResults &transformResults,
}
tilingOptions.setInterchange (getInterchange ());
IRRewriter rewriter (op->getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCFForOp (rewriter, tilingInterface, tilingOptions);
if (failed (maybeTilingResult))
Expand Down
Expand Up
@@ -2739,7 +2770,8 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
DiagnosedSilenceableFailure
transform::TileToForallOp::apply (transform::TransformResults &transformResults,
transform::TransformState &state) {
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
auto transformOp = cast<TransformOpInterface>(getOperation ());
ArrayRef<Operation *> targets = state.getPayloadOps (getTarget ());
Expand Down
Expand Up
@@ -2821,6 +2853,30 @@ LogicalResult TileToForallOp::verify() {
// TileToScfForOp
// ===----------------------------------------------------------------------===//
void transform::TileToScfForOp::build (OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedTileSizes,
ArrayRef<int64_t > interchange) {
SmallVector<int64_t > staticTileSizes;
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults (mixedTileSizes, dynamicTileSizes, staticTileSizes);
// Call the default builder which sets up the proper operands segment sizes
// attributes for multiple variadic operands. In the absence of this,
// horrible bugs ensue.
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr (staticTileSizes);
int64_t numExpectedLoops =
staticTileSizes.size () - llvm::count (staticTileSizes, 0 );
SmallVector<Type> resultTypes (numExpectedLoops,
pdl::OperationType::get (builder.getContext ()));
build (builder, result,
/* tiled_linalg_op=*/ target.getType (),
/* loops=*/ resultTypes,
/* target=*/ target,
/* dynamic_sizes=*/ dynamicTileSizes,
/* static_sizes=*/ staticTileSizesAttr,
/* interchange=*/ builder.getDenseI64ArrayAttr (interchange));
}
DiagnosedSilenceableFailure
transform::TileToScfForOp::apply (TransformResults &transformResults,
TransformState &state) {
Expand Down
Expand Up
@@ -2890,7 +2946,8 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
}
tilingOptions.setInterchange (getInterchange ());
IRRewriter rewriter (tilingInterfaceOp.getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp (rewriter, tilingInterfaceOp, tilingOptions);
if (failed (tilingResult))
Expand Down
Expand Up
@@ -3055,7 +3112,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply (
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
IRRewriter rewriter (getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
ArrayRef<Operation *> targets = state.getPayloadOps (getTarget ());
if (targets.empty ())
return DiagnosedSilenceableFailure::success ();
Expand Down
Expand Up
@@ -3160,7 +3218,8 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne (
linalg::LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
IRRewriter rewriter (target->getContext ());
TrackingListener listener (state, *this );
IRRewriter rewriter (getContext (), &listener);
rewriter.setInsertionPoint (target);
auto maybeTransformed =
TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
Expand Down
Expand Up
@@ -3195,7 +3254,7 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
TrackingListener listener (state, *this );
IRRewriter rewriter (target-> getContext (), &listener);
IRRewriter rewriter (getContext (), &listener);
auto forOp = dyn_cast<scf::ForOp>(target);
if (forOp) {
linalg::hoistRedundantSubsetExtractInsert (rewriter, forOp);
Expand Down