diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 28a83cba0bbc07..92697a248b71b0 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -60,6 +60,7 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void); DEFINE_C_API_STRUCT(MlirLocation, const void); DEFINE_C_API_STRUCT(MlirModule, const void); DEFINE_C_API_STRUCT(MlirType, const void); +DEFINE_C_API_STRUCT(MlirTypeID, const void); DEFINE_C_API_STRUCT(MlirValue, const void); #undef DEFINE_C_API_STRUCT @@ -356,6 +357,11 @@ MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, /// Gets the context this operation is associated with MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op); +/// Gets the type id of the operation. +/// Returns null if the operation does not have a registered operation +/// description. +MLIR_CAPI_EXPORTED MlirTypeID mlirOperationGetTypeID(MlirOperation op); + /// Gets the name of the operation as an identifier. MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op); @@ -626,6 +632,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, /// Gets the context that a type was created with. MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type); +/// Gets the type ID of the type. +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type); + /// Checks whether a type is null. static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; } @@ -655,6 +664,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute); /// Gets the type of this attribute. MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute); +/// Gets the type id of the attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute); + /// Checks whether an attribute is null. static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } @@ -693,6 +705,21 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident, /// Gets the string value of the identifier. MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident); +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// + +/// Checks whether a type id is null. +MLIR_CAPI_EXPORTED static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { + return !typeID.ptr; +} + +/// Checks if two type ids are equal. +MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2); + +/// Returns the hash value of the type id. +MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index ea7b265dd8efcf..d5e961367e79a1 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -33,6 +33,7 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier) DEFINE_C_API_METHODS(MlirLocation, mlir::Location) DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) DEFINE_C_API_METHODS(MlirType, mlir::Type) +DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) DEFINE_C_API_METHODS(MlirValue, mlir::Value) #endif // MLIR_CAPI_IR_H diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index eda176300dc306..ee5a5551133c95 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -23,6 +23,7 @@ #include "mlir/Parser.h" #include "llvm/Support/Debug.h" +#include using namespace mlir; @@ -345,6 +346,13 @@ MlirContext mlirOperationGetContext(MlirOperation op) { return wrap(unwrap(op)->getContext()); } +MlirTypeID mlirOperationGetTypeID(MlirOperation op) { + if (const auto *abstractOp = unwrap(op)->getAbstractOperation()) { + return wrap(abstractOp->typeID); + } + return {nullptr}; +} + MlirIdentifier mlirOperationGetName(MlirOperation op) { return wrap(unwrap(op)->getName().getIdentifier()); } @@ -658,6 +666,10 @@ MlirContext mlirTypeGetContext(MlirType type) { return wrap(unwrap(type).getContext()); } +MlirTypeID mlirTypeGetTypeID(MlirType type) { + return wrap(unwrap(type).getTypeID()); +} + bool mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } @@ -685,6 +697,10 @@ MlirType mlirAttributeGetType(MlirAttribute attribute) { return wrap(unwrap(attribute).getType()); } +MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { + return wrap(unwrap(attr).getTypeID()); +} + bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { return unwrap(a1) == unwrap(a2); } @@ -721,3 +737,15 @@ bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { return wrap(unwrap(ident).strref()); } + +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// + +bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { + return unwrap(typeID1) == unwrap(typeID2); +} + +size_t mlirTypeIDHashValue(MlirTypeID typeID) { + return hash_value(unwrap(typeID)); +} diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index d85af8fb6b700c..931f72f9e76b55 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1739,6 +1739,99 @@ void testDiagnostics() { // CHECK: more test diagnostics } +int testTypeID(MlirContext ctx) { + fprintf(stderr, "@testTypeID\n"); + + // Test getting and comparing type and attribute type ids. + MlirType i32 = mlirIntegerTypeGet(ctx, 32); + MlirTypeID i32ID = mlirTypeGetTypeID(i32); + MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32); + MlirTypeID ui32ID = mlirTypeGetTypeID(ui32); + MlirType f32 = mlirF32TypeGet(ctx); + MlirTypeID f32ID = mlirTypeGetTypeID(f32); + MlirAttribute i32Attr = mlirIntegerAttrGet(i32, 1); + MlirTypeID i32AttrID = mlirAttributeGetTypeID(i32Attr); + + if (mlirTypeIDIsNull(i32ID) || mlirTypeIDIsNull(ui32ID) || + mlirTypeIDIsNull(f32ID) || mlirTypeIDIsNull(i32AttrID)) { + fprintf(stderr, "ERROR: Expected type ids to be present\n"); + return 1; + } + + if (!mlirTypeIDEqual(i32ID, ui32ID) || + mlirTypeIDHashValue(i32ID) != mlirTypeIDHashValue(ui32ID)) { + fprintf( + stderr, + "ERROR: Expected different integer types to have the same type id\n"); + return 2; + } + + if (mlirTypeIDEqual(i32ID, f32ID) || + mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(f32ID)) { + fprintf(stderr, + "ERROR: Expected integer type id to not equal float type id\n"); + return 3; + } + + if (mlirTypeIDEqual(i32ID, i32AttrID) || + mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(i32AttrID)) { + fprintf(stderr, "ERROR: Expected integer type id to not equal integer " + "attribute type id\n"); + return 4; + } + + MlirLocation loc = mlirLocationUnknownGet(ctx); + MlirType indexType = mlirIndexTypeGet(ctx); + MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value"); + + // Create a registered operation, which should have a type id. + MlirAttribute indexZeroLiteral = + mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index")); + MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( + mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral); + MlirOperationState constZeroState = mlirOperationStateGet( + mlirStringRefCreateFromCString("std.constant"), loc); + mlirOperationStateAddResults(&constZeroState, 1, &indexType); + mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr); + MlirOperation constZero = mlirOperationCreate(&constZeroState); + + if (mlirOperationIsNull(constZero)) { + fprintf(stderr, "ERROR: Expected registered operation to be present\n"); + return 5; + } + + MlirTypeID registeredOpID = mlirOperationGetTypeID(constZero); + + if (mlirTypeIDIsNull(registeredOpID)) { + fprintf(stderr, + "ERROR: Expected registered operation type id to be present\n"); + return 6; + } + + // Create an unregistered operation, which should not have a type id. + mlirContextSetAllowUnregisteredDialects(ctx, true); + MlirOperationState opState = + mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc); + MlirOperation unregisteredOp = mlirOperationCreate(&opState); + if (mlirOperationIsNull(unregisteredOp)) { + fprintf(stderr, "ERROR: Expected unregistered operation to be present\n"); + return 7; + } + + MlirTypeID unregisteredOpID = mlirOperationGetTypeID(unregisteredOp); + + if (!mlirTypeIDIsNull(unregisteredOpID)) { + fprintf(stderr, + "ERROR: Expected unregistered operation type id to be null\n"); + return 8; + } + + mlirOperationDestroy(constZero); + mlirOperationDestroy(unregisteredOp); + + return 0; +} + int main() { MlirContext ctx = mlirContextCreate(); mlirRegisterAllDialects(ctx); @@ -1768,6 +1861,9 @@ int main() { return 11; if (testClone()) return 12; + if (testTypeID(ctx)) { + return 13; + } mlirContextDestroy(ctx);