Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][TilingInterface] Move TilingInterface tests to use transform dialect ops. #77204

Merged

Conversation

MaheshRavishankar
Copy link
Contributor

In the process a couple of test transform dialect ops are added just
for testing. These operations are not intended to use as full flushed
out of transformation ops, but are rather operations added for testing.

A separate operation is added to LinalgTransformOps.td to convert a
TilingInterface operation to loops using the
generateScalarImplementation method implemented by the
operation. Eventually this and other operations related to tiling
using the TilingInterface need to move to a better place (i.e. out
of Linalg dialect)

@MaheshRavishankar
Copy link
Contributor Author

This is based on top of #76871 . Only the last commit is part of this PR

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 6, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-scf

Author: None (MaheshRavishankar)

Changes

In the process a couple of test transform dialect ops are added just
for testing. These operations are not intended to use as full flushed
out of transformation ops, but are rather operations added for testing.

A separate operation is added to LinalgTransformOps.td to convert a
TilingInterface operation to loops using the
generateScalarImplementation method implemented by the
operation. Eventually this and other operations related to tiling
using the TilingInterface need to move to a better place (i.e. out
of Linalg dialect)


Patch is 111.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77204.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+31-1)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+24)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+22-33)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+48-22)
  • (modified) mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir (+84-10)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir (+103-22)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir (+13-1)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir (+74-13)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir (+117-38)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir (+43-2)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt (+9-1)
  • (removed) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp (-650)
  • (added) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+267)
  • (added) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+70)
  • (added) mlir/test/lib/Interfaces/TilingInterface/lit.local.cfg (+1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bc257d17483e3b..7d10ba0ae829e5 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -293,7 +293,10 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
   let results = (outs TransformHandleTypeInterface:$transformed,
                       Variadic<TransformHandleTypeInterface>:$loops);
 
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    $target ($tile_sizes^)? (`interchange` $tile_interchange^)?
+    attr-dict `:` functional-type(operands, results)
+  }];
   let hasVerifier = 1;
 }
 
@@ -1269,6 +1272,33 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    For operations that implement the `TilingInterface`, and implement
+    the `generateScalarImplementation` method, lowers the operation to
+    loops. This operation does not return any handles.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $target attr-dict `:` type($target)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::TilingInterface target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // DecomposeInterfaceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 2f8f337bb8057c..5d2d78e6e6165b 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -97,6 +97,30 @@ struct SCFTileAndFuseOptions {
     tilingOptions = options;
     return *this;
   }
+
+  /// Control function to check if a slice needs to be fused or not,
+  /// The control function receives
+  /// 1) the slice along which fusion is to be done,
+  /// 2) the producer value that is to be fused
+  /// 3) a boolean value set to `true` if the fusion is from
+  ///    a destination operand.
+  /// It retuns two booleans
+  /// - returns `true` if the fusion should be done through the candidate slice
+  /// - returns `true` if a replacement for the fused producer needs to be
+  ///   yielded from within the tiled loop. Note that it is valid to return
+  ///   `true` only if the slice fused is disjoint across all iterations of the
+  ///   tiled loop. It is up to the caller to ensure that this is true for the
+  ///   fused producers.
+  using ControlFnTy = std::function<std::tuple<bool, bool>(
+      tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
+      bool isDestinationOperand)>;
+  ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
+    return std::make_tuple(true, false);
+  };
+  SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
+    fusionControlFn = controlFn;
+    return *this;
+  }
 };
 
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5254aac976f462..97d2b4a3be5c56 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -492,38 +492,6 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
                         : DiagnosedSilenceableFailure::success();
 }
 
-ParseResult transform::FuseOp::parse(OpAsmParser &parser,
-                                     OperationState &result) {
-  OpAsmParser::UnresolvedOperand targetOperand;
-  if (parser.parseOperand(targetOperand) ||
-      parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-
-  FunctionType trailingType;
-  SMLoc typeLoc;
-  if (parser.getCurrentLocation(&typeLoc) ||
-      parser.parseColonType(trailingType)) {
-    return failure();
-  }
-  if (trailingType.getNumInputs() != 1)
-    return parser.emitError(typeLoc) << "expected one input type";
-
-  result.addTypes(trailingType.getResults());
-  if (parser.resolveOperand(targetOperand, trailingType.getInput(0),
-                            result.operands))
-    return failure();
-  return success();
-}
-
-void transform::FuseOp::print(OpAsmPrinter &p) {
-  p << ' ';
-  p << getTarget();
-  p.printOptionalAttrDict((*this)->getAttrs());
-  p << " : ";
-  p.printFunctionalType(TypeRange(getOperand().getType()),
-                        getResults().getTypes());
-}
-
 LogicalResult transform::FuseOp::verify() {
   SmallVector<int64_t> permutation =
       extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
@@ -2111,6 +2079,22 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertToLoopsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
+    transform::TransformRewriter &rewriter, TilingInterface target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  FailureOr<SmallVector<scf::ForOp>> loops =
+      scf::lowerToLoopsUsingSCFForOp(rewriter, target);
+  if (failed(loops))
+    return emitDefaultDefiniteFailure(target);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // RewriteInDestinationPassingStyleOp
 //===----------------------------------------------------------------------===//
@@ -2620,7 +2604,12 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
     }
 
     scf::SCFTilingOptions tilingOptions;
-    if (!tileSizes.empty()) {
+    if (tileSizes.empty()) {
+      tilingOptions.setTileSizeComputationFunction(
+          [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
+            return {};
+          });
+    } else {
       tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
                                                                   Operation *) {
         SmallVector<OpFoldResult> sizes;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1b6b4db9d20907..38e0625d7ce093 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -283,10 +283,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
   size_t numLoops = iterationDomain.size();
-  if (numLoops == 0) {
-    return rewriter.notifyMatchFailure(
-        op, "unable to tile op with no iteration domain");
-  }
+
   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
@@ -728,32 +725,36 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   }
 
   // 1. First tile the consumer.
-  SmallVector<scf::ForOp> forLoops;
   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
-  DenseMap<Value, Value> replacements;
-  llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
-  {
-    FailureOr<scf::SCFTilingResult> tilingResult =
-        tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
-    if (failed(tilingResult))
-      return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
-    for (auto *tiledOp : tilingResult->tiledOps)
-      tiledAndFusedOps.insert(tiledOp);
-    forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
-    for (auto [index, origValue, replacement] :
-         llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
-      replacements[origValue] = replacement;
-      yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
-          index)] = index;
-    }
-  }
+  llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
+  FailureOr<scf::SCFTilingResult> tilingResult =
+      tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
+  if (failed(tilingResult))
+    return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
+  for (auto *tiledOp : tilingResult->tiledOps)
+    tiledAndFusedOps.insert(tiledOp);
+  SmallVector<scf::ForOp> forLoops =
+      castToTypedOperations<scf::ForOp>(tilingResult->loops);
 
   // If there are no loops generated, fusion is immaterial.
   if (forLoops.empty()) {
+    DenseMap<Value, Value> replacements;
+    for (auto [origVal, replacement] :
+         llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
+      replacements[origVal] = replacement;
+    }
     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
                                      getAsOperations(forLoops), replacements};
   }
 
+  // To keep track of replacements for now just record the map from the original
+  // untiled value to the result number of the for loop. Since the loop gets
+  // potentially replaced during fusion, keeping the value directly wont work.
+  DenseMap<Value, size_t> origValToResultNumber;
+  for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
+    origValToResultNumber[result] = index;
+  }
+
   // 2. Typically, the operands of the tiled operation are slices of the
   //    operands of the untiled operation. These are expressed in IR using
   //    `tensor.extract_slice` operations with source being the operands of the
@@ -776,6 +777,18 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
     candidates.pop_front();
 
+    // Find the original producer of the slice.
+    auto [fusableProducer, destinationInitArg] =
+        getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
+                                          forLoops);
+    if (!fusableProducer)
+      continue;
+
+    auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
+        candidateSliceOp, fusableProducer, destinationInitArg.has_value());
+    if (!fuseSlice)
+      continue;
+
     // The operands of the fused producer might themselved be slices of
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
@@ -784,6 +797,13 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     if (!fusedResult)
       continue;
 
+    if (yieldReplacement) {
+      yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
+                                       fusedResult.value(), forLoops);
+      origValToResultNumber[fusableProducer] =
+          forLoops.front().getNumResults() - 1;
+    }
+
     if (Operation *tiledAndFusedOp =
             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
@@ -791,6 +811,12 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
       addCandidateSlices(tiledAndFusedOp, candidates);
     }
   }
+
+  DenseMap<Value, Value> replacements;
+  for (auto [origVal, resultNumber] : origValToResultNumber) {
+    replacements[origVal] = forLoops.front()->getResult(resultNumber);
+  }
+
   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
                                    getAsOperations(forLoops), replacements};
 }
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index c8199c325abfec..7245498f641ecf 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -transform-interpreter -split-input-file -canonicalize -cse %s | FileCheck %s
 
 func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
   %arg2 : memref<?x?xf32>) {
@@ -6,13 +6,22 @@ func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
       outs(%arg2 : memref<?x?xf32>)
   return
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %matmul : !transform.any_op
+    transform.yield
+  }
+}
 // CHECK-LABEL: func @gemm
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
 //       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
@@ -51,6 +60,15 @@ func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
     }
   return
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %generic : !transform.any_op
+    transform.yield
+  }
+}
 // CHECK-LABEL: func @indexed_generic
 //  CHECK-SAME:     %[[ARG0:.+]]: memref<200x300xi32>
 //  CHECK-SAME:     %[[ARG1:.+]]: memref<300xi16>
@@ -87,8 +105,18 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
       outs(%arg2 : memref<?x?x?x?xf32>)
   return
 }
-//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)>
-//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %conv : !transform.any_op
+    transform.yield
+  }
+}
+
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
 //       CHECK: func @conv_strides_and_dilation(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
@@ -111,8 +139,8 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
 //       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
 //       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
 //       CHECK:               scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
-//   CHECK-DAG:                 %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
-//   CHECK-DAG:                 %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
+//   CHECK-DAG:                 %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
+//   CHECK-DAG:                 %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
 //   CHECK-DAG:                 %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]]
 //   CHECK-DAG:                 %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]]
 //   CHECK-DAG:                 %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
@@ -131,8 +159,18 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
       outs(%arg2 : memref<?x?x?x?xf32>)
   return
 }
-//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)>
-//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %pool : !transform.any_op
+    transform.yield
+  }
+}
+
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
 //       CHECK: func @pool_strides_and_dilation
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
@@ -153,8 +191,8 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
 //       CHECK:         scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
 //       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
 //       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
-//   CHECK-DAG:               %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
-//   CHECK-DAG:               %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
+//   CHECK-DAG:               %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
+//   CHECK-DAG:               %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
 //   CHECK-DAG:               %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]]
 //   CHECK-DAG:               %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
 //       CHECK:               %[[T10:.+]] = arith.maximumf %[[T9]], %[[T8]]
@@ -172,6 +210,15 @@ func.func @map(%lhs: memref<64xf32>,
     }
   return
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %map = transform.structured.match ops{["linalg.map"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %map : !transform.any_op
+    transform.yield
+  }
+}
 // CHECK-LABEL: func.func @map(
 // CHECK-SAME:    %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>,
 // CHECK-SAME:    %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>,
@@ -195,6 +242,15 @@ func.func @transpose(%arg0: memref<16x32x64xf32>,
                    outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0]
   return
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %transpose : !transform.any_op
+    transform.yield
+  }
+}
 // CHECK-LABEL: func.func @transpose(
 // CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]: me...
[truncated]

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good to me, getting rid of test passes is +1. It would be nice to phase out test ops over time too but it makes sense to me to go step by step and not introduce something unstable.

…ialect ops.

In the process a couple of test transform dialect ops are added just
for testing. These operations are not intended to use as full flushed
out of transformation ops, but are rather operations added for testing.

A separate operation is added to `LinalgTransformOps.td` to convert a
`TilingInterface` operation to loops using the
`generateScalarImplementation` method implemented by the
operation. Eventually this and other operations related to tiling
using the `TilingInterface` need to move to a better place (i.e. out
of `Linalg` dialect)
@MaheshRavishankar MaheshRavishankar merged commit aa2a96a into llvm:main Jan 12, 2024
3 of 4 checks passed
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…ialect ops. (llvm#77204)

In the process a couple of test transform dialect ops are added just
for testing. These operations are not intended to use as full flushed
out of transformation ops, but are rather operations added for testing.

A separate operation is added to `LinalgTransformOps.td` to convert a
`TilingInterface` operation to loops using the
`generateScalarImplementation` method implemented by the
operation. Eventually this and other operations related to tiling
using the `TilingInterface` need to move to a better place (i.e. out
of `Linalg` dialect)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants