Skip to content

Commit b8d0de7

Browse files
committed
[MLIR][Python] add mlirLLVMStructTypeGetTypeID and enable downcasting for StructType
1 parent 3841e7d commit b8d0de7

File tree

4 files changed

+11
-2
lines changed

4 files changed

+11
-2
lines changed

mlir/include/mlir-c/Dialect/LLVM.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type);
5858
/// Returns `true` if the type is an LLVM dialect struct type.
5959
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type);
6060

61+
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMStructTypeGetTypeID();
62+
6163
/// Returns `true` if the type is a literal (unnamed) LLVM struct type.
6264
MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type);
6365

mlir/lib/Bindings/Python/DialectLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
3131
// StructType
3232
//===--------------------------------------------------------------------===//
3333

34-
auto llvmStructType =
35-
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
34+
auto llvmStructType = mlir_type_subclass(
35+
m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
3636

3737
llvmStructType
3838
.def_classmethod(

mlir/lib/CAPI/Dialect/LLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ bool mlirTypeIsALLVMStructType(MlirType type) {
7373
return isa<LLVM::LLVMStructType>(unwrap(type));
7474
}
7575

76+
MlirTypeID mlirLLVMStructTypeGetTypeID() {
77+
return wrap(LLVM::LLVMStructType::getTypeID());
78+
}
79+
7680
bool mlirLLVMStructTypeIsLiteral(MlirType type) {
7781
return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
7882
}

mlir/test/python/dialects/llvm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def testStructType():
9898
assert opaque.opaque
9999
# CHECK: !llvm.struct<"opaque", opaque>
100100

101+
typ = Type.parse('!llvm.struct<"zoo", (i32, i64)>')
102+
assert isinstance(typ, llvm.StructType)
103+
101104

102105
# CHECK-LABEL: testSmoke
103106
@constructAndPrintInModule

0 commit comments

Comments
 (0)