121 changes: 90 additions & 31 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ transform::BufferizeToAllocationOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
Attribute memorySpace =
getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
auto transformed = llvm::to_vector(
llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) {
return linalg::bufferizeToAllocation(rewriter, v, memorySpace);
Expand Down Expand Up @@ -207,7 +208,8 @@ transform::DecomposeOp::applyToOne(LinalgOp target,
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
static LogicalResult applyTilingToAll(
Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops,
RewriterBase &rewriter, Operation *transformOp,
ArrayRef<Operation *> payloadOps, unsigned numLoops,
transform::TransformResults &transformResults,
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
applyFn) {
Expand All @@ -221,7 +223,6 @@ static LogicalResult applyTilingToAll(
if (!tilingInterfaceOp)
return transformOp->emitError("only TilingInterface ops are supported");

IRRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
applyFn(tilingInterfaceOp);
Expand Down Expand Up @@ -300,12 +301,13 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
tilingOptions = tilingOptions.setTileSizes(tileSizes);
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions = tilingOptions;
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
LogicalResult result = applyTilingToAll(
getOperation(), state.getPayloadOps(getTarget()),
rewriter, getOperation(), state.getPayloadOps(getTarget()),
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
IRRewriter rewriter(getContext());
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, tilingInterfaceOp, tileAndFuseOptions);
});
Expand Down Expand Up @@ -620,7 +622,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
return failure();
};

IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
while (!remainingProducers.empty()) {
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
Expand Down Expand Up @@ -692,7 +695,8 @@ transform::GeneralizeOp::applyToOne(LinalgOp target,
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
if (succeeded(generic)) {
Expand All @@ -716,7 +720,8 @@ transform::InterchangeOp::applyToOne(GenericOp target,
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
IRRewriter rewriter(target->getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
FailureOr<GenericOp> res =
interchangeGenericOp(rewriter, target,
SmallVector<unsigned>(interchangeVector.begin(),
Expand Down Expand Up @@ -866,7 +871,8 @@ static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
tensor::PackOp target, transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
IRRewriter rewriter(target->getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
if (failed(res)) {
Expand Down Expand Up @@ -995,7 +1001,8 @@ static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
IRRewriter rewriter(target->getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
if (failed(res)) {
Expand All @@ -1021,6 +1028,15 @@ void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
result.addTypes(pdl::OperationType::get(builder.getContext()));
}

void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value target,
ArrayRef<StringRef> opNames) {
result.addOperands(target);
result.addAttribute(MatchOp::getOpsAttrName(result.name),
builder.getStrArrayAttr(opNames));
result.addTypes(resultTypes);
}

DiagnosedSilenceableFailure
transform::MatchOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
Expand Down Expand Up @@ -1241,7 +1257,8 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
DiagnosedSilenceableFailure status = unpackSingleIndexResultPDLOperations(
state, *this, packedSizes, getMixedPackedSizes());

IRRewriter rewriter(linalgOp->getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(linalgOp);
FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
if (failed(maybeResult))
Expand Down Expand Up @@ -1425,7 +1442,8 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());

SmallVector<Operation *> results;
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
for (Operation *op : targetOps) {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
Expand Down Expand Up @@ -1598,7 +1616,8 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
assert(packOp && linalgOp && "unexpected null op");

// Step 3. Actually transpose the ops.
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
FailureOr<PackTransposeResult> res = packTranspose(
rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
// Preconditions have been checked, it is an error to fail here.
Expand Down Expand Up @@ -1671,7 +1690,8 @@ transform::PadOp::applyToOne(LinalgOp target,
transposePaddings.push_back(
extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));

IRRewriter rewriter(target->getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
LinalgOp paddedOp;
FailureOr<SmallVector<Value>> result = rewriteAsPaddedOp(
rewriter, target, extractFromI64ArrayAttr(getPaddingDimensions()),
Expand Down Expand Up @@ -1744,7 +1764,8 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
if (!padOp || !loopOp)
return emitDefiniteFailure() << "requires exactly 2 non-null handles";

IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
FailureOr<linalg::detail::PackingResult> result =
linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
getTranspose());
Expand Down Expand Up @@ -1789,7 +1810,7 @@ transform::HoistPadOp::applyToOne(tensor::PadOp target,
tensor::PadOp hoistedPadOp;
SmallVector<GenericOp> transposeOps;
TrackingListener listener(state, *this);
IRRewriter rewriter(target->getContext(), &listener);
IRRewriter rewriter(getContext(), &listener);
FailureOr<Value> result =
hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
hoistedPadOp, transposeOps);
Expand Down Expand Up @@ -1872,7 +1893,8 @@ transform::PromoteOp::applyToOne(LinalgOp target,
if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
return emitDefaultDefiniteFailure(target);

IRRewriter rewriter(target->getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
if (failed(res))
Expand Down Expand Up @@ -1901,7 +1923,8 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
}

// Clone and replace.
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
Operation *pattern = &getBodyRegion().front().front();
SmallVector<Operation *> replacements;
for (Operation *target : payload) {
Expand Down Expand Up @@ -1957,7 +1980,8 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
AffineMap map = target.getShapesToLoopsMap();
if (!map)
return tileSizes;
IRRewriter rewriter(b);
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
SmallVector<OpFoldResult> shapeSizes =
makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
allShapeSizes);
Expand All @@ -1971,7 +1995,8 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
return tileSizes;
});
SmallVector<int64_t> emptyTileSizes;
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
Expand All @@ -1998,7 +2023,8 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
SmallVector<Operation *> res;
IRRewriter rewriter(target->getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
FailureOr<Operation *> maybeResult =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
Expand All @@ -2020,7 +2046,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
TransformState &state) {
// Collect the dynamic split points if provided.
ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
SmallVector<OpFoldResult> splitPoints;
splitPoints.reserve(payload.size());
if (getDynamicSplitPoint()) {
Expand Down Expand Up @@ -2227,7 +2254,8 @@ DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
unsigned(getInsertSplitDimension()),
bool(getInnerParallel())};
};
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
FailureOr<SplitReductionResult> splitResult =
(getUseScalingAlgorithm())
Expand Down Expand Up @@ -2267,7 +2295,8 @@ void transform::TileReductionUsingScfOp::build(
DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
Expand Down Expand Up @@ -2310,7 +2339,8 @@ void transform::TileReductionUsingForallOp::build(
DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
SmallVector<OpFoldResult> numThreads =
getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
Expand Down Expand Up @@ -2497,7 +2527,8 @@ transform::TileOp::apply(TransformResults &transformResults,
}

tilingOptions.setInterchange(getInterchange());
IRRewriter rewriter(op->getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
if (failed(maybeTilingResult))
Expand Down Expand Up @@ -2739,7 +2770,8 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
DiagnosedSilenceableFailure
transform::TileToForallOp::apply(transform::TransformResults &transformResults,
transform::TransformState &state) {
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
auto transformOp = cast<TransformOpInterface>(getOperation());
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());

Expand Down Expand Up @@ -2821,6 +2853,30 @@ LogicalResult TileToForallOp::verify() {
// TileToScfForOp
//===----------------------------------------------------------------------===//

void transform::TileToScfForOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedTileSizes,
ArrayRef<int64_t> interchange) {
SmallVector<int64_t> staticTileSizes;
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
// Call the default builder which sets up the proper operands segment sizes
// attributes for multiple variadic operands. In the absence of this,
// horrible bugs ensue.
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
int64_t numExpectedLoops =
staticTileSizes.size() - llvm::count(staticTileSizes, 0);
SmallVector<Type> resultTypes(numExpectedLoops,
pdl::OperationType::get(builder.getContext()));
build(builder, result,
/*tiled_linalg_op=*/target.getType(),
/*loops=*/resultTypes,
/*target=*/target,
/*dynamic_sizes=*/dynamicTileSizes,
/*static_sizes=*/staticTileSizesAttr,
/*interchange=*/builder.getDenseI64ArrayAttr(interchange));
}

DiagnosedSilenceableFailure
transform::TileToScfForOp::apply(TransformResults &transformResults,
TransformState &state) {
Expand Down Expand Up @@ -2890,7 +2946,8 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
}

tilingOptions.setInterchange(getInterchange());
IRRewriter rewriter(tilingInterfaceOp.getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
if (failed(tilingResult))
Expand Down Expand Up @@ -3055,7 +3112,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
IRRewriter rewriter(getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
if (targets.empty())
return DiagnosedSilenceableFailure::success();
Expand Down Expand Up @@ -3160,7 +3218,8 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
linalg::LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
IRRewriter rewriter(target->getContext());
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
rewriter.setInsertionPoint(target);
auto maybeTransformed =
TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
Expand Down Expand Up @@ -3195,7 +3254,7 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
TrackingListener listener(state, *this);
IRRewriter rewriter(target->getContext(), &listener);
IRRewriter rewriter(getContext(), &listener);
auto forOp = dyn_cast<scf::ForOp>(target);
if (forOp) {
linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);
Expand Down
26 changes: 22 additions & 4 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,12 @@ func.func @async_cp_i4(

// -----

// CHECK-LABEL: @async_cp_zfill(
// CHECK-LABEL: @async_cp_zfill_f32_align4(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
func.func @async_cp_zfill(
func.func @async_cp_zfill_f32_align4(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {

// CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
// CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
// CHECK: nvvm.cp.async.commit.group
%1 = nvgpu.device_async_create_group %0
Expand All @@ -312,6 +312,24 @@ func.func @async_cp_zfill(

// -----

// CHECK-LABEL: @async_cp_zfill_f32_align1(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
func.func @async_cp_zfill_f32_align1(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
// CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
// CHECK: nvvm.cp.async.commit.group
%1 = nvgpu.device_async_create_group %0
// CHECK: nvvm.cp.async.wait.group 1
nvgpu.device_async_wait %1 { numGroups = 1 : i32 }

return
}

// -----


// CHECK-LABEL: func @mma_sp_sync_f16_16832(
func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
%arg1: vector<4x2xf16>,
Expand Down