From b8d0de73cb4e2073ec94bb8a672aecdbf00acadb Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 24 Nov 2025 10:19:20 -0800 Subject: [PATCH 1/2] [MLIR][Python] add mlirLLVMStructTypeGetTypeID and enable downcasting for StructType --- mlir/include/mlir-c/Dialect/LLVM.h | 2 ++ mlir/lib/Bindings/Python/DialectLLVM.cpp | 4 ++-- mlir/lib/CAPI/Dialect/LLVM.cpp | 4 ++++ mlir/test/python/dialects/llvm.py | 3 +++ 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index c1ade9ed8617c..a0fbfd5c8227a 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -58,6 +58,8 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMStructTypeGetTypeID(); + /// Returns `true` if the type is a literal (unnamed) LLVM struct type. MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type); diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 870a713b8edcb..63aa707a4adde 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -31,8 +31,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) { // StructType //===--------------------------------------------------------------------===// - auto llvmStructType = - mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); + auto llvmStructType = mlir_type_subclass( + m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID); llvmStructType .def_classmethod( diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 6636f0ea73ec9..560148abf4c91 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -73,6 +73,10 @@ bool mlirTypeIsALLVMStructType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirLLVMStructTypeGetTypeID() { + return wrap(LLVM::LLVMStructType::getTypeID()); +} + bool mlirLLVMStructTypeIsLiteral(MlirType type) { return !cast(unwrap(type)).isIdentified(); } diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py index 8ea0fddee3f7c..42543de851436 100644 --- a/mlir/test/python/dialects/llvm.py +++ b/mlir/test/python/dialects/llvm.py @@ -98,6 +98,9 @@ def testStructType(): assert opaque.opaque # CHECK: !llvm.struct<"opaque", opaque> + typ = Type.parse('!llvm.struct<"zoo", (i32, i64)>') + assert isinstance(typ, llvm.StructType) + # CHECK-LABEL: testSmoke @constructAndPrintInModule From 605dbe2f6ce444ec24af836f9dd59e4d49356dc6 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 24 Nov 2025 10:24:34 -0800 Subject: [PATCH 2/2] add pointer --- mlir/include/mlir-c/Dialect/LLVM.h | 4 +++- mlir/lib/Bindings/Python/DialectLLVM.cpp | 3 ++- mlir/lib/CAPI/Dialect/LLVM.cpp | 4 ++++ mlir/test/python/dialects/llvm.py | 3 +++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index a0fbfd5c8227a..cc7f09f71d028 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -23,6 +23,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm); MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace); +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMPointerTypeGetTypeID(void); + /// Returns `true` if the type is an LLVM dialect pointer type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type); @@ -58,7 +60,7 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); -MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMStructTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMStructTypeGetTypeID(void); /// Returns `true` if the type is a literal (unnamed) LLVM struct type. MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type); diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 63aa707a4adde..05681cecf82b3 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -137,7 +137,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) { // PointerType //===--------------------------------------------------------------------===// - mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType) + mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType, + mlirLLVMPointerTypeGetTypeID) .def_classmethod( "get", [](const nb::object &cls, std::optional addressSpace, diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 560148abf4c91..bf231767320a5 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -27,6 +27,10 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) { return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace)); } +MlirTypeID mlirLLVMPointerTypeGetTypeID() { + return wrap(LLVM::LLVMPointerType::getTypeID()); +} + bool mlirTypeIsALLVMPointerType(MlirType type) { return isa(unwrap(type)); } diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py index 42543de851436..305ed9aba940d 100644 --- a/mlir/test/python/dialects/llvm.py +++ b/mlir/test/python/dialects/llvm.py @@ -123,6 +123,9 @@ def testPointerType(): # CHECK: !llvm.ptr<1> print(ptr_with_addr) + typ = Type.parse("!llvm.ptr<1>") + assert isinstance(typ, llvm.PointerType) + # CHECK-LABEL: testConstant @constructAndPrintInModule