diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index c1ade9ed8617c..cc7f09f71d028 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -23,6 +23,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm); MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace); +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMPointerTypeGetTypeID(void); + /// Returns `true` if the type is an LLVM dialect pointer type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type); @@ -58,6 +60,8 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMStructTypeGetTypeID(void); + /// Returns `true` if the type is a literal (unnamed) LLVM struct type. MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type); diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 870a713b8edcb..05681cecf82b3 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -31,8 +31,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) { // StructType //===--------------------------------------------------------------------===// - auto llvmStructType = - mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); + auto llvmStructType = mlir_type_subclass( + m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID); llvmStructType .def_classmethod( @@ -137,7 +137,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) { // PointerType //===--------------------------------------------------------------------===// - mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType) + mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType, + mlirLLVMPointerTypeGetTypeID) .def_classmethod( "get", [](const nb::object &cls, std::optional addressSpace, diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 6636f0ea73ec9..bf231767320a5 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -27,6 +27,10 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) { return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace)); } +MlirTypeID mlirLLVMPointerTypeGetTypeID() { + return wrap(LLVM::LLVMPointerType::getTypeID()); +} + bool mlirTypeIsALLVMPointerType(MlirType type) { return isa(unwrap(type)); } @@ -73,6 +77,10 @@ bool mlirTypeIsALLVMStructType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirLLVMStructTypeGetTypeID() { + return wrap(LLVM::LLVMStructType::getTypeID()); +} + bool mlirLLVMStructTypeIsLiteral(MlirType type) { return !cast(unwrap(type)).isIdentified(); } diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py index 8ea0fddee3f7c..305ed9aba940d 100644 --- a/mlir/test/python/dialects/llvm.py +++ b/mlir/test/python/dialects/llvm.py @@ -98,6 +98,9 @@ def testStructType(): assert opaque.opaque # CHECK: !llvm.struct<"opaque", opaque> + typ = Type.parse('!llvm.struct<"zoo", (i32, i64)>') + assert isinstance(typ, llvm.StructType) + # CHECK-LABEL: testSmoke @constructAndPrintInModule @@ -120,6 +123,9 @@ def testPointerType(): # CHECK: !llvm.ptr<1> print(ptr_with_addr) + typ = Type.parse("!llvm.ptr<1>") + assert isinstance(typ, llvm.PointerType) + # CHECK-LABEL: testConstant @constructAndPrintInModule