Skip to content

Commit

Permalink
[mlir][python] Make Operation and Value hashable
Browse files Browse the repository at this point in the history
This allows operations and values to be used as dict keys

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D112669
  • Loading branch information
rkayaith authored and ftynse committed Nov 3, 2021
1 parent 30a3a17 commit f78fe0b
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 10 deletions.
18 changes: 16 additions & 2 deletions mlir/lib/Bindings/Python/IRCore.cpp
Expand Up @@ -2171,6 +2171,10 @@ void mlir::python::populateIRCore(py::module &m) {
})
.def("__eq__",
[](PyOperationBase &self, py::object other) { return false; })
.def("__hash__",
[](PyOperationBase &self) {
return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
})
.def_property_readonly("attributes",
[](PyOperationBase &self) {
return PyOpAttributeMap(
Expand Down Expand Up @@ -2558,7 +2562,10 @@ void mlir::python::populateIRCore(py::module &m) {
.def("__eq__",
[](PyAttribute &self, PyAttribute &other) { return self == other; })
.def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
.def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
.def("__hash__",
[](PyAttribute &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def(
"dump", [](PyAttribute &self) { mlirAttributeDump(self); },
kDumpDocstring)
Expand Down Expand Up @@ -2652,7 +2659,10 @@ void mlir::python::populateIRCore(py::module &m) {
"Context that owns the Type")
.def("__eq__", [](PyType &self, PyType &other) { return self == other; })
.def("__eq__", [](PyType &self, py::object &other) { return false; })
.def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
.def("__hash__",
[](PyType &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def(
"dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
.def(
Expand Down Expand Up @@ -2703,6 +2713,10 @@ void mlir::python::populateIRCore(py::module &m) {
return self.get().ptr == other.get().ptr;
})
.def("__eq__", [](PyValue &self, py::object other) { return false; })
.def("__hash__",
[](PyValue &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def(
"__str__",
[](PyValue &self) {
Expand Down
4 changes: 0 additions & 4 deletions mlir/test/python/ir/attributes.py
Expand Up @@ -66,10 +66,6 @@ def testAttrHash():
a3 = Attribute.parse('"attr1"')
# CHECK: hash(a1) == hash(a3): True
print("hash(a1) == hash(a3):", a1.__hash__() == a3.__hash__())
# In general, hashes don't have to be unique. In this case, however, the
# hash is just the underlying pointer so it will be.
# CHECK: hash(a1) == hash(a2): False
print("hash(a1) == hash(a2):", a1.__hash__() == a2.__hash__())

s = set()
s.add(a1)
Expand Down
4 changes: 0 additions & 4 deletions mlir/test/python/ir/builtin_types.py
Expand Up @@ -67,10 +67,6 @@ def testTypeHash():

# CHECK: hash(t1) == hash(t3): True
print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__())
# In general, hashes don't have to be unique. In this case, however, the
# hash is just the underlying pointer so it will be.
# CHECK: hash(t1) == hash(t2): False
print("hash(t1) == hash(t2):", t1.__hash__() == t2.__hash__())

s = set()
s.add(t1)
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/python/ir/operation.py
Expand Up @@ -741,6 +741,7 @@ def testOperationLoc():
assert op.location == loc
assert op.operation.location == loc


# CHECK-LABEL: TEST: testModuleMerge
@run
def testModuleMerge():
Expand Down Expand Up @@ -876,3 +877,13 @@ def testSymbolTable():
raise
else:
assert False, "exepcted ValueError when adding a non-symbol"


# CHECK-LABEL: TEST: testOperationHash
@run
def testOperationHash():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx, Location.unknown():
op = Operation.create("custom.op1")
assert hash(op) == hash(op.operation)
19 changes: 19 additions & 0 deletions mlir/test/python/ir/value.py
Expand Up @@ -55,3 +55,22 @@ def testValueIsInstance():
op = func.regions[0].blocks[0].operations[0]
assert not BlockArgument.isinstance(op.results[0])
assert OpResult.isinstance(op.results[0])


# CHECK-LABEL: TEST: testValueHash
@run
def testValueHash():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func @foo(%arg0: f32) -> f32 {
%0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
return %0 : f32
}""", ctx)

[func] = module.body.operations
block = func.entry_block
op, ret = block.operations
assert hash(block.arguments[0]) == hash(op.operands[0])
assert hash(op.result) == hash(ret.operands[0])

0 comments on commit f78fe0b

Please sign in to comment.