Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Make OpOperand comparable #70410

Merged
merged 1 commit into from
Oct 27, 2023

Conversation

matthias-springer
Copy link
Member

Two OpOperands are the same if they belong to the same owner and have the same operand number. There are currently no comparison operators defined on OpOperand and we work around this in multiple places by comparing pointers.

Note: OpOperands are stored in an op, so it is valid to compare their pointers to determine if they are the same operand. E.g., getOperandNumber is also implemented via pointer arithmetics.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:tensor mlir:bufferization Bufferization infrastructure labels Oct 27, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 27, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Two OpOperands are the same if they belong to the same owner and have the same operand number. There are currently no comparison operators defined on OpOperand and we work around this in multiple places by comparing pointers.

Note: OpOperands are stored in an op, so it is valid to compare their pointers to determine if they are the same operand. E.g., getOperandNumber is also implemented via pointer arithmetics.


Full diff: https://github.com/llvm/llvm-project/pull/70410.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/UseDefLists.h (+7)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+4-4)
diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h
index e3e3e86231465dc..ae9287e6621b03f 100644
--- a/mlir/include/mlir/IR/UseDefLists.h
+++ b/mlir/include/mlir/IR/UseDefLists.h
@@ -146,6 +146,13 @@ class IROperand : public detail::IROperandBase {
     return *this;
   }
 
+  bool operator==(const IROperand<DerivedT, IRValueT> &other) const {
+    return this == &other;
+  }
+  bool operator!=(const IROperand<DerivedT, IRValueT> &other) const {
+    return !(*this == other);
+  }
+
   /// Return the current value being used by this operand.
   IRValueT get() const { return value; }
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 5716dcc9d905016..52ff6ceeee85b03 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -537,12 +537,12 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
 
 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
     OpOperand &opOperand, const AnalysisState &state) {
-  return &opOperand == &getSourceMutable();
+  return opOperand == getSourceMutable();
 }
 
 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
     OpOperand &opOperand, const AnalysisState &state) {
-  if (&opOperand == &getDestMutable()) {
+  if (opOperand == getDestMutable()) {
     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
     return true;
   }
@@ -560,7 +560,7 @@ bool MaterializeInDestinationOp::mustBufferizeInPlace(
 AliasingValueList
 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
                                               const AnalysisState &state) {
-  if (&opOperand == &getDestMutable()) {
+  if (opOperand == getDestMutable()) {
     assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
   }
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9386d0fd0f04faf..a95443db88b50b2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
     RankedTensorType destType = insertSliceOp.getDestType();
 
     // The source is always read.
-    if (&opOperand == &insertSliceOp.getSourceMutable())
+    if (opOperand == insertSliceOp.getSourceMutable())
       return true;
 
     // For the destination, it depends...
-    assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
+    assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
 
     // Dest is not read if it is entirely overwritten. E.g.:
     // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -849,7 +849,7 @@ struct ReshapeOpInterface
                               const AnalysisState &state) const {
     // Depending on the layout map, the source buffer may have to be copied.
     auto reshapeOp = cast<tensor::ReshapeOp>(op);
-    return &opOperand == &reshapeOp.getShapeMutable();
+    return opOperand == reshapeOp.getShapeMutable();
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -931,7 +931,7 @@ struct ParallelInsertSliceOpInterface
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
-    return &opOperand == &parallelInsertSliceOp.getDestMutable();
+    return opOperand == parallelInsertSliceOp.getDestMutable();
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,

Two `OpOperand`s are the same if they belong to the same owner and have the same operand number. There are currently no comparison operators defined on `OpOperand` and we work around this in multiple places by comparing pointers.

Note: `OpOperand`s are stored in an op, so it is valid to compare their pointers to determine if they are the same operand. E.g., `getOperandNumber` is also implemented via pointer arithmetics.
@matthias-springer matthias-springer merged commit 5558504 into llvm:main Oct 27, 2023
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:tensor mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants