From a04c0b7ed2f92456558af2833f64cd494d161905 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 14 Oct 2021 11:33:28 +0200 Subject: [PATCH] [mlir][python] Fix MemRefType IsAFunction in Python bindings MemRefType was using a wrong `isa` function in the bindings code, which could lead to invalid IR being constructed. Also run the verifier in memref dialect tests. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D111784 --- mlir/lib/Bindings/Python/IRTypes.cpp | 2 +- mlir/python/mlir/dialects/_memref_ops_ext.py | 2 +- mlir/test/python/dialects/memref.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 568cca160a595..fd9f3efe7405f 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -406,7 +406,7 @@ class PyMemRefLayoutMapList; /// Ranked MemRef Type subclass - MemRefType. class PyMemRefType : public PyConcreteType { public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py index cb25ef105d73f..9cc22a21c6283 100644 --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -33,5 +33,5 @@ def __init__(self, memref_resolved = _get_op_result_or_value(memref) indices_resolved = [] if indices is None else _get_op_results_or_values( indices) - return_type = memref_resolved.type + return_type = MemRefType(memref_resolved.type).element_type super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py index e421f9b2fde95..f2eda0a620610 100644 --- a/mlir/test/python/dialects/memref.py +++ b/mlir/test/python/dialects/memref.py @@ -71,3 +71,4 @@ def testCustomBuidlers(): # CHECK: func @f1(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] print(module) + assert module.operation.verify()