Skip to content

Conversation

amirBish
Copy link
Contributor

This PR adds a new transform operation which removes the duplicate arguments from the function operation based on the callOp of this function.

To have a more simple implementation for now, the transform will fail when having multiple callOps for the same function we want to eliminate the different arguments from.

This pull request also adpat the utils under the func dialect to be reusable also for this transformOp.

@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-func

Author: Amir Bishara (amirBish)

Changes

This PR adds a new transform operation which removes the duplicate arguments from the function operation based on the callOp of this function.

To have a more simple implementation for now, the transform will fail when having multiple callOps for the same function we want to eliminate the different arguments from.

This pull request also adpat the utils under the func dialect to be reusable also for this transformOp.


Patch is 31.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158266.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td (+26)
  • (modified) mlir/include/mlir/Dialect/Func/Utils/Utils.h (+19-12)
  • (modified) mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp (+95-6)
  • (modified) mlir/lib/Dialect/Func/Utils/Utils.cpp (+121-38)
  • (modified) mlir/test/Dialect/Func/func-transform-invalid.mlir (+108)
  • (modified) mlir/test/Dialect/Func/func-transform.mlir (+72)
diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
index 4062f310c6521..b64b3fcdb275b 100644
--- a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
+++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td
@@ -134,4 +134,30 @@ def ReplaceFuncSignatureOp
   }];
 }
 
+def DeduplicateFuncArgsOp
+    : Op<Transform_Dialect, "func.deduplicate_func_args",
+         [DeclareOpInterfaceMethods<TransformOpInterface>,
+          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+      This transform takes a module and a function name, and deduplicates
+      the arguments of the function. The function is expected to be defined in
+      the module.
+
+      This transform will emit a silenceable failure if:
+       - The function with the given name does not exist in the module.
+       - The function does not have duplicate arguments.
+       - The function does not have a single call.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$module,
+      SymbolRefAttr:$function_name);
+  let results = (outs TransformHandleTypeInterface:$transformed_module,
+                      TransformHandleTypeInterface:$transformed_function);
+
+  let assemblyFormat = [{
+    $function_name
+    `at` $module attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 #endif // FUNC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Func/Utils/Utils.h b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
index 2e8b6723a0e53..464ebc1305d60 100644
--- a/mlir/include/mlir/Dialect/Func/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Func/Utils/Utils.h
@@ -17,7 +17,7 @@
 #define MLIR_DIALECT_FUNC_UTILS_H
 
 #include "mlir/IR/PatternMatch.h"
-#include "llvm/ADT/ArrayRef.h"
+#include <map>
 
 namespace mlir {
 
@@ -27,21 +27,28 @@ class FuncOp;
 class CallOp;
 
 /// Creates a new function operation with the same name as the original
-/// function operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
+/// function operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
 /// The `funcOp` operation must have exactly one block.
 /// Returns the new function operation or failure if `funcOp` doesn't
 /// have exactly one block.
-FailureOr<FuncOp>
-replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
-                        llvm::ArrayRef<unsigned> newArgsOrder,
-                        llvm::ArrayRef<unsigned> newResultsOrder);
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole function arguments and results.
+mlir::FailureOr<mlir::func::FuncOp>
+replaceFuncWithNewMapping(mlir::RewriterBase &rewriter,
+                          mlir::func::FuncOp funcOp,
+                          const std::map<unsigned, unsigned> &oldArgToNewArg,
+                          const std::map<unsigned, unsigned> &oldResToNewRes);
 /// Creates a new call operation with the values as the original
-/// call operation, but with the arguments reordered according to
-/// the `newArgsOrder` and `newResultsOrder`.
-CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
-                                 llvm::ArrayRef<unsigned> newArgsOrder,
-                                 llvm::ArrayRef<unsigned> newResultsOrder);
+/// call operation, but with the arguments mapped according to
+/// the `oldArgToNewArg` and `oldResToNewRes`.
+/// Note: the method asserts that the `oldArgToNewArg` and `oldResToNewRes`
+/// maps the whole call operation arguments and results.
+mlir::func::CallOp
+replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter,
+                            mlir::func::CallOp callOp,
+                            const std::map<unsigned, unsigned> &oldArgToNewArg,
+                            const std::map<unsigned, unsigned> &oldResToNewRes);
 
 } // namespace func
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 935d3e5ac331b..486dad09f9392 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
 
@@ -296,9 +297,15 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
     }
   }
 
-  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
-      rewriter, funcOp, argsInterchange.getArrayRef(),
-      resultsInterchange.getArrayRef());
+  std::map<unsigned, unsigned> oldArgToNewArg, oldResToNewRes;
+  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(argsInterchange))
+    oldArgToNewArg[oldArgIdx] = newArgIdx;
+
+  for (auto [oldResIdx, newResIdx] : llvm::enumerate(resultsInterchange))
+    oldResToNewRes[oldResIdx] = newResIdx;
+
+  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+      rewriter, funcOp, oldArgToNewArg, oldResToNewRes);
   if (failed(newFuncOpOrFailure))
     return emitSilenceableFailure(getLoc())
            << "failed to replace function signature '" << getFunctionName()
@@ -312,9 +319,8 @@ transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
     });
 
     for (func::CallOp callOp : callOps)
-      func::replaceCallOpWithNewOrder(rewriter, callOp,
-                                      argsInterchange.getArrayRef(),
-                                      resultsInterchange.getArrayRef());
+      func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgToNewArg,
+                                        oldResToNewRes);
   }
 
   results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
@@ -330,6 +336,89 @@ void transform::ReplaceFuncSignatureOp::getEffects(
   transform::modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// DeduplicateFuncArgsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::DeduplicateFuncArgsOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+  auto payloadOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(payloadOps))
+    return emitDefiniteFailure() << "requires a single module to operate on";
+
+  auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
+  if (!targetModuleOp)
+    return emitSilenceableFailure(getLoc())
+           << "target is expected to be module operation";
+
+  func::FuncOp funcOp =
+      targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
+  if (!funcOp)
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName() << "' is not found";
+
+  SmallVector<func::CallOp> callOps;
+  targetModuleOp.walk([&](func::CallOp callOp) {
+    if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
+      callOps.push_back(callOp);
+  });
+
+  // TODO: Support more than one callOp.
+  if (!llvm::hasSingleElement(callOps))
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName()
+           << "' does not have a single callOp";
+
+  llvm::DenseSet<Value> seenValues;
+  func::CallOp callOp = callOps.front();
+  bool hasDuplicatesOperands =
+      llvm::any_of(callOp.getOperands(), [&seenValues](Value operand) {
+        return !seenValues.insert(operand).second;
+      });
+
+  if (!hasDuplicatesOperands)
+    return emitSilenceableFailure(getLoc())
+           << "function with name '" << getFunctionName()
+           << "' does not have duplicate operands";
+
+  std::map<unsigned, unsigned> oldArgIdxToNewArgIdx;
+  llvm::DenseMap<Value, unsigned> valueToNewArgIdx;
+  for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
+    if (!valueToNewArgIdx.count(operand))
+      valueToNewArgIdx[operand] = valueToNewArgIdx.size();
+    // Reduce the duplicate operands and maintain the original order.
+    oldArgIdxToNewArgIdx[operandIdx] = valueToNewArgIdx[operand];
+  }
+
+  std::map<unsigned, unsigned> oldResIdxToNewResIdx;
+  for (unsigned resultIdx = 0; resultIdx < callOp.getNumResults(); ++resultIdx)
+    oldResIdxToNewResIdx[resultIdx] = resultIdx;
+
+  FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewMapping(
+      rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
+  if (failed(newFuncOpOrFailure))
+    return emitSilenceableFailure(getLoc())
+           << "failed to deduplicate function arguments '" << getFunctionName()
+           << "'";
+
+  func::replaceCallOpWithNewMapping(rewriter, callOp, oldArgIdxToNewArgIdx,
+                                    oldResIdxToNewResIdx);
+
+  results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
+  results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::DeduplicateFuncArgsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getModuleMutable(), effects);
+  transform::producesHandle(getOperation()->getOpResults(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Func/Utils/Utils.cpp b/mlir/lib/Dialect/Func/Utils/Utils.cpp
index f781ed2d591b4..a58eb7233f460 100644
--- a/mlir/lib/Dialect/Func/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Func/Utils/Utils.cpp
@@ -18,16 +18,16 @@
 
 using namespace mlir;
 
-FailureOr<func::FuncOp>
-func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
-                              ArrayRef<unsigned> newArgsOrder,
-                              ArrayRef<unsigned> newResultsOrder) {
+FailureOr<func::FuncOp> func::replaceFuncWithNewMapping(
+    RewriterBase &rewriter, func::FuncOp funcOp,
+    const std::map<unsigned, unsigned> &oldArgToNewArg,
+    const std::map<unsigned, unsigned> &oldResToNewRes) {
   // Generate an empty new function operation with the same name as the
   // original.
-  assert(funcOp.getNumArguments() == newArgsOrder.size() &&
-         "newArgsOrder must match the number of arguments in the function");
-  assert(funcOp.getNumResults() == newResultsOrder.size() &&
-         "newResultsOrder must match the number of results in the function");
+  assert(funcOp.getNumArguments() == oldArgToNewArg.size() &&
+         "oldArgToNewArg must match the number of arguments in the function");
+  assert(funcOp.getNumResults() == oldResToNewRes.size() &&
+         "oldResToNewRes must match the number of results in the function");
 
   if (!funcOp.getBody().hasOneBlock())
     return rewriter.notifyMatchFailure(
@@ -37,12 +37,49 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
   SmallVector<Type> newInputTypes, newOutputTypes;
   SmallVector<Location> locs;
-  for (unsigned int idx : newArgsOrder) {
-    newInputTypes.push_back(origInputTypes[idx]);
-    locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
+
+  std::map<unsigned, SmallVector<unsigned>> newArgToOldArg;
+  for (auto [oldArgIdx, newArgIdx] : oldArgToNewArg)
+    newArgToOldArg[newArgIdx].push_back(oldArgIdx);
+
+  for (auto [newArgIdx, oldArgIdx] : newArgToOldArg) {
+    std::ignore = newArgIdx;
+    assert(llvm::all_of(oldArgIdx,
+                        [&funcOp](unsigned idx) -> bool {
+                          return idx < funcOp.getNumArguments();
+                        }) &&
+           "idx must be less than the number of arguments in the function");
+    assert(!oldArgIdx.empty() && "oldArgIdx must not be empty");
+    Type origInputTypeToCheck = origInputTypes[oldArgIdx.front()];
+    assert(llvm::all_of(oldArgIdx,
+                        [&](unsigned idx) -> bool {
+                          return origInputTypes[idx] == origInputTypeToCheck;
+                        }) &&
+           "all oldArgIdx must have the same type");
+    newInputTypes.push_back(origInputTypeToCheck);
+    locs.push_back(funcOp.getArgument(oldArgIdx.front()).getLoc());
+  }
+
+  std::map<unsigned, SmallVector<unsigned>> newResToOldRes;
+  for (auto [oldResIdx, newResIdx] : oldResToNewRes)
+    newResToOldRes[newResIdx].push_back(oldResIdx);
+
+  for (auto [newResIdx, oldResIdx] : newResToOldRes) {
+    std::ignore = newResIdx;
+    assert(llvm::all_of(oldResIdx,
+                        [&funcOp](unsigned idx) -> bool {
+                          return idx < funcOp.getNumResults();
+                        }) &&
+           "idx must be less than the number of results in the function");
+    Type origOutputTypeToCheck = origOutputTypes[oldResIdx.front()];
+    assert(llvm::all_of(oldResIdx,
+                        [&](unsigned idx) -> bool {
+                          return origOutputTypes[idx] == origOutputTypeToCheck;
+                        }) &&
+           "all oldResIdx must have the same type");
+    newOutputTypes.push_back(origOutputTypeToCheck);
   }
-  for (unsigned int idx : newResultsOrder)
-    newOutputTypes.push_back(origOutputTypes[idx]);
+
   rewriter.setInsertionPoint(funcOp);
   auto newFuncOp = func::FuncOp::create(
       rewriter, funcOp.getLoc(), funcOp.getName(),
@@ -58,14 +95,15 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   IRMapping operandMapper;
   SmallVector<DictionaryAttr> argAttrs, resultAttrs;
   funcOp.getAllArgAttrs(argAttrs);
-  for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
-    operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
-                      newFuncOp.getArgument(i));
-    newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
-  }
+  for (auto [oldArgIdx, newArgIdx] : oldArgToNewArg)
+    operandMapper.map(funcOp.getArgument(oldArgIdx),
+                      newFuncOp.getArgument(newArgIdx));
+  for (auto [newArgIdx, oldArgIdx] : newArgToOldArg)
+    newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
+
   funcOp.getAllResultAttrs(resultAttrs);
-  for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
-    newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
+  for (auto [newResIdx, oldResIdx] : newResToOldRes)
+    newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
 
   // Clone the operations from the original function to the new function.
   rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
@@ -76,8 +114,10 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   auto returnOp = cast<func::ReturnOp>(
       newFuncOp.getFunctionBody().begin()->getTerminator());
   SmallVector<Value> newReturnValues;
-  for (unsigned int idx : newResultsOrder)
-    newReturnValues.push_back(returnOp.getOperand(idx));
+  for (auto [newResIdx, oldResIdx] : newResToOldRes) {
+    std::ignore = newResIdx;
+    newReturnValues.push_back(returnOp.getOperand(oldResIdx.front()));
+  }
   rewriter.setInsertionPoint(returnOp);
   auto newReturnOp =
       func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
@@ -89,33 +129,76 @@ func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
   return newFuncOp;
 }
 
-func::CallOp
-func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
-                                ArrayRef<unsigned> newArgsOrder,
-                                ArrayRef<unsigned> newResultsOrder) {
+func::CallOp func::replaceCallOpWithNewMapping(
+    RewriterBase &rewriter, func::CallOp callOp,
+    const std::map<unsigned, unsigned> &oldArgToNewArg,
+    const std::map<unsigned, unsigned> &oldResToNewRes) {
   assert(
-      callOp.getNumOperands() == newArgsOrder.size() &&
-      "newArgsOrder must match the number of operands in the call operation");
+      callOp.getNumOperands() == oldArgToNewArg.size() &&
+      "oldArgToNewArg must match the number of operands in the call operation");
   assert(
-      callOp.getNumResults() == newResultsOrder.size() &&
-      "newResultsOrder must match the number of results in the call operation");
+      callOp.getNumResults() == oldResToNewRes.size() &&
+      "oldResToNewRes must match the number of results in the call operation");
+
+  // Inverse mapping from new arguments to old arguments.
+  std::map<unsigned, SmallVector<unsigned>> newArgToOldArg;
+  for (auto [oldArgIdx, newArgIdx] : oldArgToNewArg)
+    newArgToOldArg[newArgIdx].push_back(oldArgIdx);
+
   SmallVector<Value> newArgsOrderValues;
-  for (unsigned int argIdx : newArgsOrder)
-    newArgsOrderValues.push_back(callOp.getOperand(argIdx));
+  for (const auto &[newArgIdx, oldArgIdx] : newArgToOldArg) {
+    std::ignore = newArgIdx;
+    assert(
+        llvm::all_of(oldArgIdx,
+                     [&callOp](unsigned idx) -> bool {
+                       return idx < callOp.getNumOperands();
+                     }) &&
+        "idx must be less than the number of operands in the call operation");
+    assert(!oldArgIdx.empty() && "oldArgIdx must not be empty");
+    Value origOperandToCheck = callOp.getOperand(oldArgIdx.front());
+    assert(llvm::all_of(oldArgIdx,
+                        [&](unsigned idx) -> bool {
+                          return callOp.getOperand(idx).getType() ==
+                                 origOperandToCheck.getType();
+                        }) &&
+           "all oldArgIdx must have the same type");
+    newArgsOrderValues.push_back(origOperandToCheck);
+  }
+
   SmallVector<Type> newResultTypes;
-  for (unsigned int resIdx : newResultsOrder)
-    newResultTypes.push_back(callOp.getResult(resIdx).getType());
+  std::map<unsigned, SmallVector<unsigned>> newResToOldRes;
+  for (auto [oldResIdx, newResIdx] : oldResToNewRes)
+    newResToOldRes[newResIdx].push_back(oldResIdx);
+
+  for (auto [newResIdx, oldResIdx] : newResToOldRes) {
+    std::ignore = newResIdx;
+    assert(llvm::all_of(oldResIdx,
+                        [&callOp](unsigned idx) -> bool {
+                          return idx < callOp.getNumResults();
+                        }) &&
+           "idx must be less than the number of results in the call operation");
+    assert(!oldResIdx.empty() && "oldResIdx must not be empty");
+    Value origResultToCheck = callOp.getResult(oldResIdx.front());
+    assert(llvm::all_of(oldResIdx,
+                        [&](unsigned idx) -> bool {
+                          return callOp.getResult(idx).getType() ==
+                                 origResultToCheck.getType();
+                        }) &&
+           "all oldResIdx must have the same type");
+    newResultTypes.push_back(origResultToCheck.getType());
+  }
 
   // Replace the kernel call operation with a new one that has the
-  // reordered arguments.
+  // mapped arguments.
   rewriter.setInsertionPoint(callOp);
   auto newCallOp =
       func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
                            newResultTypes, newArgsOrderValues);
+  newCallOp->setDiscardableAttrs(callOp->getDiscardableAttrDictionary());
   newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
-  for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
-    rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
-                                newCallOp.getResult(newIndex));
+  for (auto &&[oldResIdx, newResIdx] : oldResToNewRes)
+    rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
+                                newCallOp.getResult(newResIdx));
   rewriter.eraseOp(callOp);
 
   return newCallOp;
diff --git a/mlir/test/Dialect/Func/func-transform-invalid.mlir b/mlir/test/Dialect/Func/func-transform-invalid.mlir
index e712eee83f36e..d260a36a723b6 100644
--- a/mlir/test/Dialect/Func/func-transform-invalid.mlir
+++ b/mlir/test/Dialect/Func/func-transform-i...
[truncated]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't %module == %arg0 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right, Removed the redundant lines and changed the %arg0 to be %module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::map is rarely the right data structure. Why not a DenseMap?

Even better: can we just use a vector? Seems like it would be bound by the number of arguments, which is a "small" size, so we can reserve the size ahead of time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I can't get why the key a vector here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The utils of the replaceFuncWithNewMapping now should have a mapping between the old argument index to the new argument index of the function operation. And we may have many (duplicated arguments) which should map to the same new argument index in the new function.

Now after the changes, I've used for the oldIndexToTheNewIndex mapping the SmallVector<unsigned>. And for the inverse mapping I'm using a SmallVector<SmallVector> since many old indexes may be mapped to a new index.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answered above :)

@amirBish amirBish force-pushed the amirBish/mlir/add_deduplicate_func_args branch from 8d09353 to ebac68f Compare September 12, 2025 14:37
@amirBish
Copy link
Contributor Author

Thanks for the review :), Could you please have a look again.

Comment on lines 365 to 366
Copy link
Collaborator

@joker-eph joker-eph Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
callOps.push_back(callOp);
if (callOp.getCallee() == getFunctionName().getRootReference().getValue()) {
if (!callOps.empty())
// TODO: Support more than one callOp.
return WalkResult::interrupt();
callOps.push_back(callOp);
}
return WalkResult::continue();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to traverse the entire module right now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems dangerous to stop at the first callOp operation and starts applying the transformation, User may have undefined behavior if he has multiple calls and they do not share mutual duplicate operands. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, you can stop at the second one instead :)
(I updated the diff, you can get the result and check if it has been interrupted)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, updated :)

Comment on lines 390 to 393
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!valueToNewArgIdx.count(operand))
valueToNewArgIdx[operand] = valueToNewArgIdx.size();
// Reduce the duplicate operands and maintain the original order.
oldArgIdxToNewArgIdx[operandIdx] = valueToNewArgIdx[operand];
auto [it, inserted] = valueToNewArgIdx.insert(operand, valueToNewArgIdx.size());
// Reduce the duplicate operands and maintain the original order.
oldArgIdxToNewArgIdx[operandIdx] = *it;

Something like this should allow to go from 3 map lookup to 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, fixed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the llvm:: prefix now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake, removed :)

Comment on lines 46 to 48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this condition guarding? Only if the input is empty? If so, this seems more direct:

Suggested change
auto maxNewArgIdx = llvm::max_element(oldArgToNewArg);
if (maxNewArgIdx != oldArgToNewArg.end())
numOfNewArgs = *maxNewArgIdx + 1;
if (! oldArgToNewArg.empty()
numOfNewArgs = 1 + *llvm::max_element(oldArgToNewArg);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, fixed.

Comment on lines 168 to 170
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like I'm missing something here, the following seems simpler, is it equivalent?

Suggested change
for (const auto &[newArgIdx, oldArgIdxs] :
llvm::enumerate(newArgIdxToOldArgIdxs)) {
std::ignore = newArgIdx;
for (const auto &oldArgIdxs : newArgIdxToOldArgIdxs) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, fixed it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole looks very close to the addition above for the func op, can this be refactored?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added two static methods one for creating the inverseMapping and one to get the new mapped elements based on it. Which has removed this duplicate code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The process for results seems oddly similar to the process for operands, can this be refactored?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added two static methods one for creating the inverseMapping and one to get the new mapped elements based on it. Which has removed this duplicate code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
newCallOp->setDiscardableAttrs(callOp->getDiscardableAttrDictionary());

That does not seem safe to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it, It was in the changes previously. I guess the good approach would be declaring the expected behavior in the listener before applying the transformation in the downstream projects.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of code here that seems unrelated to the "transform op" plumbing: can you please make this available to a utility usable outside of the transform dialect?
(for example a module pass could do this just as well)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, could you please share more of your thoughts about this issue? It sounds an interesting transformation which also can used by transform dialect with a transformOp. Maybe I'm not fully understood you POV.

Copy link
Collaborator

@joker-eph joker-eph Sep 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm saying that the "transform ops" aspect of the transformation is just the glue to hook it to an operation.
However the transformation has nothing to do with the "transform op" and should be usable without the transform dialect. That is: if should be a C++ utility function in a head and c++ file independent of anything to do with the transform dialect.
(possibly mlir/lib/Dialect/Func/Utils/Utils.cpp, but could also be `mlir/lib/Dialect/Func/Utils/ArgDedup.cpp or anything like that)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice added a util, preferred to add it in util.h

Comment on lines 131 to 134
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs)) {
std::ignore = newResIdx;
newReturnValues.push_back(returnOp.getOperand(oldResIdx.front()));
}
for (int oldResIdx : newResToOldResIdxs)
newReturnValues.push_back(returnOp.getOperand(oldResIdx.front()));

(is it equivalent?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure fixed, missed it after the changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::SmallVector<llvm::SmallVector<unsigned>> newResToOldResIdxs(
llvm::SmallVector<llvm::SmallVector<int>> newResToOldResIdxs(

Can we use int preferably everywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure changed, Still could you please elaborate more about why preferring the int instead of using the unsigned in this context (though we're handling indexes).

@amirBish amirBish force-pushed the amirBish/mlir/add_deduplicate_func_args branch from ebac68f to ffce307 Compare September 13, 2025 18:01
@amirBish
Copy link
Contributor Author

@joker-eph Thanks for the detailed review, really appreciate it :) Answered your code review.

@amirBish amirBish force-pushed the amirBish/mlir/add_deduplicate_func_args branch from ffce307 to 2d6a725 Compare September 13, 2025 20:12
Copy link
Contributor Author

@amirBish amirBish left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved the remained threads.

Comment on lines 365 to 366
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, updated :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice added a util, preferred to add it in util.h

@amirBish amirBish force-pushed the amirBish/mlir/add_deduplicate_func_args branch from 2d6a725 to bfc15d5 Compare September 14, 2025 05:33
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we don't use strings for errorMessage, but a callback with the signature: function_ref<InFlightDiagnostic()> emitError.
It should be checked for null before usage:

if (emitError)
  emitError() << ...'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, For now used the LDBG() since it really seems more logging than erros.

Comment on lines 239 to 240
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (int resultIdx = 0; resultIdx < static_cast<int>(callOp.getNumResults());
++resultIdx)
for (int resultIdx : llvm::seq<int>(0, callOp.getNumResults()))

Shortcut :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, adopted it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a valid logging (use LDBG() for logging), but does not seem an error in itself.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually none of the failures in this method seem like "errors" to diagnose: it's all WAI with LGDB() being appropriate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with LDBG().

This PR adds a new transform operation which removes the
duplicate arguments from the function operation based on
the callOp of this function.

To have a more simple implementation for now, the transform
will fail when having multiple callOps for the same function
we want to eliminate the different arguments from.

This pull request also adpat the utils under the func dialect
to be reusable also for this transformOp.
@amirBish amirBish force-pushed the amirBish/mlir/add_deduplicate_func_args branch from bfc15d5 to 050c32d Compare September 14, 2025 20:20
Copy link
Contributor Author

@amirBish amirBish left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, resolved the threads.

Comment on lines 239 to 240
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, adopted it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with LDBG().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, For now used the LDBG() since it really seems more logging than erros.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@amirBish amirBish merged commit 471bd17 into llvm:main Sep 15, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants