diff --git a/mlir/include/mlir-c/Dialect/SMT.h b/mlir/include/mlir-c/Dialect/SMT.h index 66caeffea8281..a289172b9d0f2 100644 --- a/mlir/include/mlir-c/Dialect/SMT.h +++ b/mlir/include/mlir-c/Dialect/SMT.h @@ -48,6 +48,8 @@ MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx, MLIR_CAPI_EXPORTED MlirStringRef mlirSMTBitVectorTypeGetName(void); +MLIR_CAPI_EXPORTED MlirTypeID mlirSMTBitVectorTypeGetTypeID(void); + /// Checks if the given type is a smt::BoolType. MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type); @@ -56,6 +58,8 @@ MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx); MLIR_CAPI_EXPORTED MlirStringRef mlirSMTBoolTypeGetName(void); +MLIR_CAPI_EXPORTED MlirTypeID mlirSMTBoolTypeGetTypeID(void); + /// Checks if the given type is a smt::IntType. MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type); @@ -64,6 +68,8 @@ MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx); MLIR_CAPI_EXPORTED MlirStringRef mlirSMTIntTypeGetName(void); +MLIR_CAPI_EXPORTED MlirTypeID mlirSMTIntTypeGetTypeID(void); + /// Checks if the given type is a smt::FuncType. MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type); diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp index eff10a0e55c05..963076a6c672e 100644 --- a/mlir/lib/Bindings/Python/DialectSMT.cpp +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -29,6 +29,8 @@ namespace MLIR_BINDINGS_PYTHON_DOMAIN { namespace smt { struct BoolType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABool; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirSMTBoolTypeGetTypeID; static constexpr const char *pyClassName = "BoolType"; static inline const MlirStringRef name = mlirSMTBoolTypeGetName(); using Base::Base; @@ -46,6 +48,8 @@ struct BoolType : PyConcreteType { struct BitVectorType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABitVector; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirSMTBitVectorTypeGetTypeID; static constexpr const char *pyClassName = "BitVectorType"; static inline const MlirStringRef name = mlirSMTBitVectorTypeGetName(); using Base::Base; @@ -64,6 +68,8 @@ struct BitVectorType : PyConcreteType { struct IntType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsAInt; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirSMTIntTypeGetTypeID; static constexpr const char *pyClassName = "IntType"; static inline const MlirStringRef name = mlirSMTIntTypeGetName(); using Base::Base; diff --git a/mlir/lib/CAPI/Dialect/SMT.cpp b/mlir/lib/CAPI/Dialect/SMT.cpp index 56b771f55b0e3..e90ca0130e298 100644 --- a/mlir/lib/CAPI/Dialect/SMT.cpp +++ b/mlir/lib/CAPI/Dialect/SMT.cpp @@ -53,6 +53,10 @@ MlirStringRef mlirSMTBitVectorTypeGetName(void) { return wrap(BitVectorType::name); } +MlirTypeID mlirSMTBitVectorTypeGetTypeID(void) { + return wrap(BitVectorType::getTypeID()); +} + bool mlirSMTTypeIsABool(MlirType type) { return isa(unwrap(type)); } MlirType mlirSMTTypeGetBool(MlirContext ctx) { @@ -61,6 +65,10 @@ MlirType mlirSMTTypeGetBool(MlirContext ctx) { MlirStringRef mlirSMTBoolTypeGetName(void) { return wrap(BoolType::name); } +MlirTypeID mlirSMTBoolTypeGetTypeID(void) { + return wrap(BoolType::getTypeID()); +} + bool mlirSMTTypeIsAInt(MlirType type) { return isa(unwrap(type)); } MlirType mlirSMTTypeGetInt(MlirContext ctx) { @@ -69,6 +77,8 @@ MlirType mlirSMTTypeGetInt(MlirContext ctx) { MlirStringRef mlirSMTIntTypeGetName(void) { return wrap(IntType::name); } +MlirTypeID mlirSMTIntTypeGetTypeID(void) { return wrap(IntType::getTypeID()); } + bool mlirSMTTypeIsASMTFunc(MlirType type) { return isa(unwrap(type)); }