Skip to content

Commit

Permalink
[mlir][python] Fix MemRefType IsAFunction in Python bindings
Browse files Browse the repository at this point in the history
MemRefType was using a wrong `isa` function in the bindings code, which
could lead to invalid IR being constructed. Also run the verifier in
memref dialect tests.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D111784
  • Loading branch information
ftynse committed Oct 14, 2021
1 parent e3e1da2 commit a04c0b7
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mlir/lib/Bindings/Python/IRTypes.cpp
Expand Up @@ -406,7 +406,7 @@ class PyMemRefLayoutMapList;
/// Ranked MemRef Type subclass - MemRefType.
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
static constexpr const char *pyClassName = "MemRefType";
using PyConcreteType::PyConcreteType;

Expand Down
2 changes: 1 addition & 1 deletion mlir/python/mlir/dialects/_memref_ops_ext.py
Expand Up @@ -33,5 +33,5 @@ def __init__(self,
memref_resolved = _get_op_result_or_value(memref)
indices_resolved = [] if indices is None else _get_op_results_or_values(
indices)
return_type = memref_resolved.type
return_type = MemRefType(memref_resolved.type).element_type
super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)
1 change: 1 addition & 0 deletions mlir/test/python/dialects/memref.py
Expand Up @@ -71,3 +71,4 @@ def testCustomBuidlers():
# CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
# CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
print(module)
assert module.operation.verify()

0 comments on commit a04c0b7

Please sign in to comment.